Skip to content

Commit

Permalink
[DOP-16676] Use Postgres advisory locks to avoid insert conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Jul 19, 2024
1 parent ccc8db5 commit 17b7893
Show file tree
Hide file tree
Showing 13 changed files with 346 additions and 179 deletions.
12 changes: 6 additions & 6 deletions data_rentgen/consumer/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_unit_of_work(session: AsyncSession = Depends(Stub(AsyncSession))) -> Uni
return UnitOfWork(session)


@router.subscriber("input.runs")
@router.subscriber("input.runs", group_id="data-rentgen")
async def runs_handler(event: OpenLineageRunEvent, unit_of_work: UnitOfWork = Depends(get_unit_of_work)):
if event.job.facets.jobType and event.job.facets.jobType.jobType == OpenLineageJobType.JOB:
await handle_operation(event, unit_of_work)
Expand All @@ -66,7 +66,7 @@ async def handle_run(event: OpenLineageRunEvent, unit_of_work: UnitOfWork) -> No

async with unit_of_work:
raw_job = extract_job(event.job)
job = await get_or_create_job(raw_job, unit_of_work)
job = await create_or_update_job(raw_job, unit_of_work)

async with unit_of_work:
raw_user = extract_run_user(event)
Expand Down Expand Up @@ -149,13 +149,13 @@ async def get_or_create_parent_run(event: OpenLineageRunEvent, unit_of_work: Uni
return None

raw_parent_run = extract_parent_run(event.run.facets.parent)
parent_job = await get_or_create_job(raw_parent_run.job, unit_of_work)
parent_job = await create_or_update_job(raw_parent_run.job, unit_of_work)
return await unit_of_work.run.get_or_create_minimal(raw_parent_run, parent_job.id)


async def get_or_create_job(job: JobDTO, unit_of_work: UnitOfWork) -> Job:
async def create_or_update_job(job: JobDTO, unit_of_work: UnitOfWork) -> Job:
matching_location = await unit_of_work.location.get_or_create(job.location)
return await unit_of_work.job.get_or_create(job, matching_location.id)
return await unit_of_work.job.create_or_update(job, matching_location.id)


async def get_or_create_user(user: UserDTO, unit_of_work: UnitOfWork) -> User:
Expand Down Expand Up @@ -190,7 +190,7 @@ async def get_or_create_dataset_symlink(
symlink_type: DatasetSymlinkTypeDTO,
unit_of_work: UnitOfWork,
) -> DatasetSymlink:
return await unit_of_work.dataset.create_or_update_symlink(
return await unit_of_work.dataset_symlink.create_or_update(
from_dataset.id,
to_dataset.id,
DatasetSymlinkType(symlink_type),
Expand Down
1 change: 1 addition & 0 deletions data_rentgen/db/models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@ class Job(Base):
ChoiceType(JobType, impl=String(32)),
index=True,
nullable=False,
default=JobType.UNKNOWN,
doc="Job type, e.g. AIRFLOW_DAG, AIRFLOW_TASK, SPARK_APPLICATION",
)
30 changes: 20 additions & 10 deletions data_rentgen/db/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

from abc import ABC
from typing import Generic, Tuple, TypeVar
from hashlib import sha1
from typing import Any, Generic, Tuple, TypeVar

from sqlalchemy import ScalarResult, Select, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import ColumnElement, SQLColumnExpression
from sqlalchemy.sql import SQLColumnExpression

from data_rentgen.db.models import Base
from data_rentgen.dto import PaginationDTO
Expand All @@ -31,14 +32,6 @@ def model_type(cls) -> type[Model]:
# Get `User` from `UserRepository(Repository[User])`
return cls.__orig_bases__[0].__args__[0] # type: ignore[attr-defined]

async def _get(
self,
*where: ColumnElement,
) -> Model | None:
model_type = self.model_type()
query: Select = select(model_type).where(*where)
return await self._session.scalar(query)

async def _paginate_by_query(
self,
order_by: list[SQLColumnExpression],
Expand Down Expand Up @@ -71,3 +64,20 @@ async def _count(

result = await self._session.scalars(query)
return result.one()

async def _lock(
self,
*keys: Any,
) -> None:
"""
Take a lock on a specific table and set of keys, to avoid inserting the same row multiple times.
Based on [pg_advisory_xact_lock](https://www.postgresql.org/docs/current/functions-admin.html).
Lock is held until the transaction is committed or rolled back.
"""
model_type = self.model_type()
data = ".".join(map(str, [model_type.__table__, *keys]))
digest = sha1(data.encode("utf-8"), usedforsecurity=False).digest()
# sha1 returns 160bit hash, we need only first 64 bits
lock_key = int.from_bytes(digest[:8], byteorder="big", signed=True)
statement = select(func.pg_advisory_xact_lock(lock_key))
await self._session.execute(statement)
55 changes: 24 additions & 31 deletions data_rentgen/db/repositories/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,23 @@
from sqlalchemy.orm import selectinload

from data_rentgen.db.models import Dataset, Location
from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType
from data_rentgen.db.repositories.base import Repository
from data_rentgen.dto import DatasetDTO, PaginationDTO


class DatasetRepository(Repository[Dataset]):
async def create_or_update(self, dataset: DatasetDTO, location_id: int) -> Dataset:
statement = select(Dataset).where(
Dataset.location_id == location_id,
Dataset.name == dataset.name,
)
result = await self._session.scalar(statement)
if not result:
result = Dataset(location_id=location_id, name=dataset.name, format=dataset.format)
self._session.add(result)
elif dataset.format:
result.format = dataset.format
result = await self._get(location_id, dataset.name)

await self._session.flush([result])
return result

async def create_or_update_symlink(
self,
from_dataset_id: int,
to_dataset_id: int,
symlink_type: DatasetSymlinkType,
) -> DatasetSymlink:
statement = select(DatasetSymlink).where(
DatasetSymlink.from_dataset_id == from_dataset_id,
DatasetSymlink.to_dataset_id == to_dataset_id,
)
result = await self._session.scalar(statement)
if not result:
result = DatasetSymlink(from_dataset_id=from_dataset_id, to_dataset_id=to_dataset_id, type=symlink_type)
self._session.add(result)
else:
result.type = symlink_type
# try one more time, but with lock acquired.
# if another worker already created the same row, just use it. if not - create with holding the lock.
await self._lock(location_id, dataset.name)
result = await self._get(location_id, dataset.name)

await self._session.flush([result])
return result
if not result:
return await self._create(dataset, location_id)
return await self._update(result, dataset)

async def paginate(self, page: int, page_size: int, dataset_id: list[int]) -> PaginationDTO[Dataset]:
query = (
Expand All @@ -53,3 +30,19 @@ async def paginate(self, page: int, page_size: int, dataset_id: list[int]) -> Pa
.options(selectinload(Dataset.location).selectinload(Location.addresses))
)
return await self._paginate_by_query(order_by=[Dataset.id], page=page, page_size=page_size, query=query)

async def _get(self, location_id: int, name: str) -> Dataset | None:
statement = select(Dataset).where(Dataset.location_id == location_id, Dataset.name == name)
return await self._session.scalar(statement)

async def _create(self, dataset: DatasetDTO, location_id: int) -> Dataset:
result = Dataset(location_id=location_id, name=dataset.name, format=dataset.format)
self._session.add(result)
await self._session.flush([result])
return result

async def _update(self, existing: Dataset, new: DatasetDTO) -> Dataset:
if new.format:
existing.format = new.format
await self._session.flush([existing])
return existing
49 changes: 49 additions & 0 deletions data_rentgen/db/repositories/dataset_symlink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-FileCopyrightText: 2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0

from sqlalchemy import select

from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType
from data_rentgen.db.repositories.base import Repository


class DatasetSymlinkRepository(Repository[DatasetSymlink]):
async def create_or_update(
self,
from_dataset_id: int,
to_dataset_id: int,
symlink_type: DatasetSymlinkType,
) -> DatasetSymlink:
result = await self._get(from_dataset_id, to_dataset_id)
if not result:
# try one more time, but with lock acquired.
# if another worker already created the same row, just use it. if not - create with holding the lock.
await self._lock(from_dataset_id, to_dataset_id)
result = await self._get(from_dataset_id, to_dataset_id)

if not result:
return await self._create(from_dataset_id, to_dataset_id, symlink_type)
return await self._update(result, symlink_type)

async def _get(self, from_dataset_id: int, to_dataset_id: int) -> DatasetSymlink | None:
query = select(DatasetSymlink).where(
DatasetSymlink.from_dataset_id == from_dataset_id,
DatasetSymlink.to_dataset_id == to_dataset_id,
)
return await self._session.scalar(query)

async def _create(
self,
from_dataset_id: int,
to_dataset_id: int,
symlink_type: DatasetSymlinkType,
) -> DatasetSymlink:
result = DatasetSymlink(from_dataset_id=from_dataset_id, to_dataset_id=to_dataset_id, type=symlink_type)
self._session.add(result)
await self._session.flush([result])
return result

async def _update(self, existing: DatasetSymlink, new_type: DatasetSymlinkType) -> DatasetSymlink:
existing.type = new_type
await self._session.flush([existing])
return existing
71 changes: 47 additions & 24 deletions data_rentgen/db/repositories/interaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0

from datetime import datetime

from sqlalchemy import select
from uuid6 import UUID

Expand All @@ -25,35 +27,56 @@ async def create_or_update(
# to avoid scanning all partitions and speed up insert queries
created_at = extract_timestamp_from_uuid(operation_id)

# instead of using UniqueConstraint on multiple fields, one of which (schema_id) can be NULL,
# use them to calculate unique id
id_components = f"{operation_id}.{dataset_id}.{interaction.type}.{schema_id}"
interaction_id = generate_incremental_uuid(created_at, id_components.encode("utf-8"))

query = select(Interaction).where(
Interaction.created_at == created_at,
Interaction.id == interaction_id,
)
result = await self._session.scalar(query)
result = await self._get(created_at, interaction_id)
if not result:
# try one more time, but with lock acquired.
# if another worker already created the same row, just use it. if not - create with holding the lock.
await self._lock(interaction_id)
result = await self._get(created_at, interaction_id)

if not result:
result = Interaction(
created_at=created_at,
id=interaction_id,
operation_id=operation_id,
dataset_id=dataset_id,
type=InteractionType(interaction.type),
schema_id=schema_id,
num_bytes=interaction.num_bytes,
num_rows=interaction.num_rows,
num_files=interaction.num_files,
)
self._session.add(result)
else:
if interaction.num_bytes is not None:
result.num_bytes = interaction.num_bytes
if interaction.num_rows is not None:
result.num_rows = interaction.num_rows
if interaction.num_files is not None:
result.num_files = interaction.num_files
return await self._create(created_at, interaction_id, interaction, operation_id, dataset_id, schema_id)
return await self._update(result, interaction)

async def _get(self, created_at: datetime, interaction_id: UUID) -> Interaction | None:
query = select(Interaction).where(Interaction.created_at == created_at, Interaction.id == interaction_id)
return await self._session.scalar(query)

async def _create(
self,
created_at: datetime,
interaction_id: UUID,
interaction: InteractionDTO,
operation_id: UUID,
dataset_id: int,
schema_id: int | None = None,
) -> Interaction:
result = Interaction(
created_at=created_at,
id=interaction_id,
operation_id=operation_id,
dataset_id=dataset_id,
type=InteractionType(interaction.type),
schema_id=schema_id,
num_bytes=interaction.num_bytes,
num_rows=interaction.num_rows,
num_files=interaction.num_files,
)
self._session.add(result)
await self._session.flush([result])
return result

async def _update(self, existing: Interaction, new: InteractionDTO) -> Interaction:
if new.num_bytes is not None:
existing.num_bytes = new.num_bytes
if new.num_rows is not None:
existing.num_rows = new.num_rows
if new.num_files is not None:
existing.num_files = new.num_files
await self._session.flush([existing])
return existing
39 changes: 28 additions & 11 deletions data_rentgen/db/repositories/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,34 @@ async def paginate(self, page: int, page_size: int, job_id: list[int]) -> Pagina
)
return await self._paginate_by_query(order_by=[Job.id], page=page, page_size=page_size, query=query)

async def get_or_create(self, job: JobDTO, location_id: int) -> Job:
statement = select(Job).where(Job.location_id == location_id, Job.name == job.name)
result = await self._session.scalar(statement)
async def create_or_update(self, job: JobDTO, location_id: int) -> Job:
result = await self._get(location_id, job.name)
if not result:
result = Job(
location_id=location_id,
name=job.name,
type=JobType(job.type) if job.type else JobType.UNKNOWN,
)
self._session.add(result)
elif job.type:
result.type = JobType(job.type)
# try one more time, but with lock acquired.
# if another worker already created the same row, just use it. if not - create with holding the lock.
await self._lock(location_id, job.name)
result = await self._get(location_id, job.name)

if not result:
return await self._create(job, location_id)
return await self._update(result, job)

async def _get(self, location_id: int, name: str) -> Job | None:
statement = select(Job).where(Job.location_id == location_id, Job.name == name)
return await self._session.scalar(statement)

async def _create(self, job: JobDTO, location_id: int) -> Job:
result = Job(
location_id=location_id,
name=job.name,
type=JobType(job.type) if job.type else JobType.UNKNOWN,
)
self._session.add(result)
await self._session.flush([result])
return result

async def _update(self, existing: Job, new: JobDTO) -> Job:
if new.type:
existing.type = JobType(new.type)
await self._session.flush([existing])
return existing
Loading

0 comments on commit 17b7893

Please sign in to comment.