Skip to content

feat: Support setting the default base URL in clients via BaseURL.setDefaultBaseUrls() #715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions google/genai/_base_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
from typing import Optional

from .types import HttpOptions

_default_base_gemini_url = None
_default_base_vertex_url = None


class BaseUrlParameters:
"""Parameters for setting the base URLs for the Gemini API and Vertex AI API."""

gemini_url: Optional[str]
vertex_url: Optional[str]

def __init__(
self,
gemini_url: Optional[str],
vertex_url: Optional[str],
):
self.gemini_url = gemini_url
self.vertex_url = vertex_url


def set_default_base_urls(base_url_params: BaseUrlParameters) -> None:
"""Overrides the base URLs for the Gemini API and Vertex AI API."""
global _default_base_gemini_url, _default_base_vertex_url
_default_base_gemini_url = base_url_params.gemini_url
_default_base_vertex_url = base_url_params.vertex_url


def get_default_base_urls() -> BaseUrlParameters:
"""Returns the base URLs for the Gemini API and Vertex AI API."""
return BaseUrlParameters(
gemini_url=_default_base_gemini_url, vertex_url=_default_base_vertex_url
)


def get_base_url(
vertexai: bool,
http_options: Optional[HttpOptions] = None,
) -> Optional[str]:
"""Returns the default base URL based on the following priority.

1. Base URLs set via HttpOptions.
2. Base URLs set via the latest call to setDefaultBaseUrls.
3. Base URLs set via environment variables.
"""
if http_options and http_options.base_url:
return http_options.base_url

if vertexai:
return _default_base_vertex_url or os.getenv('GOOGLE_VERTEX_BASE_URL')
else:
return _default_base_gemini_url or os.getenv('GOOGLE_GEMINI_BASE_URL')
29 changes: 20 additions & 9 deletions google/genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pydantic

from ._api_client import BaseApiClient
from ._base_url import get_base_url
from ._replay_api_client import ReplayApiClient
from .batches import AsyncBatches, Batches
from .caches import AsyncCaches, Caches
Expand Down Expand Up @@ -78,6 +79,7 @@ def live(self) -> AsyncLive:
def operations(self) -> AsyncOperations:
return self._operations


class DebugConfig(pydantic.BaseModel):
"""Configuration options that change client network behavior when testing."""

Expand Down Expand Up @@ -114,26 +116,28 @@ class Client:
Attributes:
api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
use for authentication. Applies to the Gemini Developer API only.
vertexai: Indicates whether the client should use the Vertex AI
API endpoints. Defaults to False (uses Gemini Developer API endpoints).
vertexai: Indicates whether the client should use the Vertex AI API
endpoints. Defaults to False (uses Gemini Developer API endpoints).
Applies to the Vertex AI API only.
credentials: The credentials to use for authentication when calling the
Vertex AI APIs. Credentials can be obtained from environment variables and
default credentials. For more information, see
`Set up Application Default Credentials
default credentials. For more information, see `Set up Application Default
Credentials
<https://cloud.google.com/docs/authentication/provide-credentials-adc>`_.
Applies to the Vertex AI API only.
project: The `Google Cloud project ID <https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to
use for quota. Can be obtained from environment variables (for example,
project: The `Google Cloud project ID
<https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to use
for quota. Can be obtained from environment variables (for example,
``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only.
location: The `location <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_
location: The `location
<https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_
to send API requests to (for example, ``us-central1``). Can be obtained
from environment variables. Applies to the Vertex AI API only.
debug_config: Config settings that control network behavior of the client.
This is typically used when running test code.
http_options: Http options to use for the client. These options will be
applied to all requests made by the client. Example usage:
`client = genai.Client(http_options=types.HttpOptions(api_version='v1'))`.
applied to all requests made by the client. Example usage: `client =
genai.Client(http_options=types.HttpOptions(api_version='v1'))`.

Usage for the Gemini Developer API:

Expand Down Expand Up @@ -198,6 +202,13 @@ def __init__(
if isinstance(http_options, dict):
http_options = HttpOptions(**http_options)

base_url = get_base_url(vertexai or False, http_options)
if base_url:
if http_options:
http_options.base_url = base_url
else:
http_options = HttpOptions(base_url=base_url)

self._api_client = self._get_api_client(
vertexai=vertexai,
api_key=api_key,
Expand Down
167 changes: 136 additions & 31 deletions google/genai/tests/client/test_client_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@

"""Tests for client initialization."""

import logging
import os
import ssl

import certifi
import google.auth
from google.auth import credentials
import logging
import os
import pytest
import ssl

from ... import _api_client as api_client
from ... import _base_url as base_url
from ... import _replay_api_client as replay_api_client
from ... import Client

Expand Down Expand Up @@ -274,6 +276,7 @@ def test_invalid_vertexai_constructor_empty(monkeypatch):
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "")
monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "")
monkeypatch.setenv("GOOGLE_API_KEY", "")

def mock_auth_default(scopes=None):
return None, None

Expand Down Expand Up @@ -319,10 +322,7 @@ def test_invalid_vertexai_constructor3(monkeypatch):
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
project_id = "fake_project_id"
with pytest.raises(ValueError):
Client(
vertexai=True,
project=project_id
)
Client(vertexai=True, project=project_id)


def test_vertexai_explicit_arg_precedence1(monkeypatch):
Expand Down Expand Up @@ -578,7 +578,7 @@ def test_vertexai_global_endpoint(monkeypatch):


def test_client_logs_to_logger_instance(monkeypatch, caplog):
caplog.set_level(logging.DEBUG, logger='google_genai._api_client')
caplog.set_level(logging.DEBUG, logger="google_genai._api_client")

project_id = "fake_project_id"
location = "fake-location"
Expand All @@ -588,75 +588,180 @@ def test_client_logs_to_logger_instance(monkeypatch, caplog):

_ = Client(vertexai=True, api_key=api_key)

assert 'INFO' in caplog.text
assert 'The user provided Vertex AI API key will take precedence' in caplog.text
assert "INFO" in caplog.text
assert (
"The user provided Vertex AI API key will take precedence" in caplog.text
)


def test_client_ssl_context_implicit_initialization():
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
api_client.HttpOptions())
api_client.HttpOptions()
)

