diff --git a/.changes/unreleased/Features-20241217-181340.yaml b/.changes/unreleased/Features-20241217-181340.yaml new file mode 100644 index 000000000..a2c1c523f --- /dev/null +++ b/.changes/unreleased/Features-20241217-181340.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add IdpTokenAuthPlugin authentication method. +time: 2024-12-17T18:13:40.281494-08:00 +custom: + Author: versusfacit + Issue: "898" diff --git a/dbt/adapters/redshift/auth_providers.py b/dbt/adapters/redshift/auth_providers.py new file mode 100644 index 000000000..bd4a7a309 --- /dev/null +++ b/dbt/adapters/redshift/auth_providers.py @@ -0,0 +1,87 @@ +import requests +from abc import ABC, abstractmethod +from enum import Enum +from typing import Dict, Any + +from dbt.adapters.exceptions import FailedToConnectError +from dbt_common.exceptions import DbtRuntimeError + + +# Define an Enum for the supported token endpoint types +class TokenServiceBase(ABC): + def __init__(self, token_endpoint: Dict[str, Any]): + expected_keys = {"type", "request_url", "request_data"} + for key in expected_keys: + if key not in token_endpoint: + raise FailedToConnectError(f"Missing required key in token_endpoint: '{key}'") + + self.type: str = token_endpoint["type"] + self.url: str = token_endpoint["request_url"] + self.data: str = token_endpoint["request_data"] + + self.other_params = {k: v for k, v in token_endpoint.items() if k not in expected_keys} + + @abstractmethod + def build_header_payload(self) -> Dict[str, Any]: + pass + + def handle_request(self) -> requests.Response: + """ + Handles the request with rate limiting and error handling. + """ + response = requests.post(self.url, headers=self.build_header_payload(), data=self.data) + + if response.status_code == 429: + raise DbtRuntimeError( + "Rate limit on identity provider's token dispatch has been reached. " + "Consider increasing your identity provider's refresh token rate or " + "lower dbt's maximum concurrent thread count." + ) + + response.raise_for_status() + return response + + +class OktaIdpTokenService(TokenServiceBase): + def build_header_payload(self) -> Dict[str, Any]: + if encoded_idp_client_creds := self.other_params.get("idp_auth_credentials"): + return { + "accept": "application/json", + "authorization": f"Basic {encoded_idp_client_creds}", + "content-type": "application/x-www-form-urlencoded", + } + else: + raise FailedToConnectError( + "Missing 'idp_auth_credentials' from token_endpoint. Please provide client_id:client_secret in base64 encoded format as a profile entry under token_endpoint." + ) + + +class EntraIdpTokenService(TokenServiceBase): + """ + formatted based on docs: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow#refresh-the-access-token + """ + + def build_header_payload(self) -> Dict[str, Any]: + return { + "accept": "application/json", + "content-type": "application/x-www-form-urlencoded", + } + + +class TokenServiceType(Enum): + OKTA = "okta" + ENTRA = "entra" + + +def create_token_service_client(token_endpoint: Dict[str, Any]) -> TokenServiceBase: + if (service_type := token_endpoint.get("type")) is None: + raise FailedToConnectError("Missing required key in token_endpoint: 'type'") + + if service_type == TokenServiceType.OKTA.value: + return OktaIdpTokenService(token_endpoint) + elif service_type == TokenServiceType.ENTRA.value: + return EntraIdpTokenService(token_endpoint) + else: + raise ValueError( + f"Unsupported identity provider type: {service_type}. Select 'okta' or 'entra.'" + ) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index a23563c72..9632be77b 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -1,17 +1,19 @@ import re +import redshift_connector +import sqlparse + from multiprocessing import Lock from contextlib import contextmanager from typing import Any, Callable, Dict, Tuple, Union, Optional, List, TYPE_CHECKING from dataclasses import dataclass, field -import sqlparse -import redshift_connector from dbt.adapters.exceptions import FailedToConnectError from redshift_connector.utils.oids import get_datatype_name from dbt.adapters.sql import SQLConnectionManager from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.redshift.auth_providers import create_token_service_client from dbt_common.contracts.util import Replaceable from dbt_common.dataclass_schema import dbtClassMixin, StrEnum, ValidationError from dbt_common.helper_types import Port @@ -37,20 +39,16 @@ def get_message(self) -> str: logger = AdapterLogger("Redshift") -class IdentityCenterTokenType(StrEnum): - ACCESS_TOKEN = "ACCESS_TOKEN" - EXT_JWT = "EXT_JWT" - - class RedshiftConnectionMethod(StrEnum): DATABASE = "database" IAM = "iam" IAM_ROLE = "iam_role" IAM_IDENTITY_CENTER_BROWSER = "browser_identity_center" + IAM_IDENTITY_CENTER_TOKEN = "oauth_token_identity_center" @classmethod def uses_identity_center(cls, method: str) -> bool: - return method in (cls.IAM_IDENTITY_CENTER_BROWSER,) + return method in (cls.IAM_IDENTITY_CENTER_BROWSER, cls.IAM_IDENTITY_CENTER_TOKEN) @classmethod def is_iam(cls, method: str) -> bool: @@ -153,6 +151,12 @@ class RedshiftCredentials(Credentials): idc_client_display_name: Optional[str] = "Amazon Redshift driver" idp_response_timeout: Optional[int] = None + # token_endpoint + # a field that we expect to be a dictionary of values used to create + # access tokens from an external identity provider integrated with a redshift + # and aws org or account Iam Idc instance + token_endpoint: Optional[Dict[str, str]] = None + _ALIASES = {"dbname": "database", "pass": "password"} @property @@ -323,6 +327,34 @@ def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]: return __iam_kwargs(credentials) | idc_kwargs + def __iam_idc_token_kwargs(credentials) -> Dict[str, Any]: + """ + Accepts a `credentials` object with a `token_endpoint` field that corresponds to + either Okta or Entra authentication services. + + We only support token_type=EXT_JWT tokens. token_type=ACCESS_TOKEN has not been + tested. It can be added with a presenting use-case. + """ + + logger.debug("Connecting to Redshift with '{credentials.method}' credentials method") + + __validate_required_fields("oauth_token_identity_center", ("token_endpoint",)) + + token_service = create_token_service_client(credentials.token_endpoint) + response = token_service.handle_request() + try: + access_token = response.json()["access_token"] + except KeyError: + raise FailedToConnectError( + "access_token missing from Idp token request. Please confirm correct configuration of the token_endpoint field in profiles.yml and that your Idp can use a refresh token to obtain an OIDC-compliant access token." + ) + + return __iam_kwargs(credentials) | { + "credentials_provider": "IdpTokenAuthPlugin", + "token": access_token, + "token_type": "EXT_JWT", + } + # # Head of function execution # @@ -333,6 +365,7 @@ def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]: RedshiftConnectionMethod.IAM: __iam_user_kwargs, RedshiftConnectionMethod.IAM_ROLE: __iam_role_kwargs, RedshiftConnectionMethod.IAM_IDENTITY_CENTER_BROWSER: __iam_idc_browser_kwargs, + RedshiftConnectionMethod.IAM_IDENTITY_CENTER_TOKEN: __iam_idc_token_kwargs, } try: diff --git a/hatch.toml b/hatch.toml index 3a3990a6c..44bae3b31 100644 --- a/hatch.toml +++ b/hatch.toml @@ -11,18 +11,19 @@ packages = ["dbt"] dependencies = [ "dbt-adapters @ git+https://github.com/dbt-labs/dbt-adapters.git", "dbt-common @ git+https://github.com/dbt-labs/dbt-common.git", - "dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter", "dbt-core @ git+https://github.com/dbt-labs/dbt-core.git#subdirectory=core", + "dbt-tests-adapter @ git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter", "ddtrace==2.3.0", + "freezegun", "ipdb~=0.13.13", "pre-commit==3.7.0", - "freezegun", - "pytest>=7.0,<8.0", "pytest-csv~=3.0", "pytest-dotenv", "pytest-logbook~=1.2", "pytest-mock", "pytest-xdist", + "pytest>=7.0,<8.0", + "requests", ] [envs.default.scripts] diff --git a/pyproject.toml b/pyproject.toml index e68aa2607..4399915e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ # installed via dbt-core but referenced directly; don't pin to avoid version conflicts with dbt-core "sqlparse>=0.5.0,<0.6.0", "agate", + "requests", ] [project.urls] diff --git a/tests/functional/test_auth_method.py b/tests/functional/test_auth_method.py index b2273e02c..493bc5af6 100644 --- a/tests/functional/test_auth_method.py +++ b/tests/functional/test_auth_method.py @@ -101,3 +101,25 @@ def dbt_profile_target(self): "host": "", # host is a required field in dbt-core "port": 0, # port is a required field in dbt-core } + + +@pytest.mark.skip( + reason="We need to cut over to new adapters team AWS account which has infra to support this as an automated test. This will include a GHA step that renders a refresh token and loading secrets into Github secrets for the <> delimited placeholder values below" +) +class TestIamIdcAuthProfileOktaIdp(AuthMethod): + @pytest.fixture(scope="class") + def dbt_profile_target(self): + return { + "type": "redshift", + "method": "oauth_token_identity_center", + "host": os.getenv("REDSHIFT_TEST_HOST"), + "port": 5439, + "dbname": "dev", + "threads": 1, + "token_endpoint": { + "type": "okta", + "request_url": "https://.oktapreview.com/oauth2/default/v1/token", + "idp_auth_credentials": "", + "request_data": "grant_type=refresh_token&redirect_uri=&refresh_token=", + }, + } diff --git a/tests/unit/test_auth_method.py b/tests/unit/test_auth_method.py index 16d13268f..46412e9d2 100644 --- a/tests/unit/test_auth_method.py +++ b/tests/unit/test_auth_method.py @@ -1,3 +1,5 @@ +import requests + from multiprocessing import get_context from unittest import TestCase, mock from unittest.mock import MagicMock @@ -673,3 +675,113 @@ def test_invalid_adapter_missing_fields(self): "'idc_region', 'issuer_url' field(s) are required for 'browser_identity_center' credentials method" in context.exception.msg ) + + +class TestIAMIdcToken(AuthMethod): + @mock.patch("redshift_connector.connect", MagicMock()) + def test_profile_idc_token_all_required_fields_okta(self): + """This test doesn't follow the idiom elsewhere in this file because we + a real test would need a valid refresh token which would require a valid + authorization request, neither of which are possible in automated testing at + merge. This is a surrogate test. + """ + self.config.credentials = self.config.credentials.replace( + method="oauth_token_identity_center", + token_endpoint={ + "type": "okta", + "request_url": "https://dbtcs.oktapreview.com/oauth2/default/v1/token", + "idp_auth_credentials": "my_auth_creds", + "request_data": "grant_type=refresh_token&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Flogin%2Foauth2%2Fcode%2Fokta&refresh_token=my_token", + }, + ) + with self.assertRaises(requests.exceptions.HTTPError) as context: + """ + An http says we've made it in operation to call the token request which fails + due to invalid refresh token and auth creds + """ + connection = self.adapter.acquire_connection("dummy") + connection.handle + + assert "401 Client Error: Unauthorized for url" in str(context.exception) + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_profile_idc_token_all_required_fields_entra(self): + """This test doesn't follow the idiom elsewhere in this file because we + a real test would need a valid refresh token which would require a valid + authorization request, neither of which are possible in automated testing at + merge. This is a surrogate test. + """ + self.config.credentials = self.config.credentials.replace( + method="oauth_token_identity_center", + token_endpoint={ + "type": "entra", + "request_url": "https://login.microsoftonline.com/my_tenant/oauth2/v2.0/token", + "request_data": "my_data", + }, + ) + with self.assertRaises(requests.exceptions.HTTPError) as context: + """ + An http says we've made it in operation to call the token request which fails + due to invalid refresh token and auth creds + """ + connection = self.adapter.acquire_connection("dummy") + connection.handle + + assert "400 Client Error: Bad Request for url" in str(context.exception) + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_invalid_idc_token_missing_field(self): + # Successful test + self.config.credentials = self.config.credentials.replace( + method="oauth_token_identity_center", + ) + with self.assertRaises(FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + assert ( + "'token_endpoint' field(s) are required for 'oauth_token_identity_center' credentials method" + in context.exception.msg + ) + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_invalid_idc_token_missing_token_endpoint_subfield_okta(self): + # Successful test + self.config.credentials = self.config.credentials.replace( + method="oauth_token_identity_center", + token_endpoint={ + "type": "okta", + "request_data": "my_data", + "idp_auth_credentials": "my_auth_creds", + }, + ) + with self.assertRaises(FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + assert "Missing required key in token_endpoint: 'request_url'" in context.exception.msg + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_invalid_idc_token_missing_token_endpoint_subfield_entra(self): + # Successful test + self.config.credentials = self.config.credentials.replace( + method="oauth_token_identity_center", + token_endpoint={ + "type": "entra", + "request_url": "https://dbtcs.oktapreview.com/oauth2/default/v1/token", + }, + ) + with self.assertRaises(FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + assert "Missing required key in token_endpoint: 'request_data'" in context.exception.msg + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_invalid_idc_token_missing_token_endpoint_type(self): + # Successful test + self.config.credentials = self.config.credentials.replace( + method="oauth_token_identity_center", + token_endpoint={}, + ) + with self.assertRaises(FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + assert "Missing required key in token_endpoint: 'type'" in context.exception.msg