Skip to content

Commit

Permalink
Fix arg ordering after switch from bytecode wrapping (bytecode wrappi…
Browse files Browse the repository at this point in the history
…ng made args[0] be object self/cls ref)
  • Loading branch information
Yun-Kim committed Jan 22, 2025
1 parent 037e5b5 commit c53a169
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 56 deletions.
100 changes: 50 additions & 50 deletions ddtrace/contrib/internal/openai/_endpoint_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@ class _EndpointHook:
OPERATION_ID = "" # Each endpoint hook must provide an operationID as specified in the OpenAI API specs:
# https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml

def _record_request(self, pin, integration, span, args, kwargs):
def _record_request(self, pin, integration, instance, span, args, kwargs):
"""
Set base-level openai tags, as well as request params from args and kwargs.
All inherited EndpointHook classes should include a super call to this method before performing
endpoint-specific request tagging logic.
"""
endpoint = self.ENDPOINT_NAME
if endpoint is None:
endpoint = "%s" % args[0].OBJECT_NAME
endpoint = "%s" % getattr(instance, "OBJECT_NAME", "")
span.set_tag_str("openai.request.endpoint", "/%s/%s" % (API_VERSION, endpoint))
span.set_tag_str("openai.request.method", self.HTTP_METHOD_TYPE)

if self._request_arg_params and len(self._request_arg_params) > 1:
for idx, arg in enumerate(self._request_arg_params, 1):
for idx, arg in enumerate(self._request_arg_params):
if idx >= len(args):
break
if arg is None or args[idx] is None:
Expand All @@ -74,8 +74,8 @@ def _record_request(self, pin, integration, span, args, kwargs):
else:
span.set_tag_str("openai.request.%s" % kw_attr, str(kwargs[kw_attr]))

def handle_request(self, pin, integration, span, args, kwargs):
self._record_request(pin, integration, span, args, kwargs)
def handle_request(self, pin, integration, instance, span, args, kwargs):
self._record_request(pin, integration, instance, span, args, kwargs)
resp, error = yield
if hasattr(resp, "parse"):
# Users can request the raw response, in which case we need to process on the parsed response
Expand Down Expand Up @@ -186,8 +186,8 @@ class _CompletionHook(_BaseCompletionHook):
HTTP_METHOD_TYPE = "POST"
OPERATION_ID = "createCompletion"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
if integration.is_pc_sampled_span(span):
prompt = kwargs.get("prompt", "")
if isinstance(prompt, str):
Expand Down Expand Up @@ -241,8 +241,8 @@ class _ChatCompletionHook(_BaseCompletionHook):
HTTP_METHOD_TYPE = "POST"
OPERATION_ID = "createChatCompletion"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
for idx, m in enumerate(kwargs.get("messages", [])):
role = getattr(m, "role", "")
name = getattr(m, "name", "")
Expand Down Expand Up @@ -305,12 +305,12 @@ class _EmbeddingHook(_EndpointHook):
HTTP_METHOD_TYPE = "POST"
OPERATION_ID = "createEmbedding"

def _record_request(self, pin, integration, span, args, kwargs):
def _record_request(self, pin, integration, instance, span, args, kwargs):
"""
Embedding endpoint allows multiple inputs, each of which we specify a request tag for, so have to
manually set them in _pre_response().
"""
super()._record_request(pin, integration, span, args, kwargs)
super()._record_request(pin, integration, instance, span, args, kwargs)
embedding_input = kwargs.get("input", "")
if integration.is_pc_sampled_span(span):
if isinstance(embedding_input, str) or isinstance(embedding_input[0], int):
Expand Down Expand Up @@ -340,8 +340,8 @@ class _ListHook(_EndpointHook):
HTTP_METHOD_TYPE = "GET"
OPERATION_ID = "list"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
endpoint = span.get_tag("openai.request.endpoint")
if endpoint.endswith("/models"):
span.resource = "listModels"
Expand Down Expand Up @@ -399,15 +399,15 @@ class _RetrieveHook(_EndpointHook):
HTTP_METHOD_TYPE = "GET"
OPERATION_ID = "retrieve"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
endpoint = span.get_tag("openai.request.endpoint")
if endpoint.endswith("/models"):
span.resource = "retrieveModel"
span.set_tag_str("openai.request.model", args[1] if len(args) >= 2 else kwargs.get("model", ""))
span.set_tag_str("openai.request.model", args[0] if len(args) >= 1 else kwargs.get("model", ""))
elif endpoint.endswith("/files"):
span.resource = "retrieveFile"
span.set_tag_str("openai.request.file_id", args[1] if len(args) >= 2 else kwargs.get("file_id", ""))
span.set_tag_str("openai.request.file_id", args[1] if len(args) >= 1 else kwargs.get("file_id", ""))
span.set_tag_str("openai.request.endpoint", "%s/*" % endpoint)

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
Expand All @@ -434,9 +434,9 @@ class _ModelRetrieveHook(_RetrieveHook):
ENDPOINT_NAME = "models"
OPERATION_ID = "retrieveModel"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
span.set_tag_str("openai.request.model", args[1] if len(args) >= 2 else kwargs.get("model", ""))
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
span.set_tag_str("openai.request.model", args[0] if len(args) >= 1 else kwargs.get("model", ""))