assert client_args["verify"]
assert async_client_args["verify"]
assert isinstance(client_args["verify"], ssl.SSLContext)
assert isinstance(async_client_args["verify"], ssl.SSLContext)


def test_client_ssl_context_explicit_initialization_same_args():
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
capath=os.environ.get("SSL_CERT_DIR"),
)

options = api_client.HttpOptions(
client_args={"verify": ctx}, async_client_args={"verify": ctx})
client_args={"verify": ctx}, async_client_args={"verify": ctx}
)
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
options)
options
)

assert client_args["verify"] == ctx
assert async_client_args["verify"] == ctx


def test_client_ssl_context_explicit_initialization_separate_args():
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
capath=os.environ.get("SSL_CERT_DIR"),
)

async_ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
capath=os.environ.get("SSL_CERT_DIR"),
)

options = api_client.HttpOptions(
client_args={"verify": ctx}, async_client_args={"verify": async_ctx})
client_args={"verify": ctx}, async_client_args={"verify": async_ctx}
)
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
options)
options
)

assert client_args["verify"] == ctx
assert async_client_args["verify"] == async_ctx


def test_client_ssl_context_explicit_initialization_sync_args():
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
capath=os.environ.get("SSL_CERT_DIR"),
)

options = api_client.HttpOptions(
client_args={"verify": ctx})
options = api_client.HttpOptions(client_args={"verify": ctx})
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
options)
options
)

assert client_args["verify"] == ctx
assert async_client_args["verify"] == ctx


def test_client_ssl_context_explicit_initialization_async_args():
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
cafile=os.environ.get("SSL_CERT_FILE", certifi.where()),
capath=os.environ.get("SSL_CERT_DIR"),
)

options = api_client.HttpOptions(
async_client_args={"verify": ctx})
options = api_client.HttpOptions(async_client_args={"verify": ctx})
client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx(
options)
options
)

assert client_args["verify"] == ctx
assert async_client_args["verify"] == ctx


def test_constructor_with_base_url_from_http_options():
mldev_http_options = {
"base_url": "https://placeholder-fake-url.com/",
}
vertexai_http_options = {
"base_url": (
"https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
),
}

mldev_client = Client(
api_key="google_api_key", http_options=mldev_http_options
)
assert not mldev_client.models._api_client.vertexai
assert (
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
== "https://placeholder-fake-url.com/"
)

vertexai_client = Client(
vertexai=True,
project="fake_project_id",
location="fake-location",
http_options=vertexai_http_options,
)
assert vertexai_client.models._api_client.vertexai
assert (
vertexai_client.models._api_client.get_read_only_http_options()[
"base_url"
]
== "https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
)


def test_constructor_with_base_url_from_set_default_base_urls():
base_url.set_default_base_urls(
base_url.BaseUrlParameters(
gemini_url="https://gemini-base-url.com/",
vertex_url="https://vertex-base-url.com/",
)
)
mldev_client = Client(api_key="google_api_key")
assert not mldev_client.models._api_client.vertexai
assert (
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
== "https://gemini-base-url.com/"
)

vertexai_client = Client(
vertexai=True,
project="fake_project_id",
location="fake-location",
)
assert vertexai_client.models._api_client.vertexai
assert (
vertexai_client.models._api_client.get_read_only_http_options()[
"base_url"
]
== "https://vertex-base-url.com/"
)
base_url.set_default_base_urls(
base_url.BaseUrlParameters(
gemini_url=None,
vertex_url=None,
)
)


def test_constructor_with_base_url_from_environment_variables(monkeypatch):
monkeypatch.setenv("GOOGLE_GEMINI_BASE_URL", "https://gemini-base-url.com/")
monkeypatch.setenv("GOOGLE_VERTEX_BASE_URL", "https://vertex-base-url.com/")

mldev_client = Client(api_key="google_api_key")
assert not mldev_client.models._api_client.vertexai
assert (
mldev_client.models._api_client.get_read_only_http_options()["base_url"]
== "https://gemini-base-url.com/"
)

vertexai_client = Client(
vertexai=True,
project="fake_project_id",
location="fake-location",
)
assert vertexai_client.models._api_client.vertexai
assert (
vertexai_client.models._api_client.get_read_only_http_options()[
"base_url"
]
== "https://vertex-base-url.com/"
)