diff --git a/gel/orm/django/generator.py b/gel/orm/django/generator.py index 6ce63aea..bfcd1328 100644 --- a/gel/orm/django/generator.py +++ b/gel/orm/django/generator.py @@ -22,9 +22,9 @@ # values are controlled in Django via settings (USE_TZ) and are mutually # exclusive in the same app under default circumstances. 'std::datetime': 'DateTimeField', - 'cal::local_date': 'DateField', - 'cal::local_datetime': 'DateTimeField', - 'cal::local_time': 'TimeField', + 'std::cal::local_date': 'DateField', + 'std::cal::local_datetime': 'DateTimeField', + 'std::cal::local_time': 'TimeField', # all kinds of durations are not supported due to this error: # iso_8601 intervalstyle currently not supported } @@ -38,6 +38,8 @@ # from django.db import models +from django.contrib.postgres import fields as pgf + class GelUUIDField(models.UUIDField): # This field must be treated as a auto-generated UUID. @@ -55,6 +57,17 @@ class GelPGMeta: ''' CLOSEPAR_RE = re.compile(r'\)(?=\s+#|$)') +ARRAY_RE = re.compile(r'^array<(?P.+)>$') +NAME_RE = re.compile(r'^(?P\w+?)(?P\d*)$') + + +def field_name_sort(item): + key, val = item + + match = NAME_RE.fullmatch(key) + res = (match.group('alpha'), int(match.group('num') or -1)) + + return res class ModelClass(object): @@ -166,11 +179,11 @@ def build_models(self, maps): mod.links['source'] = ( f"LTForeignKey({source!r}, models.DO_NOTHING, " - f"db_column='source', primary_key=True)" + f"db_column='source')" ) mod.links['target'] = ( f"LTForeignKey({target!r}, models.DO_NOTHING, " - f"db_column='target')" + f"db_column='target', primary_key=True)" ) # Update the source model with the corresponding @@ -197,6 +210,12 @@ def render_prop(self, prop): req = 'blank=True, null=True' target = prop['target']['name'] + is_array = False + match = ARRAY_RE.fullmatch(target) + if match: + is_array = True + target = match.group('el') + try: ftype = GEL_SCALAR_MAP[target] except KeyError: @@ -206,7 +225,10 @@ def render_prop(self, prop): ) return '' - return f'models.{ftype}({req})' + if is_array: + return f'pgf.ArrayField(models.{ftype}({req}))' + else: + return f'models.{ftype}({req})' def render_link(self, link, bklink=None): if link['required']: @@ -267,7 +289,7 @@ def render_models(self, spec): self.out = f self.write(BASE_STUB) - for mod in modmap.values(): + for mod in sorted(modmap.values(), key=lambda x: x.name): self.write() self.write() self.render_model_class(mod) @@ -284,19 +306,22 @@ def render_model_class(self, mod): if mod.props: self.write() self.write(f'# properties as Fields') - for name, val in mod.props.items(): + props = sorted(mod.props.items(), key=field_name_sort) + for name, val in props: self.write(f'{name} = {val}') if mod.links: self.write() self.write(f'# links as ForeignKeys') - for name, val in mod.links.items(): + links = sorted(mod.links.items(), key=field_name_sort) + for name, val in links: self.write(f'{name} = {val}') if mod.mlinks: self.write() self.write(f'# multi links as ManyToManyFields') - for name, val in mod.mlinks.items(): + mlinks = sorted(mod.mlinks.items(), key=field_name_sort) + for name, val in mlinks: self.write(f'{name} = {val}') if '.' not in mod.table: diff --git a/gel/orm/introspection.py b/gel/orm/introspection.py index ec43464f..39e436dc 100644 --- a/gel/orm/introspection.py +++ b/gel/orm/introspection.py @@ -31,7 +31,7 @@ ), target: {name}, }, - } filter .name != '__type__', + } filter .name != '__type__' and not exists .expr, properties: { name, readonly, @@ -42,7 +42,7 @@ filter .name = 'std::exclusive' ), target: {name}, - }, + } filter not exists .expr, backlinks := >[], } filter diff --git a/gel/orm/sqla.py b/gel/orm/sqla.py index 9af5172e..f9949383 100644 --- a/gel/orm/sqla.py +++ b/gel/orm/sqla.py @@ -9,48 +9,64 @@ GEL_SCALAR_MAP = { - 'std::bool': ('bool', 'Boolean'), - 'std::str': ('str', 'String'), - 'std::int16': ('int', 'Integer'), - 'std::int32': ('int', 'Integer'), - 'std::int64': ('int', 'Integer'), - 'std::float32': ('float', 'Float'), - 'std::float64': ('float', 'Float'), - 'std::uuid': ('uuid.UUID', 'Uuid'), + 'std::bool': ('bool', 'sa.Boolean'), + 'std::str': ('str', 'sa.String'), + 'std::int16': ('int', 'sa.Integer'), + 'std::int32': ('int', 'sa.Integer'), + 'std::int64': ('int', 'sa.Integer'), + 'std::float32': ('float', 'sa.Float'), + 'std::float64': ('float', 'sa.Float'), + 'std::uuid': ('uuid.UUID', 'sa.Uuid'), + 'std::bytes': ('bytes', 'sa.LargeBinary'), + 'std::cal::local_date': ('datetime.date', 'sa.Date'), + 'std::cal::local_time': ('datetime.time', 'sa.Time'), + 'std::cal::local_datetime': ('datetime.datetime', 'sa.DateTime'), + 'std::datetime': ('datetime.datetime', 'sa.TIMESTAMP'), } -CLEAN_RE = re.compile(r'[^A-Za-z0-9]+') +ARRAY_RE = re.compile(r'^array<(?P.+)>$') +NAME_RE = re.compile(r'^(?P\w+?)(?P\d*)$') COMMENT = '''\ # # Automatically generated from Gel schema. +# +# Do not edit directly as re-generating this file will overwrite any changes. #\ ''' BASE_STUB = f'''\ {COMMENT} -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import orm as orm -class Base(DeclarativeBase): +class Base(orm.DeclarativeBase): pass\ ''' MODELS_STUB = f'''\ {COMMENT} +import datetime import uuid -from typing import List -from typing import Optional +from typing import List, Optional -from sqlalchemy import MetaData, Table, Column, ForeignKey -from sqlalchemy import String, Uuid, Integer, Float, Boolean -from sqlalchemy.orm import Mapped, mapped_column, relationship +import sqlalchemy as sa +from sqlalchemy import orm as orm ''' +def field_name_sort(spec): + key = spec['name'] + + match = NAME_RE.fullmatch(key) + res = (match.group('alpha'), int(match.group('num') or -1)) + + return res + + class ModelGenerator(FilePrinter): def __init__(self, *, outdir=None, basemodule=None): # set the output to be stdout by default, but this is generally @@ -118,9 +134,9 @@ def init_module(self, mod, modules): def get_fk(self, mod, table, curmod): if mod == curmod: # No need for anything fancy within the same schema - return f'ForeignKey("{table}.id")' + return f'sa.ForeignKey("{table}.id")' else: - return f'ForeignKey("{mod}.{table}.id")' + return f'sa.ForeignKey("{mod}.{table}.id")' def get_py_name(self, mod, name, curmod): if False and mod == curmod: @@ -179,7 +195,8 @@ def render_models(self, spec): self.write(MODELS_STUB) self.write(f'from ._sqlabase import Base') - for rec in spec['link_tables']: + link_tables = sorted(spec['link_tables'], key=lambda x: x['name']) + for rec in link_tables: self.write() self.render_link_table(rec) @@ -189,11 +206,19 @@ def render_models(self, spec): # skip apparently empty modules continue - for lobj in maps.get('link_objects', {}).values(): + link_objects = sorted( + maps.get('link_objects', {}).values(), + key=lambda x: x['name'], + ) + for lobj in link_objects: self.write() self.render_link_object(lobj, modules) - for rec in maps.get('object_types', {}).values(): + object_types = sorted( + maps.get('object_types', {}).values(), + key=lambda x: x['name'], + ) + for rec in object_types: self.write() self.render_type(rec, modules) @@ -204,13 +229,13 @@ def render_link_table(self, spec): t_fk = self.get_fk(tmod, target, 'default') self.write() - self.write(f'{spec["name"]} = Table(') + self.write(f'{spec["name"]} = sa.Table(') self.indent() self.write(f'{spec["table"]!r},') self.write(f'Base.metadata,') # source is in the same module as this table - self.write(f'Column("source", {s_fk}),') - self.write(f'Column("target", {t_fk}),') + self.write(f'sa.Column("source", {s_fk}),') + self.write(f'sa.Column("target", {t_fk}),') self.write(f'schema={mod!r},') self.dedent() self.write(f')') @@ -243,10 +268,13 @@ def render_link_object(self, spec, modules): tmod, target = get_mod_and_name(link['target']['name']) fk = self.get_fk(tmod, target, mod) pyname = self.get_py_name(tmod, target, mod) - self.write(f'{lname}_id: Mapped[uuid.UUID] = mapped_column(') + self.write(f'{lname}_id: orm.Mapped[uuid.UUID] = orm.mapped_column(') self.indent() - self.write(f'{lname!r}, Uuid(), {fk},') - self.write(f'primary_key=True, nullable=False,') + self.write(f'{lname!r},') + self.write(f'sa.Uuid(),') + self.write(f'{fk},') + self.write(f'primary_key=True,') + self.write(f'nullable=False,') self.dedent() self.write(')') @@ -260,8 +288,8 @@ def render_link_object(self, spec, modules): ) self.write( - f'{lname}: Mapped[{pyname}] = ' - f'relationship(back_populates={bklink!r})' + f'{lname}: orm.Mapped[{pyname}] = ' + f'orm.relationship(back_populates={bklink!r})' ) if spec['properties']: @@ -292,25 +320,30 @@ def render_type(self, spec, modules): self.write() # Add two fields that all objects have - self.write(f'id: Mapped[uuid.UUID] = mapped_column(') + self.write(f'id: orm.Mapped[uuid.UUID] = orm.mapped_column(') self.indent() - self.write( - f"Uuid(), primary_key=True, server_default='uuid_generate_v4()')") + self.write(f"sa.Uuid(),") + self.write(f"primary_key=True,") + self.write(f"server_default='uuid_generate_v4()',") self.dedent() + self.write(f')') # This is maintained entirely by Gel, the server_default simply # indicates to SQLAlchemy that this value may be omitted. - self.write(f'gel_type_id: Mapped[uuid.UUID] = mapped_column(') + self.write(f'gel_type_id: orm.Mapped[uuid.UUID] = orm.mapped_column(') self.indent() - self.write( - f"'__type__', Uuid(), server_default='PLACEHOLDER')") + self.write(f"'__type__',") + self.write(f"sa.Uuid(),") + self.write(f"server_default='PLACEHOLDER',") self.dedent() + self.write(f")") if spec['properties']: self.write() self.write('# Properties:') - for prop in spec['properties']: + properties = sorted(spec['properties'], key=field_name_sort) + for prop in properties: if prop['name'] != 'id': self.render_prop(prop, mod, name, modules) @@ -318,14 +351,16 @@ def render_type(self, spec, modules): self.write() self.write('# Links:') - for link in spec['links']: + links = sorted(spec['links'], key=field_name_sort) + for link in links: self.render_link(link, mod, name, modules) if spec['backlinks']: self.write() self.write('# Back-links:') - for link in spec['backlinks']: + backlinks = sorted(spec['backlinks'], key=field_name_sort) + for link in backlinks: self.render_backlink(link, mod, modules) self.dedent() @@ -336,8 +371,15 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): cardinality = spec['cardinality'] target = spec['target']['name'] + is_array = False + match = ARRAY_RE.fullmatch(target) + if match: + is_array = True + target = match.group('el') + try: pytype, sqlatype = GEL_SCALAR_MAP[target] + sqlatype = sqlatype + '()' except KeyError: warnings.warn( f'Scalar type {target} is not supported', @@ -346,23 +388,29 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): # Skip rendering this one return + if is_array: + pytype = f'List[{pytype}]' + sqlatype = f'sa.ARRAY({sqlatype})' + if is_pk: # special case of a primary key property (should only happen to # 'target' in multi property table) - self.write( - f'{name}: Mapped[{pytype}] = mapped_column(' - f'{sqlatype}(), primary_key=True, nullable=False)' - ) + self.write(f'{name}: orm.Mapped[{pytype}] = orm.mapped_column(') + self.indent() + self.write(f'{sqlatype}, primary_key=True, nullable=False,') + self.dedent() + self.write(f')') elif cardinality == 'Many': # skip it return else: # plain property - self.write( - f'{name}: Mapped[{pytype}] = ' - f'mapped_column({sqlatype}(), nullable={nullable})' - ) + self.write(f'{name}: orm.Mapped[{pytype}] = orm.mapped_column(') + self.indent() + self.write(f'{sqlatype}, nullable={nullable},') + self.dedent() + self.write(f')') def render_link(self, spec, mod, parent, modules): name = spec['name'] @@ -383,22 +431,22 @@ def render_link(self, spec, mod, parent, modules): if cardinality == 'One': self.write( - f'{name}: Mapped[{pyname}] = ' - f"relationship(back_populates='source')" + f'{name}: orm.Mapped[{pyname}] = ' + f"orm.relationship(back_populates='source')" ) elif cardinality == 'Many': self.write( - f'{name}: Mapped[List[{pyname}]] = ' - f"relationship(back_populates='source')" + f'{name}: orm.Mapped[List[{pyname}]] = ' + f"orm.relationship(back_populates='source')" ) if cardinality == 'One': - tmap = f'Mapped[{pyname}]' + tmap = f'orm.Mapped[{pyname}]' elif cardinality == 'Many': - tmap = f'Mapped[List[{pyname}]]' + tmap = f'orm.Mapped[List[{pyname}]]' # We want the cascade to delete orphans here as the intermediate # objects represent links and must not exist without source. - self.write(f'{name}: {tmap} = relationship(') + self.write(f'{name}: {tmap} = orm.relationship(') self.indent() self.write(f"back_populates='source',") self.write(f"cascade='all, delete-orphan',") @@ -411,26 +459,29 @@ def render_link(self, spec, mod, parent, modules): if cardinality == 'One': self.write( - f'{name}_id: Mapped[uuid.UUID] = ' - f'mapped_column(Uuid(), ' - f'{fk}, nullable={nullable})' - ) - self.write( - f'{name}: Mapped[{pyname}] = ' - f'relationship(back_populates={bklink!r})' - ) + f'{name}_id: orm.Mapped[uuid.UUID] = orm.mapped_column(') + self.indent() + self.write(f'sa.Uuid(), {fk}, nullable={nullable},') + self.dedent() + self.write(f')') + + self.write(f'{name}: orm.Mapped[{pyname}] = orm.relationship(') + self.indent() + self.write(f'back_populates={bklink!r},') + self.dedent() + self.write(f')') elif cardinality == 'Many': secondary = f'{parent}_{name}_table' + self.write( - f'{name}: Mapped[List[{pyname}]] = relationship(') + f'{name}: orm.Mapped[List[{pyname}]] = orm.relationship(') self.indent() - self.write( - f'{pyname}, secondary={secondary}, ' - f'back_populates={bklink!r},' - ) + self.write(f'{pyname},') + self.write(f'secondary={secondary},') + self.write(f'back_populates={bklink!r},') self.dedent() - self.write(')') + self.write(f')') def render_backlink(self, spec, mod, modules): name = spec['name'] @@ -449,12 +500,12 @@ def render_backlink(self, spec, mod, modules): pyname = self.get_py_name(tmod, target, mod) if cardinality == 'One': - tmap = f'Mapped[{pyname}]' + tmap = f'orm.Mapped[{pyname}]' elif cardinality == 'Many': - tmap = f'Mapped[List[{pyname}]]' + tmap = f'orm.Mapped[List[{pyname}]]' # We want the cascade to delete orphans here as the intermediate # objects represent links and must not exist without target. - self.write(f'{name}: {tmap} = relationship(') + self.write(f'{name}: {tmap} = orm.relationship(') self.indent() self.write(f"back_populates='target',") self.write(f"cascade='all, delete-orphan',") @@ -467,25 +518,29 @@ def render_backlink(self, spec, mod, modules): # This is a backlink from a single link. There is no link table # involved. if cardinality == 'One': - self.write( - f'{name}: Mapped[{pyname}] = ' - f'relationship(back_populates={bklink!r})' - ) + self.write(f'{name}: orm.Mapped[{pyname}] = \\') + self.indent() + self.write(f'orm.relationship(back_populates={bklink!r})') + self.dedent() + elif cardinality == 'Many': - self.write( - f'{name}: Mapped[List[{pyname}]] = ' - f'relationship(back_populates={bklink!r})' - ) + self.write(f'{name}: orm.Mapped[List[{pyname}]] = \\') + self.indent() + self.write(f'orm.relationship(back_populates={bklink!r})') + self.dedent() else: # This backlink involves a link table, so we still treat it as # a Many-to-Many. secondary = f'{target}_{bklink}_table' - self.write(f'{name}: Mapped[List[{pyname}]] = relationship(') + + self.write(f'{name}: orm.Mapped[List[{pyname}]] = \\') self.indent() - self.write( - f'{pyname}, secondary={secondary}, ' - f'back_populates={bklink!r},' - ) + self.write(f'orm.relationship(') + self.indent() + self.write(f'{pyname},') + self.write(f'secondary={secondary},') + self.write(f'back_populates={bklink!r},') self.dedent() self.write(')') + self.dedent() diff --git a/gel/orm/sqlmodel.py b/gel/orm/sqlmodel.py index f8deccce..8da16973 100644 --- a/gel/orm/sqlmodel.py +++ b/gel/orm/sqlmodel.py @@ -9,34 +9,52 @@ GEL_SCALAR_MAP = { - 'std::bool': 'bool', - 'std::str': 'str', - 'std::int16': 'int', - 'std::int32': 'int', - 'std::int64': 'int', - 'std::float32': 'float', - 'std::float64': 'float', - 'std::uuid': 'uuid.UUID', + 'std::bool': ('bool', None), + 'std::str': ('str', None), + 'std::int16': ('int', None), + 'std::int32': ('int', None), + 'std::int64': ('int', None), + 'std::float32': ('float', None), + 'std::float64': ('float', None), + 'std::uuid': ('uuid.UUID', None), + 'std::bytes': ('bytes', None), + 'std::cal::local_date': ('datetime.date', None), + 'std::cal::local_time': ('datetime.time', None), + 'std::cal::local_datetime': ('datetime.datetime', 'sa.DateTime()'), + 'std::datetime': ('datetime.datetime', 'sa.TIMESTAMP(timezone=True)'), } CLEAN_RE = re.compile(r'[^A-Za-z0-9]+') +NAME_RE = re.compile(r'^(?P\w+?)(?P\d*)$') COMMENT = '''\ # # Automatically generated from Gel schema. +# +# Do not edit directly as re-generating this file will overwrite any changes. #\ ''' MODELS_STUB = f'''\ {COMMENT} +import datetime import uuid -from sqlmodel import SQLModel, Field, Relationship -from sqlalchemy import Column, ForeignKey +import sqlmodel as sm +import sqlalchemy as sa ''' +def field_name_sort(spec): + key = spec['name'] + + match = NAME_RE.fullmatch(key) + res = (match.group('alpha'), int(match.group('num') or -1)) + + return res + + class ModelGenerator(FilePrinter): def __init__(self, *, outdir=None, basemodule=None): # set the output to be stdout by default, but this is generally @@ -105,9 +123,9 @@ def get_fk(self, mod, table, curmod): def get_sqla_fk(self, mod, table, curmod): if mod == curmod: # No need for anything fancy within the same schema - return f'ForeignKey("{table}.id")' + return f'sa.ForeignKey("{table}.id")' else: - return f'ForeignKey("{mod}.{table}.id")' + return f'sa.ForeignKey("{mod}.{table}.id")' def get_py_name(self, mod, name, curmod): if mod == curmod: @@ -182,10 +200,18 @@ def render_models(self, spec): # skip apparently empty modules return - for lobj in maps.get('link_objects', {}).values(): + link_objects = sorted( + maps.get('link_objects', {}).values(), + key=lambda x: x['name'] + ) + for lobj in link_objects: self.write() self.render_link_object(lobj, modules) + objects = sorted( + maps.get('object_types', {}).values(), + key=lambda x: x['name'] + ) for rec in maps.get('object_types', {}).values(): self.write() self.render_type(rec, modules) @@ -207,7 +233,7 @@ def render_link_table(self, spec): return self.write() - self.write(f'class {spec["name"]}(SQLModel, table=True):') + self.write(f'class {spec["name"]}(sm.SQLModel, table=True):') self.indent() self.write(f'__tablename__ = {spec["table"]!r}') if mod != 'default': @@ -218,8 +244,17 @@ def render_link_table(self, spec): self.write('__mapper_args__ = {"confirm_deleted_rows": False}') self.write() # source is in the same module as this table - self.write(f'source: uuid.UUID = Field({s_fk}, primary_key=True)') - self.write(f'target: uuid.UUID = Field({t_fk}, primary_key=True)') + self.write(f'source: uuid.UUID = sm.Field(') + self.indent() + self.write(f'{s_fk}, primary_key=True,') + self.dedent() + self.write(f')') + + self.write(f'target: uuid.UUID = sm.Field(') + self.indent() + self.write(f'{t_fk}, primary_key=True,') + self.dedent() + self.write(f')') self.dedent() def render_link_object(self, spec, modules): @@ -237,7 +272,7 @@ def render_link_object(self, spec, modules): return self.write() - self.write(f'class {name}(SQLModel, table=True):') + self.write(f'class {name}(sm.SQLModel, table=True):') self.indent() self.write(f'__tablename__ = {sql_name!r}') if mod != 'default': @@ -268,14 +303,15 @@ def render_link_object(self, spec, modules): fk = self.get_fk(tmod, target, mod) sqlafk = self.get_sqla_fk(tmod, target, mod) pyname = self.get_py_name(tmod, target, mod) - self.write(f'{lname}_id: uuid.UUID = Field(sa_column=Column(') + self.write( + f'{lname}_id: uuid.UUID = sm.Field(sa_column=sa.Column(') self.indent() self.write(f'{lname!r},') self.write(f'{sqlafk},') self.write(f'primary_key=True,') self.write(f'nullable=False,') self.dedent() - self.write('))') + self.write(f'))') if lname == 'source': bklink = source_link @@ -287,9 +323,11 @@ def render_link_object(self, spec, modules): ) self.write( - f'{lname}: {pyname} = ' - f'Relationship(back_populates={bklink!r})' - ) + f'{lname}: {pyname} = sm.Relationship(') + self.indent() + self.write(f'back_populates={bklink!r},') + self.dedent() + self.write(f')') if spec['properties']: self.write() @@ -315,7 +353,7 @@ def render_type(self, spec, modules): return self.write() - self.write(f'class {name}(SQLModel, table=True):') + self.write(f'class {name}(sm.SQLModel, table=True):') self.indent() self.write(f'__tablename__ = {sql_name!r}') if mod != 'default': @@ -327,10 +365,10 @@ def render_type(self, spec, modules): self.write() # Add two fields that all objects have - self.write(f'id: uuid.UUID | None = Field(') + self.write(f'id: uuid.UUID | None = sm.Field(') self.indent() - self.write( - f"default=None, primary_key=True,") + self.write(f"default=None,") + self.write(f"primary_key=True,") self.write( f"sa_column_kwargs=dict(server_default='uuid_generate_v4()'),") self.dedent() @@ -338,12 +376,11 @@ def render_type(self, spec, modules): # This is maintained entirely by Gel, the server_default simply # indicates to SQLAlchemy that this value may be omitted. - self.write(f'gel_type_id: uuid.UUID | None = Field(') + self.write(f'gel_type_id: uuid.UUID | None = sm.Field(') self.indent() + self.write(f"default=None,") self.write( - f"default=None,") - self.write( - f"sa_column=Column('__type__', server_default='PLACEHOLDER'),") + f"sa_column=sa.Column('__type__', server_default='PLACEHOLDER'),") self.dedent() self.write(')') @@ -351,7 +388,8 @@ def render_type(self, spec, modules): self.write() self.write('# Properties:') - for prop in spec['properties']: + properties = sorted(spec['properties'], key=field_name_sort) + for prop in properties: if prop['name'] != 'id': self.render_prop(prop, mod, name, modules) @@ -359,14 +397,16 @@ def render_type(self, spec, modules): self.write() self.write('# Links:') - for link in spec['links']: + links = sorted(spec['links'], key=field_name_sort) + for link in links: self.render_link(link, mod, name, modules) if spec['backlinks']: self.write() self.write('# Back-links:') - for link in spec['backlinks']: + backlinks = sorted(spec['backlinks'], key=field_name_sort) + for link in backlinks: self.render_backlink(link, mod, modules) self.dedent() @@ -378,7 +418,8 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): target = spec['target']['name'] try: - pytype = GEL_SCALAR_MAP[target] + pytype, sa_col = GEL_SCALAR_MAP[target] + except KeyError: warnings.warn( f'Scalar type {target} is not supported', @@ -387,21 +428,23 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): # Skip rendering this one return - if is_pk: - # special case of a primary key property (should only happen to - # 'target' in multi property table) - self.write( - f'{name}: {pytype} = Field(primary_key=True, nullable=False)' - ) - elif cardinality == 'Many': + if cardinality == 'Many': # skip it return else: # plain property - self.write( - f'{name}: {pytype} = Field(nullable={nullable})' - ) + if sa_col: + self.write(f'{name}: {pytype} = sm.Field(sa_column=sa.Column(') + self.indent() + self.write(f'{sa_col},') + self.write(f'nullable={nullable},') + self.dedent() + self.write(f'))') + else: + self.write( + f'{name}: {pytype} = sm.Field(nullable={nullable})' + ) def render_link(self, spec, mod, parent, modules): name = spec['name'] @@ -431,12 +474,12 @@ def render_link(self, spec, mod, parent, modules): if cardinality == 'One': self.write( f'{name}: {pyname} = ' - f"Relationship(back_populates='source')" + f"sm.Relationship(back_populates='source')" ) elif cardinality == 'Many': self.write( f'{name}: list[{pyname}] = ' - f"Relationship(back_populates='source')" + f"sm.Relationship(back_populates='source')" ) if cardinality == 'One': @@ -445,7 +488,7 @@ def render_link(self, spec, mod, parent, modules): tmap = f'list[{pyname}]' # We want the cascade to delete orphans here as the intermediate # objects represent links and must not exist without source. - self.write(f'{name}: {tmap} = Relationship(') + self.write(f'{name}: {tmap} = sm.Relationship(') self.indent() self.write(f"back_populates='source',") self.write(f"cascade_delete=True,") @@ -457,19 +500,22 @@ def render_link(self, spec, mod, parent, modules): pyname = self.get_py_name(tmod, target, mod) if cardinality == 'One': - self.write( - f'{name}_id: uuid.UUID = Field({fk}, nullable={nullable})' - ) - self.write( - f'{name}: {pyname} = ' - f'Relationship(back_populates={bklink!r})' - ) + self.write(f'{name}_id: uuid.UUID = sm.Field(') + self.indent() + self.write(f'{fk},') + self.write(f'nullable={nullable},') + self.dedent() + self.write(')') + + self.write(f'{name}: {pyname} = sm.Relationship(') + self.indent() + self.write(f'back_populates={bklink!r},') + self.dedent() + self.write(')') elif cardinality == 'Many': secondary = f'{parent}_{name}_table' - self.write( - f'{name}: list[{pyname}] = Relationship(' - ) + self.write(f'{name}: list[{pyname}] = sm.Relationship(') self.indent() self.write(f'back_populates={bklink!r},') self.write(f'link_model={secondary},') @@ -506,7 +552,7 @@ def render_backlink(self, spec, mod, modules): tmap = f'list[{pyname}]' # We want the cascade to delete orphans here as the intermediate # objects represent links and must not exist without target. - self.write(f'{name}: {tmap} = Relationship(') + self.write(f'{name}: {tmap} = sm.Relationship(') self.indent() self.write(f"back_populates='target',") self.write(f"cascade_delete=True,") @@ -519,22 +565,24 @@ def render_backlink(self, spec, mod, modules): # This is a backlink from a single link. There is no link table # involved. if cardinality == 'One': - self.write( - f'{name}: {pyname} = ' - f'Relationship(back_populates={bklink!r})' - ) + self.write(f'{name}: {pyname} = sm.Relationship(') + self.indent() + self.write(f"back_populates={bklink!r},") + self.dedent() + self.write(')') elif cardinality == 'Many': - self.write( - f'{name}: list[{pyname}] = ' - f'Relationship(back_populates={bklink!r})' - ) + self.write(f'{name}: list[{pyname}] = sm.Relationship(') + self.indent() + self.write(f"back_populates={bklink!r},") + self.dedent() + self.write(')') else: # This backlink involves a link table, so we still treat it as # a Many-to-Many. secondary = f'{target}_{bklink}_table' self.write( - f'{name}: list[{pyname}] = Relationship(' + f'{name}: list[{pyname}] = sm.Relationship(' ) self.indent() self.write(f'back_populates={bklink!r},') diff --git a/tests/dbsetup/base.edgeql b/tests/dbsetup/base.edgeql index cfbf46e8..582bb254 100644 --- a/tests/dbsetup/base.edgeql +++ b/tests/dbsetup/base.edgeql @@ -42,3 +42,13 @@ insert Post { author := assert_single((select User filter .name = 'Elsa')), body := '*magic stuff*', }; + +insert AssortedScalars { + name:= 'hello world', + vals := ['brown', 'fox'], + bstr := b'word\x00\x0b', + time := '20:13:45.678', + date:= '2025-01-26', + ts:='2025-01-26T20:13:45+00:00', + lts:='2025-01-26T20:13:45', +}; \ No newline at end of file diff --git a/tests/dbsetup/base.esdl b/tests/dbsetup/base.esdl index 4a5c02c0..735749b1 100644 --- a/tests/dbsetup/base.esdl +++ b/tests/dbsetup/base.esdl @@ -15,9 +15,23 @@ type GameSession { }; } -type User extending Named; +type User extending Named { + # test computed backlink + groups := .; + + date: cal::local_date; + time: cal::local_time; + ts: datetime; + lts: cal::local_datetime; + bstr: bytes; +} \ No newline at end of file diff --git a/tests/dbsetup/sqlmodel.edgeql b/tests/dbsetup/sqlmodel.edgeql index 20f699aa..2d7bee80 100644 --- a/tests/dbsetup/sqlmodel.edgeql +++ b/tests/dbsetup/sqlmodel.edgeql @@ -58,4 +58,14 @@ set { update HasLinkPropsB set { children += (select Child{@b := 'world'} filter .num = 1) +}; + +insert AssortedScalars { + name:= 'hello world', + vals := ['brown', 'fox'], + bstr := b'word\x00\x0b', + time := '20:13:45.678', + date:= '2025-01-26', + ts:='2025-01-26T20:13:45+00:00', + lts:='2025-01-26T20:13:45', }; \ No newline at end of file diff --git a/tests/dbsetup/sqlmodel.esdl b/tests/dbsetup/sqlmodel.esdl index 8e1347f4..d456dd81 100644 --- a/tests/dbsetup/sqlmodel.esdl +++ b/tests/dbsetup/sqlmodel.esdl @@ -38,4 +38,15 @@ type HasLinkPropsB { multi link children: Child { property b: str; } +} + +type AssortedScalars { + required name: str; + vals: array; + + date: cal::local_date; + time: cal::local_time; + ts: datetime; + lts: cal::local_datetime; + bstr: bytes; } \ No newline at end of file diff --git a/tests/test_django_basic.py b/tests/test_django_basic.py index 069f9c8f..488e0f4c 100644 --- a/tests/test_django_basic.py +++ b/tests/test_django_basic.py @@ -16,6 +16,7 @@ # limitations under the License. # +import datetime as dt import os import uuid import unittest @@ -299,6 +300,28 @@ def test_django_read_models_07(self): } ) + def test_django_read_models_08(self): + # test arrays, bytes and various date/time scalars + + res = self.m.AssortedScalars.objects.all()[0] + + self.assertEqual(res.name, 'hello world') + self.assertEqual(res.vals, ['brown', 'fox']) + self.assertEqual(bytes(res.bstr), b'word\x00\x0b') + self.assertEqual( + res.time, + dt.time(20, 13, 45, 678_000), + ) + self.assertEqual( + res.date, + dt.date(2025, 1, 26), + ) + # time zone aware (default for Django) + self.assertEqual( + res.ts, + dt.datetime.fromisoformat('2025-01-26T20:13:45+00:00'), + ) + def test_django_create_models_01(self): vals = self.m.User.objects.filter(name='Yvonne').all() self.assertEqual(list(vals), []) @@ -492,3 +515,39 @@ def test_django_update_models_04(self): post = self.m.Post.objects.get(id=post_id) self.assertEqual(post.author.name, 'Zoe') + + def test_django_update_models_05(self): + # test arrays, bytes and various date/time scalars + # + # For the purpose of sending data creating and updating a model are + # both testing accurate data transfer. + + res = self.m.AssortedScalars.objects.all()[0] + + res.name = 'New Name' + res.vals.append('jumped') + res.bstr = b'\x01success\x02' + res.time = dt.time(8, 23, 54, 999_000) + res.date = dt.date(2020, 2, 14) + res.ts = res.ts - dt.timedelta(days=6) + + res.save() + + upd = self.m.AssortedScalars.objects.all()[0] + + self.assertEqual(upd.name, 'New Name') + self.assertEqual(upd.vals, ['brown', 'fox', 'jumped']) + self.assertEqual(bytes(upd.bstr), b'\x01success\x02') + self.assertEqual( + upd.time, + dt.time(8, 23, 54, 999_000), + ) + self.assertEqual( + upd.date, + dt.date(2020, 2, 14), + ) + # time zone aware (default for Django) + self.assertEqual( + upd.ts, + dt.datetime.fromisoformat('2025-01-20T20:13:45+00:00'), + ) diff --git a/tests/test_sqla_basic.py b/tests/test_sqla_basic.py index 35934a6b..00394561 100644 --- a/tests/test_sqla_basic.py +++ b/tests/test_sqla_basic.py @@ -16,6 +16,7 @@ # limitations under the License. # +import datetime as dt import os import uuid import unittest @@ -344,6 +345,33 @@ def test_sqla_read_models_07(self): } ) + def test_sqla_read_models_08(self): + # test arrays, bytes and various date/time scalars + + res = self.sess.query(self.sm.AssortedScalars).one() + + self.assertEqual(res.name, 'hello world') + self.assertEqual(res.vals, ['brown', 'fox']) + self.assertEqual(res.bstr, b'word\x00\x0b') + self.assertEqual( + res.time, + dt.time(20, 13, 45, 678_000), + ) + self.assertEqual( + res.date, + dt.date(2025, 1, 26), + ) + # time zone aware + self.assertEqual( + res.ts, + dt.datetime.fromisoformat('2025-01-26T20:13:45+00:00'), + ) + # naive datetime + self.assertEqual( + res.lts, + dt.datetime.fromisoformat('2025-01-26T20:13:45'), + ) + def test_sqla_create_models_01(self): vals = self.sess.query(self.sm.User).filter_by(name='Yvonne').all() self.assertEqual(list(vals), []) @@ -578,3 +606,46 @@ def test_sqla_update_models_04(self): post = self.sess.get(self.sm.Post, post_id) self.assertEqual(post.author.name, 'Zoe') + + def test_sqla_update_models_05(self): + # test arrays, bytes and various date/time scalars + # + # For the purpose of sending data creating and updating a model are + # both testing accurate data transfer. + + res = self.sess.query(self.sm.AssortedScalars).one() + + res.name = 'New Name' + res.vals.append('jumped') + res.bstr = b'\x01success\x02' + res.time = dt.time(8, 23, 54, 999_000) + res.date = dt.date(2020, 2, 14) + res.ts = res.ts - dt.timedelta(days=6) + res.lts = res.lts + dt.timedelta(days=6) + + self.sess.add(res) + self.sess.flush() + + upd = self.sess.query(self.sm.AssortedScalars).one() + + self.assertEqual(upd.name, 'New Name') + self.assertEqual(upd.vals, ['brown', 'fox', 'jumped']) + self.assertEqual(upd.bstr, b'\x01success\x02') + self.assertEqual( + upd.time, + dt.time(8, 23, 54, 999_000), + ) + self.assertEqual( + upd.date, + dt.date(2020, 2, 14), + ) + # time zone aware + self.assertEqual( + upd.ts, + dt.datetime.fromisoformat('2025-01-20T20:13:45+00:00'), + ) + # naive datetime + self.assertEqual( + upd.lts, + dt.datetime.fromisoformat('2025-02-01T20:13:45'), + ) diff --git a/tests/test_sqlmodel_basic.py b/tests/test_sqlmodel_basic.py index ba8a1fc5..9a201535 100644 --- a/tests/test_sqlmodel_basic.py +++ b/tests/test_sqlmodel_basic.py @@ -16,6 +16,7 @@ # limitations under the License. # +import datetime as dt import os import uuid import unittest @@ -323,6 +324,34 @@ def test_sqlmodel_read_models_07(self): } ) + def test_sqlmodel_read_models_08(self): + # test arrays, bytes and various date/time scalars + + res = self.sess.exec( + select(self.sm.AssortedScalars) + ).one() + + self.assertEqual(res.name, 'hello world') + self.assertEqual(res.bstr, b'word\x00\x0b') + self.assertEqual( + res.time, + dt.time(20, 13, 45, 678_000), + ) + self.assertEqual( + res.date, + dt.date(2025, 1, 26), + ) + # time zone aware + self.assertEqual( + res.ts, + dt.datetime.fromisoformat('2025-01-26T20:13:45+00:00'), + ) + # naive datetime + self.assertEqual( + res.lts, + dt.datetime.fromisoformat('2025-01-26T20:13:45'), + ) + def test_sqlmodel_create_models_01(self): vals = self.sess.exec( select(self.sm.User).where( @@ -591,6 +620,52 @@ def test_sqlmodel_update_models_04(self): post = self.sess.get(self.sm.Post, post_id) self.assertEqual(post.author.name, 'Zoe') + def test_sqlmodel_update_models_05(self): + # test arrays, bytes and various date/time scalars + # + # For the purpose of sending data creating and updating a model are + # both testing accurate data transfer. + + res = self.sess.exec( + select(self.sm.AssortedScalars) + ).one() + + res.name = 'New Name' + # res.vals.append('jumped') + res.bstr = b'\x01success\x02' + res.time = dt.time(8, 23, 54, 999_000) + res.date = dt.date(2020, 2, 14) + res.ts = res.ts - dt.timedelta(days=6) + res.lts = res.lts + dt.timedelta(days=6) + + self.sess.add(res) + self.sess.flush() + + upd = self.sess.exec( + select(self.sm.AssortedScalars) + ).one() + + self.assertEqual(upd.name, 'New Name') + self.assertEqual(upd.bstr, b'\x01success\x02') + self.assertEqual( + upd.time, + dt.time(8, 23, 54, 999_000), + ) + self.assertEqual( + upd.date, + dt.date(2020, 2, 14), + ) + # time zone aware + self.assertEqual( + upd.ts, + dt.datetime.fromisoformat('2025-01-20T20:13:45+00:00'), + ) + # naive datetime + self.assertEqual( + upd.lts, + dt.datetime.fromisoformat('2025-02-01T20:13:45'), + ) + def test_sqlmodel_linkprops_01(self): val = self.sess.exec(select(self.sm.HasLinkPropsA)).one() self.assertEqual(val.child.target.num, 0)