class _FileRetrieveHook(_RetrieveHook):
Expand All @@ -447,9 +447,9 @@ class _FileRetrieveHook(_RetrieveHook):
ENDPOINT_NAME = "files"
OPERATION_ID = "retrieveFile"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
span.set_tag_str("openai.request.file_id", args[1] if len(args) >= 2 else kwargs.get("file_id", ""))
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
span.set_tag_str("openai.request.file_id", args[0] if len(args) >= 1 else kwargs.get("file_id", ""))


class _DeleteHook(_EndpointHook):
Expand All @@ -461,15 +461,15 @@ class _DeleteHook(_EndpointHook):
HTTP_METHOD_TYPE = "DELETE"
OPERATION_ID = "delete"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
endpoint = span.get_tag("openai.request.endpoint")
if endpoint.endswith("/models"):
span.resource = "deleteModel"
span.set_tag_str("openai.request.model", args[1] if len(args) >= 2 else kwargs.get("model", ""))
span.set_tag_str("openai.request.model", args[0] if len(args) >= 1 else kwargs.get("model", ""))
elif endpoint.endswith("/files"):
span.resource = "deleteFile"
span.set_tag_str("openai.request.file_id", args[1] if len(args) >= 2 else kwargs.get("file_id", ""))
span.set_tag_str("openai.request.file_id", args[0] if len(args) >= 1 else kwargs.get("file_id", ""))
span.set_tag_str("openai.request.endpoint", "%s/*" % endpoint)

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
Expand Down Expand Up @@ -508,8 +508,8 @@ class _ImageHook(_EndpointHook):
ENDPOINT_NAME = "images"
HTTP_METHOD_TYPE = "POST"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
span.set_tag_str("openai.request.model", "dall-e")

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
Expand All @@ -526,10 +526,10 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
if "prompt" in self._request_kwarg_params:
attrs_dict.update({"prompt": kwargs.get("prompt", "")})
if "image" in self._request_kwarg_params:
image = args[1] if len(args) >= 2 else kwargs.get("image", "")
image = args[0] if len(args) >= 1 else kwargs.get("image", "")
attrs_dict.update({"image": image.name.split("/")[-1]})
if "mask" in self._request_kwarg_params:
mask = args[2] if len(args) >= 3 else kwargs.get("mask", "")
mask = args[1] if len(args) >= 2 else kwargs.get("mask", "")
attrs_dict.update({"mask": mask.name.split("/")[-1]})
integration.log(
span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict
Expand Down Expand Up @@ -560,12 +560,12 @@ class _ImageEditHook(_ImageHook):
ENDPOINT_NAME = "images/edits"
OPERATION_ID = "createImageEdit"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
if not integration.is_pc_sampled_span:
return
image = args[1] if len(args) >= 2 else kwargs.get("image", "")
mask = args[2] if len(args) >= 3 else kwargs.get("mask", "")
image = args[0] if len(args) >= 1 else kwargs.get("image", "")
mask = args[1] if len(args) >= 2 else kwargs.get("mask", "")
if image:
if hasattr(image, "name"):
span.set_tag_str("openai.request.image", integration.trunc(image.name.split("/")[-1]))
Expand All @@ -584,11 +584,11 @@ class _ImageVariationHook(_ImageHook):
ENDPOINT_NAME = "images/variations"
OPERATION_ID = "createImageVariation"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
if not integration.is_pc_sampled_span:
return
image = args[1] if len(args) >= 2 else kwargs.get("image", "")
image = args[0] if len(args) >= 1 else kwargs.get("image", "")
if image:
if hasattr(image, "name"):
span.set_tag_str("openai.request.image", integration.trunc(image.name.split("/")[-1]))
Expand All @@ -602,11 +602,11 @@ class _BaseAudioHook(_EndpointHook):
ENDPOINT_NAME = "audio"
HTTP_METHOD_TYPE = "POST"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
if not integration.is_pc_sampled_span:
return
audio_file = args[2] if len(args) >= 3 else kwargs.get("file", "")
audio_file = args[1] if len(args) >= 2 else kwargs.get("file", "")
if audio_file and hasattr(audio_file, "name"):
span.set_tag_str("openai.request.filename", integration.trunc(audio_file.name.split("/")[-1]))
else:
Expand All @@ -626,7 +626,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
if integration.is_pc_sampled_span(span):
span.set_tag_str("openai.response.text", integration.trunc(text))
if integration.is_pc_sampled_log(span):
file_input = args[2] if len(args) >= 3 else kwargs.get("file", "")
file_input = args[1] if len(args) >= 2 else kwargs.get("file", "")
integration.log(
span,
"info" if error is None else "error",
Expand Down Expand Up @@ -685,8 +685,8 @@ class _ModerationHook(_EndpointHook):
HTTP_METHOD_TYPE = "POST"
OPERATION_ID = "createModeration"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
Expand Down Expand Up @@ -723,9 +723,9 @@ class _FileCreateHook(_BaseFileHook):
HTTP_METHOD_TYPE = "POST"
OPERATION_ID = "createFile"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
fp = args[1] if len(args) >= 2 else kwargs.get("file", "")
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
fp = args[0] if len(args) >= 1 else kwargs.get("file", "")
if fp and hasattr(fp, "name"):
span.set_tag_str("openai.request.filename", fp.name.split("/")[-1])
else:
Expand All @@ -742,9 +742,9 @@ class _FileDownloadHook(_BaseFileHook):
OPERATION_ID = "downloadFile"
ENDPOINT_NAME = "files/*/content"

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
span.set_tag_str("openai.request.file_id", args[1] if len(args) >= 2 else kwargs.get("file_id", ""))
def _record_request(self, pin, integration, instance, span, args, kwargs):
super()._record_request(pin, integration, instance, span, args, kwargs)
span.set_tag_str("openai.request.file_id", args[0] if len(args) >= 1 else kwargs.get("file_id", ""))

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
Expand Down
8 changes: 4 additions & 4 deletions ddtrace/contrib/internal/openai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _patched_make_session(func, instance, args, kwargs):
return session


def _traced_endpoint(endpoint_hook, integration, pin, args, kwargs):
def _traced_endpoint(endpoint_hook, integration, instance, pin, args, kwargs):
span = integration.trace(pin, endpoint_hook.OPERATION_ID)
openai_api_key = _format_openai_api_key(kwargs.get("api_key"))
err = None
Expand All @@ -247,7 +247,7 @@ def _traced_endpoint(endpoint_hook, integration, pin, args, kwargs):
span.set_tag_str("openai.user.api_key", openai_api_key)
try:
# Start the hook
hook = endpoint_hook().handle_request(pin, integration, span, args, kwargs)
hook = endpoint_hook().handle_request(pin, integration, instance, span, args, kwargs)
hook.send(None)

resp, err = yield
Expand Down Expand Up @@ -275,7 +275,7 @@ def _patched_endpoint(openai, patch_hook):
@with_traced_module
def patched_endpoint(openai, pin, func, instance, args, kwargs):
integration = openai._datadog_integration
g = _traced_endpoint(patch_hook, integration, pin, args, kwargs)
g = _traced_endpoint(patch_hook, integration, instance, pin, args, kwargs)
g.send(None)
resp, err = None, None
try:
Expand All @@ -300,7 +300,7 @@ def _patched_endpoint_async(openai, patch_hook):
@with_traced_module
async def patched_endpoint(openai, pin, func, instance, args, kwargs):
integration = openai._datadog_integration
g = _traced_endpoint(patch_hook, integration, pin, args, kwargs)
g = _traced_endpoint(patch_hook, integration, instance, pin, args, kwargs)
g.send(None)
resp, err = None, None
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
interactions:
- request:
body: '{"model":"text-curie-001","prompt":"how does openai tokenize prompts?","max_tokens":150,"n":1,"stream":true,"temperature":0.8}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '126'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.59.7
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.59.7
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.13.1
method: POST
uri: https://api.openai.com/v1/completions
response:
body:
string: "{\n \"error\": {\n \"message\": \"Incorrect API key provided:
sk-wrong****-key. You can find your API key at https://platform.openai.com/account/api-keys.\",\n
\ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\":
\"invalid_api_key\"\n }\n}\n"
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 9058b3cc3bcdd63c-IAD
Connection:
- keep-alive
Content-Length:
- '266'
Content-Type:
- application/json; charset=utf-8
Date:
- Tue, 21 Jan 2025 16:32:48 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=WUZdhCkUNTJUEkju8qgk4MKCHL7CFOaIUNvU0L9XmvA-1737477168-1.0.1.1-RJ7MOiDyJEfHrXSN0WQVgZFtkxlkwBL3p.5t3._uu77WPJSM8tYzI3wMHSu.yMwD9QkrbgR5yavkTN.RTWl_1A;
path=/; expires=Tue, 21-Jan-25 17:02:48 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=7KOfpy1ICNI532AjhDxBh2qtnyNpsjauHeWi6dEJgT4-1737477168271-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
X-Content-Type-Options:
- nosniff
alt-svc:
- h3=":443"; ma=86400
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
vary:
- Origin
x-request-id:
- req_c45bfc7515dca54ef87c667f8210af23
status:
code: 401
message: Unauthorized
version: 1
3 changes: 1 addition & 2 deletions tests/contrib/openai/test_openai_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,7 @@ def test_misuse(openai, snapshot_tracer):
)
def test_span_finish_on_stream_error(openai, openai_vcr, snapshot_tracer):
with openai_vcr.use_cassette("completion_stream_wrong_api_key.yaml"):
with pytest.raises(openai.APIConnectionError):
with pytest.raises(openai.AuthenticationError):
with pytest.raises((openai.APIConnectionError, openai.AuthenticationError)):
client = openai.OpenAI(api_key="sk-wrong-api-key")
client.completions.create(
model="text-curie-001",
Expand Down

0 comments on commit c53a169

Please sign in to comment.