diff --git a/src/main/java/com/unfbx/chatgpt/entity/chat/BaseMessage.java b/src/main/java/com/unfbx/chatgpt/entity/chat/BaseMessage.java index e9cccb2..22b5fa9 100644 --- a/src/main/java/com/unfbx/chatgpt/entity/chat/BaseMessage.java +++ b/src/main/java/com/unfbx/chatgpt/entity/chat/BaseMessage.java @@ -35,10 +35,17 @@ public class BaseMessage implements Serializable { /** * The tool calls generated by the model, such as function calls. + * @since 1.1.2 */ @JsonProperty("tool_calls") private List toolCalls; + /** + * @since 1.1.2 + */ + @JsonProperty("tool_call_id") + private String toolCallId; + @Deprecated @JsonProperty("function_call") private FunctionCall functionCall; diff --git a/src/main/java/com/unfbx/chatgpt/entity/chat/Message.java b/src/main/java/com/unfbx/chatgpt/entity/chat/Message.java index ac1ec4d..606e09e 100644 --- a/src/main/java/com/unfbx/chatgpt/entity/chat/Message.java +++ b/src/main/java/com/unfbx/chatgpt/entity/chat/Message.java @@ -37,11 +37,12 @@ public static Builder builder() { * @param content content * @param functionCall functionCall */ - public Message(String role, String name, String content, List toolCalls, FunctionCall functionCall) { + public Message(String role, String name, String content, List toolCalls, String toolCallId, FunctionCall functionCall) { this.content = content; super.setRole(role); super.setName(name); super.setToolCalls(toolCalls); + super.setToolCallId(toolCallId); super.setFunctionCall(functionCall); } @@ -54,12 +55,14 @@ private Message(Builder builder) { super.setName(builder.name); super.setFunctionCall(builder.functionCall); super.setToolCalls(builder.toolCalls); + super.setToolCallId(builder.toolCallId); } public static final class Builder { private String role; private String content; private String name; + private String toolCallId; private List toolCalls; private FunctionCall functionCall; @@ -96,6 +99,11 @@ public Builder toolCalls(List toolCalls) { return this; } + public Builder toolCallId(String toolCallId) { + this.toolCallId = toolCallId; + return this; + } + public Message build() { return new Message(this); } diff --git a/src/main/java/com/unfbx/chatgpt/entity/chat/MessagePicture.java b/src/main/java/com/unfbx/chatgpt/entity/chat/MessagePicture.java index 0640d63..93141cb 100644 --- a/src/main/java/com/unfbx/chatgpt/entity/chat/MessagePicture.java +++ b/src/main/java/com/unfbx/chatgpt/entity/chat/MessagePicture.java @@ -39,11 +39,12 @@ public static Builder builder() { * @param content content * @param functionCall functionCall */ - public MessagePicture(String role, String name, List content, List toolCalls, FunctionCall functionCall) { + public MessagePicture(String role, String name, List content, List toolCalls, String toolCallId, FunctionCall functionCall) { this.content = content; super.setRole(role); super.setName(name); super.setToolCalls(toolCalls); + super.setToolCallId(toolCallId); super.setFunctionCall(functionCall); } @@ -56,12 +57,14 @@ private MessagePicture(Builder builder) { super.setName(builder.name); super.setFunctionCall(builder.functionCall); super.setToolCalls(builder.toolCalls); + super.setToolCallId(builder.toolCallId); } public static final class Builder { private String role; private List content; private String name; + private String toolCallId; private List toolCalls; private FunctionCall functionCall; @@ -98,6 +101,11 @@ public Builder toolCalls(List toolCalls) { return this; } + public Builder toolCallId(String toolCallId) { + this.toolCallId = toolCallId; + return this; + } + public MessagePicture build() { return new MessagePicture(this); } diff --git a/src/main/java/com/unfbx/chatgpt/entity/chat/tool/ToolCalls.java b/src/main/java/com/unfbx/chatgpt/entity/chat/tool/ToolCalls.java index 5b94a13..e51a3df 100644 --- a/src/main/java/com/unfbx/chatgpt/entity/chat/tool/ToolCalls.java +++ b/src/main/java/com/unfbx/chatgpt/entity/chat/tool/ToolCalls.java @@ -1,9 +1,6 @@ package com.unfbx.chatgpt.entity.chat.tool; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; +import lombok.*; import java.io.Serializable; @@ -29,4 +26,12 @@ public class ToolCalls implements Serializable { private String type; private ToolCallFunction function; + + @Getter + @AllArgsConstructor + public enum Type { + FUNCTION("function"), + ; + private final String name; + } } diff --git a/src/test/java/com/unfbx/chatgpt/v1_1_2/OpenAiClientTest.java b/src/test/java/com/unfbx/chatgpt/v1_1_2/OpenAiClientTest.java index 3b40bff..cb66676 100644 --- a/src/test/java/com/unfbx/chatgpt/v1_1_2/OpenAiClientTest.java +++ b/src/test/java/com/unfbx/chatgpt/v1_1_2/OpenAiClientTest.java @@ -2,14 +2,20 @@ import cn.hutool.json.JSONObject; +import cn.hutool.json.JSONUtil; import com.unfbx.chatgpt.FirstKeyStrategy; import com.unfbx.chatgpt.OpenAiClient; +import com.unfbx.chatgpt.OpenAiClientFunctionTest; import com.unfbx.chatgpt.entity.chat.*; +import com.unfbx.chatgpt.entity.chat.tool.ToolCallFunction; +import com.unfbx.chatgpt.entity.chat.tool.ToolCalls; import com.unfbx.chatgpt.entity.chat.tool.Tools; import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; import com.unfbx.chatgpt.interceptor.OpenAILogger; import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor; +import lombok.Builder; +import lombok.Data; import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; import okhttp3.logging.HttpLoggingInterceptor; @@ -50,7 +56,6 @@ public void before() { client = OpenAiClient.builder() //支持多key传入,请求时候随机选择 // .apiKey(Arrays.asList("*********************")) - .apiKey(Arrays.asList("***************")) //自定义key的获取策略:默认KeyRandomStrategy //.keyStrategy(new KeyRandomStrategy()) .keyStrategy(new FirstKeyStrategy()) @@ -131,6 +136,76 @@ public void toolsChat() { ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0); log.info("构造的方法值:{}", chatChoice.getMessage().getToolCalls()); + + ToolCalls openAiReturnToolCalls = chatChoice.getMessage().getToolCalls().get(0); + WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class); + String oneWord = getOneWord(wordParam); + + + ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build(); + ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build(); + //构造tool call + Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build(); + String content + = "{ " + + "\"wordLength\": \"3\", " + + "\"language\": \"zh\", " + + "\"word\": \"" + oneWord + "\"," + + "\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" + + "}"; + Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build(); + List messageList = Arrays.asList(message, message2, message3); + ChatCompletion chatCompletionV2 = ChatCompletion + .builder() + .messages(messageList) + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + ChatCompletionResponse chatCompletionResponseV2 = client.chatCompletion(chatCompletionV2); + log.info("自定义的方法返回值:{}",chatCompletionResponseV2.getChoices().get(0).getMessage().getContent()); + } + + + + + + /** + * 获取一个词语 + * @param wordParam + * @return + */ + public String getOneWord(WordParam wordParam) { + + List zh = Arrays.asList("大香蕉", "哈密瓜", "苹果"); + List en = Arrays.asList("apple", "banana", "cantaloupe"); + if (wordParam.getLanguage().equals("zh")) { + for (String e : zh) { + if (e.length() == wordParam.getWordLength()) { + return e; + } + } + } + if (wordParam.getLanguage().equals("en")) { + for (String e : en) { + if (e.length() == wordParam.getWordLength()) { + return e; + } + } + } + return "西瓜"; + } + + @Test + public void testInput() { + System.out.println(getOneWord(WordParam.builder().wordLength(2).language("zh").build())); + } + + @Data + @Builder + static class WordParam { + private int wordLength; + @Builder.Default + private String language = "zh"; + } }