Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Retrieve extra user info from claims from zgw_auth_backend #62

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion src/dowc/accounts/authentication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
import logging
from typing import Dict

from django.contrib.auth import get_user_model
from django.utils.translation import ugettext_lazy as _

from requests.models import Request
from rest_framework import exceptions
from rest_framework.authentication import TokenAuthentication as _TokenAuthentication
from rest_framework.authentication import (
TokenAuthentication as _TokenAuthentication,
get_authorization_header,
)
from zgw_auth_backend.authentication import ZGWAuthentication as _ZGWAuthentication
from zgw_auth_backend.zgw import ZGWAuth

logger = logging.getLogger(__name__)


class ApplicationTokenAuthentication(_TokenAuthentication):
Expand All @@ -16,3 +28,72 @@ def authenticate_credentials(self, key):
raise exceptions.AuthenticationFailed(_("Invalid token."))

return (None, token)


class ZGWAuthentication(_ZGWAuthentication):
"""
Taken from zgw_auth_backend and adapted to further suit our needs.
We want to include first and last names and check every authentication
if an update is needed to reflect changes done to their first
and last name.

"""

def authenticate(self, request: Request):
auth = get_authorization_header(request).split()

if not auth or auth[0].lower() != b"bearer":
return None

if len(auth) == 1:
msg = _("Invalid bearer header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
msg = _(
"Invalid bearer header. Credentials string should not contain spaces."
)
raise exceptions.AuthenticationFailed(msg)

auth = ZGWAuth(auth[1].decode("utf-8"))

user_id = auth.payload.get("user_id")
if not user_id:
msg = _("Invalid 'user_id' claim. The 'user_id' should not be empty.")
raise exceptions.AuthenticationFailed(msg)

email = auth.payload.get("email", "")
return self.authenticate_user_id(user_id, email, auth.payload)

def authenticate_user_id(self, username: str, email: str, payload: Dict):
UserModel = get_user_model()
fields = {UserModel.USERNAME_FIELD: username}
user, created = UserModel._default_manager.get_or_create(**fields)
if created:
msg = "Created user object for username %s" % username
logger.info(msg)

if email:
email_field = UserModel.get_email_field_name()
email_value = getattr(user, email_field)
if not email_value or email_value != email:
setattr(user, email_field, email)
user.save()
msg = "Set email to %s of user with username %s" % (email, username)
logger.info(msg)

extra_user_info_fields = ["first_name", "last_name"]
data = {
field: value
for field, value in payload.items()
if field in extra_user_info_fields
}
for field, value in data.items():
if not getattr(user, field) == value:
setattr(user, field, value)
try:
user.save(update_fields=[field])
except ValueError:
logger.error(exc_info=True)
continue

return (user, None)
1 change: 1 addition & 0 deletions src/dowc/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def get_magic_url(self, obj) -> str:

if obj.purpose in [DocFileTypes.read, DocFileTypes.write]:
fn, fext = os.path.splitext(obj.document.name)
print("wut?")
if scheme_name := EXTENSION_HANDLER.get(fext, ""):
command_argument = {
DocFileTypes.read: ":ofv|u|",
Expand Down
76 changes: 75 additions & 1 deletion src/dowc/api/tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
"""
Test that authorization is required for the API endpoints.

Test that authorization creates or gets the user.
"""
import uuid
from unittest.mock import patch

from rest_framework import status
from rest_framework.reverse import reverse
from rest_framework.reverse import reverse, reverse_lazy
from rest_framework.test import APITestCase
from zds_client import ClientAuth
from zgw_auth_backend.models import ApplicationCredentials
from zgw_consumers.api_models.base import factory
from zgw_consumers.api_models.documenten import Document
from zgw_consumers.constants import APITypes
from zgw_consumers.models import Service
from zgw_consumers.test import generate_oas_component, mock_service_oas_get

from dowc.accounts.models import User
from dowc.accounts.tests.factories import UserFactory
from dowc.core.constants import DocFileTypes


class AuthTests(APITestCase):
Expand Down Expand Up @@ -40,3 +54,63 @@ def test_invalid_token(self):
with self.subTest(method=method, path=path):
response = getattr(self.client, method)(path, **headers)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_create_user_during_authentication(self):
drc_url = "https://some.drc.nl/api/v1/"
Service.objects.create(api_type=APITypes.drc, api_root=drc_url)
list_url = reverse_lazy("documentfile-list")

# Create mock url for drc object
_uuid = str(uuid.uuid4())
doc_url = f"{drc_url}enkelvoudiginformatieobjecten/{_uuid}"

# No users exist
self.assertEqual(User.objects.count(), 0)
data = {
"drc_url": doc_url,
"purpose": DocFileTypes.read,
"info_url": "http://www.some-referer-url.com/",
"user_id": "some-user",
}
ApplicationCredentials.objects.create(client_id="dummy", secret="secret")
auth = ClientAuth("dummy", "secret", user_id="some-user").credentials()

response = self.client.post(
list_url, data, HTTP_AUTHORIZATION=auth["Authorization"]
)

self.assertEqual(User.objects.get().username, "some-user")

def test_update_user_during_authentication(self):
drc_url = "https://some.drc.nl/api/v1/"
Service.objects.create(api_type=APITypes.drc, api_root=drc_url)
list_url = reverse_lazy("documentfile-list")

# Create mock url for drc object
_uuid = str(uuid.uuid4())
doc_url = f"{drc_url}enkelvoudiginformatieobjecten/{_uuid}"

# User exists
user = UserFactory.create(
username="some-user", first_name="First", last_name="Last"
)
self.assertEqual(User.objects.count(), 1)
data = {
"drc_url": doc_url,
"purpose": DocFileTypes.read,
"info_url": "http://www.some-referer-url.com/",
}
ApplicationCredentials.objects.create(client_id="dummy", secret="secret")
auth = ClientAuth(
"dummy",
"secret",
user_id="some-user",
first_name="some other first",
last_name="some other last",
).credentials()

response = self.client.post(
list_url, data, HTTP_AUTHORIZATION=auth["Authorization"]
)

self.assertEqual(User.objects.get().first_name, "some other first")
2 changes: 1 addition & 1 deletion src/dowc/conf/includes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": (
"rest_framework.authentication.TokenAuthentication",
"zgw_auth_backend.authentication.ZGWAuthentication",
"dowc.accounts.authentication.ZGWAuthentication",
),
"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",),
"DEFAULT_FILTER_BACKENDS": [
Expand Down
Loading