Skip to content

Commit

Permalink
Merge pull request #2362 from opensafely-core/evansd/batch-download-n…
Browse files Browse the repository at this point in the history
…onunique

Allow `fetch_table_in_batches` to work without unique key
  • Loading branch information
evansd authored Jan 21, 2025
2 parents a453c71 + fc66bcd commit ace5072
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 49 deletions.
1 change: 1 addition & 0 deletions ehrql/query_engines/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def get_results(self, dataset):
execute_with_retry,
results_table,
key_column=results_table.c.patient_id,
key_is_unique=True,
# This value was copied from the previous cohortextractor. I suspect it
# has no real scientific basis.
batch_size=32000,
Expand Down
151 changes: 148 additions & 3 deletions ehrql/utils/sqlalchemy_exec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def fetch_table_in_batches(
execute, table, key_column, batch_size=32000, log=lambda *_: None
execute, table, key_column, key_is_unique, batch_size=32000, log=lambda *_: None
):
"""
Returns an iterator over all the rows in a table by querying it in batches
Expand All @@ -14,19 +14,42 @@ def fetch_table_in_batches(
execute: callable which accepts a SQLAlchemy query and returns results (can be
just a Connection.execute method)
table: SQLAlchemy TableClause
key_column: reference to a unique orderable column on `table`, used for
key_column: reference to an orderable column on `table`, used for
paging (note that this will need an index on it to avoid terrible
performance)
key_is_unique: if the key_column contains only unique values then we can use a
simpler and more efficient algorithm to do the paging
batch_size: how many results to fetch in each batch
log: callback to receive log messages
"""
if key_is_unique:
return fetch_table_in_batches_unique(
execute, table, key_column, batch_size, log
)
else:
return fetch_table_in_batches_nonunique(
execute, table, key_column, batch_size, log
)


def fetch_table_in_batches_unique(
execute, table, key_column, batch_size=32000, log=lambda *_: None
):
"""
Returns an iterator over all the rows in a table by querying it in batches using a
unique key column
"""
assert batch_size > 0
batch_count = 1
total_rows = 0
min_key = None

key_column_index = table.columns.values().index(key_column)

