Skip to content

No public description #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,16 +442,50 @@ def __init__(
else:
if self._http_options.headers is not None:
_append_library_version_headers(self._http_options.headers)
# Initialize the httpx client.
# Unlike requests, the httpx package does not automatically pull in the
# environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
# enabled explicitly.
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),

client_args, async_client_args = self._ensure_ssl_ctx(self._http_options)
self._httpx_client = SyncHttpxClient(**client_args)
self._async_httpx_client = AsyncHttpxClient(**async_client_args)

@staticmethod
def _ensure_ssl_ctx(options: HttpOptions) -> (
Tuple[dict[str, Any], dict[str, Any]]):
"""Ensures the SSL context is present in the client args.

Create a default SSL context is not provided.

Args:
options: The http options to update.

Returns:
A tuple of sync and async httpx client options.
"""

verify = 'verify'
args = options.client_args
async_args = options.async_client_args
ctx = (
args.get(verify) if args else None
or async_args.get(verify) if async_args else None
)

if not ctx:
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
)

def _maybe_set(args: dict[str, Any], ctx: ssl.SSLContext) -> dict[str, Any]:
"""Sets the SSL context in the client args if not set by making a copy."""
if not args or not args.get(verify):
args = (args or {}).copy()
args[verify] = ctx
return args

return (
_maybe_set(args, ctx),
_maybe_set(async_args, ctx),
)
self._httpx_client = SyncHttpxClient(verify=ctx)
self._async_httpx_client = AsyncHttpxClient(verify=ctx)

def _websocket_base_url(self):
url_parts = urlparse(self._http_options.base_url)
Expand Down
73 changes: 73 additions & 0 deletions google/genai/tests/client/test_client_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

"""Tests for client initialization."""

import certifi
import google.auth
from google.auth import credentials
import logging
import os
import pytest
import ssl

from ... import _api_client as api_client
from ... import _replay_api_client as replay_api_client
Expand Down Expand Up @@ -587,3 +590,73 @@ def test_client_logs_to_logger_instance(monkeypatch, caplog):

assert 'INFO' in caplog.text
assert 'The user provided Vertex AI API key will take precedence' in caplog.text

def test_client_ssl_context_implicit_initialization():
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
api_client.HttpOptions())

assert client_args["verify"]
assert async_client_args["verify"]
assert isinstance(client_args["verify"], ssl.SSLContext)
assert isinstance(async_client_args["verify"], ssl.SSLContext)

def test_client_ssl_context_explicit_initialization_same_args():
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
)

options = api_client.HttpOptions(
client_args={"verify": ctx}, async_client_args={"verify": ctx})
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
options)

assert client_args["verify"] == ctx
assert async_client_args["verify"] == ctx

def test_client_ssl_context_explicit_initialization_separate_args():
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
)

async_ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
)

options = api_client.HttpOptions(
client_args={"verify": ctx}, async_client_args={"verify": async_ctx})
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
options)

assert client_args["verify"] == ctx
assert async_client_args["verify"] == async_ctx

def test_client_ssl_context_explicit_initialization_sync_args():
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
)

options = api_client.HttpOptions(
client_args={"verify": ctx})
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
options)

assert client_args["verify"] == ctx
assert async_client_args["verify"] == ctx

def test_client_ssl_context_explicit_initialization_async_args():
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
)

options = api_client.HttpOptions(
async_client_args={"verify": ctx})
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
options)

assert client_args["verify"] == ctx
assert async_client_args["verify"] == ctx
2 changes: 2 additions & 0 deletions google/genai/tests/client/test_http_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def test_patch_http_options_with_copies_all_fields():
api_version='v1',
headers={'X-Custom-Header': 'custom_value'},
timeout=10000,
client_args={'http2': True},
async_client_args={'http1': True},
)
options = types.HttpOptions()
patched = _api_client._patch_http_options(options, patch_options)
Expand Down
14 changes: 14 additions & 0 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,14 @@ class HttpOptions(_common.BaseModel):
timeout: Optional[int] = Field(
default=None, description="""Timeout for the request in milliseconds."""
)
client_args: Optional[dict[str, Any]] = Field(
default=None,
description="""Args passed directly to the sync HTTP client.""",
)
async_client_args: Optional[dict[str, Any]] = Field(
default=None,
description="""Args passed directly to the async HTTP client.""",
)


class HttpOptionsDict(TypedDict, total=False):
Expand All @@ -837,6 +845,12 @@ class HttpOptionsDict(TypedDict, total=False):
timeout: Optional[int]
"""Timeout for the request in milliseconds."""

client_args: Optional[dict[str, Any]]
"""Args passed directly to the sync HTTP client."""

async_client_args: Optional[dict[str, Any]]
"""Args passed directly to the async HTTP client."""


HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]

Expand Down
Loading