diff --git a/CHANGELOG.md b/CHANGELOG.md index 786822d..0519910 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## ?.?.? + +* `okdata.aws.status.sdk.Status` now accepts an additional optional + parameter, `sdk_config`, which allows the underlying Status SDK to + be configured. + ## 2.1.0 - 2024-02-15 * New utility function `okdata.aws.ssm.get_secret` for retrieving secure strings diff --git a/okdata/aws/status/sdk.py b/okdata/aws/status/sdk.py index 36bccfc..3c60dbb 100644 --- a/okdata/aws/status/sdk.py +++ b/okdata/aws/status/sdk.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from typing import Dict, Union +from okdata.sdk.config import Config from okdata.sdk.status import Status as StatusSDK from requests.exceptions import HTTPError, RetryError @@ -12,7 +13,11 @@ class Status: - def __init__(self, status_data: Union[StatusData, Dict, str]): + def __init__( + self, + status_data: Union[StatusData, Dict, str], + sdk_config: Config = None, + ): if isinstance(status_data, str): # TODO: Remove in future - for backwards-compatibility: # Status-class used directly in state-machine-event only(?), @@ -22,7 +27,7 @@ def __init__(self, status_data: Union[StatusData, Dict, str]): status_data = StatusData.parse_obj(status_data) self.status_data = status_data - self._sdk = StatusSDK() + self._sdk = StatusSDK(sdk_config) def _process_payload(self): if self.status_data.trace_id is None: diff --git a/tests/test_status.py b/tests/test_status.py index a893043..f0de346 100644 --- a/tests/test_status.py +++ b/tests/test_status.py @@ -1,10 +1,13 @@ import re -import pytest from copy import deepcopy +from unittest.mock import patch + +import pytest from freezegun import freeze_time +from okdata.sdk.config import Config -from okdata.aws.status.sdk import Status from okdata.aws.status.model import StatusData, TraceStatus, TraceEventStatus +from okdata.aws.status.sdk import Status from okdata.aws.status.wrapper import _status_from_lambda_context @@ -66,6 +69,14 @@ def test_status_data_from_object(self): s = Status(status_data) assert s.status_data.trace_id == trace_id + def test_status_with_sdk_config(self): + config = Config(config={"foo": "bar"}) + # Mock out the `Authenticate` class so that the SDK doesn't try to set + # up authentication with a malformed config. + with patch("okdata.sdk.sdk.Authenticate"): + s = Status(trace_id, config) + assert s._sdk.config.config == {"foo": "bar"} + def test_status_data_from_lambda(self): lambda_context = MockLambdaContext() s = _status_from_lambda_context(