Skip to content

Commit

Permalink
estuary-cdk: support authorization code grant type
Browse files Browse the repository at this point in the history
Add an additional method of authentication using the authorization code flow grant type in the request to get access tokens. This is required by Monday.com. So this update supports the `source-monday` connector OAuth2 authentication once that is enabled.
  • Loading branch information
= authored and JustinASmith committed Feb 24, 2025
1 parent 6d738c8 commit bb1b4ad
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 18 deletions.
1 change: 1 addition & 0 deletions estuary-cdk/estuary_cdk/capture/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CaptureBinding,
ClientCredentialsOAuth2Credentials,
ClientCredentialsOAuth2Spec,
AuthorizationCodeFlowOAuth2Credentials,
LongLivedClientCredentialsOAuth2Credentials,
OAuth2Spec,
ValidationError,
Expand Down
37 changes: 37 additions & 0 deletions estuary-cdk/estuary_cdk/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,43 @@ class ClientCredentialsOAuth2Credentials(abc.ABC, BaseModel):
)


class AuthorizationCodeFlowOAuth2Credentials(abc.ABC, BaseModel):
credentials_title: Literal["OAuth Credentials"] = Field(
default="OAuth Credentials", json_schema_extra={"type": "string"}
)
client_id: str = Field(
title="Client Id",
json_schema_extra={"secret": True},
)
client_secret: str = Field(
title="Client Secret",
json_schema_extra={"secret": True},
)

@abc.abstractmethod
def _you_must_build_oauth2_credentials_for_a_provider(self): ...

@staticmethod
def for_provider(
provider: str,
) -> type["AuthorizationCodeFlowOAuth2Credentials"]:
"""
Builds an OAuth2Credentials model for the given OAuth2 `provider`.
This routine is only available in Pydantic V2 environments.
"""
from pydantic import ConfigDict

class _OAuth2Credentials(AuthorizationCodeFlowOAuth2Credentials):
model_config = ConfigDict(
json_schema_extra={"x-oauth2-provider": provider},
title="OAuth",
)

def _you_must_build_oauth2_credentials_for_a_provider(self): ...

return _OAuth2Credentials


