Skip to content

Commit

Permalink
[DOP-21268] - update User model
Browse files Browse the repository at this point in the history
  • Loading branch information
maxim-lixakov committed Nov 12, 2024
1 parent 416b40c commit cb15595
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 18 deletions.
9 changes: 8 additions & 1 deletion syncmaster/backend/providers/auth/dummy_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ async def get_token_password_grant(
try:
user = await self._uow.user.read_by_username(login)
except EntityNotFoundError:
user = await self._uow.user.create(username=login, is_active=True)
user = await self._uow.user.create(
username=login,
email=f"{login}@example.com",
first_name=f"{login}_first",
middle_name=f"{login}_middle",
last_name=f"{login}_last",
is_active=True,
)

log.info("User with id %r found", user.id)
if not user.is_active:
Expand Down
49 changes: 33 additions & 16 deletions syncmaster/backend/providers/auth/keycloak_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,8 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
refresh_token = request.session.get("refresh_token")

if not access_token:
state = generate_state(request.url.path) # initial url request
auth_url = self.keycloak_openid.auth_url(
redirect_uri=self.settings.redirect_uri,
scope=self.settings.scope,
state=state,
)
raise RedirectException(redirect_url=auth_url)
log.debug("No access token found in session.")
self.redirect_to_auth(request.url.path)

try:
token_info = self.keycloak_openid.decode_token(token=access_token)
Expand All @@ -93,20 +88,29 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:

if not token_info and refresh_token:
log.debug("Access token invalid. Attempting to refresh.")
new_tokens = await self.refresh_access_token(refresh_token)

new_access_token = new_tokens.get("access_token")
new_refresh_token = new_tokens.get("refresh_token")
request.session["access_token"] = new_access_token
request.session["refresh_token"] = new_refresh_token
try:
new_tokens = await self.refresh_access_token(refresh_token)

token_info = self.keycloak_openid.decode_token(
token=new_access_token,
)
log.debug("Access token refreshed and decoded successfully.")
new_access_token = new_tokens.get("access_token")
new_refresh_token = new_tokens.get("refresh_token")
request.session["access_token"] = new_access_token
request.session["refresh_token"] = new_refresh_token

token_info = self.keycloak_openid.decode_token(
token=new_access_token,
)
log.debug("Access token refreshed and decoded successfully.")
except Exception as e:
log.debug("Failed to refresh access token: %s", e)
self.redirect_to_auth(request.url.path)

user_id = token_info.get("sub")
login = token_info.get("preferred_username")
email = token_info.get("email")
first_name = token_info.get("given_name")
middle_name = token_info.get("middle_name")
last_name = token_info.get("family_name")

if not user_id:
raise AuthorizationError("Invalid token payload")
Expand All @@ -117,6 +121,10 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
except EntityNotFoundError:
user = await self._uow.user.create(
username=login,
email=email,
first_name=first_name,
middle_name=middle_name,
last_name=last_name,
is_active=True,
)
return user
Expand All @@ -127,3 +135,12 @@ async def refresh_access_token(self, refresh_token: str) -> dict[str, Any]:

async def get_user_info(self, access_token: str) -> dict[str, Any]:
return self.keycloak_openid.userinfo(access_token)

def redirect_to_auth(self, path: str) -> None:
state = generate_state(path)
auth_url = self.keycloak_openid.auth_url(
redirect_uri=self.settings.redirect_uri,
scope=self.settings.scope,
state=state,
)
raise RedirectException(redirect_url=auth_url)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def upgrade():
"user",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("username", sa.String(length=256), nullable=False),
sa.Column("email", sa.String(length=256), nullable=False),
sa.Column("first_name", sa.String(length=256)),
sa.Column("last_name", sa.String(length=256)),
sa.Column("middle_name", sa.String(length=256)),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False),
Expand Down
4 changes: 4 additions & 0 deletions syncmaster/db/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
class User(Base, TimestampMixin, DeletableMixin):
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
username: Mapped[str] = mapped_column(String(256), nullable=False, unique=True, index=True)
email: Mapped[str] = mapped_column(String(256), nullable=False, unique=True)
first_name: Mapped[str] = mapped_column(String(256))
last_name: Mapped[str] = mapped_column(String(256))
middle_name: Mapped[str] = mapped_column(String(256))
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)

Expand Down
15 changes: 14 additions & 1 deletion syncmaster/db/repositories/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,24 @@ async def update(self, user_id: int, data: dict) -> User:
except IntegrityError as e:
self._raise_error(e)

async def create(self, username: str, is_active: bool, is_superuser: bool = False) -> User:
async def create(
self,
username: str,
email: str,
is_active: bool,
first_name: str | None,
middle_name: str | None,
last_name: str | None,
is_superuser: bool = False,
) -> User:
query = (
insert(User)
.values(
username=username,
email=email,
first_name=first_name,
middle_name=middle_name,
last_name=last_name,
is_active=is_active,
is_superuser=is_superuser,
)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,21 @@ async def create_user_cm(
is_active: bool = False,
is_superuser: bool = False,
is_deleted: bool = False,
email: str = None,
first_name: str = None,
middle_name: str = None,
last_name: str = None,
) -> AsyncGenerator[User, None]:
email = email or f"{username}@user.user"
first_name = first_name or f"{username}_first"
middle_name = middle_name or f"{username}_middle"
last_name = last_name or f"{username}_last"
u = User(
username=username,
email=email,
first_name=first_name,
middle_name=middle_name,
last_name=last_name,
is_active=is_active,
is_superuser=is_superuser,
is_deleted=is_deleted,
Expand All @@ -54,9 +66,21 @@ async def create_user(
is_active: bool = False,
is_superuser: bool = False,
is_deleted: bool = False,
email: str = None,
first_name: str = None,
middle_name: str = None,
last_name: str = None,
) -> User:
email = email or f"{username}@user.user"
first_name = first_name or f"{username}_first"
middle_name = middle_name or f"{username}_middle"
last_name = last_name or f"{username}_last"
u = User(
username=username,
email=email,
first_name=first_name,
middle_name=middle_name,
last_name=last_name,
is_active=is_active,
is_superuser=is_superuser,
is_deleted=is_deleted,
Expand Down

0 comments on commit cb15595

Please sign in to comment.