diff --git a/estuary-cdk/estuary_cdk/capture/common.py b/estuary-cdk/estuary_cdk/capture/common.py index 13079f3fb..55744c3b3 100644 --- a/estuary-cdk/estuary_cdk/capture/common.py +++ b/estuary-cdk/estuary_cdk/capture/common.py @@ -25,6 +25,7 @@ CaptureBinding, ClientCredentialsOAuth2Credentials, ClientCredentialsOAuth2Spec, + AuthorizationCodeFlowOAuth2Credentials, LongLivedClientCredentialsOAuth2Credentials, OAuth2Spec, ValidationError, diff --git a/estuary-cdk/estuary_cdk/flow.py b/estuary-cdk/estuary_cdk/flow.py index 04cbd6909..1e7d0b825 100644 --- a/estuary-cdk/estuary_cdk/flow.py +++ b/estuary-cdk/estuary_cdk/flow.py @@ -142,6 +142,43 @@ class ClientCredentialsOAuth2Credentials(abc.ABC, BaseModel): ) +class AuthorizationCodeFlowOAuth2Credentials(abc.ABC, BaseModel): + credentials_title: Literal["OAuth Credentials"] = Field( + default="OAuth Credentials", json_schema_extra={"type": "string"} + ) + client_id: str = Field( + title="Client Id", + json_schema_extra={"secret": True}, + ) + client_secret: str = Field( + title="Client Secret", + json_schema_extra={"secret": True}, + ) + + @abc.abstractmethod + def _you_must_build_oauth2_credentials_for_a_provider(self): ... + + @staticmethod + def for_provider( + provider: str, + ) -> type["AuthorizationCodeFlowOAuth2Credentials"]: + """ + Builds an OAuth2Credentials model for the given OAuth2 `provider`. + This routine is only available in Pydantic V2 environments. + """ + from pydantic import ConfigDict + + class _OAuth2Credentials(AuthorizationCodeFlowOAuth2Credentials): + model_config = ConfigDict( + json_schema_extra={"x-oauth2-provider": provider}, + title="OAuth", + ) + + def _you_must_build_oauth2_credentials_for_a_provider(self): ... + + return _OAuth2Credentials + + class LongLivedClientCredentialsOAuth2Credentials(abc.ABC, BaseModel): credentials_title: Literal["OAuth Credentials"] = Field( default="OAuth Credentials", diff --git a/estuary-cdk/estuary_cdk/http.py b/estuary-cdk/estuary_cdk/http.py index e38aefe28..b688f4eb2 100644 --- a/estuary-cdk/estuary_cdk/http.py +++ b/estuary-cdk/estuary_cdk/http.py @@ -14,6 +14,7 @@ AccessToken, BasicAuth, BaseOAuth2Credentials, + AuthorizationCodeFlowOAuth2Credentials, ClientCredentialsOAuth2Credentials, ClientCredentialsOAuth2Spec, LongLivedClientCredentialsOAuth2Credentials, @@ -24,16 +25,19 @@ StreamedObject = TypeVar("StreamedObject", bound=BaseModel) + class HTTPError(RuntimeError): """ - HTTPError is an custom error class that provides the HTTP status code + HTTPError is an custom error class that provides the HTTP status code as a distinct attribute. """ + def __init__(self, message: str, code: int): super().__init__(message) self.code = code self.message = message + class HTTPSession(abc.ABC): """ HTTPSession is an abstract base class for an HTTP client implementation. @@ -60,7 +64,7 @@ async def request( json: dict[str, Any] | None = None, form: dict[str, Any] | None = None, _with_token: bool = True, # Unstable internal API. - headers: dict[str, Any] = {} + headers: dict[str, Any] = {}, ) -> bytes: """Request a url and return its body as bytes""" @@ -102,7 +106,7 @@ async def request_lines( yield buffer return - + async def request_stream( self, log: Logger, @@ -114,9 +118,7 @@ async def request_stream( ) -> AsyncGenerator[bytes, None]: """Request a url and and return the raw response as a stream of bytes""" - return self._request_stream( - log, url, method, params, json, form, True - ) + return self._request_stream(log, url, method, params, json, form, True) @abc.abstractmethod def _request_stream( @@ -128,7 +130,7 @@ def _request_stream( json: dict[str, Any] | None, form: dict[str, Any] | None, _with_token: bool, - headers: dict[str, Any] = {} + headers: dict[str, Any] = {}, ) -> AsyncGenerator[bytes, None]: ... # TODO(johnny): This is an unstable API. @@ -138,7 +140,6 @@ def _request_stream( @dataclass class TokenSource: - class AccessTokenResponse(BaseModel): access_token: str token_type: str @@ -147,13 +148,22 @@ class AccessTokenResponse(BaseModel): scope: str = "" oauth_spec: OAuth2Spec | ClientCredentialsOAuth2Spec | None - credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials | LongLivedClientCredentialsOAuth2Credentials | AccessToken | BasicAuth + credentials: ( + BaseOAuth2Credentials + | ClientCredentialsOAuth2Credentials + | AuthorizationCodeFlowOAuth2Credentials + | LongLivedClientCredentialsOAuth2Credentials + | AccessToken + | BasicAuth + ) authorization_header: str = DEFAULT_AUTHORIZATION_HEADER _access_token: AccessTokenResponse | None = None _fetched_at: int = 0 async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str]: - if isinstance(self.credentials, AccessToken) or isinstance(self.credentials, LongLivedClientCredentialsOAuth2Credentials): + if isinstance(self.credentials, AccessToken) or isinstance( + self.credentials, LongLivedClientCredentialsOAuth2Credentials + ): return ("Bearer", self.credentials.access_token) elif isinstance(self.credentials, BasicAuth): return ( @@ -163,7 +173,11 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str ).decode(), ) - assert isinstance(self.credentials, BaseOAuth2Credentials) or isinstance(self.credentials, ClientCredentialsOAuth2Credentials) + assert ( + isinstance(self.credentials, BaseOAuth2Credentials) + or isinstance(self.credentials, ClientCredentialsOAuth2Credentials) + or isinstance(self.credentials, AuthorizationCodeFlowOAuth2Credentials) + ) current_time = time.time() if self._access_token is not None: @@ -184,7 +198,12 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str return ("Bearer", self._access_token.access_token) async def _fetch_oauth2_token( - self, log: Logger, session: HTTPSession, credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials + self, + log: Logger, + session: HTTPSession, + credentials: BaseOAuth2Credentials + | ClientCredentialsOAuth2Credentials + | AuthorizationCodeFlowOAuth2Credentials, ) -> AccessTokenResponse: assert self.oauth_spec @@ -204,10 +223,17 @@ async def _fetch_oauth2_token( "grant_type": "client_credentials", } headers = { - "Authorization": "Basic " + base64.b64encode( + "Authorization": "Basic " + + base64.b64encode( f"{credentials.client_id}:{credentials.client_secret}".encode() ).decode() } + case AuthorizationCodeFlowOAuth2Credentials(): + form = { + "grant_type": "authorization_code", + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + } case _: raise TypeError(f"Unsupported credentials type: {type(credentials)}.") @@ -266,7 +292,6 @@ def error_ratio(self) -> float: # HTTPMixin is an opinionated implementation of HTTPSession. class HTTPMixin(Mixin, HTTPSession): - inner: aiohttp.ClientSession rate_limiter: RateLimiter token_source: TokenSource | None = None @@ -292,13 +317,17 @@ async def _request_stream( headers: dict[str, Any] = {}, ) -> AsyncGenerator[bytes, None]: while True: - cur_delay = self.rate_limiter.delay await asyncio.sleep(cur_delay) if _with_token and self.token_source is not None: token_type, token = await self.token_source.fetch_token(log, self) - header_value = f"{token_type} {token}" if self.token_source.authorization_header == DEFAULT_AUTHORIZATION_HEADER else f"{token}" + header_value = ( + f"{token_type} {token}" + if self.token_source.authorization_header + == DEFAULT_AUTHORIZATION_HEADER + else f"{token}" + ) headers[self.token_source.authorization_header] = header_value async with self.inner.request( @@ -309,7 +338,6 @@ async def _request_stream( params=params, url=url, ) as resp: - self.rate_limiter.update(cur_delay, resp.status == 429) if resp.status == 429: @@ -333,7 +361,7 @@ async def _request_stream( body = await resp.read() raise HTTPError( f"Encountered HTTP error status {resp.status} which cannot be retried.\nURL: {url}\nResponse:\n{body.decode('utf-8')}", - resp.status + resp.status, ) else: resp.raise_for_status()