From 22d171338fb9c3740f16788026221479a223fa94 Mon Sep 17 00:00:00 2001 From: = <=> Date: Fri, 21 Feb 2025 13:11:24 -0600 Subject: [PATCH] estuary-cdk: support authorization code grant type Add an additional method of authentication using the authorization code flow grant type in the request to get access tokens. This is required by Monday.com. So this update supports the `source-monday` connector OAuth2 authentication once that is enabled. --- estuary-cdk/estuary_cdk/capture/common.py | 1 + estuary-cdk/estuary_cdk/flow.py | 37 +++++++++++++ estuary-cdk/estuary_cdk/http.py | 64 ++++++++++++++++------- 3 files changed, 84 insertions(+), 18 deletions(-) diff --git a/estuary-cdk/estuary_cdk/capture/common.py b/estuary-cdk/estuary_cdk/capture/common.py index 13079f3fbf..55744c3b3e 100644 --- a/estuary-cdk/estuary_cdk/capture/common.py +++ b/estuary-cdk/estuary_cdk/capture/common.py @@ -25,6 +25,7 @@ CaptureBinding, ClientCredentialsOAuth2Credentials, ClientCredentialsOAuth2Spec, + AuthorizationCodeFlowOAuth2Credentials, LongLivedClientCredentialsOAuth2Credentials, OAuth2Spec, ValidationError, diff --git a/estuary-cdk/estuary_cdk/flow.py b/estuary-cdk/estuary_cdk/flow.py index 04cbd6909a..1e7d0b8259 100644 --- a/estuary-cdk/estuary_cdk/flow.py +++ b/estuary-cdk/estuary_cdk/flow.py @@ -142,6 +142,43 @@ class ClientCredentialsOAuth2Credentials(abc.ABC, BaseModel): ) +class AuthorizationCodeFlowOAuth2Credentials(abc.ABC, BaseModel): + credentials_title: Literal["OAuth Credentials"] = Field( + default="OAuth Credentials", json_schema_extra={"type": "string"} + ) + client_id: str = Field( + title="Client Id", + json_schema_extra={"secret": True}, + ) + client_secret: str = Field( + title="Client Secret", + json_schema_extra={"secret": True}, + ) + + @abc.abstractmethod + def _you_must_build_oauth2_credentials_for_a_provider(self): ... + + @staticmethod + def for_provider( + provider: str, + ) -> type["AuthorizationCodeFlowOAuth2Credentials"]: + """ + Builds an OAuth2Credentials model for the given OAuth2 `provider`. + This routine is only available in Pydantic V2 environments. + """ + from pydantic import ConfigDict + + class _OAuth2Credentials(AuthorizationCodeFlowOAuth2Credentials): + model_config = ConfigDict( + json_schema_extra={"x-oauth2-provider": provider}, + title="OAuth", + ) + + def _you_must_build_oauth2_credentials_for_a_provider(self): ... + + return _OAuth2Credentials + + class LongLivedClientCredentialsOAuth2Credentials(abc.ABC, BaseModel): credentials_title: Literal["OAuth Credentials"] = Field( default="OAuth Credentials", diff --git a/estuary-cdk/estuary_cdk/http.py b/estuary-cdk/estuary_cdk/http.py index e38aefe285..b688f4eb23 100644 --- a/estuary-cdk/estuary_cdk/http.py +++ b/estuary-cdk/estuary_cdk/http.py @@ -14,6 +14,7 @@ AccessToken, BasicAuth, BaseOAuth2Credentials, + AuthorizationCodeFlowOAuth2Credentials, ClientCredentialsOAuth2Credentials, ClientCredentialsOAuth2Spec, LongLivedClientCredentialsOAuth2Credentials, @@ -24,16 +25,19 @@ StreamedObject = TypeVar("StreamedObject", bound=BaseModel) + class HTTPError(RuntimeError): """ - HTTPError is an custom error class that provides the HTTP status code + HTTPError is an custom error class that provides the HTTP status code as a distinct attribute. """ + def __init__(self, message: str, code: int): super().__init__(message) self.code = code self.message = message + class HTTPSession(abc.ABC): """ HTTPSession is an abstract base class for an HTTP client implementation. @@ -60,7 +64,7 @@ async def request( json: dict[str, Any] | None = None, form: dict[str, Any] | None = None, _with_token: bool = True, # Unstable internal API. - headers: dict[str, Any] = {} + headers: dict[str, Any] = {}, ) -> bytes: """Request a url and return its body as bytes""" @@ -102,7 +106,7 @@ async def request_lines( yield buffer return - + async def request_stream( self, log: Logger, @@ -114,9 +118,7 @@ async def request_stream( ) -> AsyncGenerator[bytes, None]: """Request a url and and return the raw response as a stream of bytes""" - return self._request_stream( - log, url, method, params, json, form, True - ) + return self._request_stream(log, url, method, params, json, form, True) @abc.abstractmethod def _request_stream( @@ -128,7 +130,7 @@ def _request_stream( json: dict[str, Any] | None, form: dict[str, Any] | None, _with_token: bool, - headers: dict[str, Any] = {} + headers: dict[str, Any] = {}, ) -> AsyncGenerator[bytes, None]: ... # TODO(johnny): This is an unstable API. @@ -138,7 +140,6 @@ def _request_stream( @dataclass class TokenSource: - class AccessTokenResponse(BaseModel): access_token: str token_type: str @@ -147,13 +148,22 @@ class AccessTokenResponse(BaseModel): scope: str = "" oauth_spec: OAuth2Spec | ClientCredentialsOAuth2Spec | None - credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials | LongLivedClientCredentialsOAuth2Credentials | AccessToken | BasicAuth + credentials: ( + BaseOAuth2Credentials + | ClientCredentialsOAuth2Credentials + | AuthorizationCodeFlowOAuth2Credentials + | LongLivedClientCredentialsOAuth2Credentials + | AccessToken + | BasicAuth + ) authorization_header: str = DEFAULT_AUTHORIZATION_HEADER _access_token: AccessTokenResponse | None = None _fetched_at: int = 0 async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str]: - if isinstance(self.credentials, AccessToken) or isinstance(self.credentials, LongLivedClientCredentialsOAuth2Credentials): + if isinstance(self.credentials, AccessToken) or isinstance( + self.credentials, LongLivedClientCredentialsOAuth2Credentials + ): return ("Bearer", self.credentials.access_token) elif isinstance(self.credentials, BasicAuth): return ( @@ -163,7 +173,11 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str ).decode(), ) - assert isinstance(self.credentials, BaseOAuth2Credentials) or isinstance(self.credentials, ClientCredentialsOAuth2Credentials) + assert ( + isinstance(self.credentials, BaseOAuth2Credentials) + or isinstance(self.credentials, ClientCredentialsOAuth2Credentials) + or isinstance(self.credentials, AuthorizationCodeFlowOAuth2Credentials) + ) current_time = time.time() if self._access_token is not None: @@ -184,7 +198,12 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str return ("Bearer", self._access_token.access_token) async def _fetch_oauth2_token( - self, log: Logger, session: HTTPSession, credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials + self, + log: Logger, + session: HTTPSession, + credentials: BaseOAuth2Credentials + | ClientCredentialsOAuth2Credentials + | AuthorizationCodeFlowOAuth2Credentials, ) -> AccessTokenResponse: assert self.oauth_spec @@ -204,10 +223,17 @@ async def _fetch_oauth2_token( "grant_type": "client_credentials", } headers = { - "Authorization": "Basic " + base64.b64encode( + "Authorization": "Basic " + + base64.b64encode( f"{credentials.client_id}:{credentials.client_secret}".encode() ).decode() } + case AuthorizationCodeFlowOAuth2Credentials(): + form = { + "grant_type": "authorization_code", + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + } case _: raise TypeError(f"Unsupported credentials type: {type(credentials)}.") @@ -266,7 +292,6 @@ def error_ratio(self) -> float: # HTTPMixin is an opinionated implementation of HTTPSession. class HTTPMixin(Mixin, HTTPSession): - inner: aiohttp.ClientSession rate_limiter: RateLimiter token_source: TokenSource | None = None @@ -292,13 +317,17 @@ async def _request_stream( headers: dict[str, Any] = {}, ) -> AsyncGenerator[bytes, None]: while True: - cur_delay = self.rate_limiter.delay await asyncio.sleep(cur_delay) if _with_token and self.token_source is not None: token_type, token = await self.token_source.fetch_token(log, self) - header_value = f"{token_type} {token}" if self.token_source.authorization_header == DEFAULT_AUTHORIZATION_HEADER else f"{token}" + header_value = ( + f"{token_type} {token}" + if self.token_source.authorization_header + == DEFAULT_AUTHORIZATION_HEADER + else f"{token}" + ) headers[self.token_source.authorization_header] = header_value async with self.inner.request( @@ -309,7 +338,6 @@ async def _request_stream( params=params, url=url, ) as resp: - self.rate_limiter.update(cur_delay, resp.status == 429) if resp.status == 429: @@ -333,7 +361,7 @@ async def _request_stream( body = await resp.read() raise HTTPError( f"Encountered HTTP error status {resp.status} which cannot be retried.\nURL: {url}\nResponse:\n{body.decode('utf-8')}", - resp.status + resp.status, ) else: resp.raise_for_status()