Skip to content

Commit

Permalink
feature 1.1.2 新的api开发
Browse files Browse the repository at this point in the history
  • Loading branch information
Grt1228 authored and guorutao committed Nov 12, 2023
1 parent 30ddcc2 commit 25e2dbf
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 7 deletions.
7 changes: 7 additions & 0 deletions src/main/java/com/unfbx/chatgpt/entity/chat/BaseMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@ public class BaseMessage implements Serializable {

/**
* The tool calls generated by the model, such as function calls.
* @since 1.1.2
*/
@JsonProperty("tool_calls")
private List<ToolCalls> toolCalls;

/**
* @since 1.1.2
*/
@JsonProperty("tool_call_id")
private String toolCallId;

@Deprecated
@JsonProperty("function_call")
private FunctionCall functionCall;
Expand Down
10 changes: 9 additions & 1 deletion src/main/java/com/unfbx/chatgpt/entity/chat/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ public static Builder builder() {
* @param content content
* @param functionCall functionCall
*/
public Message(String role, String name, String content, List<ToolCalls> toolCalls, FunctionCall functionCall) {
public Message(String role, String name, String content, List<ToolCalls> toolCalls, String toolCallId, FunctionCall functionCall) {
this.content = content;
super.setRole(role);
super.setName(name);
super.setToolCalls(toolCalls);
super.setToolCallId(toolCallId);
super.setFunctionCall(functionCall);
}

Expand All @@ -54,12 +55,14 @@ private Message(Builder builder) {
super.setName(builder.name);
super.setFunctionCall(builder.functionCall);
super.setToolCalls(builder.toolCalls);
super.setToolCallId(builder.toolCallId);
}

public static final class Builder {
private String role;
private String content;
private String name;
private String toolCallId;
private List<ToolCalls> toolCalls;
private FunctionCall functionCall;

Expand Down Expand Up @@ -96,6 +99,11 @@ public Builder toolCalls(List<ToolCalls> toolCalls) {
return this;
}

public Builder toolCallId(String toolCallId) {
this.toolCallId = toolCallId;
return this;
}

public Message build() {
return new Message(this);
}
Expand Down
10 changes: 9 additions & 1 deletion src/main/java/com/unfbx/chatgpt/entity/chat/MessagePicture.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ public static Builder builder() {
* @param content content
* @param functionCall functionCall
*/
public MessagePicture(String role, String name, List<Content> content, List<ToolCalls> toolCalls, FunctionCall functionCall) {
public MessagePicture(String role, String name, List<Content> content, List<ToolCalls> toolCalls, String toolCallId, FunctionCall functionCall) {
this.content = content;
super.setRole(role);
super.setName(name);
super.setToolCalls(toolCalls);
super.setToolCallId(toolCallId);
super.setFunctionCall(functionCall);
}

Expand All @@ -56,12 +57,14 @@ private MessagePicture(Builder builder) {
super.setName(builder.name);
super.setFunctionCall(builder.functionCall);
super.setToolCalls(builder.toolCalls);
super.setToolCallId(builder.toolCallId);
}

public static final class Builder {
private String role;
private List<Content> content;
private String name;
private String toolCallId;
private List<ToolCalls> toolCalls;
private FunctionCall functionCall;

Expand Down Expand Up @@ -98,6 +101,11 @@ public Builder toolCalls(List<ToolCalls> toolCalls) {
return this;
}

public Builder toolCallId(String toolCallId) {
this.toolCallId = toolCallId;
return this;
}

public MessagePicture build() {
return new MessagePicture(this);
}
Expand Down
13 changes: 9 additions & 4 deletions src/main/java/com/unfbx/chatgpt/entity/chat/tool/ToolCalls.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package com.unfbx.chatgpt.entity.chat.tool;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.*;

import java.io.Serializable;

Expand All @@ -29,4 +26,12 @@ public class ToolCalls implements Serializable {
private String type;

private ToolCallFunction function;

@Getter
@AllArgsConstructor
public enum Type {
FUNCTION("function"),
;
private final String name;
}
}
77 changes: 76 additions & 1 deletion src/test/java/com/unfbx/chatgpt/v1_1_2/OpenAiClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@


import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.unfbx.chatgpt.FirstKeyStrategy;
import com.unfbx.chatgpt.OpenAiClient;

import com.unfbx.chatgpt.OpenAiClientFunctionTest;
import com.unfbx.chatgpt.entity.chat.*;
import com.unfbx.chatgpt.entity.chat.tool.ToolCallFunction;
import com.unfbx.chatgpt.entity.chat.tool.ToolCalls;
import com.unfbx.chatgpt.entity.chat.tool.Tools;
import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction;
import com.unfbx.chatgpt.interceptor.OpenAILogger;
import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
Expand Down Expand Up @@ -50,7 +56,6 @@ public void before() {
client = OpenAiClient.builder()
//支持多key传入,请求时候随机选择
// .apiKey(Arrays.asList("*********************"))
.apiKey(Arrays.asList("***************"))
//自定义key的获取策略:默认KeyRandomStrategy
//.keyStrategy(new KeyRandomStrategy())
.keyStrategy(new FirstKeyStrategy())
Expand Down Expand Up @@ -131,6 +136,76 @@ public void toolsChat() {

ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0);
log.info("构造的方法值:{}", chatChoice.getMessage().getToolCalls());

ToolCalls openAiReturnToolCalls = chatChoice.getMessage().getToolCalls().get(0);
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
String oneWord = getOneWord(wordParam);


ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
//构造tool call
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
String content
= "{ " +
"\"wordLength\": \"3\", " +
"\"language\": \"zh\", " +
"\"word\": \"" + oneWord + "\"," +
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
"}";
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
List<Message> messageList = Arrays.asList(message, message2, message3);
ChatCompletion chatCompletionV2 = ChatCompletion
.builder()
.messages(messageList)
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
ChatCompletionResponse chatCompletionResponseV2 = client.chatCompletion(chatCompletionV2);
log.info("自定义的方法返回值:{}",chatCompletionResponseV2.getChoices().get(0).getMessage().getContent());

}






/**
* 获取一个词语
* @param wordParam
* @return
*/
public String getOneWord(WordParam wordParam) {

List<String> zh = Arrays.asList("大香蕉", "哈密瓜", "苹果");
List<String> en = Arrays.asList("apple", "banana", "cantaloupe");
if (wordParam.getLanguage().equals("zh")) {
for (String e : zh) {
if (e.length() == wordParam.getWordLength()) {
return e;
}
}
}
if (wordParam.getLanguage().equals("en")) {
for (String e : en) {
if (e.length() == wordParam.getWordLength()) {
return e;
}
}
}
return "西瓜";
}

@Test
public void testInput() {
System.out.println(getOneWord(WordParam.builder().wordLength(2).language("zh").build()));
}

@Data
@Builder
static class WordParam {
private int wordLength;
@Builder.Default
private String language = "zh";
}
}

0 comments on commit 25e2dbf

Please sign in to comment.