diff --git a/fence/blueprints/login/base.py b/fence/blueprints/login/base.py index 7cee07fbe..28f99459d 100644 --- a/fence/blueprints/login/base.py +++ b/fence/blueprints/login/base.py @@ -104,11 +104,9 @@ def __init__( self.app = app - # This block of code probably need to be made more concise - if "persist_refresh_token" in config["OPENID_CONNECT"].get(self.idp_name, {}): - self.persist_refresh_token = config["OPENID_CONNECT"][self.idp_name][ - "persist_refresh_token" - ] + self.persist_refresh_token = ( + config["OPENID_CONNECT"].get(self.idp_name, {}).get("persist_refresh_token") + ) if "is_authz_groups_sync_enabled" in config["OPENID_CONNECT"].get( self.idp_name, {} @@ -163,6 +161,7 @@ def get(self): # default to now + REFRESH_TOKEN_EXPIRES_IN if expires is None: expires = int(time.time()) + config["REFRESH_TOKEN_EXPIRES_IN"] + logger.info(self, f"Refresh token not in JWT, using default: {expires}") # Store refresh token in db should_persist_token = ( diff --git a/fence/config-default.yaml b/fence/config-default.yaml index f25cf6f2b..16c39fc30 100755 --- a/fence/config-default.yaml +++ b/fence/config-default.yaml @@ -45,7 +45,7 @@ ENCRYPTION_KEY: '' # ////////////////////////////////////////////////////////////////////////////////////// # flask's debug setting # WARNING: DO NOT ENABLE IN PRODUCTION (for testing purposes only) -DEBUG: true +DEBUG: false # if true, will automatically login a user with username "test" # WARNING: DO NOT ENABLE IN PRODUCTION (for testing purposes only) MOCK_AUTH: false @@ -127,6 +127,10 @@ OPENID_CONNECT: # or removed from relevant groups in the local system to ensure their group memberships # remain up-to-date. If this flag is disabled, no group synchronization occurs is_authz_groups_sync_enabled: true + # Key used to retrieve group information from the token + group_claim_field: "groups" + # IdP group membership expiration (seconds). + group_membership_expiration_duration: 604800 authz_groups_sync: # This defines the prefix used to identify authorization groups. group_prefix: "some_prefix" diff --git a/fence/config.py b/fence/config.py index 775296025..9d331c4d5 100644 --- a/fence/config.py +++ b/fence/config.py @@ -139,6 +139,10 @@ def post_process(self): ) for idp_id, idp in self._configs.get("OPENID_CONNECT", {}).items(): + if not isinstance(idp, dict): + raise TypeError( + "Expected 'OPENID_CONNECT' configuration to be a dictionary." + ) mfa_info = idp.get("multifactor_auth_claim_info") if mfa_info and mfa_info["claim"] not in ["amr", "acr"]: logger.warning( diff --git a/fence/error_handler.py b/fence/error_handler.py index 6ac6f99dc..a8b6a5a0e 100644 --- a/fence/error_handler.py +++ b/fence/error_handler.py @@ -15,36 +15,44 @@ def get_error_response(error: Exception): + """ + Generates a response for the given error with detailed logs and appropriate status codes. + + Args: + error (Exception): The error that occurred. + + Returns: + Tuple (str, int): Rendered error HTML and HTTP status code. + """ details, status_code = get_error_details_and_status(error) support_email = config.get("SUPPORT_EMAIL_FOR_ERRORS") app_name = config.get("APP_NAME", "Gen3 Data Commons") - message = details.get("message") - error_id = _get_error_identifier() logger.error( - "{} HTTP error occured. ID: {}\nDetails: {}".format( - status_code, error_id, str(details) + "{} HTTP error occurred. ID: {}\nDetails: {}\nTraceback: {}".format( + status_code, error_id, details, traceback.format_exc() ) ) - # TODO: Issue: Error messages are obfuscated, the line below needs be - # uncommented when troubleshooting errors. - # Breaks tests if not commented out / removed. We need a fix for this. - # raise error + # Decide whether to re-raise errors or handle gracefully based on the debug flag + debug_mode = config.get("DEBUG", False) + + if debug_mode: + # Re-raise the error in debug mode for troubleshooting + raise error - # don't include internal details in the public error message - # to do this, only include error messages for known http status codes - # that are less that 500 + # Prepare user-facing message + message = details.get("message") valid_http_status_codes = [ int(code) for code in list(http_responses.keys()) if int(code) < 500 ] + try: status_code = int(status_code) if status_code not in valid_http_status_codes: message = None except (ValueError, TypeError): - # this handles case where status_code is NOT a valid integer (e.g. HTTP status code) message = None status_code = 500 @@ -65,6 +73,15 @@ def get_error_response(error: Exception): def get_error_details_and_status(error): + """ + Extracts details and HTTP status code from the given error. + + Args: + error (Exception): The error to process. + + Returns: + Tuple (dict, int): Error details as a dictionary and HTTP status code. + """ message = error.message if hasattr(error, "message") else str(error) if isinstance(error, APIError): if hasattr(error, "json") and error.json: @@ -76,11 +93,11 @@ def get_error_details_and_status(error): error_response = {"message": error.description}, error.status_code elif isinstance(error, HTTPException): error_response = ( - {"message": getattr(error, "description")}, + {"message": getattr(error, "description", str(error))}, error.get_response().status_code, ) else: - logger.exception("Catch exception") + logger.exception("Unexpected exception occurred") error_code = 500 if hasattr(error, "code"): error_code = error.code @@ -92,4 +109,10 @@ def get_error_details_and_status(error): def _get_error_identifier(): + """ + Generates a unique identifier for tracking the error. + + Returns: + UUID: A unique identifier for the error. + """ return uuid.uuid4() diff --git a/fence/job/access_token_updater.py b/fence/job/access_token_updater.py index 6909357b4..8c1c15b1c 100644 --- a/fence/job/access_token_updater.py +++ b/fence/job/access_token_updater.py @@ -14,7 +14,7 @@ logger = get_logger(__name__, log_level="debug") -class AccessTokenUpdater(object): +class TokenAndAuthUpdater(object): def __init__( self, chunk_size=None, @@ -51,14 +51,8 @@ def __init__( self.oidc_clients_requiring_token_refresh = {} # keep this as a special case, because RAS will not set group information configuration. - # Initialize visa clients: oidc = config.get("OPENID_CONNECT", {}) - if not isinstance(oidc, dict): - raise TypeError( - "Expected 'OPENID_CONNECT' configuration to be a dictionary." - ) - if "ras" not in oidc: self.logger.error("RAS client not configured") else: @@ -96,7 +90,6 @@ async def update_tokens(self, db_session): """ start_time = time.time() - # Change this line to reflect we are refreshing tokens, not just visas self.logger.info("Initializing Visa Update and Token refreshing Cronjob . . .") self.logger.info("Total concurrency size: {}".format(self.concurrency)) self.logger.info("Total thread pool size: {}".format(self.thread_pool_size)) diff --git a/fence/resources/openid/idp_oauth2.py b/fence/resources/openid/idp_oauth2.py index 92181d027..598af742c 100644 --- a/fence/resources/openid/idp_oauth2.py +++ b/fence/resources/openid/idp_oauth2.py @@ -1,3 +1,5 @@ +from email.policy import default + from authlib.integrations.requests_client import OAuth2Session from boto3 import client from cached_property import cached_property @@ -94,38 +96,6 @@ def get_jwt_keys(self, jwks_uri): return None return resp.json()["keys"] - def get_raw_token_claims(self, token_id): - """Extracts unvalidated claims from a JWT (JSON Web Token). - - This function decodes a JWT and extracts claims without verifying - the token's signature or audience. It is intended for cases where - access to the raw, unvalidated token claims is sufficient. - - Args: - token_id (str): The JWT token from which to extract claims. - - Returns: - dict: A dictionary of token claims if decoding is successful. - - Raises: - JWTError: If there is an error decoding the token without validation. - - Notes: - This function does not perform any validation of the token. It should - only be used in contexts where validation is not critical or is handled - elsewhere in the application. - """ - try: - # Decode without verification - unvalidated_claims = jwt.decode( - token_id, options={"verify_signature": False} - ) - self.logger.info("Raw token claims extracted successfully.") - return unvalidated_claims - except JWTError as e: - self.logger.error(f"Error extracting claims: {e}") - raise JWTError("Unable to decode the token without validation.") - def decode_and_validate_token(self, token_id, keys, audience, verify_aud=True): """Decodes and validates a JWT (JSON Web Token) using provided keys and audience. @@ -279,7 +249,8 @@ def get_auth_info(self, code): if self.read_authz_groups_from_tokens: try: - groups = claims.get("groups") + group_claim_field = self.settings.get("group_claim_field", "groups") + groups = claims.get(group_claim_field) group_prefix = self.settings.get("authz_groups_sync", {}).get( "group_prefix", "" ) @@ -315,15 +286,15 @@ def get_access_token(self, user, token_endpoint, db_session=None): """ Get access_token using a refresh_token and store new refresh in upstream_refresh_token table. """ - # this function is not correct. use self.session.fetch_access_token, - # validate the token for audience and then return the validated token. - # Still store the refresh token. it will be needed for periodic re-fetching of information. refresh_token = None expires = None - # get refresh_token and expiration from db + + # Get the refresh_token and expiration from the database for row in sorted(user.upstream_refresh_tokens, key=lambda row: row.expires): refresh_token = row.refresh_token expires = row.expires + + # Check if the token is expired if time.time() > expires: # reset to check for next token refresh_token = None @@ -336,21 +307,29 @@ def get_access_token(self, user, token_endpoint, db_session=None): if not refresh_token: raise AuthError("User doesn't have a valid, non-expired refresh token") - token_response = self.session.refresh_token( - url=token_endpoint, - proxies=self.get_proxies(), - refresh_token=refresh_token, - ) - refresh_token = token_response["refresh_token"] + try: + token_response = self.session.refresh_token( + url=token_endpoint, + proxies=self.get_proxies(), + refresh_token=refresh_token, + ) - self.store_refresh_token( - user, - refresh_token=refresh_token, - expires=expires, - db_session=db_session, - ) + refresh_token = token_response["refresh_token"] + # Fetching the expires at from token_response. + # Defaulting to 1 hour if not available. + expires_at = token_response.get("expires_at", time.time() + 3600) + + self.store_refresh_token( + user, + refresh_token=refresh_token, + expires=expires_at, + db_session=db_session, + ) - return token_response + return token_response + except Exception as e: + self.logger.exception(f"Error refreshing token for user {user.id}: {e}") + raise AuthError("Failed to refresh access token.") def has_mfa_claim(self, decoded_id_token): """ @@ -405,8 +384,24 @@ def store_refresh_token(self, user, refresh_token, expires, db_session=None): db_session.commit() def get_groups_from_token(self, decoded_id_token, group_prefix=""): - """Retrieve and format groups from the decoded token.""" - authz_groups_from_idp = decoded_id_token.get("groups", []) + """ + Retrieve and format groups from the decoded token based on a configurable field name. + + Args: + decoded_id_token (dict): The decoded token containing claims. + group_prefix (str): The prefix to strip from group names. + + Returns: + list: A list of formatted group names. + + Variables: + group_claim_field (str): The field name in the token that contains the group information. + authz_groups_from_idp (list): The list of groups retrieved from the token, potentially empty. + """ + # Retrieve the configured field name for groups, defaulting to 'groups' + group_claim_field = self.settings.get("group_claim_field", "groups") + authz_groups_from_idp = decoded_id_token.get(group_claim_field, []) + if authz_groups_from_idp: authz_groups_from_idp = [ group.removeprefix(group_prefix).lstrip("/") @@ -455,9 +450,6 @@ def update_user_authorization(self, user, pkey_cache, db_session=None, **kwargs) """ db_session = db_session or current_app.scoped_session() - # Initialize the failure flag for group removal - removal_failed = False - expires_at = None try: @@ -505,48 +497,26 @@ def update_user_authorization(self, user, pkey_cache, db_session=None, **kwargs) idp_group_names = set(authz_groups_from_idp) + # Expiration for group membership. Default 7 days + group_membership_duration = self.settings.get( + "group_membership_expiration_duration", 3600 * 24 * 7 + ) + group_membership_expires_at = datetime.datetime.now( + tz=datetime.timezone.utc + ) + datetime.timedelta(seconds=group_membership_duration) + # Add user to all matching groups from IDP for arborist_group in arborist_groups: if arborist_group["name"] in idp_group_names: self.logger.info( - f"Adding {user.username} to group: {arborist_group['name']}, sub: {user.id} exp: {exp}" + f"Adding {user.username} to group: {arborist_group['name']}, sub: {user.id} exp: {group_membership_expires_at}" ) self.arborist.add_user_to_group( username=user.username, group_name=arborist_group["name"], - expires_at=exp, + expires_at=group_membership_expires_at, ) - - # Remove user from groups in Arborist that they are not part of in IDP - for arborist_group in arborist_groups: - if arborist_group["name"] not in idp_group_names: - if user.username in arborist_group.get("users", []): - try: - self.remove_user_from_arborist_group( - user.username, arborist_group["name"] - ) - except Exception as e: - self.logger.error( - f"Failed to remove {user.username} from group {arborist_group['name']}: {e}" - ) - removal_failed = ( - # Set the failure flag if any removal fails - True - ) - else: self.logger.warning( f"is_authz_groups_sync_enabled feature is enabled, but did not receive groups from idp {self.idp} for user: {user.username}" ) - - # Raise an exception if any group removal failed - if removal_failed: - raise Exception("One or more group removals failed.") - - def remove_user_from_arborist_group(self, username, group_name): - """ - Attempt to remove a user from an Arborist group, catching any errors to allow - processing of remaining groups. Logs errors and re-raises them after all removals are attempted. - """ - self.logger.info(f"Removing {username} from group: {group_name}") - self.arborist.remove_user_from_group(username=username, group_name=group_name) diff --git a/fence/scripting/fence_create.py b/fence/scripting/fence_create.py index 9a94e3601..77d080a5d 100644 --- a/fence/scripting/fence_create.py +++ b/fence/scripting/fence_create.py @@ -38,7 +38,7 @@ generate_signed_refresh_token, issued_and_expiration_times, ) -from fence.job.access_token_updater import AccessTokenUpdater +from fence.job.access_token_updater import TokenAndAuthUpdater from fence.models import ( Client, GoogleServiceAccount, @@ -1821,7 +1821,7 @@ def access_token_polling_job( logger=get_logger("user_syncer.arborist_client"), ) driver = get_SQLAlchemyDriver(db) - job = AccessTokenUpdater( + job = TokenAndAuthUpdater( chunk_size=int(chunk_size) if chunk_size else None, concurrency=int(concurrency) if concurrency else None, thread_pool_size=int(thread_pool_size) if thread_pool_size else None, diff --git a/tests/dbgap_sync/test_user_sync.py b/tests/dbgap_sync/test_user_sync.py index 7cc565c4a..a95a3f5d8 100644 --- a/tests/dbgap_sync/test_user_sync.py +++ b/tests/dbgap_sync/test_user_sync.py @@ -10,7 +10,7 @@ from fence import models from fence.resources.google.access_utils import GoogleUpdateException from fence.config import config -from fence.job.access_token_updater import AccessTokenUpdater +from fence.job.access_token_updater import TokenAndAuthUpdater from fence.utils import DEFAULT_BACKOFF_SETTINGS from tests.dbgap_sync.conftest import ( @@ -490,7 +490,9 @@ def test_sync_with_google_errors(syncer, monkeypatch): syncer._update_arborist = MagicMock() syncer._update_authz_in_arborist = MagicMock() - with patch("fence.sync.sync_users.update_google_groups_for_users") as mock_bulk_update: + with patch( + "fence.sync.sync_users.update_google_groups_for_users" + ) as mock_bulk_update: mock_bulk_update.side_effect = GoogleUpdateException("Something's Wrong!") with pytest.raises(GoogleUpdateException): syncer.sync() @@ -498,21 +500,30 @@ def test_sync_with_google_errors(syncer, monkeypatch): syncer._update_arborist.assert_called() syncer._update_authz_in_arborist.assert_called() + @patch("fence.sync.sync_users.paramiko.SSHClient") @patch("os.makedirs") @patch("os.path.exists", return_value=False) @pytest.mark.parametrize("syncer", ["google", "cleversafe"], indirect=True) -def test_sync_with_sftp_connection_errors(mock_path, mock_makedir, mock_ssh_client, syncer, monkeypatch): +def test_sync_with_sftp_connection_errors( + mock_path, mock_makedir, mock_ssh_client, syncer, monkeypatch +): """ Verifies that when there is an sftp connection error connection, that the connection is retried the max amount of tries as configured by DEFAULT_BACKOFF_SETTINGS """ monkeypatch.setattr(syncer, "is_sync_from_dbgap_server", True) - mock_ssh_client.return_value.__enter__.return_value.connect.side_effect = Exception("Authentication timed out") + mock_ssh_client.return_value.__enter__.return_value.connect.side_effect = Exception( + "Authentication timed out" + ) # usersync System Exits if any exception is raised during download. with pytest.raises(SystemExit): syncer.sync() - assert mock_ssh_client.return_value.__enter__.return_value.connect.call_count == DEFAULT_BACKOFF_SETTINGS['max_tries'] + assert ( + mock_ssh_client.return_value.__enter__.return_value.connect.call_count + == DEFAULT_BACKOFF_SETTINGS["max_tries"] + ) + @pytest.mark.parametrize("syncer", ["google", "cleversafe"], indirect=True) def test_sync_from_files(syncer, db_session, storage_client): @@ -998,7 +1009,7 @@ def test_user_sync_with_visa_sync_job( # use refresh tokens from users to call access token polling "fence-create update-visa" # and sync authorization from visas - job = AccessTokenUpdater() + job = TokenAndAuthUpdater() job.pkey_cache = { "https://stsstg.nih.gov": { kid: rsa_public_key, @@ -1063,6 +1074,7 @@ def test_revoke_all_policies_no_user(db_session, syncer): # we only care that this doesn't error assert True + @pytest.mark.parametrize("syncer", ["cleversafe", "google"], indirect=True) def test_revoke_all_policies_preserve_mfa(monkeypatch, db_session, syncer): """ diff --git a/tests/job/test_access_token_updater.py b/tests/job/test_access_token_updater.py index 0ba9f6368..f22df71f9 100644 --- a/tests/job/test_access_token_updater.py +++ b/tests/job/test_access_token_updater.py @@ -4,7 +4,7 @@ from fence.models import User from fence.resources.openid.idp_oauth2 import Oauth2ClientBase as OIDCClient from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient -from fence.job.access_token_updater import AccessTokenUpdater +from fence.job.access_token_updater import TokenAndAuthUpdater @pytest.fixture(scope="session", autouse=True) @@ -59,7 +59,7 @@ def mock_oidc_clients(): @pytest.fixture def access_token_updater_config(mock_oidc_clients): - """Fixture to instantiate AccessTokenUpdater with mocked OIDC clients.""" + """Fixture to instantiate TokenAndAuthUpdater with mocked OIDC clients.""" with patch( "fence.config", { @@ -70,7 +70,7 @@ def access_token_updater_config(mock_oidc_clients): "ENABLE_AUTHZ_GROUPS_FROM_OIDC": True, }, ): - updater = AccessTokenUpdater() + updater = TokenAndAuthUpdater() # Ensure this is a dictionary rather than a list updater.oidc_clients_requiring_token_refresh = { diff --git a/tests/ras/test_ras.py b/tests/ras/test_ras.py index c1439e056..ab2bb2258 100644 --- a/tests/ras/test_ras.py +++ b/tests/ras/test_ras.py @@ -25,7 +25,7 @@ from tests.utils import add_test_ras_user, TEST_RAS_USERNAME, TEST_RAS_SUB from tests.dbgap_sync.conftest import add_visa_manually -from fence.job.access_token_updater import AccessTokenUpdater +from fence.job.access_token_updater import TokenAndAuthUpdater import tests.utils from tests.conftest import get_subjects_to_passports @@ -95,6 +95,7 @@ def test_update_visa_token( """ Test to check visa table is updated when getting new visa """ + # ensure we don't actually try to reach out to external sites to refresh public keys def validate_jwt_no_key_refresh(*args, **kwargs): kwargs.update({"attempt_refresh": False}) @@ -713,7 +714,7 @@ def _get_userinfo(*args, **kwargs): mock_userinfo.side_effect = _get_userinfo # test "fence-create update-visa" - job = AccessTokenUpdater() + job = TokenAndAuthUpdater() job.pkey_cache = { "https://stsstg.nih.gov": { kid: rsa_public_key, diff --git a/tests/test-fence-config.yaml b/tests/test-fence-config.yaml index bb055b835..6e6130918 100755 --- a/tests/test-fence-config.yaml +++ b/tests/test-fence-config.yaml @@ -44,7 +44,7 @@ ENCRYPTION_KEY: '' # ////////////////////////////////////////////////////////////////////////////////////// # flask's debug setting # WARNING: DO NOT ENABLE IN PRODUCTION -DEBUG: true +DEBUG: false # if true, will automatically login a user with username "test" MOCK_AUTH: true # if true, will only fake a successful login response from Google in /login/google @@ -160,6 +160,10 @@ OPENID_CONNECT: # or removed from relevant groups in the local system to ensure their group memberships # remain up-to-date. If this flag is disabled, no group synchronization occurs is_authz_groups_sync_enabled: false + # Key used to retrieve group information from the token + group_claim_field: "groups" + # IdP group membership expiration (seconds). + group_membership_expiration_duration: 604800 authz_groups_sync: # This defines the prefix used to identify authorization groups. group_prefix: /covid