Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix instrumentation support for OpenAI client 0.14+ #531

Merged
merged 9 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.next-release.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@

* Add support for OpenAI client 0.14+ - #531
3 changes: 2 additions & 1 deletion custom/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ plugins {
}

val instrumentations = listOf<String>(
":instrumentation:openai-client-instrumentation"
":instrumentation:openai-client-instrumentation:instrumentation-0.2",
":instrumentation:openai-client-instrumentation:instrumentation-0.14"
)

dependencies {
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ant = "org.apache.ant:ant:1.10.15"
asm = "org.ow2.asm:asm:9.7"

# Instrumented libraries
openaiClient = "com.openai:openai-java:0.13.0"
openaiClient = "com.openai:openai-java:0.21.0"

[bundles]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
plugins {
id("elastic-otel.java-conventions")
}

dependencies {
compileOnly(catalog.openaiClient)
compileOnly("io.opentelemetry:opentelemetry-sdk")
compileOnly("io.opentelemetry.instrumentation:opentelemetry-instrumentation-api")
compileOnly("io.opentelemetry.javaagent:opentelemetry-javaagent-extension-api")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to Elasticsearch B.V. under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch B.V. licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package co.elastic.otel.openai.wrappers;

import com.openai.models.ChatCompletionAssistantMessageParam;
import com.openai.models.ChatCompletionContentPart;
import com.openai.models.ChatCompletionCreateParams;
import com.openai.models.ChatCompletionMessageParam;
import com.openai.models.ChatCompletionSystemMessageParam;
import com.openai.models.ChatCompletionToolMessageParam;
import com.openai.models.ChatCompletionUserMessageParam;
import java.util.function.Supplier;

/**
* Api Adapter to encapsulate breaking changes across openai-client versions. If e.g. methods are
* renamed we add a adapter method here, so that we can provide per-version implementations. These
* implementations have to be added to instrumentations as helpers, which also ensures muzzle works
* effectively.
*/
public abstract class ApiAdapter {

private static volatile ApiAdapter instance;

public static ApiAdapter get() {
return instance;
}

protected static void init(Supplier<ApiAdapter> implementation) {
if (instance == null) {
synchronized (ApiAdapter.class) {
if (instance == null) {
instance = implementation.get();
}
}
}
}

/**
* Extracts the concrete message object e.g. ({@link ChatCompletionUserMessageParam}) from the
* given encapsulating {@link ChatCompletionMessageParam}.
*
* @param base the encapsulating param
* @return the unboxed concrete message param type
*/
public abstract Object extractConcreteCompletionMessageParam(ChatCompletionMessageParam base);

/**
* @return the contained text, if the content is text. null otherwise.
*/
public abstract String asText(ChatCompletionToolMessageParam.Content content);

/**
* @return the contained text, if the content is text. null otherwise.
*/
public abstract String asText(ChatCompletionAssistantMessageParam.Content content);

/**
* @return the contained text, if the content is text. null otherwise.
*/
public abstract String asText(ChatCompletionSystemMessageParam.Content content);

/**
* @return the contained text, if the content is text. null otherwise.
*/
public abstract String asText(ChatCompletionUserMessageParam.Content content);

/**
* @return the text or refusal reason if either is available, otherwise null
*/
public abstract String extractTextOrRefusal(
ChatCompletionAssistantMessageParam.Content.ChatCompletionRequestAssistantMessageContentPart
part);

/**
* @return the text if available, otherwise null
*/
public abstract String extractText(ChatCompletionContentPart part);

/**
* @return the type if available, otherwise null
*/
public abstract String extractType(ChatCompletionCreateParams.ResponseFormat val);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import com.openai.models.ChatCompletion;
import com.openai.models.ChatCompletionAssistantMessageParam;
import com.openai.models.ChatCompletionContentPart;
import com.openai.models.ChatCompletionContentPartText;
import com.openai.models.ChatCompletionCreateParams;
import com.openai.models.ChatCompletionMessage;
Expand All @@ -40,6 +39,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class ChatCompletionEventsHelper {
Expand All @@ -54,24 +54,28 @@ public static void emitPromptLogEvents(
if (!settings.emitEvents) {
return;
}

for (ChatCompletionMessageParam msg : request.messages()) {
String eventType;
MapValueBuilder bodyBuilder = new MapValueBuilder();
if (msg.isChatCompletionSystemMessageParam()) {
ChatCompletionSystemMessageParam sysMsg = msg.asChatCompletionSystemMessageParam();
Object concreteMessageParam = ApiAdapter.get().extractConcreteCompletionMessageParam(msg);
if (concreteMessageParam instanceof ChatCompletionSystemMessageParam) {
ChatCompletionSystemMessageParam sysMsg =
(ChatCompletionSystemMessageParam) concreteMessageParam;
eventType = "gen_ai.system.message";
if (settings.captureMessageContent) {
putIfNotEmpty(bodyBuilder, "content", contentToString(sysMsg.content()));
}
} else if (msg.isChatCompletionUserMessageParam()) {
ChatCompletionUserMessageParam userMsg = msg.asChatCompletionUserMessageParam();
} else if (concreteMessageParam instanceof ChatCompletionUserMessageParam) {
ChatCompletionUserMessageParam userMsg =
(ChatCompletionUserMessageParam) concreteMessageParam;
eventType = "gen_ai.user.message";
if (settings.captureMessageContent) {
putIfNotEmpty(bodyBuilder, "content", contentToString(userMsg.content()));
}
} else if (msg.isChatCompletionAssistantMessageParam()) {
} else if (concreteMessageParam instanceof ChatCompletionAssistantMessageParam) {
ChatCompletionAssistantMessageParam assistantMsg =
msg.asChatCompletionAssistantMessageParam();
(ChatCompletionAssistantMessageParam) concreteMessageParam;
eventType = "gen_ai.assistant.message";
if (settings.captureMessageContent) {
assistantMsg
Expand All @@ -89,8 +93,9 @@ public static void emitPromptLogEvents(
bodyBuilder.put("tool_calls", Value.of(toolCallsJson));
});
}
} else if (msg.isChatCompletionToolMessageParam()) {
ChatCompletionToolMessageParam toolMsg = msg.asChatCompletionToolMessageParam();
} else if (concreteMessageParam instanceof ChatCompletionToolMessageParam) {
ChatCompletionToolMessageParam toolMsg =
(ChatCompletionToolMessageParam) concreteMessageParam;
eventType = "gen_ai.tool.message";
if (settings.captureMessageContent) {
putIfNotEmpty(bodyBuilder, "content", contentToString(toolMsg.content()));
Expand All @@ -110,8 +115,9 @@ private static void putIfNotEmpty(MapValueBuilder bodyBuilder, String key, Strin
}

private static String contentToString(ChatCompletionToolMessageParam.Content content) {
if (content.isTextContent()) {
return content.asTextContent();
String text = ApiAdapter.get().asText(content);
if (text != null) {
return text;
} else if (content.isArrayOfContentParts()) {
return content.asArrayOfContentParts().stream()
.map(ChatCompletionContentPartText::text)
Expand All @@ -122,28 +128,23 @@ private static String contentToString(ChatCompletionToolMessageParam.Content con
}

private static String contentToString(ChatCompletionAssistantMessageParam.Content content) {
if (content.isTextContent()) {
return content.asTextContent();
String text = ApiAdapter.get().asText(content);
if (text != null) {
return text;
} else if (content.isArrayOfContentParts()) {
return content.asArrayOfContentParts().stream()
.map(
cnt -> {
if (cnt.isChatCompletionContentPartText()) {
return cnt.asChatCompletionContentPartText().text();
} else if (cnt.isChatCompletionContentPartRefusal()) {
return cnt.asChatCompletionContentPartRefusal().refusal();
}
return "";
})
.map(ApiAdapter.get()::extractTextOrRefusal)
.filter(Objects::nonNull)
.collect(Collectors.joining());
} else {
throw new IllegalStateException("Unhandled content type for " + content);
}
}

private static String contentToString(ChatCompletionSystemMessageParam.Content content) {
if (content.isTextContent()) {
return content.asTextContent();
String text = ApiAdapter.get().asText(content);
if (text != null) {
return text;
} else if (content.isArrayOfContentParts()) {
return content.asArrayOfContentParts().stream()
.map(ChatCompletionContentPartText::text)
Expand All @@ -154,13 +155,13 @@ private static String contentToString(ChatCompletionSystemMessageParam.Content c
}

private static String contentToString(ChatCompletionUserMessageParam.Content content) {
if (content.isTextContent()) {
return content.asTextContent();
String text = ApiAdapter.get().asText(content);
if (text != null) {
return text;
} else if (content.isArrayOfContentParts()) {
return content.asArrayOfContentParts().stream()
.filter(ChatCompletionContentPart::isChatCompletionContentPartText)
.map(ChatCompletionContentPart::asChatCompletionContentPartText)
.map(ChatCompletionContentPartText::text)
.map(ApiAdapter.get()::extractText)
.filter(Objects::nonNull)
.collect(Collectors.joining());
} else {
throw new IllegalStateException("Unhandled content type for " + content);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ public void onStart(
.responseFormat()
.ifPresent(
val -> {
if (val.isResponseFormatText()) {
attributes.put(
GEN_AI_OPENAI_REQUEST_RESPONSE_FORMAT,
val.asResponseFormatText()._type().toString());
String typeString = ApiAdapter.get().extractType(val);
if (typeString != null) {
attributes.put(GEN_AI_OPENAI_REQUEST_RESPONSE_FORMAT, typeString);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,18 @@ plugins {

dependencies {
compileOnly(catalog.openaiClient)
testImplementation(catalog.openaiClient)
implementation(project(":instrumentation:openai-client-instrumentation:common"))

testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2")
testImplementation("com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.2")
testImplementation("org.slf4j:slf4j-simple:2.0.16")
testImplementation(catalog.wiremock)
testImplementation(catalog.openaiClient)
testImplementation(project(":instrumentation:openai-client-instrumentation:testing-common"))
}

muzzle {
pass {
val openaiClientLib = catalog.openaiClient.get()
group.set(openaiClientLib.group)
module.set(openaiClientLib.name)
versions.set("(,${openaiClientLib.version}]")
versions.set("(0.13.0,${openaiClientLib.version}]")
// no assertInverse.set(true) here because we don't want muzzle to fail for newer releases on our main branch
// instead, renovate will bump the version and failures will be automatically detected on that bump PR
}
Expand Down
Loading