Skip to content

Commit

Permalink
allow several prompts concatenated
Browse files Browse the repository at this point in the history
  • Loading branch information
stoerr committed Feb 25, 2024
1 parent e5b4e60 commit c3d08ef
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import java.io.PrintStream;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;

Expand All @@ -18,9 +20,10 @@ public class AIGenPipeline {
protected boolean help, verbose, dryRun, check, force, version;
protected String output, explain;
protected String url;
protected String key;
protected String apiKey;
protected List<String> inputFiles = new ArrayList<>();
protected List<String> promptFiles = new ArrayList<>();
protected Map<String, String> keyValues = new LinkedHashMap<>();
protected String model = "gpt-4-turbo-preview";
protected AIGenerationTask task;
protected File rootDir = new File(".");
Expand Down Expand Up @@ -51,8 +54,8 @@ public AIChatBuilder makeChatBuilder() {
if (null != url) {
chatBuilder.url(url);
}
if (null != key) {
chatBuilder.key(key);
if (null != apiKey) {
chatBuilder.key(apiKey);
}
if (null != model) {
chatBuilder.model(model);
Expand All @@ -71,7 +74,7 @@ protected void run() throws IOException {
if (promptFiles.isEmpty()) {
throw new IllegalArgumentException("At least one prompt file has to be given.");
}
promptFiles.stream().map(this::toFile).forEach(task::addPrompt);
promptFiles.stream().map(this::toFile).forEach(f -> task.addPrompt(f, keyValues));
task.force(force);
if (verbose) {
logStream.println(task.toJson(this::makeChatBuilder, rootDir));
Expand Down Expand Up @@ -109,6 +112,7 @@ protected void printHelpAndExit(boolean onerror) {
" --version Show the version of the AIGenPipeline tool and exit.\n" +
" -o, --output <file> Specify the output file where the generated content will be written. Mandatory.\n" +
" -p, --prompt <file> Reads a prompt from the given file.\n" +
" -k <key>=<value> Sets a key-value pair replacing ${key} in prompt files with the value. \n" +
" -s, --sysmsg <file> Optional: Reads a system message from the given file instead of using the default. \n" +
" -v, --verbose Enable verbose output to stderr, providing more details about the process.\n" +
" -n, --dry-run Enable dry-run mode, where the tool will only print to stderr what it would do without \n" +
Expand All @@ -124,7 +128,7 @@ protected void printHelpAndExit(boolean onerror) {
" -u, --url <url> The URL of the AI server. Default is https://api.openai.com/v1/chat/completions .\n" +
" In the case of OpenAI the API key is expected to be in the environment variable \n" +
" OPENAI_API_KEY, or given as -k option.\n" +
" -k, --key <key> The API key for the AI server. If not given, it's expected to be in the environment variable \n" +
" -a, --api-key <key> The API key for the AI server. If not given, it's expected to be in the environment variable \n" +
" OPENAI_API_KEY, or you could use a -u option to specify a different server that doesnt need\n" +
" an API key. Used in \"Authorization: Bearer <key>\" header.\n" +
" -m, --model <model> The model to use for the AI. Default is gpt-4-turbo-preview .\n" +
Expand Down Expand Up @@ -168,6 +172,10 @@ protected void parseArguments(String[] args) throws IOException {
case "--prompt":
promptFiles.add(args[++i]);
break;
case "-k":
String[] kv = args[++i].split("=", 2);
keyValues.put(kv[0], kv[1]);
break;
case "-s":
case "--sysmsg":
task.setSystemMessage(new File(args[++i]));
Expand Down Expand Up @@ -196,9 +204,9 @@ protected void parseArguments(String[] args) throws IOException {
case "--url":
url = args[++i];
break;
case "-k":
case "--key":
key = args[++i];
case "-a":
case "--api-key":
apiKey = args[++i];
break;
case "-m":
case "--model":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,24 +244,34 @@ protected AIVersionMarker getRecordedOutputVersionMarker() {
/**
* The actual prompt to be executed. The prompt file content can contain placeholders that are replaced by the values given: placeholdersAndValues contain alternatingly placeholder names and values for them.
*
* @return
* @return this
*/
public AIGenerationTask addPrompt(@Nonnull File promptFile, String... placeholdersAndValues) {
Map<String, String> map = new LinkedHashMap<>();
for (int i = 0; i < placeholdersAndValues.length; i += 2) {
map.put(placeholdersAndValues[i], placeholdersAndValues[i + 1]);
}
return addPrompt(promptFile, map);
}

/**
* The actual prompt to be executed. The prompt file content can contain placeholders that are replaced by the values given.
*
* @return this
*/
public AIGenerationTask addPrompt(@Nonnull File promptFile, Map<String, String> placeholdersAndValues) {
String newPrompt = unclutter(getFileContent(promptFile));
requireNonNull(newPrompt, "Could not read prompt file " + promptFile);
if (placeholdersAndValues.length % 2 != 0) {
throw new IllegalArgumentException("Odd number of placeholdersAndValues");
}
for (int i = 0; i < placeholdersAndValues.length; i += 2) {
newPrompt = newPrompt.replace(placeholdersAndValues[i], placeholdersAndValues[i + 1]);
this.placeholdersAndValues.put(placeholdersAndValues[i], placeholdersAndValues[i + 1]);
for (Map.Entry<String, String> entry : placeholdersAndValues.entrySet()) {
newPrompt = newPrompt.replace(entry.getKey(), entry.getValue());
}
if (this.prompt == null) {
this.prompt = newPrompt;
} else {
this.prompt += "\n\n" + newPrompt;
}
this.promptFiles.add(promptFile);
this.placeholdersAndValues.putAll(placeholdersAndValues);
return this;
}

Expand Down

0 comments on commit c3d08ef

Please sign in to comment.