log(f"Fetching rows from '{table}' in batches of {batch_size}")
log(
f"Fetching rows from '{table}' in batches of {batch_size} using unique "
f"column '{key_column.name}'"
)
while True:
query = select(table).order_by(key_column).limit(batch_size)
if min_key is not None:
Expand All @@ -50,6 +73,128 @@ def fetch_table_in_batches(
min_key = row[key_column_index]


def fetch_table_in_batches_nonunique(
execute, table, key_column, batch_size=32000, log=lambda *_: None
):
"""
Returns an iterator over all the rows in a table by querying it in batches using a
non-unique key column
The algorithm below is designed (and tested) to work correctly without relying on
sort-stability. That is, if we repeatedly ask the database for results sorted by X
then rows with the same value for X may be returned in a different order each time.
Handling this involves some inefficiency in the form of slightly overlapping
batches. If we have a table like this:
patient_id | value
------------+-------
1 | a
1 | b
2 | c
3 | d
3 | e
4 | f
And we fetch a batch of four, ordered by `patient_id` we could get results like this
(note the different order of `values`):
patient_id | value
------------+-------
1 | b
1 | a
2 | c
3 | e
We can be sure we've got all the results for patients 1 and 2, but we don't know how
many more results there might be for patient 3. And we can't say "give me results
for patients 3 and above but skip the first one" because we can't guarantee that the
row we've already got (with value `e`) is going to be the first row in the next set
of results.
So instead, we have to ask for the next batch with `patient_id` greater than 2:
patient_id | value
------------+-------
3 | d
3 | e
4 | f
Now, because we get to rows for patient 4, we can be sure we've got all the rows for
patient 3. And, because we requested a batch of size four and only got three
results, we know we've reached the end of the table, and therefore that we've got
all the rows for patient 4 as well.
But in order to do this we ended up fetching some rows for patient 3 twice. In
general, the degree of inefficiency here will depend on the number of repeated keys
you get at the end of batch relative to the batch size. Given that we use batch
sizes in at least the tens of thousands the maximum number of rows per patient is
likely to be so far below this that the inefficiency will be negligable.
There is also an edge case where if the maximum number of rows per patients equals
or exceeds the batch size then the algorithm can make no progress at all. Again,
given the likely sizes involved this seems very unlikely but we add a check to raise
an explicit error if this ever happens.
"""
assert batch_size > 1
batch_count = 1
total_rows = 0
current_key = None
last_fully_fetched_key = None
accumulated_rows = []

key_column_index = table.columns.values().index(key_column)

log(
f"Fetching rows from '{table}' in batches of {batch_size} using non-unique "
f"column '{key_column.name}'"
)
while True:
query = select(table).order_by(key_column).limit(batch_size)
if last_fully_fetched_key is not None:
query = query.where(key_column > last_fully_fetched_key)

log(f"Fetching batch {batch_count}")
results = execute(query)

# We iterate over the results for the batch, accumulating rows in a list
row_count = 0
for row in results:
row_count += 1
next_key = row[key_column_index]
# Whenever the value of the key changes we know we've now got a complete set
# of rows with the _previous_ key, so we emit those rows, empty the
# accumulator, and mark the new value of the key as the current one
if next_key != current_key:
yield from accumulated_rows
accumulated_rows.clear()
last_fully_fetched_key = current_key
current_key = next_key
accumulated_rows.append(row)

# The total number of rows we've emitted is the number we've read minus any
# still left in the accumulator
total_rows += row_count - len(accumulated_rows)
batch_count += 1

if row_count < batch_size:
# If we got fewer rows than we asked for then we've reach the end of the
# table: emit any remaining rows, log, and exit
yield from accumulated_rows
total_rows += len(accumulated_rows)
log(f"Fetch complete, total rows: {total_rows}")
break
elif row_count == len(accumulated_rows):
# If we didn't emit _any_ rows then we must have a group of rows with the
# same key that is equal to, or larger than, the batch size. We cannot
# handle this situation so we throw an error. (Given the sizes involved it
# seems unlikely we could hit this in production.)
raise AssertionError("`batch_size` too small to make progress")
else:
# Otherwise we empty the accumulator and fetch another batch
accumulated_rows.clear()


def execute_with_retry_factory(
connection, max_retries=0, retry_sleep=0, backoff_factor=1, log=lambda *_: None
):
Expand Down
41 changes: 36 additions & 5 deletions tests/integration/utils/test_sqlalchemy_exec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,58 @@

class SomeTable(Base):
__tablename__ = "some_table"
pk = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
pk = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=False)
key = sqlalchemy.Column(sqlalchemy.Integer)
foo = sqlalchemy.Column(sqlalchemy.String)


def test_fetch_table_in_batches(engine):
def test_fetch_table_in_batches_unique(engine):
if engine.name == "in_memory":
pytest.skip("SQL tests do not apply to in-memory engine")

table_size = 15
batch_size = 6

table_data = [(i, f"foo{i}") for i in range(table_size)]
table_data = [(i, i, f"foo{i}") for i in range(table_size)]

engine.setup([SomeTable(pk=row[0], foo=row[1]) for row in table_data])
engine.setup([SomeTable(pk=row[0], key=row[1], foo=row[2]) for row in table_data])

table = SomeTable.__table__

with engine.sqlalchemy_engine().connect() as connection:
results = fetch_table_in_batches(
connection.execute, table, table.c.pk, batch_size=batch_size
connection.execute,
table,
table.c.key,
key_is_unique=True,
batch_size=batch_size,
)
results = list(results)

assert results == table_data


def test_fetch_table_in_batches_nonunique(engine):
if engine.name == "in_memory":
pytest.skip("SQL tests do not apply to in-memory engine")

batch_size = 6
repeats = [1, 2, 3, 4, 5, 0, 5, 4, 3, 2, 1]
keys = [key for key, n in enumerate(repeats) for _ in range(n)]
table_data = [(i, key, f"foo{i}") for i, key in enumerate(keys)]

engine.setup([SomeTable(pk=row[0], key=row[1], foo=row[2]) for row in table_data])

table = SomeTable.__table__

with engine.sqlalchemy_engine().connect() as connection:
results = fetch_table_in_batches(
connection.execute,
table,
table.c.key,
key_is_unique=False,
batch_size=batch_size,
)
results = sorted(results)

assert results == table_data
Loading

0 comments on commit ace5072

Please sign in to comment.