Skip to content

Commit

Permalink
Merge pull request #525 from tmc/depllm
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc authored Jan 17, 2024
2 parents 1526a5f + cbab2c9 commit 098382d
Show file tree
Hide file tree
Showing 28 changed files with 48 additions and 51 deletions.
2 changes: 1 addition & 1 deletion agents/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type ConversationalAgent struct {

var _ Agent = (*ConversationalAgent)(nil)

func NewConversationalAgent(llm llms.LLM, tools []tools.Tool, opts ...CreationOption) *ConversationalAgent {
func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...CreationOption) *ConversationalAgent {
options := conversationalDefaultOptions()
for _, opt := range opts {
opt(&options)
Expand Down
2 changes: 1 addition & 1 deletion agents/initialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const (
// model, tools, agent type, and options. It returns an Executor or an error
// if there is any issues during the creation process.
func Initialize(
llm llms.LLM,
llm llms.Model,
tools []tools.Tool,
agentType AgentType,
opts ...CreationOption,
Expand Down
2 changes: 1 addition & 1 deletion agents/mrkl.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ var _ Agent = (*OneShotZeroAgent)(nil)
// NewOneShotAgent creates a new OneShotZeroAgent with the given LLM model, tools,
// and options. It returns a pointer to the created agent. The opts parameter
// represents the options for the agent.
func NewOneShotAgent(llm llms.LLM, tools []tools.Tool, opts ...CreationOption) *OneShotZeroAgent {
func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...CreationOption) *OneShotZeroAgent {
options := mrklDefaultOptions()
for _, opt := range opts {
opt(&options)
Expand Down
4 changes: 2 additions & 2 deletions chains/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ type APIChain struct {

// NewAPIChain creates a new APIChain object.
//
// It takes a LanguageModel(llm) and an HTTPRequest(request) as parameters.
// It takes a language model (llm) and an HTTPRequest (request) as parameters.
// It returns an APIChain object.
func NewAPIChain(llm llms.LLM, request HTTPRequest) APIChain {
func NewAPIChain(llm llms.Model, request HTTPRequest) APIChain {
reqPrompt := prompts.NewPromptTemplate(_llmAPIURLPrompt, []string{"api_docs", "input"})
reqChain := NewLLMChain(llm, reqPrompt)

Expand Down
2 changes: 1 addition & 1 deletion chains/chains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (l *testLanguageModel) GenerateContent(_ context.Context, _ []llms.MessageC
panic("not implemented")
}

var _ llms.LLM = &testLanguageModel{}
var _ llms.Model = &testLanguageModel{}

func TestApply(t *testing.T) {
t.Parallel()
Expand Down
4 changes: 2 additions & 2 deletions chains/constitution/constitutional.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type Constitutional struct {
critiqueChain chains.LLMChain
revisionChain chains.LLMChain
constitutionalPrinciples []ConstitutionalPrinciple
llm llms.LLM
llm llms.Model
returnIntermediateSteps bool
memory schema.Memory
}
Expand All @@ -58,7 +58,7 @@ func NewConstitutionalPrinciple(critique, revision string, names ...string) Cons
}

// NewConstitutional creates a new Constitutional chain.
func NewConstitutional(llm llms.LLM, chain chains.LLMChain,
func NewConstitutional(llm llms.Model, chain chains.LLMChain,
constitutionalPrinciples []ConstitutionalPrinciple, options map[string]*prompts.FewShotPrompt,
) *Constitutional {
CritiquePrompt, RevisionPrompt := initCritiqueRevision()
Expand Down
4 changes: 2 additions & 2 deletions chains/constitutional.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ type Constitutional struct {
critiqueChain LLMChain
revisionChain LLMChain
constitutionalPrinciples []ConstitutionalPrinciple
llm llms.LLM
llm llms.Model
returnIntermediateSteps bool
memory schema.Memory
}
Expand Down Expand Up @@ -557,7 +557,7 @@ func NewConstitutionalPrinciple(critique, revision string, names ...string) Cons
}

// NewConstitutional creates a new Constitutional chain.
func NewConstitutional(llm llms.LLM, chain LLMChain, constitutionalPrinciples []ConstitutionalPrinciple,
func NewConstitutional(llm llms.Model, chain LLMChain, constitutionalPrinciples []ConstitutionalPrinciple,
options map[string]*prompts.FewShotPrompt,
) *Constitutional {
CritiquePrompt, RevisionPrompt := initCritiqueRevision()
Expand Down
2 changes: 1 addition & 1 deletion chains/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Current conversation:
Human: {{.input}}
AI:`

func NewConversation(llm llms.LLM, memory schema.Memory) LLMChain {
func NewConversation(llm llms.Model, memory schema.Memory) LLMChain {
return LLMChain{
Prompt: prompts.NewPromptTemplate(
_conversationTemplate,
Expand Down
2 changes: 1 addition & 1 deletion chains/conversational_retrieval_qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func NewConversationalRetrievalQA(
}

func NewConversationalRetrievalQAFromLLM(
llm llms.LLM,
llm llms.Model,
retriever schema.Retriever,
memory schema.Memory,
) ConversationalRetrievalQA {
Expand Down
4 changes: 2 additions & 2 deletions chains/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const _llmChainDefaultOutputKey = "text"

type LLMChain struct {
Prompt prompts.FormatPrompter
LLM llms.LLM
LLM llms.Model
Memory schema.Memory
CallbacksHandler callbacks.Handler
OutputParser schema.OutputParser[any]
Expand All @@ -29,7 +29,7 @@ var (
)

// NewLLMChain creates a new LLMChain with an LLM and a prompt.
func NewLLMChain(llm llms.LLM, prompt prompts.FormatPrompter, opts ...ChainCallOption) *LLMChain {
func NewLLMChain(llm llms.Model, prompt prompts.FormatPrompter, opts ...ChainCallOption) *LLMChain {
opt := &chainCallOption{}
for _, o := range opts {
o(opt)
Expand Down
2 changes: 1 addition & 1 deletion chains/llm_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type LLMMathChain struct {

var _ Chain = LLMMathChain{}

func NewLLMMathChain(llm llms.LLM) LLMMathChain {
func NewLLMMathChain(llm llms.Model) LLMMathChain {
p := prompts.NewPromptTemplate(_llmMathPrompt, []string{"question"})
c := NewLLMChain(llm, p)
return LLMMathChain{
Expand Down
6 changes: 3 additions & 3 deletions chains/prompt_selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@ import (
// PromptSelector is the interface for selecting a formatter depending on the
// LLM given.
type PromptSelector interface {
GetPrompt(llm llms.LLM) prompts.PromptTemplate
GetPrompt(llm llms.Model) prompts.PromptTemplate
}

// ConditionalPromptSelector is a formatter selector that selects a prompt
// depending on conditionals.
type ConditionalPromptSelector struct {
DefaultPrompt prompts.PromptTemplate
Conditionals []struct {
Condition func(llms.LLM) bool
Condition func(llms.Model) bool
Prompt prompts.PromptTemplate
}
}

var _ PromptSelector = ConditionalPromptSelector{}

func (s ConditionalPromptSelector) GetPrompt(llm llms.LLM) prompts.PromptTemplate {
func (s ConditionalPromptSelector) GetPrompt(llm llms.Model) prompts.PromptTemplate {
for _, conditional := range s.Conditionals {
if conditional.Condition(llm) {
return conditional.Prompt
Expand Down
10 changes: 5 additions & 5 deletions chains/question_answering.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Follow Up Input: {{.question}}
Standalone question:`

// LoadCondenseQuestionGenerator chain is used to generate a new question for the sake of retrieval.
func LoadCondenseQuestionGenerator(llm llms.LLM) *LLMChain {
func LoadCondenseQuestionGenerator(llm llms.Model) *LLMChain {
condenseQuestionPromptTemplate := prompts.NewPromptTemplate(
_defaultCondenseQuestionTemplate,
[]string{"chat_history", "question"},
Expand All @@ -101,7 +101,7 @@ func LoadCondenseQuestionGenerator(llm llms.LLM) *LLMChain {
}

// LoadStuffQA loads a StuffDocuments chain with default prompts for the llm chain.
func LoadStuffQA(llm llms.LLM) StuffDocuments {
func LoadStuffQA(llm llms.Model) StuffDocuments {
defaultQAPromptTemplate := prompts.NewPromptTemplate(
_defaultStuffQATemplate,
[]string{"context", "question"},
Expand All @@ -118,7 +118,7 @@ func LoadStuffQA(llm llms.LLM) StuffDocuments {

// LoadRefineQA loads a refine documents chain for question answering. Inputs are
// "question" and "input_documents".
func LoadRefineQA(llm llms.LLM) RefineDocuments {
func LoadRefineQA(llm llms.Model) RefineDocuments {
questionPrompt := prompts.NewPromptTemplate(
_defaultStuffQATemplate,
[]string{"context", "question"},
Expand All @@ -136,7 +136,7 @@ func LoadRefineQA(llm llms.LLM) RefineDocuments {

// LoadMapReduceQA loads a refine documents chain for question answering. Inputs are
// "question" and "input_documents".
func LoadMapReduceQA(llm llms.LLM) MapReduceDocuments {
func LoadMapReduceQA(llm llms.Model) MapReduceDocuments {
getInfoPrompt := prompts.NewPromptTemplate(
_defaultMapReduceGetInformationQATemplate,
[]string{"question", "context"},
Expand All @@ -156,7 +156,7 @@ func LoadMapReduceQA(llm llms.LLM) MapReduceDocuments {

// LoadMapRerankQA loads a map rerank documents chain for question answering. Inputs are
// "question" and "input_documents".
func LoadMapRerankQA(llm llms.LLM) MapRerankDocuments {
func LoadMapRerankQA(llm llms.Model) MapRerankDocuments {
mapRerankPrompt := prompts.NewPromptTemplate(
_defaultMapRerankTemplate,
[]string{"context", "question"},
Expand Down
2 changes: 1 addition & 1 deletion chains/retrieval_qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewRetrievalQA(combineDocumentsChain Chain, retriever schema.Retriever) Ret

// NewRetrievalQAFromLLM loads a question answering combine documents chain
// from the llm and creates a new retrievalQA chain.
func NewRetrievalQAFromLLM(llm llms.LLM, retriever schema.Retriever) RetrievalQA {
func NewRetrievalQAFromLLM(llm llms.Model, retriever schema.Retriever) RetrievalQA {
return NewRetrievalQA(
LoadStuffQA(llm),
retriever,
Expand Down
2 changes: 1 addition & 1 deletion chains/sql_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type SQLDatabaseChain struct {

// NewSQLDatabaseChain creates a new SQLDatabaseChain.
// The topK is the max number of results to return.
func NewSQLDatabaseChain(llm llms.LLM, topK int, database *sqldatabase.SQLDatabase) *SQLDatabaseChain {
func NewSQLDatabaseChain(llm llms.Model, topK int, database *sqldatabase.SQLDatabase) *SQLDatabaseChain {
p := prompts.NewPromptTemplate(_defaultSQLTemplate+_defaultSQLSuffix,
[]string{"dialect", "top_k", "table_info", "input"})
c := NewLLMChain(llm, p)
Expand Down
6 changes: 3 additions & 3 deletions chains/summarization.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ REFINED SUMMARY:`

// LoadStuffSummarization loads a summarization chain that stuffs all documents
// given into the prompt.
func LoadStuffSummarization(llm llms.LLM) StuffDocuments {
func LoadStuffSummarization(llm llms.Model) StuffDocuments {
llmChain := NewLLMChain(llm, prompts.NewPromptTemplate(
_stuffSummarizationTemplate, []string{"context"},
))
Expand All @@ -38,7 +38,7 @@ func LoadStuffSummarization(llm llms.LLM) StuffDocuments {

// LoadRefineSummarization loads a refine documents chain for summarization of
// documents.
func LoadRefineSummarization(llm llms.LLM) RefineDocuments {
func LoadRefineSummarization(llm llms.Model) RefineDocuments {
llmChain := NewLLMChain(llm, prompts.NewPromptTemplate(
_stuffSummarizationTemplate, []string{"context"},
))
Expand All @@ -51,7 +51,7 @@ func LoadRefineSummarization(llm llms.LLM) RefineDocuments {

// LoadMapReduceSummarization loads a map reduce documents chain for
// summarization of documents.
func LoadMapReduceSummarization(llm llms.LLM) MapReduceDocuments {
func LoadMapReduceSummarization(llm llms.Model) MapReduceDocuments {
mapChain := NewLLMChain(llm, prompts.NewPromptTemplate(
_stuffSummarizationTemplate, []string{"context"},
))
Expand Down
2 changes: 1 addition & 1 deletion llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type LLM struct {
client *anthropicclient.Client
}

var _ llms.LLM = (*LLM)(nil)
var _ llms.Model = (*LLM)(nil)

// New returns a new Anthropic LLM.
func New(opts ...Option) (*LLM, error) {
Expand Down
2 changes: 1 addition & 1 deletion llms/cohere/coherellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type LLM struct {
client *cohereclient.Client
}

var _ llms.LLM = (*LLM)(nil)
var _ llms.Model = (*LLM)(nil)

func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
return llms.CallLLM(ctx, o, prompt, options...)
Expand Down
3 changes: 1 addition & 2 deletions llms/ernie/erniellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type LLM struct {
CallbacksHandler callbacks.Handler
}

var _ llms.LLM = (*LLM)(nil)
var _ llms.Model = (*LLM)(nil)

// New returns a new Anthropic LLM.
func New(opts ...Option) (*LLM, error) {
Expand Down Expand Up @@ -59,7 +59,6 @@ doc: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2`, ernieclient.ErrNot
ernieclient.WithAKSK(opts.apiKey, opts.secretKey))
}

// Call implements llms.LLM.
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
return llms.CallLLM(ctx, o, prompt, options...)
}
Expand Down
2 changes: 1 addition & 1 deletion llms/huggingface/huggingfacellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type LLM struct {
client *huggingfaceclient.Client
}

var _ llms.LLM = (*LLM)(nil)
var _ llms.Model = (*LLM)(nil)

// Call implements the LLM interface.
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion llms/llms.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// LLM is an alias for model, for backwards compatibility.
//
// This alias may be removed in the future; please use Model
// Deprecated: This alias may be removed in the future; please use Model
// instead.
type LLM = Model

Expand Down
5 changes: 1 addition & 4 deletions llms/local/localllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ type LLM struct {
client *localclient.Client
}

// _ ensures that LLM implements the llms.LLM and language model interface.
var (
_ llms.LLM = (*LLM)(nil)
)
var _ llms.Model = (*LLM)(nil)

// Call calls the local LLM binary with the given prompt.
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion llms/ollama/ollamallm.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type LLM struct {
options options
}

var _ llms.LLM = (*LLM)(nil)
var _ llms.Model = (*LLM)(nil)

// New creates a new ollama LLM implementation.
func New(opts ...Option) (*LLM, error) {
Expand Down
2 changes: 1 addition & 1 deletion llms/openai/multicontent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/tmc/langchaingo/schema"
)

func newTestClient(t *testing.T, opts ...Option) *LLM {
func newTestClient(t *testing.T, opts ...Option) llms.Model {
t.Helper()
if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
Expand Down
2 changes: 1 addition & 1 deletion llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
RoleFunction = "function"
)

var _ llms.LLM = (*LLM)(nil)
var _ llms.Model = (*LLM)(nil)

// New returns a new OpenAI LLM.
func New(opts ...Option) (*LLM, error) {
Expand Down
15 changes: 8 additions & 7 deletions llms/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import "context"
// CallOption is a function that configures a CallOptions.
type CallOption func(*CallOptions)

// CallOptions is a set of options for LLM.Call.
// CallOptions is a set of options for calling models.
type CallOptions struct {
// Model is the model to use.
Model string `json:"model"`
Expand Down Expand Up @@ -66,42 +66,43 @@ const (
FunctionCallBehaviorAuto FunctionCallBehavior = "auto"
)

// WithModel is an option for LLM.Call.
// WithModel specifies which model name to use.
func WithModel(model string) CallOption {
return func(o *CallOptions) {
o.Model = model
}
}

// WithMaxTokens is an option for LLM.Call.
// WithMaxTokens specifies the max number of tokens to generate.
func WithMaxTokens(maxTokens int) CallOption {
return func(o *CallOptions) {
o.MaxTokens = maxTokens
}
}

// WithTemperature is an option for LLM.Call.
// WithTemperature specifies the model temperature, a hyperparameter that
// regulates the randomness, or creativity, of the AI's responses.
func WithTemperature(temperature float64) CallOption {
return func(o *CallOptions) {
o.Temperature = temperature
}
}

// WithStopWords is an option for LLM.Call.
// WithStopWords specifies a list of words to stop generation on.
func WithStopWords(stopWords []string) CallOption {
return func(o *CallOptions) {
o.StopWords = stopWords
}
}

// WithOptions is an option for LLM.Call.
// WithOptions specifies options.
func WithOptions(options CallOptions) CallOption {
return func(o *CallOptions) {
(*o) = options
}
}

// WithStreamingFunc is an option for LLM.Call that allows streaming responses.
// WithStreamingFunc specifies the streaming function to use.
func WithStreamingFunc(streamingFunc func(ctx context.Context, chunk []byte) error) CallOption {
return func(o *CallOptions) {
o.StreamingFunc = streamingFunc
Expand Down
Loading

0 comments on commit 098382d

Please sign in to comment.