-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathstock_price.py
30 lines (21 loc) · 902 Bytes
/
stock_price.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from langchain.tools import BaseTool
from typing import Optional, Type
from pydantic import BaseModel, Field
from yf_tool import get_stock_price
class StockPriceCheckInput(BaseModel):
"""Input for Stock price check."""
stockticker: str = Field(...,
description="Ticker symbol for stock or index")
class StockPriceTool(BaseTool):
name = "get_stock_ticker_price"
description = (
"Useful for when you need to find out the price of stock. "
"You should input the stock ticker used on the yfinance API"
)
def _run(self, stockticker: str):
# print("i'm running")
price_response = get_stock_price(stockticker)
return price_response
def _arun(self, stockticker: str):
raise NotImplementedError("This tool does not support async")
args_schema: Optional[Type[BaseModel]] = StockPriceCheckInput