class LongLivedClientCredentialsOAuth2Credentials(abc.ABC, BaseModel):
credentials_title: Literal["OAuth Credentials"] = Field(
default="OAuth Credentials",
Expand Down
64 changes: 46 additions & 18 deletions estuary-cdk/estuary_cdk/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AccessToken,
BasicAuth,
BaseOAuth2Credentials,
AuthorizationCodeFlowOAuth2Credentials,
ClientCredentialsOAuth2Credentials,
ClientCredentialsOAuth2Spec,
LongLivedClientCredentialsOAuth2Credentials,
Expand All @@ -24,16 +25,19 @@

StreamedObject = TypeVar("StreamedObject", bound=BaseModel)


class HTTPError(RuntimeError):
"""
HTTPError is an custom error class that provides the HTTP status code
HTTPError is an custom error class that provides the HTTP status code
as a distinct attribute.
"""

def __init__(self, message: str, code: int):
super().__init__(message)
self.code = code
self.message = message


class HTTPSession(abc.ABC):
"""
HTTPSession is an abstract base class for an HTTP client implementation.
Expand All @@ -60,7 +64,7 @@ async def request(
json: dict[str, Any] | None = None,
form: dict[str, Any] | None = None,
_with_token: bool = True, # Unstable internal API.
headers: dict[str, Any] = {}
headers: dict[str, Any] = {},
) -> bytes:
"""Request a url and return its body as bytes"""

Expand Down Expand Up @@ -102,7 +106,7 @@ async def request_lines(
yield buffer

return

async def request_stream(
self,
log: Logger,
Expand All @@ -114,9 +118,7 @@ async def request_stream(
) -> AsyncGenerator[bytes, None]:
"""Request a url and and return the raw response as a stream of bytes"""

return self._request_stream(
log, url, method, params, json, form, True
)
return self._request_stream(log, url, method, params, json, form, True)

@abc.abstractmethod
def _request_stream(
Expand All @@ -128,7 +130,7 @@ def _request_stream(
json: dict[str, Any] | None,
form: dict[str, Any] | None,
_with_token: bool,
headers: dict[str, Any] = {}
headers: dict[str, Any] = {},
) -> AsyncGenerator[bytes, None]: ...

# TODO(johnny): This is an unstable API.
Expand All @@ -138,7 +140,6 @@ def _request_stream(

@dataclass
class TokenSource:

class AccessTokenResponse(BaseModel):
access_token: str
token_type: str
Expand All @@ -147,13 +148,22 @@ class AccessTokenResponse(BaseModel):
scope: str = ""

oauth_spec: OAuth2Spec | ClientCredentialsOAuth2Spec | None
credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials | LongLivedClientCredentialsOAuth2Credentials | AccessToken | BasicAuth
credentials: (
BaseOAuth2Credentials
| ClientCredentialsOAuth2Credentials
| AuthorizationCodeFlowOAuth2Credentials
| LongLivedClientCredentialsOAuth2Credentials
| AccessToken
| BasicAuth
)
authorization_header: str = DEFAULT_AUTHORIZATION_HEADER
_access_token: AccessTokenResponse | None = None
_fetched_at: int = 0

async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str]:
if isinstance(self.credentials, AccessToken) or isinstance(self.credentials, LongLivedClientCredentialsOAuth2Credentials):
if isinstance(self.credentials, AccessToken) or isinstance(
self.credentials, LongLivedClientCredentialsOAuth2Credentials
):
return ("Bearer", self.credentials.access_token)
elif isinstance(self.credentials, BasicAuth):
return (
Expand All @@ -163,7 +173,11 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str
).decode(),
)

assert isinstance(self.credentials, BaseOAuth2Credentials) or isinstance(self.credentials, ClientCredentialsOAuth2Credentials)
assert (
isinstance(self.credentials, BaseOAuth2Credentials)
or isinstance(self.credentials, ClientCredentialsOAuth2Credentials)
or isinstance(self.credentials, AuthorizationCodeFlowOAuth2Credentials)
)
current_time = time.time()

if self._access_token is not None:
Expand All @@ -184,7 +198,12 @@ async def fetch_token(self, log: Logger, session: HTTPSession) -> tuple[str, str
return ("Bearer", self._access_token.access_token)

async def _fetch_oauth2_token(
self, log: Logger, session: HTTPSession, credentials: BaseOAuth2Credentials | ClientCredentialsOAuth2Credentials
self,
log: Logger,
session: HTTPSession,
credentials: BaseOAuth2Credentials
| ClientCredentialsOAuth2Credentials
| AuthorizationCodeFlowOAuth2Credentials,
) -> AccessTokenResponse:
assert self.oauth_spec

Expand All @@ -204,10 +223,17 @@ async def _fetch_oauth2_token(
"grant_type": "client_credentials",
}
headers = {
"Authorization": "Basic " + base64.b64encode(
"Authorization": "Basic "
+ base64.b64encode(
f"{credentials.client_id}:{credentials.client_secret}".encode()
).decode()
}
case AuthorizationCodeFlowOAuth2Credentials():
form = {
"grant_type": "authorization_code",
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
}
case _:
raise TypeError(f"Unsupported credentials type: {type(credentials)}.")

Expand Down Expand Up @@ -266,7 +292,6 @@ def error_ratio(self) -> float:

# HTTPMixin is an opinionated implementation of HTTPSession.
class HTTPMixin(Mixin, HTTPSession):

inner: aiohttp.ClientSession
rate_limiter: RateLimiter
token_source: TokenSource | None = None
Expand All @@ -292,13 +317,17 @@ async def _request_stream(
headers: dict[str, Any] = {},
) -> AsyncGenerator[bytes, None]:
while True:

cur_delay = self.rate_limiter.delay
await asyncio.sleep(cur_delay)

if _with_token and self.token_source is not None:
token_type, token = await self.token_source.fetch_token(log, self)
header_value = f"{token_type} {token}" if self.token_source.authorization_header == DEFAULT_AUTHORIZATION_HEADER else f"{token}"
header_value = (
f"{token_type} {token}"
if self.token_source.authorization_header
== DEFAULT_AUTHORIZATION_HEADER
else f"{token}"
)
headers[self.token_source.authorization_header] = header_value

async with self.inner.request(
Expand All @@ -309,7 +338,6 @@ async def _request_stream(
params=params,
url=url,
) as resp:

self.rate_limiter.update(cur_delay, resp.status == 429)

if resp.status == 429:
Expand All @@ -333,7 +361,7 @@ async def _request_stream(
body = await resp.read()
raise HTTPError(
f"Encountered HTTP error status {resp.status} which cannot be retried.\nURL: {url}\nResponse:\n{body.decode('utf-8')}",
resp.status
resp.status,
)
else:
resp.raise_for_status()
Expand Down

0 comments on commit bb1b4ad

Please sign in to comment.