From c3d08efb9a321762a07646f13221ad147b093cd2 Mon Sep 17 00:00:00 2001 From: "Dr. Hans-Peter Stoerr" Date: Sun, 25 Feb 2024 22:23:58 +0100 Subject: [PATCH] allow several prompts concatenated --- .../commandline/AIGenPipeline.java | 24 ++++++++++++------- .../framework/task/AIGenerationTask.java | 24 +++++++++++++------ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/aigenpipeline-commandline/src/main/java/net/stoerr/ai/aigenpipeline/commandline/AIGenPipeline.java b/aigenpipeline-commandline/src/main/java/net/stoerr/ai/aigenpipeline/commandline/AIGenPipeline.java index 76957ae7..9644e4c3 100644 --- a/aigenpipeline-commandline/src/main/java/net/stoerr/ai/aigenpipeline/commandline/AIGenPipeline.java +++ b/aigenpipeline-commandline/src/main/java/net/stoerr/ai/aigenpipeline/commandline/AIGenPipeline.java @@ -5,7 +5,9 @@ import java.io.PrintStream; import java.nio.file.Path; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Properties; @@ -18,9 +20,10 @@ public class AIGenPipeline { protected boolean help, verbose, dryRun, check, force, version; protected String output, explain; protected String url; - protected String key; + protected String apiKey; protected List inputFiles = new ArrayList<>(); protected List promptFiles = new ArrayList<>(); + protected Map keyValues = new LinkedHashMap<>(); protected String model = "gpt-4-turbo-preview"; protected AIGenerationTask task; protected File rootDir = new File("."); @@ -51,8 +54,8 @@ public AIChatBuilder makeChatBuilder() { if (null != url) { chatBuilder.url(url); } - if (null != key) { - chatBuilder.key(key); + if (null != apiKey) { + chatBuilder.key(apiKey); } if (null != model) { chatBuilder.model(model); @@ -71,7 +74,7 @@ protected void run() throws IOException { if (promptFiles.isEmpty()) { throw new IllegalArgumentException("At least one prompt file has to be given."); } - promptFiles.stream().map(this::toFile).forEach(task::addPrompt); + promptFiles.stream().map(this::toFile).forEach(f -> task.addPrompt(f, keyValues)); task.force(force); if (verbose) { logStream.println(task.toJson(this::makeChatBuilder, rootDir)); @@ -109,6 +112,7 @@ protected void printHelpAndExit(boolean onerror) { " --version Show the version of the AIGenPipeline tool and exit.\n" + " -o, --output Specify the output file where the generated content will be written. Mandatory.\n" + " -p, --prompt Reads a prompt from the given file.\n" + + " -k = Sets a key-value pair replacing ${key} in prompt files with the value. \n" + " -s, --sysmsg Optional: Reads a system message from the given file instead of using the default. \n" + " -v, --verbose Enable verbose output to stderr, providing more details about the process.\n" + " -n, --dry-run Enable dry-run mode, where the tool will only print to stderr what it would do without \n" + @@ -124,7 +128,7 @@ protected void printHelpAndExit(boolean onerror) { " -u, --url The URL of the AI server. Default is https://api.openai.com/v1/chat/completions .\n" + " In the case of OpenAI the API key is expected to be in the environment variable \n" + " OPENAI_API_KEY, or given as -k option.\n" + - " -k, --key The API key for the AI server. If not given, it's expected to be in the environment variable \n" + + " -a, --api-key The API key for the AI server. If not given, it's expected to be in the environment variable \n" + " OPENAI_API_KEY, or you could use a -u option to specify a different server that doesnt need\n" + " an API key. Used in \"Authorization: Bearer \" header.\n" + " -m, --model The model to use for the AI. Default is gpt-4-turbo-preview .\n" + @@ -168,6 +172,10 @@ protected void parseArguments(String[] args) throws IOException { case "--prompt": promptFiles.add(args[++i]); break; + case "-k": + String[] kv = args[++i].split("=", 2); + keyValues.put(kv[0], kv[1]); + break; case "-s": case "--sysmsg": task.setSystemMessage(new File(args[++i])); @@ -196,9 +204,9 @@ protected void parseArguments(String[] args) throws IOException { case "--url": url = args[++i]; break; - case "-k": - case "--key": - key = args[++i]; + case "-a": + case "--api-key": + apiKey = args[++i]; break; case "-m": case "--model": diff --git a/aigenpipeline-framework/src/main/java/net/stoerr/ai/aigenpipeline/framework/task/AIGenerationTask.java b/aigenpipeline-framework/src/main/java/net/stoerr/ai/aigenpipeline/framework/task/AIGenerationTask.java index 4291c736..57dd5e16 100644 --- a/aigenpipeline-framework/src/main/java/net/stoerr/ai/aigenpipeline/framework/task/AIGenerationTask.java +++ b/aigenpipeline-framework/src/main/java/net/stoerr/ai/aigenpipeline/framework/task/AIGenerationTask.java @@ -244,17 +244,26 @@ protected AIVersionMarker getRecordedOutputVersionMarker() { /** * The actual prompt to be executed. The prompt file content can contain placeholders that are replaced by the values given: placeholdersAndValues contain alternatingly placeholder names and values for them. * - * @return + * @return this */ public AIGenerationTask addPrompt(@Nonnull File promptFile, String... placeholdersAndValues) { + Map map = new LinkedHashMap<>(); + for (int i = 0; i < placeholdersAndValues.length; i += 2) { + map.put(placeholdersAndValues[i], placeholdersAndValues[i + 1]); + } + return addPrompt(promptFile, map); + } + + /** + * The actual prompt to be executed. The prompt file content can contain placeholders that are replaced by the values given. + * + * @return this + */ + public AIGenerationTask addPrompt(@Nonnull File promptFile, Map placeholdersAndValues) { String newPrompt = unclutter(getFileContent(promptFile)); requireNonNull(newPrompt, "Could not read prompt file " + promptFile); - if (placeholdersAndValues.length % 2 != 0) { - throw new IllegalArgumentException("Odd number of placeholdersAndValues"); - } - for (int i = 0; i < placeholdersAndValues.length; i += 2) { - newPrompt = newPrompt.replace(placeholdersAndValues[i], placeholdersAndValues[i + 1]); - this.placeholdersAndValues.put(placeholdersAndValues[i], placeholdersAndValues[i + 1]); + for (Map.Entry entry : placeholdersAndValues.entrySet()) { + newPrompt = newPrompt.replace(entry.getKey(), entry.getValue()); } if (this.prompt == null) { this.prompt = newPrompt; @@ -262,6 +271,7 @@ public AIGenerationTask addPrompt(@Nonnull File promptFile, String... placeholde this.prompt += "\n\n" + newPrompt; } this.promptFiles.add(promptFile); + this.placeholdersAndValues.putAll(placeholdersAndValues); return this; }