From 11b0b8efc79898b7ea386116c2a37aa66eff02a4 Mon Sep 17 00:00:00 2001 From: "Gavin M. Roy" Date: Sat, 8 Feb 2025 16:14:14 -0800 Subject: [PATCH] **Refactor and modernize codebase for improved clarity and functionality** Refactored several modules and test cases to use type annotations, f-strings, and other Python 3 modernizations. Consolidated imports for better maintainability and readability. Enhanced context management and error handling, while replacing redundant or obsolete patterns like `list()` comprehensions and legacy formatting. --- pgdumplib/constants.py | 25 +- pgdumplib/converters.py | 23 +- pgdumplib/dump.py | 519 ++++++++++++++++++--------------------- pgdumplib/exceptions.py | 9 +- pgdumplib/models.py | 59 +++++ tests/__init__.py | 2 +- tests/test_converters.py | 21 +- tests/test_dump.py | 31 ++- tests/test_edge_cases.py | 13 +- tests/test_exceptions.py | 9 +- tests/test_save_dump.py | 72 +++--- 11 files changed, 377 insertions(+), 406 deletions(-) create mode 100644 pgdumplib/models.py diff --git a/pgdumplib/constants.py b/pgdumplib/constants.py index 5a95adc..17de5a1 100644 --- a/pgdumplib/constants.py +++ b/pgdumplib/constants.py @@ -4,7 +4,6 @@ unless you are hacking on the library itself. """ -import typing K_VERSION_MAP = { ((9, 0, 0), (10, 2)): (1, 12, 0), @@ -24,14 +23,7 @@ FORMAT_NULL: int = 4 FORMAT_DIRECTORY: int = 5 -FORMATS: typing.List[str] = [ - 'Unknown', - 'Custom', - 'Files', - 'Tar', - 'Null', - 'Directory' -] +FORMATS: list[str] = ['Unknown', 'Custom', 'Files', 'Tar', 'Null', 'Directory'] K_OFFSET_POS_NOT_SET: int = 1 """Specifies the entry has data but no offset""" @@ -42,10 +34,10 @@ MAGIC: bytes = b'PGDMP' -MIN_VER: typing.Tuple[int, int, int] = (1, 12, 0) +MIN_VER: tuple[int, int, int] = (1, 12, 0) """The minumum supported version of pg_dump files ot support""" -MAX_VER: typing.Tuple[int, int, int] = (1, 14, 0) +MAX_VER: tuple[int, int, int] = (1, 14, 0) """The maximum supported version of pg_dump files ot support""" PGDUMP_STRFTIME_FMT: str = '%Y-%m-%d %H:%M:%S %Z' @@ -62,14 +54,11 @@ SECTION_POST_DATA: str = 'Post-Data' """Post-data section for an entry in a dump's table of contents""" -SECTIONS: typing.List[str] = [ - SECTION_NONE, - SECTION_PRE_DATA, - SECTION_DATA, - SECTION_POST_DATA +SECTIONS: list[str] = [ + SECTION_NONE, SECTION_PRE_DATA, SECTION_DATA, SECTION_POST_DATA ] -VERSION: typing.Tuple[int, int, int] = (1, 12, 0) +VERSION: tuple[int, int, int] = (1, 12, 0) """pg_dump file format version to create by default""" ZLIB_OUT_SIZE: int = 4096 @@ -144,7 +133,7 @@ USER_MAPPING: str = 'USER MAPPING' VIEW: str = 'VIEW' -SECTION_MAPPING: typing.Dict[str, str] = { +SECTION_MAPPING: dict[str, str] = { ACCESS_METHOD: SECTION_PRE_DATA, ACL: SECTION_NONE, AGGREGATE: SECTION_PRE_DATA, diff --git a/pgdumplib/converters.py b/pgdumplib/converters.py index 61028e8..5687cc0 100644 --- a/pgdumplib/converters.py +++ b/pgdumplib/converters.py @@ -16,7 +16,6 @@ import datetime import decimal import ipaddress -import typing import uuid import pendulum @@ -33,7 +32,7 @@ class DataConverter: """ @staticmethod - def convert(row: str) -> typing.Tuple[typing.Optional[str], ...]: + def convert(row: str) -> tuple[str | None, ...]: """Convert the string based row into a tuple of columns. :param str row: The row to convert @@ -56,17 +55,9 @@ def convert(row: str) -> str: return row -SmartColumn = typing.Union[ - None, - str, - int, - datetime.datetime, - decimal.Decimal, - ipaddress.IPv4Address, - ipaddress.IPv4Network, - ipaddress.IPv6Address, - ipaddress.IPv6Network, - uuid.UUID] +SmartColumn = (None | str | int | datetime.datetime | decimal.Decimal + | ipaddress.IPv4Address | ipaddress.IPv4Network + | ipaddress.IPv6Address | ipaddress.IPv6Network | uuid.UUID) class SmartDataConverter(DataConverter): @@ -89,7 +80,7 @@ class SmartDataConverter(DataConverter): - :py:class:`uuid.UUID` """ - def convert(self, row: str) -> typing.Tuple[SmartColumn, ...]: + def convert(self, row: str) -> tuple[SmartColumn, ...]: """Convert the string based row into a tuple of columns""" return tuple(self._convert_column(c) for c in row.split('\t')) @@ -119,8 +110,8 @@ def _convert_column(column: str) -> SmartColumn: pass for tz_fmt in {'Z', 'ZZ', 'z', 'zz'}: try: - return pendulum.from_format( - column, 'YYYY-MM-DD HH:mm:ss {}'.format(tz_fmt)) + return pendulum.from_format(column, + f'YYYY-MM-DD HH:mm:ss {tz_fmt}') except ValueError: pass return column diff --git a/pgdumplib/dump.py b/pgdumplib/dump.py index 1936eb0..86d0d15 100644 --- a/pgdumplib/dump.py +++ b/pgdumplib/dump.py @@ -21,10 +21,7 @@ cleaned up when the :py:class:`~pgdumplib.dump.Dump` instance is released. """ -from __future__ import annotations - import contextlib -import dataclasses import datetime import gzip import io @@ -39,7 +36,7 @@ import toposort -from pgdumplib import constants, converters, exceptions, version +from pgdumplib import constants, converters, exceptions, models, version LOGGER = logging.getLogger(__name__) @@ -47,6 +44,80 @@ VERSION_INFO = '{} (pgdumplib {})' +Converters = (type[converters.DataConverter] | type[converters.NoOpConverter] + | type[converters.SmartDataConverter]) + + +class TableData: + """Used to encapsulate table data using temporary file and allowing + for an API that allows for the appending of data one row at a time. + + Do not create this class directly, instead invoke + :py:meth:`~pgdumplib.dump.Dump.table_data_writer`. + + """ + def __init__(self, dump_id: int, tempdir: str, encoding: str): + self.dump_id = dump_id + self._encoding = encoding + self._path = pathlib.Path(tempdir) / f'{dump_id}.gz' + self._handle = gzip.open(self._path, 'wb') + + def append(self, *args) -> None: + """Append a row to the table data, passing columns in as args + + Column order must match the order specified when + :py:meth:`~pgdumplib.dump.Dump.table_data_writer` was invoked. + + All columns will be coerced to a string with special attention + paid to ``None``, converting it to the null marker (``\\N``) and + :py:class:`datetime.datetime` objects, which will have the proper + pg_dump timestamp format applied to them. + + """ + row = '\t'.join([self._convert(c) for c in args]) + self._handle.write(f'{row}\n'.encode(self._encoding)) + + def finish(self) -> None: + """Invoked prior to saving a dump to close the temporary data + handle and switch the class into read-only mode. + + For use by :py:class:`pgdumplib.dump.Dump` only. + + """ + if not self._handle.closed: + self._handle.close() + self._handle = gzip.open(self._path, 'rb') + + def read(self) -> bytes: + """Read the data from disk for writing to the dump + + For use by :py:class:`pgdumplib.dump.Dump` only. + + """ + self._handle.seek(0) + return self._handle.read() + + @property + def size(self) -> int: + """Return the current size of the data on disk""" + self._handle.seek(0, io.SEEK_END) # Seek to end to figure out size + size = self._handle.tell() + self._handle.seek(0) + return size + + @staticmethod + def _convert(column: typing.Any) -> str: + """Convert the column to a string + + :param column: The column to convert + + """ + if isinstance(column, datetime.datetime): + return column.strftime(constants.PGDUMP_STRFTIME_FMT) + elif column is None: + return '\\N' + return str(column) + class Dump: """Create a new instance of the :py:class:`~pgdumplib.dump.Dump` class @@ -60,36 +131,37 @@ class Dump: (Default: :py:class:`pgdumplib.converters.DataConverter`) """ - def __init__( - self, dbname: str = 'pgdumplib', encoding: str = 'UTF8', - converter: typing.Optional[ - typing.Type[converters.DataConverter], - typing.Type[converters.NoOpConverter], - typing.Type[converters.SmartDataConverter]] = None, - appear_as: str = '12.0'): + def __init__(self, + dbname: str = 'pgdumplib', + encoding: str = 'UTF8', + converter: Converters | None = None, + appear_as: str = '12.0'): self.compression = False self.dbname = dbname self.dump_version = VERSION_INFO.format(appear_as, version) self.encoding = encoding self.entries = [ - Entry( - dump_id=1, tag=constants.ENCODING, desc=constants.ENCODING, - defn="SET client_encoding = '{}';\n".format(self.encoding)), - Entry( - dump_id=2, tag='STDSTRINGS', desc='STDSTRINGS', - defn="SET standard_conforming_strings = 'on';\n"), - Entry( - dump_id=3, tag='SEARCHPATH', desc='SEARCHPATH', - defn='SELECT pg_catalog.set_config(' - "'search_path', '', false);\n") + models.Entry(dump_id=1, + tag=constants.ENCODING, + desc=constants.ENCODING, + defn=f"SET client_encoding = '{self.encoding}';\n"), + models.Entry(dump_id=2, + tag='STDSTRINGS', + desc='STDSTRINGS', + defn="SET standard_conforming_strings = 'on';\n"), + models.Entry(dump_id=3, + tag='SEARCHPATH', + desc='SEARCHPATH', + defn='SELECT pg_catalog.set_config(' + "'search_path', '', false);\n") ] self.server_version = self.dump_version - self.timestamp = datetime.datetime.now() + self.timestamp = datetime.datetime.now(tz=datetime.UTC) converter = converter or converters.DataConverter self._converter: converters.DataConverter = converter() self._format: str = 'Custom' - self._handle: typing.Optional[typing.BinaryIO] = None + self._handle: typing.BinaryIO | None = None self._intsize: int = 4 self._offsize: int = 8 self._temp_dir = tempfile.TemporaryDirectory() @@ -98,25 +170,25 @@ def __init__( self._vmaj: int = k_version[0] self._vmin: int = k_version[1] self._vrev: int = k_version[2] - self._writers: typing.Dict[int, TableData] = {} + self._writers: dict[int, TableData] = {} def __repr__(self) -> str: - return ''.format( - self._format, self.timestamp.isoformat(), len(self.entries)) - - def add_entry( - self, - desc: str, - namespace: typing.Optional[str] = None, - tag: typing.Optional[str] = None, - owner: typing.Optional[str] = None, - defn: typing.Optional[str] = None, - drop_stmt: typing.Optional[str] = None, - copy_stmt: typing.Optional[str] = None, - dependencies: typing.Optional[typing.List[int]] = None, - tablespace: typing.Optional[str] = None, - tableam: typing.Optional[str] = None, - dump_id: typing.Optional[int] = None) -> Entry: + return f'' + + def add_entry(self, + desc: str, + namespace: str | None = None, + tag: str | None = None, + owner: str | None = None, + defn: str | None = None, + drop_stmt: str | None = None, + copy_stmt: str | None = None, + dependencies: list[int] | None = None, + tablespace: str | None = None, + tableam: str | None = None, + dump_id: int | None = None) -> models.Entry: """Add an entry to the dump The ``namespace`` and ``tag`` are required. @@ -160,7 +232,7 @@ def add_entry( """ if desc not in constants.SECTION_MAPPING: - raise ValueError('Invalid desc: {}'.format(desc)) + raise ValueError(f'Invalid desc: {desc}') if dump_id is not None and dump_id < 1: raise ValueError('dump_id must be greater than 1') @@ -173,21 +245,21 @@ def add_entry( for dependency in dependencies or []: if dependency not in dump_ids: raise ValueError( - 'Dependency dump_id {!r} not found'.format(dependency)) - self.entries.append(Entry( - dump_id or self._next_dump_id(), False, '', '', tag or '', desc, - defn or '', drop_stmt or '', copy_stmt or '', namespace or '', - tablespace or '', tableam or '', owner or '', False, - dependencies or [])) + f'Dependency dump_id {dependency!r} not found') + self.entries.append( + models.Entry(dump_id or self._next_dump_id(), False, '', '', tag + or '', desc, defn or '', drop_stmt or '', copy_stmt + or '', namespace or '', tablespace or '', tableam + or '', owner or '', False, dependencies or [])) return self.entries[-1] - def blobs(self) -> typing.Generator[typing.Tuple[int, bytes], None, None]: + def blobs(self) -> typing.Generator[tuple[int, bytes], None, None]: """Iterator that returns each blob in the dump :rtype: tuple(int, bytes) """ - def read_oid(fd: typing.BinaryIO) -> typing.Optional[int]: + def read_oid(fd: typing.BinaryIO) -> int | None: """Small helper function to deduplicate code""" try: return struct.unpack('I', fd.read(4))[0] @@ -197,17 +269,16 @@ def read_oid(fd: typing.BinaryIO) -> typing.Optional[int]: for entry in self._data_entries: if entry.desc == constants.BLOBS: with self._tempfile(entry.dump_id, 'rb') as handle: - oid: typing.Optional[int] = read_oid(handle) + oid: int | None = read_oid(handle) while oid: length: int = struct.unpack('I', handle.read(4))[0] yield oid, handle.read(length) oid = read_oid(handle) - def get_entry(self, dump_id: int) -> typing.Optional[Entry]: + def get_entry(self, dump_id: int) -> models.Entry | None: """Return the entry for the given `dump_id` :param int dump_id: The dump ID of the entry to return. - :rtype: pgdumplib.dump.Entry or None """ for entry in self.entries: @@ -215,7 +286,7 @@ def get_entry(self, dump_id: int) -> typing.Optional[Entry]: return entry return None - def load(self, path: os.PathLike) -> Dump: + def load(self, path: os.PathLike) -> typing.Self: """Load the Dumpfile, including extracting all data into a temporary directory @@ -225,7 +296,7 @@ def load(self, path: os.PathLike) -> Dump: """ if not pathlib.Path(path).exists(): - raise ValueError('Path {!r} does not exist'.format(path)) + raise ValueError(f'Path {path!r} does not exist') LOGGER.debug('Loading dump file from %s', path) @@ -234,8 +305,7 @@ def load(self, path: os.PathLike) -> Dump: self._read_header() if not constants.MIN_VER <= self.version <= constants.MAX_VER: raise ValueError( - 'Unsupported backup version: {}.{}.{}'.format( - *self.version)) + 'Unsupported backup version: {}.{}.{}'.format(*self.version)) self.compression = self._read_int() != 0 self.timestamp = self._read_timestamp() @@ -255,18 +325,18 @@ def load(self, path: os.PathLike) -> Dump: self._handle.seek(entry.offset, io.SEEK_SET) block_type, dump_id = self._read_block_header() if not dump_id or dump_id != entry.dump_id: - raise RuntimeError('Dump IDs do not match ({} != {}'.format( - dump_id, entry.dump_id)) + raise RuntimeError( + f'Dump IDs do not match ({dump_id} != {entry.dump_id}') if block_type == constants.BLK_DATA: self._cache_table_data(dump_id) elif block_type == constants.BLK_BLOBS: self._cache_blobs(dump_id) else: - raise RuntimeError('Unknown block type: {}'.format(block_type)) + raise RuntimeError(f'Unknown block type: {block_type}') return self def lookup_entry(self, desc: str, namespace: str, tag: str) \ - -> typing.Optional[Entry]: + -> models.Entry | None: """Return the entry for the given namespace and tag :param str desc: The desc / object type of the entry @@ -278,13 +348,13 @@ def lookup_entry(self, desc: str, namespace: str, tag: str) \ """ if desc not in constants.SECTION_MAPPING: - raise ValueError('Invalid desc: {}'.format(desc)) + raise ValueError(f'Invalid desc: {desc}') for entry in [e for e in self.entries if e.desc == desc]: if entry.namespace == namespace and entry.tag == tag: return entry return None - def save(self, path: os.PathLike) -> typing.NoReturn: + def save(self, path: os.PathLike) -> None: """Save the Dump file to the specified path :param os.PathLike path: The path to save the dump to @@ -298,8 +368,7 @@ def save(self, path: os.PathLike) -> typing.NoReturn: self._handle.close() def table_data(self, namespace: str, table: str) \ - -> typing.Generator[ - typing.Union[str, typing.Tuple[typing.Any, ...]], None, None]: + -> typing.Generator[str | tuple[typing.Any, ...], None, None]: """Iterator that returns data for the given namespace and table :param str namespace: The namespace/schema for the table @@ -315,7 +384,9 @@ def table_data(self, namespace: str, table: str) \ raise exceptions.EntityNotFoundError(namespace=namespace, table=table) @contextlib.contextmanager - def table_data_writer(self, entry: Entry, columns: typing.Sequence) \ + def table_data_writer(self, + entry: models.Entry, + columns: typing.Sequence) \ -> typing.Generator[TableData, None, None]: """A context manager that is used to return a :py:class:`~pgdumplib.dump.TableData` instance, which can be used @@ -331,21 +402,26 @@ def table_data_writer(self, entry: Entry, columns: typing.Sequence) \ """ if entry.dump_id not in self._writers.keys(): dump_id = self._next_dump_id() - self.entries.append(Entry( - dump_id=dump_id, had_dumper=True, tag=entry.tag, - desc=constants.TABLE_DATA, - copy_stmt='COPY {}.{} ({}) FROM stdin;'.format( - entry.namespace, entry.tag, ', '.join(columns)), - namespace=entry.namespace, owner=entry.owner, - dependencies=[entry.dump_id], - data_state=constants.K_OFFSET_POS_NOT_SET)) - self._writers[entry.dump_id] = TableData( - dump_id, self._temp_dir.name, self.encoding) + self.entries.append( + models.Entry(dump_id=dump_id, + had_dumper=True, + tag=entry.tag, + desc=constants.TABLE_DATA, + copy_stmt='COPY {}.{} ({}) FROM stdin;'.format( + entry.namespace, entry.tag, + ', '.join(columns)), + namespace=entry.namespace, + owner=entry.owner, + dependencies=[entry.dump_id], + data_state=constants.K_OFFSET_POS_NOT_SET)) + self._writers[entry.dump_id] = TableData(dump_id, + self._temp_dir.name, + self.encoding) yield self._writers[entry.dump_id] return None @property - def version(self) -> typing.Tuple[int, int, int]: + def version(self) -> tuple[int, int, int]: """Return the version as a tuple to make version comparisons easier. :rtype: tuple @@ -353,7 +429,7 @@ def version(self) -> typing.Tuple[int, int, int]: """ return self._vmaj, self._vmin, self._vrev - def _cache_blobs(self, dump_id: int) -> typing.NoReturn: + def _cache_blobs(self, dump_id: int) -> None: """Create a temp cache file for blob data :param int dump_id: The dump ID for the filename @@ -367,7 +443,7 @@ def _cache_blobs(self, dump_id: int) -> typing.NoReturn: handle.write(blob) count += 1 - def _cache_table_data(self, dump_id: int) -> typing.NoReturn: + def _cache_table_data(self, dump_id: int) -> None: """Create a temp cache file for the table data :param int dump_id: The dump ID for the filename @@ -377,7 +453,7 @@ def _cache_table_data(self, dump_id: int) -> typing.NoReturn: handle.write(self._read_data()) @property - def _data_entries(self) -> typing.List[Entry]: + def _data_entries(self) -> list[models.Entry]: """Return the list of entries that are in the data section :rtype: list @@ -386,13 +462,12 @@ def _data_entries(self) -> typing.List[Entry]: return [e for e in self.entries if e.section == constants.SECTION_DATA] @staticmethod - def _get_k_version(appear_as: typing.Tuple[int, int]) \ - -> typing.Tuple[int, int, int]: + def _get_k_version(appear_as: tuple[int, int]) \ + -> tuple[int, int, int]: for (min_ver, max_ver), value in constants.K_VERSION_MAP.items(): if min_ver <= appear_as <= max_ver: return value - raise RuntimeError( - 'Unsupported PostgreSQL version: {}'.format(appear_as)) + raise RuntimeError(f'Unsupported PostgreSQL version: {appear_as}') def _next_dump_id(self) -> int: """Get the next ``dump_id`` that is available for adding an entry @@ -402,8 +477,7 @@ def _next_dump_id(self) -> int: """ return max(e.dump_id for e in self.entries) + 1 - def _read_blobs(self) -> typing.Generator[ - typing.Tuple[int, bytes], None, None]: + def _read_blobs(self) -> typing.Generator[tuple[int, bytes], None, None]: """Read blobs, returning a tuple of the blob ID and the blob data :rtype: (int, bytes) @@ -418,7 +492,7 @@ def _read_blobs(self) -> typing.Generator[ if oid == 0: oid = self._read_int() - def _read_block_header(self) -> typing.Tuple[bytes, typing.Optional[int]]: + def _read_block_header(self) -> tuple[bytes, int | None]: """Read the block header in :rtype: bytes, int @@ -426,7 +500,7 @@ def _read_block_header(self) -> typing.Tuple[bytes, typing.Optional[int]]: """ return self._handle.read(1), self._read_int() - def _read_byte(self) -> typing.Optional[int]: + def _read_byte(self) -> int | None: """Read in an individual byte :rtype: int @@ -507,12 +581,12 @@ def _read_dependencies(self) -> list: values.add(int(value)) return sorted(values) - def _read_entries(self) -> typing.NoReturn: + def _read_entries(self) -> None: """Read in all of the entries""" for _i in range(0, self._read_int() or 0): self._read_entry() - def _read_entry(self) -> typing.NoReturn: + def _read_entry(self) -> None: """Read in an individual entry and append it to the entries stack""" dump_id = self._read_int() had_dumper = bool(self._read_int()) @@ -534,15 +608,26 @@ def _read_entry(self) -> typing.NoReturn: with_oids = self._read_bytes() == b'true' dependencies = self._read_dependencies() data_state, offset = self._read_offset() - self.entries.append(Entry( - dump_id=dump_id, had_dumper=had_dumper, table_oid=table_oid, - oid=oid, tag=tag, desc=desc, defn=defn, drop_stmt=drop_stmt, - copy_stmt=copy_stmt, namespace=namespace, tablespace=tablespace, - tableam=tableam, owner=owner, with_oids=with_oids, - dependencies=dependencies, data_state=data_state or 0, - offset=offset or 0)) - - def _read_header(self) -> typing.NoReturn: + self.entries.append( + models.Entry(dump_id=dump_id, + had_dumper=had_dumper, + table_oid=table_oid, + oid=oid, + tag=tag, + desc=desc, + defn=defn, + drop_stmt=drop_stmt, + copy_stmt=copy_stmt, + namespace=namespace, + tablespace=tablespace, + tableam=tableam, + owner=owner, + with_oids=with_oids, + dependencies=dependencies, + data_state=data_state or 0, + offset=offset or 0)) + + def _read_header(self) -> None: """Read in the dump header :raises: ValueError @@ -557,10 +642,10 @@ def _read_header(self) -> typing.NoReturn: self._offsize = struct.unpack('B', self._handle.read(1))[0] self._format = constants.FORMATS[struct.unpack( 'B', self._handle.read(1))[0]] - LOGGER.debug('Archive version %i.%i.%i', - self._vmaj, self._vmin, self._vrev) + LOGGER.debug('Archive version %i.%i.%i', self._vmaj, self._vmin, + self._vrev) - def _read_int(self) -> typing.Optional[int]: + def _read_int(self) -> int | None: """Read in a signed integer :rtype: int or None @@ -577,7 +662,7 @@ def _read_int(self) -> typing.Optional[int]: bs += 8 return -value if sign else value - def _read_offset(self) -> typing.Tuple[int, int]: + def _read_offset(self) -> tuple[int, int]: """Read in the value for the length of the data stored in the file :rtype: int, int @@ -613,14 +698,24 @@ def _read_timestamp(self) -> datetime.datetime: :rtype: datetime.datetime """ - second, minute, hour, day, month, year = ( - self._read_int(), self._read_int(), self._read_int(), - self._read_int(), (self._read_int() or 0) + 1, - (self._read_int() or 0) + 1900) + second, minute, hour, day, month, year = (self._read_int(), + self._read_int(), + self._read_int(), + self._read_int(), + (self._read_int() or 0) + 1, + (self._read_int() or 0) + + 1900) self._read_int() # DST flag - return datetime.datetime(year, month, day, hour, minute, second, 0) - - def _save(self) -> typing.NoReturn: + return datetime.datetime(year, + month, + day, + hour, + minute, + second, + 0, + tzinfo=datetime.UTC) + + def _save(self) -> None: """Save the dump file to disk""" self._write_toc() self._write_entries() @@ -628,7 +723,7 @@ def _save(self) -> typing.NoReturn: self._write_toc() # Overwrite ToC and entries self._write_entries() - def _set_encoding(self) -> typing.NoReturn: + def _set_encoding(self) -> None: """If the encoding is found in the dump entries, set the encoding to `self.encoding`. @@ -649,14 +744,14 @@ def _tempfile(self, dump_id: int, mode: str) \ :param str mode: The mode (rb, wb) """ - path = pathlib.Path(self._temp_dir.name) / '{}.gz'.format(dump_id) + path = pathlib.Path(self._temp_dir.name) / f'{dump_id}.gz' if not path.exists() and mode.startswith('r'): raise exceptions.NoDataError() with gzip.open(path, mode) as handle: try: yield handle - finally: - return + except Exception: + raise def _write_blobs(self, dump_id: int) -> int: """Write the blobs for the entry. @@ -681,7 +776,7 @@ def _write_blobs(self, dump_id: int) -> int: self._write_int(0) return length - def _write_byte(self, value: int) -> typing.NoReturn: + def _write_byte(self, value: int) -> None: """Write a byte to the handle :param int value: The byte value @@ -716,39 +811,26 @@ def _write_entries(self): self._write_entry(entry) saved.add(entry.dump_id) - saved = self._write_section( - constants.SECTION_PRE_DATA, [ - constants.GROUP, - constants.ROLE, - constants.USER, - constants.SCHEMA, - constants.EXTENSION, - constants.AGGREGATE, - constants.OPERATOR, - constants.OPERATOR_CLASS, - constants.CAST, - constants.COLLATION, - constants.CONVERSION, - constants.PROCEDURAL_LANGUAGE, - constants.FOREIGN_DATA_WRAPPER, - constants.FOREIGN_SERVER, - constants.SERVER, - constants.DOMAIN, - constants.TYPE, - constants.SHELL_TYPE], saved) + saved = self._write_section(constants.SECTION_PRE_DATA, [ + constants.GROUP, constants.ROLE, constants.USER, constants.SCHEMA, + constants.EXTENSION, constants.AGGREGATE, constants.OPERATOR, + constants.OPERATOR_CLASS, constants.CAST, constants.COLLATION, + constants.CONVERSION, constants.PROCEDURAL_LANGUAGE, + constants.FOREIGN_DATA_WRAPPER, constants.FOREIGN_SERVER, + constants.SERVER, constants.DOMAIN, constants.TYPE, + constants.SHELL_TYPE + ], saved) saved = self._write_section(constants.SECTION_DATA, [], saved) - saved = self._write_section( - constants.SECTION_POST_DATA, [ - constants.CHECK_CONSTRAINT, - constants.CONSTRAINT, - constants.INDEX], saved) + saved = self._write_section(constants.SECTION_POST_DATA, [ + constants.CHECK_CONSTRAINT, constants.CONSTRAINT, constants.INDEX + ], saved) saved = self._write_section(constants.SECTION_NONE, [], saved) LOGGER.debug('Wrote %i of %i entries', len(saved), len(self.entries)) - def _write_entry(self, entry: Entry) -> typing.NoReturn: + def _write_entry(self, entry: models.Entry) -> None: """Write the entry :param pgdumplib.dump.Entry entry: The entry to write @@ -777,10 +859,10 @@ def _write_entry(self, entry: Entry) -> typing.NoReturn: self._write_int(-1) self._write_offset(entry.offset, entry.data_state) - def _write_header(self) -> typing.NoReturn: + def _write_header(self) -> None: """Write the file header""" - LOGGER.debug('Writing archive version %i.%i.%i', - self._vmaj, self._vmin, self._vrev) + LOGGER.debug('Writing archive version %i.%i.%i', self._vmaj, + self._vmin, self._vrev) self._handle.write(constants.MAGIC) self._write_byte(self._vmaj) self._write_byte(self._vmin) @@ -789,7 +871,7 @@ def _write_header(self) -> typing.NoReturn: self._write_byte(self._offsize) self._write_byte(constants.FORMATS.index(self._format)) - def _write_int(self, value: int) -> typing.NoReturn: + def _write_int(self, value: int) -> None: """Write an integer value :param int value: @@ -802,7 +884,7 @@ def _write_int(self, value: int) -> typing.NoReturn: self._write_byte(value & 0xFF) value >>= 8 - def _write_offset(self, value: int, data_state: int) -> typing.NoReturn: + def _write_offset(self, value: int, data_state: int) -> None: """Write the offset value. :param int value: The value to write @@ -810,7 +892,7 @@ def _write_offset(self, value: int, data_state: int) -> typing.NoReturn: """ self._write_byte(data_state) - for offset in range(0, self._offsize): + for _offset in range(0, self._offsize): self._write_byte(value & 0xFF) value >>= 8 @@ -820,14 +902,16 @@ def _write_section(self, section: str, obj_types: list, saved: set) -> set: self._write_entry(entry) saved.add(entry.dump_id) for dump_id in toposort.toposort_flatten( - {e.dump_id: set(e.dependencies) for e in self.entries - if e.section == section}, True): + { + e.dump_id: set(e.dependencies) + for e in self.entries if e.section == section + }, True): if dump_id not in saved: self._write_entry(self.get_entry(dump_id)) saved.add(dump_id) return saved - def _write_str(self, value: str) -> typing.NoReturn: + def _write_str(self, value: str) -> None: """Write a string :param str value: The string to write @@ -868,7 +952,7 @@ def _write_table_data(self, dump_id: int) -> int: self._write_int(0) # End of data indicator return size - def _write_timestamp(self, value: datetime.datetime) -> typing.NoReturn: + def _write_timestamp(self, value: datetime.datetime) -> None: """Write a datetime.datetime value :param datetime.datetime value: The value to write @@ -882,7 +966,7 @@ def _write_timestamp(self, value: datetime.datetime) -> typing.NoReturn: self._write_int(value.year - 1900) self._write_int(1 if value.dst() else 0) - def _write_toc(self) -> typing.NoReturn: + def _write_toc(self) -> None: """Write the ToC for the file""" self._handle.seek(0) self._write_header() @@ -891,136 +975,3 @@ def _write_toc(self) -> typing.NoReturn: self._write_str(self.dbname) self._write_str(self.server_version) self._write_str(self.dump_version) - - -@dataclasses.dataclass(eq=True) -class Entry: - """The entry model represents a single entry in the dataclass - - Custom formatted dump files are primarily comprised of entries, which - contain all of the metadata and DDL required to construct the database. - - For table data and blobs, there are entries that contain offset locations - in the dump file that instruct the reader as to where the data lives - in the file. - - :var int dump_id: The dump id, will be auto-calculated if left empty - :var bool had_dumper: Indicates - :var str oid: The OID of the object the entry represents - :var str tag: The name/table/relation/etc of the entry - :var str desc: The entry description - :var str defn: The DDL definition for the entry - :var str drop_stmt: A drop statement used to drop the entry before - :var str copy_stmt: A copy statement used when there is a corresponding - data section. - :var str namespace: The namespace of the entry - :var str tablespace: The tablespace to use - :var str tableam: The table access method - :var str owner: The owner of the object in Postgres - :var bool with_oids: Indicates ... - :var list dependencies: A list of dump_ids of objects that the entry - is dependent upon. - :var int data_state: Indicates if the entry has data and how it is stored - :var int offset: If the entry has data, the offset to the data in the file - :var str section: The section of the dump file the entry belongs to - - """ - dump_id: int - had_dumper: bool = False - table_oid: str = '0' - oid: str = '0' - tag: typing.Optional[str] = None - desc: typing.Optional[str] = None - defn: typing.Optional[str] = None - drop_stmt: typing.Optional[str] = None - copy_stmt: typing.Optional[str] = None - namespace: typing.Optional[str] = None - tablespace: typing.Optional[str] = None - tableam: typing.Optional[str] = None - owner: typing.Optional[str] = None - with_oids: bool = False - dependencies: typing.List[int] = dataclasses.field(default_factory=list) - data_state: int = constants.K_OFFSET_NO_DATA - offset: int = 0 - - @property - def section(self) -> str: - """Return the section the entry belongs to""" - return constants.SECTION_MAPPING[self.desc] - - -class TableData: - """Used to encapsulate table data using temporary file and allowing - for an API that allows for the appending of data one row at a time. - - Do not create this class directly, instead invoke - :py:meth:`~pgdumplib.dump.Dump.table_data_writer`. - - """ - def __init__(self, dump_id: int, tempdir: str, encoding: str): - self.dump_id = dump_id - self._encoding = encoding - self._path = pathlib.Path(tempdir) / '{}.gz'.format(dump_id) - self._handle = gzip.open(self._path, 'wb') - - def append(self, *args) -> typing.NoReturn: - """Append a row to the table data, passing columns in as args - - Column order must match the order specified when - :py:meth:`~pgdumplib.dump.Dump.table_data_writer` was invoked. - - All columns will be coerced to a string with special attention - paid to ``None``, converting it to the null marker (``\\N``) and - :py:class:`datetime.datetime` objects, which will have the proper - pg_dump timestamp format applied to them. - - """ - row = '\t'.join([self._convert(c) for c in args]) - self._handle.write('{}\n'.format(row).encode(self._encoding)) - - def finish(self) -> typing.NoReturn: - """Invoked prior to saving a dump to close the temporary data - handle and switch the class into read-only mode. - - For use by :py:class:`pgdumplib.dump.Dump` only. - - """ - if not self._handle.closed: - self._handle.close() - self._handle = gzip.open(self._path, 'rb') - - def read(self) -> bytes: - """Read the data from disk for writing to the dump - - For use by :py:class:`pgdumplib.dump.Dump` only. - - :rtype: bytes - - """ - self._handle.seek(0) - return self._handle.read() - - @property - def size(self) -> int: - """Return the current size of the data on disk - - :rtype: int - - """ - self._handle.seek(0, io.SEEK_END) # Seek to end to figure out size - size = self._handle.tell() - self._handle.seek(0) - return size - - @staticmethod - def _convert(column: typing.Any) -> str: - """Convert the column to a string - - :param any column: The column to convert - - """ - if isinstance(column, datetime.datetime): - return column.strftime(constants.PGDUMP_STRFTIME_FMT) - elif column is None: - return '\\N' - return str(column) diff --git a/pgdumplib/exceptions.py b/pgdumplib/exceptions.py index 39f146b..51abd7a 100644 --- a/pgdumplib/exceptions.py +++ b/pgdumplib/exceptions.py @@ -20,16 +20,15 @@ class EntityNotFoundError(PgDumpLibException): and ``table`` specified were not found. """ - def __init__(self, namespace: str, table: str): super().__init__() self.namespace = namespace self.table = table def __repr__(self) -> str: # pragma: nocover - return ''.format( - self.namespace, self.table) + return f'' def __str__(self) -> str: # pragma: nocover - return 'Did not find {}.{} in the table of contents'.format( - self.namespace, self.table) + return f'Did not find {self.namespace}.{self.table} in the table ' \ + f'of contents' diff --git a/pgdumplib/models.py b/pgdumplib/models.py new file mode 100644 index 0000000..831b06a --- /dev/null +++ b/pgdumplib/models.py @@ -0,0 +1,59 @@ +import dataclasses + +from pgdumplib import constants + + +@dataclasses.dataclass(eq=True) +class Entry: + """The entry model represents a single entry in the dataclass + + Custom formatted dump files are primarily comprised of entries, which + contain all of the metadata and DDL required to construct the database. + + For table data and blobs, there are entries that contain offset locations + in the dump file that instruct the reader as to where the data lives + in the file. + + :var dump_id: The dump id, will be auto-calculated if left empty + :var had_dumper: Indicates + :var oid: The OID of the object the entry represents + :var tag: The name/table/relation/etc of the entry + :var desc: The entry description + :var defn: The DDL definition for the entry + :var drop_stmt: A drop statement used to drop the entry before + :var copy_stmt: A copy statement used when there is a corresponding + data section. + :var namespace: The namespace of the entry + :var tablespace: The tablespace to use + :var tableam: The table access method + :var owner: The owner of the object in Postgres + :var with_oids: Indicates ... + :var dependencies: A list of dump_ids of objects that the entry + is dependent upon. + :var data_state: Indicates if the entry has data and how it is stored + :var offset: If the entry has data, the offset to the data in the file + :var section: The section of the dump file the entry belongs to + + """ + dump_id: int + had_dumper: bool = False + table_oid: str = '0' + oid: str = '0' + tag: str | None = None + desc: str = 'Unknown' + defn: str | None = None + drop_stmt: str | None = None + copy_stmt: str | None = None + namespace: str | None = None + tablespace: str | None = None + tableam: str | None = None + owner: str | None = None + with_oids: bool = False + dependencies: list[int] = dataclasses.field(default_factory=list) + data_state: int = constants.K_OFFSET_NO_DATA + offset: int = 0 + + @property + def section(self) -> str: + """Return the section the entry belongs to""" + return constants.SECTION_MAPPING[self.desc] diff --git a/tests/__init__.py b/tests/__init__.py index 80846cf..21e4739 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -10,5 +10,5 @@ def setup_module(): line = line[7:] name, _, value = line.strip().partition('=') os.environ[name] = value - except IOError: + except OSError: pass diff --git a/tests/test_converters.py b/tests/test_converters.py index c5abdb4..4b0741b 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -10,21 +10,18 @@ class TestCase(unittest.TestCase): - def test_data_converter(self): data = [] for row in range(0, 10): data.append([ str(row), str(uuid.uuid4()), - str(datetime.datetime.utcnow()), - str(uuid.uuid4()), - str(uuid.uuid4()), - None + str(datetime.datetime.now(tz=datetime.UTC)), + str(uuid.uuid4()), None ]) converter = converters.DataConverter() - for offset, expectation in enumerate(data): + for _offset, expectation in enumerate(data): line = '\t'.join(['\\N' if e is None else e for e in expectation]) self.assertListEqual(list(converter.convert(line)), expectation) @@ -34,7 +31,6 @@ def test_noop_converter(self): self.assertEqual(converter.convert(value), value) def test_smart_data_converter(self): - def convert(value): """Convert the value to the proper string type""" if value is None: @@ -47,20 +43,19 @@ def convert(value): data = [] for row in range(0, 10): data.append([ - row, - None, + row, None, fake.pydecimal(positive=True, left_digits=5, right_digits=3), uuid.uuid4(), ipaddress.IPv4Network(fake.ipv4(True)), ipaddress.IPv4Address(fake.ipv4()), ipaddress.IPv6Address(fake.ipv6()), - maya.now().datetime( - to_timezone='US/Eastern', naive=True).strftime( - constants.PGDUMP_STRFTIME_FMT) + maya.now().datetime(to_timezone='US/Eastern', + naive=True).strftime( + constants.PGDUMP_STRFTIME_FMT) ]) converter = converters.SmartDataConverter() - for offset, expectation in enumerate(data): + for _offset, expectation in enumerate(data): line = '\t'.join([convert(e) for e in expectation]) row = list(converter.convert(line)) self.assertListEqual(row, expectation) diff --git a/tests/test_dump.py b/tests/test_dump.py index d665b67..1c1adde 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -8,8 +8,8 @@ import subprocess import tempfile import unittest -from unittest import mock import uuid +from unittest import mock import psycopg2 @@ -76,8 +76,8 @@ def test_read_dump_entity_not_found(self): LOGGER.debug('Line: %r', line) def test_lookup_entry(self): - entry = self.dump.lookup_entry( - constants.TABLE, 'public', 'pgbench_accounts') + entry = self.dump.lookup_entry(constants.TABLE, 'public', + 'pgbench_accounts') self.assertEqual(entry.namespace, 'public') self.assertEqual(entry.tag, 'pgbench_accounts') self.assertEqual(entry.section, constants.SECTION_PRE_DATA) @@ -91,8 +91,8 @@ def test_lookup_entry_invalid_desc(self): self.dump.lookup_entry('foo', 'public', 'pgbench_accounts') def test_get_entry(self): - entry = self.dump.lookup_entry( - constants.TABLE, 'public', 'pgbench_accounts') + entry = self.dump.lookup_entry(constants.TABLE, 'public', + 'pgbench_accounts') self.assertEqual(self.dump.get_entry(entry.dump_id), entry) def test_get_entry_not_found(self): @@ -130,7 +130,7 @@ def test_read_dump_data(self): for line in self.dump.table_data('public', 'pgbench_accounts'): self.assertTrue( line.startswith('INSERT INTO public.pgbench_accounts'), - 'Unexpected start @ row {}: {!r}'.format(count, line)) + f'Unexpected start @ row {count}: {line!r}') count += 1 self.assertEqual(count, 100000) @@ -149,19 +149,17 @@ def test_table_data_empty(self): super().test_table_data_empty() def test_read_blobs(self): - self.assertEqual(len([b for b in self.dump.blobs()]), 0) + self.assertEqual(len(list(self.dump.blobs())), 0) class ErrorsTestCase(unittest.TestCase): - def test_missing_file_raises_value_error(self): path = pathlib.Path(tempfile.gettempdir()) / str(uuid.uuid4()) with self.assertRaises(ValueError): pgdumplib.load(path) def test_min_version_failure_raises(self): - min_ver = (constants.MIN_VER[0], - constants.MIN_VER[1] + 10, + min_ver = (constants.MIN_VER[0], constants.MIN_VER[1] + 10, constants.MIN_VER[2]) LOGGER.debug('Setting pgdumplib.constants.MIN_VER to %s', min_ver) with mock.patch('pgdumplib.constants.MIN_VER', min_ver): @@ -187,7 +185,6 @@ def test_invalid_dump_file(self): class NewDumpTestCase(unittest.TestCase): - def test_pgdumplib_new(self): dmp = pgdumplib.new('test', 'UTF8', converters.SmartDataConverter) self.assertIsInstance(dmp, dump.Dump) @@ -213,9 +210,10 @@ def _read_dump(self): @classmethod def _read_dump_info(cls, remote_path) -> DumpInfo: - restore = subprocess.run( - ['pg_restore', '-l', str(remote_path)], - check=True, capture_output=True) + restore = subprocess.run( # noqa: S603 + ['pg_restore', '-l', str(remote_path)], # noqa: S607 + check=True, + capture_output=True) stdout = restore.stdout.decode('utf-8') data = {} for key, pattern in PATTERNS.items(): @@ -246,8 +244,8 @@ def test_toc_entry_count(self): self.assertEqual(len(self.dump.entries), self.info.entry_count) def test_toc_server_version(self): - self.assertEqual( - self.dump.server_version, self.info.server_version) + self.assertEqual(self.dump.server_version, self.info.server_version) + # def test_toc_timestamp(self): # self.assertEqual( # self.dump.timestamp.isoformat(), self.info.timestamp.isoformat()) @@ -270,7 +268,6 @@ class RestoreComparisonDataOnlyTestCase(RestoreComparisonTestCase): class KVersionTestCase(unittest.TestCase): - def test_default(self): instance = dump.Dump() self.assertEqual(instance.version, (1, 14, 0)) diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index dfe1c6f..45ea25b 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -13,7 +13,6 @@ class EdgeTestCase(unittest.TestCase): - @staticmethod def _write_byte(handle, value) -> None: """Write a byte to the handle""" @@ -23,7 +22,7 @@ def _write_int(self, handle, value): self._write_byte(handle, 1 if value < 0 else 0) if value < 0: value = -value - for offset in range(0, 4): + for _offset in range(0, 4): self._write_byte(handle, value & 0xFF) value >>= 8 @@ -35,7 +34,9 @@ def tearDown(self) -> None: def test_invalid_dependency(self): dmp = pgdumplib.new('test') with self.assertRaises(ValueError): - dmp.add_entry(constants.TABLE, '', 'block_table', + dmp.add_entry(constants.TABLE, + '', + 'block_table', dependencies=[1024]) def test_invalid_block_type_in_data(self): @@ -69,8 +70,8 @@ def test_encoding_no_entries(self): def test_dump_id_mismatch_in_data(self): dmp = pgdumplib.new('test') dmp.add_entry(constants.TABLE_DATA, '', 'block_table', dump_id=1024) - with gzip.open( - pathlib.Path(dmp._temp_dir.name) / '1024.gz', 'wb') as handle: + with gzip.open(pathlib.Path(dmp._temp_dir.name) / '1024.gz', + 'wb') as handle: handle.write(b'1\t\1\t\1\n') dmp.save('build/data/dump.test') @@ -86,7 +87,7 @@ def test_no_data(self): h.write(b'') dmp.save('build/data/dump.test') dmp = pgdumplib.load('build/data/dump.test') - data = [line for line in dmp.table_data('', 'empty_table')] + data = list(dmp.table_data('', 'empty_table')) self.assertEqual(len(data), 0) def test_runtime_error_when_pos_not_set(self): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 0b1a0eb..0961bc6 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -4,13 +4,12 @@ class ExceptionTestCase(unittest.TestCase): - def test_repr_formatting(self): exc = exceptions.EntityNotFoundError('public', 'table') - self.assertEqual( - repr(exc), "") + self.assertEqual(repr(exc), + "") def test_str_formatting(self): exc = exceptions.EntityNotFoundError('public', 'table') - self.assertEqual( - str(exc), 'Did not find public.table in the table of contents') + self.assertEqual(str(exc), + 'Did not find public.table in the table of contents') diff --git a/tests/test_save_dump.py b/tests/test_save_dump.py index 8debff4..10261b1 100644 --- a/tests/test_save_dump.py +++ b/tests/test_save_dump.py @@ -1,9 +1,9 @@ import dataclasses +import datetime import pathlib import unittest import uuid -from dateutil import tz import faker from faker.providers import date_time @@ -12,7 +12,6 @@ class SavedDumpTestCase(unittest.TestCase): - def setUp(self): dmp = pgdumplib.load('build/data/dump.compressed') dmp.save('build/data/dump.test') @@ -42,36 +41,33 @@ def test_entries_mostly_match(self): saved_entry = self.saved.get_entry(original.dump_id) for attr in attrs: self.assertEqual( - getattr(original, attr), - getattr(saved_entry, attr), - '{} does not match: {} != {}'.format( - attr, getattr(original, attr), - getattr(saved_entry, attr))) + getattr(original, attr), getattr(saved_entry, attr), + f'{attr} does not match: {getattr(original, attr)} != ' + f'{getattr(saved_entry, attr)}') def test_table_data_matches(self): for entry in range(0, len(self.original.entries)): if self.original.entries[entry].desc != constants.TABLE_DATA: continue - original_data = [row for row in self.original.table_data( - self.original.entries[entry].namespace, - self.original.entries[entry].tag)] + original_data = list( + self.original.table_data( + self.original.entries[entry].namespace, + self.original.entries[entry].tag)) - saved_data = [row for row in self.saved.table_data( - self.original.entries[entry].namespace, - self.original.entries[entry].tag)] + saved_data = list( + self.saved.table_data(self.original.entries[entry].namespace, + self.original.entries[entry].tag)) for offset in range(0, len(original_data)): self.assertListEqual( list(original_data[offset]), list(saved_data[offset]), - 'Data in {}.{} does not match for row {}'.format( - self.original.entries[entry].namespace, - self.original.entries[entry].tag, - offset)) + f'Data in {self.original.entries[entry].namespace}.' + f'{self.original.entries[entry].tag} does not match ' + f'for row {offset}') class EmptyDumpTestCase(unittest.TestCase): - def test_empty_dump_has_base_entries(self): dump = pgdumplib.new('test', 'UTF8') self.assertEqual(len(dump.entries), 3) @@ -85,7 +81,6 @@ def test_empty_save_does_not_err(self): class CreateDumpTestCase(unittest.TestCase): - def tearDown(self) -> None: test_file = pathlib.Path('build/data/dump.test') if test_file.exists(): @@ -93,52 +88,47 @@ def tearDown(self) -> None: def test_dump_expectations(self): dmp = pgdumplib.new('test', 'UTF8') - database = dmp.add_entry( - desc=constants.DATABASE, - tag='postgres', - owner='postgres', - defn="""\ + database = dmp.add_entry(desc=constants.DATABASE, + tag='postgres', + owner='postgres', + defn="""\ CREATE DATABASE postgres WITH TEMPLATE = template0 ENCODING = 'UTF8' LC_COLLATE = 'en_US.utf8' LC_CTYPE = 'en_US.utf8';""", - drop_stmt='DROP DATABASE postgres') + drop_stmt='DROP DATABASE postgres') - dmp.add_entry( - constants.COMMENT, - tag='DATABASE postgres', - owner='postgres', - defn="""\ + dmp.add_entry(constants.COMMENT, + tag='DATABASE postgres', + owner='postgres', + defn="""\ COMMENT ON DATABASE postgres IS 'default administrative connection database';""", - dependencies=[database.dump_id]) + dependencies=[database.dump_id]) example = dmp.add_entry( constants.TABLE, 'public', 'example', 'postgres', 'CREATE TABLE public.example (\ id UUID NOT NULL PRIMARY KEY, \ created_at TIMESTAMP WITH TIME ZONE, \ - value TEXT NOT NULL);', - 'DROP TABLE public.example') + value TEXT NOT NULL);', 'DROP TABLE public.example') columns = 'id', 'created_at', 'value' fake = faker.Faker() fake.add_provider(date_time) - rows = [ - (uuid.uuid4(), fake.date_time(tzinfo=tz.tzutc()), 'foo'), - (uuid.uuid4(), fake.date_time(tzinfo=tz.tzutc()), 'bar'), - (uuid.uuid4(), fake.date_time(tzinfo=tz.tzutc()), 'baz'), - (uuid.uuid4(), fake.date_time(tzinfo=tz.tzutc()), 'qux') - ] + rows = [(uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), 'foo'), + (uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), 'bar'), + (uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), 'baz'), + (uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), 'qux')] with dmp.table_data_writer(example, columns) as writer: for row in rows: writer.append(*row) - row = (uuid.uuid4(), fake.date_time(tzinfo=tz.tzutc()), None) + row = (uuid.uuid4(), fake.date_time(tzinfo=datetime.UTC), None) rows.append(row) # Append a second time to get same writer @@ -155,5 +145,5 @@ def test_dump_expectations(self): self.assertEqual(entry.desc, 'DATABASE') self.assertEqual(entry.owner, 'postgres') self.assertEqual(entry.tag, 'postgres') - values = [row for row in dmp.table_data('public', 'example')] + values = list(dmp.table_data('public', 'example')) self.assertListEqual(values, rows)