Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix all type hint and docstrings for callable #5035

Merged
merged 2 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/hyperparameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, name, validate=lambda _: True, validation_message="", data_ty
"""Args:

name (str): The name of this hyperparameter validate
(callable[object]->[bool]): A validation function or list of validation
(Callable[object]->[bool]): A validation function or list of validation
functions.

Each function validates an object and returns False if the object
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/ipinsights.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
chain.
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
serializes input data to text/csv.
deserializer (callable): Optional. Default parses JSON responses
deserializer (Callable): Optional. Default parses JSON responses
using ``json.load(...)``.
component_name (str): Optional. Name of the Amazon SageMaker inference
component corresponding the predictor.
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def create_model(
training cluster for distributed training. Default: False
model_kms_key (str): KMS key ARN used to encrypt the repacked
model archive file if the model is repacked
predictor_cls (callable[string, sagemaker.session.Session]): A
Callable[[string, sagemaker.session.Session], Any]: A
function to call to create a predictor (default: None). If
specified, ``deploy()`` returns the result of invoking this
function on the created endpoint name.
Expand Down Expand Up @@ -591,7 +591,7 @@ def deploy(
training cluster for distributed training. Default: False
model_kms_key (str): KMS key ARN used to encrypt the repacked
model archive file if the model is repacked
predictor_cls (callable[string, sagemaker.session.Session]): A
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
function to call to create a predictor (default: None). If
specified, ``deploy()`` returns the result of invoking this
function on the created endpoint name.
Expand All @@ -609,7 +609,7 @@ def deploy(
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests

Returns:
callable[string, sagemaker.session.Session] or ``None``:
Optional[Callable[[string, sagemaker.session.Session], Any]]:
If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on
the created endpoint name. Otherwise, ``None``.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/automl/automlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def create_model(
training cluster for distributed training. Default: False
model_kms_key (str): KMS key ARN used to encrypt the repacked
model archive file if the model is repacked
predictor_cls (callable[string, sagemaker.session.Session]): A
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
function to call to create a predictor (default: None). If
specified, ``deploy()`` returns the result of invoking this
function on the created endpoint name.
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def deploy(
training cluster for distributed training. Default: False
model_kms_key (str): KMS key ARN used to encrypt the repacked
model archive file if the model is repacked
predictor_cls (callable[string, sagemaker.session.Session]): A
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
function to call to create a predictor (default: None). If
specified, ``deploy()`` returns the result of invoking this
function on the created endpoint name.
Expand All @@ -1148,7 +1148,7 @@ def deploy(
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests

Returns:
callable[string, sagemaker.session.Session] or ``None``:
Optional[Callable[[string, sagemaker.session.Session], Any]]:
If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on
the created endpoint name. Otherwise, ``None``.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

import logging
from typing import Optional, Union, List, Dict
from typing import Callable, Optional, Union, List, Dict

import sagemaker
from sagemaker import image_uris, ModelMetrics
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
image_uri: Optional[Union[str, PipelineVariable]] = None,
framework_version: Optional[str] = None,
py_version: Optional[str] = None,
predictor_cls: callable = ChainerPredictor,
predictor_cls: Optional[Callable] = ChainerPredictor,
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
**kwargs,
):
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
py_version (str): Python version you want to use for executing your
model training code. Defaults to ``None``. Required unless
``image_uri`` is provided.
predictor_cls (callable[str, sagemaker.session.Session]): A function
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function
to call to create a predictor with an endpoint name and
SageMaker ``Session``. If specified, ``deploy()`` returns the
result of invoking this function on the created endpoint name.
Expand Down
12 changes: 6 additions & 6 deletions src/sagemaker/djl_inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

import logging
from typing import Optional, Dict, Any
from typing import Callable, Optional, Dict, Any

from sagemaker import image_uris
from sagemaker.model import Model
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
parallel_loading: bool = False,
model_loading_timeout: Optional[int] = None,
prediction_timeout: Optional[int] = None,
predictor_cls: callable = DJLPredictor,
predictor_cls: Optional[Callable] = DJLPredictor,
huggingface_hub_token: Optional[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -97,10 +97,10 @@ def __init__(
None. If not provided, the default is 240 seconds.
prediction_timeout (int): The worker predict call (handler) timeout in seconds.
Defaults to None. If not provided, the default is 120 seconds.
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a
predictor with an endpoint name and SageMaker ``Session``. If specified,
``deploy()`` returns
the result of invoking this function on the created endpoint name.
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call
to create a predictor with an endpoint name and SageMaker ``Session``. If
specified, ``deploy()`` returns the result of invoking this function on the created
endpoint name.
huggingface_hub_token (str): The HuggingFace Hub token to use for downloading the model
artifacts for a model stored on the huggingface hub.
Defaults to None. If not provided, the token must be specified in the
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

import logging
from typing import Optional, Union, List, Dict
from typing import Callable, Optional, Union, List, Dict

import sagemaker
from sagemaker import image_uris, ModelMetrics
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(
pytorch_version: Optional[str] = None,
py_version: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
predictor_cls: callable = HuggingFacePredictor,
predictor_cls: Optional[Callable] = HuggingFacePredictor,
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
**kwargs,
):
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(
If not specified, a default image for PyTorch will be used. If ``framework_version``
or ``py_version`` are ``None``, then ``image_uri`` is required. If
also ``None``, then a ``ValueError`` will be raised.
predictor_cls (callable[str, sagemaker.session.Session]): A function
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function
to call to create a predictor with an endpoint name and
SageMaker ``Session``. If specified, ``deploy()`` returns the
result of invoking this function on the created endpoint name.
Expand Down Expand Up @@ -304,7 +304,7 @@ def deploy(
- If a wrong type of object is provided as serverless inference config or async
inference config
Returns:
callable[string, sagemaker.session.Session] or None: Invocation of
Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
is not None. Otherwise, return None.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import


from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
from sagemaker import session
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.base_deserializers import BaseDeserializer
Expand Down Expand Up @@ -817,7 +817,7 @@ def deploy(
explainer_config: Optional[ExplainerConfig] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
role: Optional[str] = None,
predictor_cls: Optional[callable] = None,
predictor_cls: Optional[Callable] = None,
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
model_name: Optional[str] = None,
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
Expand Down Expand Up @@ -918,7 +918,7 @@ def deploy(
It can be null if this is being used to create a Model to pass
to a ``PipelineModel`` which has its own Role field. (Default:
None).
predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A
predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A
function to call to create a predictor (Default: None). If not
None, ``deploy`` will return the result of invoking this
function on the created endpoint name. (Default: None).
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import


from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
from sagemaker import (
environment_variables,
hyperparameters as hyperparameters_utils,
Expand Down Expand Up @@ -330,7 +330,7 @@ def get_deploy_kwargs(
explainer_config: Optional[ExplainerConfig] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
role: Optional[str] = None,
predictor_cls: Optional[callable] = None,
predictor_cls: Optional[Callable] = None,
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
sagemaker_session: Optional[Session] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json


from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from sagemaker_core.shapes import ModelAccessConfig
from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
Expand Down Expand Up @@ -855,7 +855,7 @@ def get_init_kwargs(
image_uri: Optional[Union[str, PipelineVariable]] = None,
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
role: Optional[str] = None,
predictor_cls: Optional[callable] = None,
predictor_cls: Optional[Callable] = None,
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
name: Optional[str] = None,
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import absolute_import

from typing import Dict, List, Optional, Any, Union
from typing import Callable, Dict, List, Optional, Any, Union
import pandas as pd
from botocore.exceptions import ClientError

Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
image_uri: Optional[Union[str, PipelineVariable]] = None,
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
role: Optional[str] = None,
predictor_cls: Optional[callable] = None,
predictor_cls: Optional[Callable] = None,
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
name: Optional[str] = None,
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(
It can be null if this is being used to create a Model to pass
to a ``PipelineModel`` which has its own Role field. (Default:
None).
predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A
predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A
function to call to create a predictor (Default: None). If not
None, ``deploy`` will return the result of invoking this
function on the created endpoint name. (Default: None).
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import re
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union
from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
from sagemaker.utils import (
Expand Down Expand Up @@ -2150,7 +2150,7 @@ def __init__(
image_uri: Optional[Union[str, Any]] = None,
model_data: Optional[Union[str, Any, dict]] = None,
role: Optional[str] = None,
predictor_cls: Optional[callable] = None,
predictor_cls: Optional[Callable] = None,
env: Optional[Dict[str, Union[str, Any]]] = None,
name: Optional[str] = None,
vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None,
Expand Down Expand Up @@ -2698,7 +2698,7 @@ def __init__(
explainer_config: Optional[Any] = None,
image_uri: Optional[Union[str, Any]] = None,
role: Optional[str] = None,
predictor_cls: Optional[callable] = None,
predictor_cls: Optional[Callable] = None,
env: Optional[Dict[str, Union[str, Any]]] = None,
model_name: Optional[str] = None,
vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None,
Expand Down
12 changes: 6 additions & 6 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
import re
import copy
from typing import List, Dict, Optional, Union, Any
from typing import Callable, List, Dict, Optional, Union, Any

import sagemaker
from sagemaker import (
Expand Down Expand Up @@ -154,7 +154,7 @@ def __init__(
image_uri: Optional[Union[str, PipelineVariable]] = None,
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
role: Optional[str] = None,
predictor_cls: Optional[callable] = None,
predictor_cls: Optional[Callable] = None,
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
name: Optional[str] = None,
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
Expand Down Expand Up @@ -186,7 +186,7 @@ def __init__(
It can be null if this is being used to create a Model to pass
to a ``PipelineModel`` which has its own Role field. (default:
None)
predictor_cls (callable[string, sagemaker.session.Session]): A
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
function to call to create a predictor (default: None). If not
None, ``deploy`` will return the result of invoking this
function on the created endpoint name.
Expand Down Expand Up @@ -1501,7 +1501,7 @@ def deploy(
inference config or
- If inference recommendation id is specified along with incompatible parameters
Returns:
callable[string, sagemaker.session.Session] or None: Invocation of
Callable[[string, sagemaker.session.Session], Any] or None: Invocation of
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
is not None. Otherwise, return None.
"""
Expand Down Expand Up @@ -1959,7 +1959,7 @@ def __init__(
role: Optional[str] = None,
entry_point: Optional[str] = None,
source_dir: Optional[str] = None,
predictor_cls: Optional[callable] = None,
predictor_cls: Optional[Callable] = None,
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
name: Optional[str] = None,
container_log_level: Union[int, PipelineVariable] = logging.INFO,
Expand Down Expand Up @@ -2012,7 +2012,7 @@ def __init__(
>>> |----- test.py

You can assign entry_point='inference.py', source_dir='src'.
predictor_cls (callable[string, sagemaker.session.Session]): A
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
function to call to create a predictor (default: None). If not
None, ``deploy`` will return the result of invoking this
function on the created endpoint name.
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def deploy(
Amazon SageMaker Model Monitoring. Default: None.

Returns:
callable[string, sagemaker.session.Session] or None: Invocation of
Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of
``self.predictor_cls`` on the created endpoint name,
if ``self.predictor_cls``
is not None. Otherwise, return None.
Expand Down
Loading
Loading