Skip to content

Commit

Permalink
Added error-parsing logic to llm chain (#7)
Browse files Browse the repository at this point in the history
* inital commit for yellowhammer jupyter assistant

* minor changes for initial commit

* minor changes for initial commit

* pre-commit update

* add error parsing logic to llm chain

* pre-commit autoupdate
  • Loading branch information
yue-here authored Nov 9, 2024
1 parent d91f04a commit 1da4ebd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,9 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Test notebook
examples/Demo_yue.ipynb
examples/yue_examples
examples/FA-Mn-H2PO2.raw
src/yellowhammer/yue_test.ipynb
23 changes: 21 additions & 2 deletions src/yellowhammer/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ class FinalResponse(BaseModel):
final_output: Union[code, ConversationalResponse]


def error_parser(output):
"""
Parse the API output to handle errors gracefully.
"""
if output["parsing_error"]:
raw_output = str(output["raw"].content)
error = output["parsing_error"]
out_string = f"Error parsing LLM output. Parse error: {error}. \n Raw output: {raw_output}"
return FinalResponse(final_output=ConversationalResponse(response=out_string))

elif not output["parsed"]:
raw_output = str(output["raw"].content)
out_string = f"Error in LLM response. \n Raw output: {raw_output}"
return FinalResponse(final_output=ConversationalResponse(response=out_string))

else:
# Return the parsed output (should be FinalResponse)
return output["parsed"]


def get_chain(
api_provider,
api_key,
Expand Down Expand Up @@ -62,7 +82,6 @@ def get_chain(
)

# Create a chain where the final output takes the FinalResponse schema
chain = prompt | llm.with_structured_output(FinalResponse, include_raw=False)

chain = prompt | llm.with_structured_output(FinalResponse, include_raw=True) | error_parser
# Returns a runnable chain which accepts datalab API documentation "context" and user question "messages"
return chain
2 changes: 1 addition & 1 deletion src/yellowhammer/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, shell):
"-T",
"--temp",
type=float,
default=0.0,
default=0.1,
help="""Temperature, float in [0,1]. Higher values push the algorithm
to generate more aggressive/"creative" output. [default=0.1].""",
)
Expand Down

0 comments on commit 1da4ebd

Please sign in to comment.