Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Similar code to OLS downstream #368

Merged
merged 1 commit into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from ols.utils import errors_parsing, suid
from ols.utils.token_handler import PromptTooLongError

KEYWORDS = keywords.KEYWORDS
INVALID_QUERY_RESP = prompts.INVALID_QUERY_RESP

logger = logging.getLogger(__name__)

router = APIRouter(tags=["query"])
Expand Down Expand Up @@ -98,7 +101,7 @@ def conversation_request(
if not valid:
# response containing info about query that can not be validated
summarizer_response = SummarizerResponse(
prompts.INVALID_QUERY_RESP,
INVALID_QUERY_RESP,
[],
False,
None,
Expand Down Expand Up @@ -599,7 +602,7 @@ def _validate_question_keyword(query: str) -> bool:
# Current implementation is without any tokenizer method, lemmatization/n-grams.
# Add valid keywords to keywords.py file.
query_temp = query.lower()
for kw in keywords.KEYWORDS:
for kw in KEYWORDS:
if kw in query_temp:
return True
# query_temp = {q_word.lower().strip(".?,") for q_word in query.split()}
Expand Down
4 changes: 3 additions & 1 deletion ols/app/endpoints/streaming_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from ols.utils import errors_parsing
from ols.utils.token_handler import PromptTooLongError

INVALID_QUERY_RESP = prompts.INVALID_QUERY_RESP

logger = logging.getLogger(__name__)

router = APIRouter(tags=["streaming_query"])
Expand Down Expand Up @@ -126,7 +128,7 @@ async def invalid_response_generator() -> AsyncGenerator[str, None]:
Yields:
str: The response indicating invalid query.
"""
yield prompts.INVALID_QUERY_RESP
yield INVALID_QUERY_RESP


def format_stream_data(d: dict) -> str:
Expand Down