Skip to content

Commit

Permalink
storage works
Browse files Browse the repository at this point in the history
  • Loading branch information
juanbc committed Jan 9, 2025
1 parent 2311251 commit cf1ba3f
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 38 deletions.
6 changes: 4 additions & 2 deletions ajiaco/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib

from .app import AjcApplication

__all__ = ["Application"]
__all__ = ["AjcApplication"]


VERSION = 0.2
VERSION = importlib.metadata.version("ajiaco")
7 changes: 6 additions & 1 deletion ajiaco/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import pathlib
import sys

import attrs
from attrs import validators as valids
Expand Down Expand Up @@ -92,6 +91,12 @@ def _experiment_sessions_default(self):

# API =====================================================================

@property
def version(self):
from . import VERSION

return VERSION

@property
def app_path(self):
path = pathlib.Path(self.filename)
Expand Down
81 changes: 77 additions & 4 deletions ajiaco/cli/builtins.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import collections
import code
import pprint

from IPython import start_ipython

from .register import AjcCommandRegister

from ..utils import sysinfo


CLI_BUILTINS = AjcCommandRegister("BUILTINS")


@CLI_BUILTINS.register
def version(app):
"Show the version of Ajiaco and exit"
from .. import VERSION

print(f"Ajiaco v.{VERSION}")
print(f"Ajiaco v.{app.version}")


@CLI_BUILTINS.register(name="reset-storage")
Expand Down Expand Up @@ -39,12 +45,79 @@ def reset_storage(app, noinput: bool = False):
storage.create_schema()

print(" - Stamping...")
storage.stamp()
stamp_data = sysinfo.info_dict()
stamp_data["AJIACO_VERSION"] = app.version
with storage.transaction() as conn:
conn.stamp(stamp_data)

print("DONE!")


@CLI_BUILTINS.register(name="storage-stamp")
def storage_stamp(app):
"""Show the stamp inside the storage"""
with app.storage.transaction() as conn:
stamp = conn.get_stamp()
pprint.pprint(stamp)


@CLI_BUILTINS.register()
def serve(app):
"""Run the uvicorn webserver"""
return app.webapp.run(app)


# =============================================================================
# SHELL
# =============================================================================


def _run_plain(slocals, banner):
console = code.InteractiveConsole(slocals)
console.interact(banner)


def _run_ipython(slocals, banner):
start_ipython(
argv=["--TerminalInteractiveShell.banner2={}".format(banner)],
user_ns=slocals,
)


def _create_banner(app, slocals):

by_module = collections.defaultdict(list)
for k, v in slocals.items():
module_name = getattr(v, "__module__", None) or ""
by_module[module_name].append(k)

lines = []
for module_name, imported in sorted(by_module.items()):
prefix = ", ".join(imported)
suffix = "({})".format(module_name) if module_name else ""
line = "\t>>> {} {}".format(prefix, suffix)
lines.append(line)

banner_parts = (
[f"Ajiaco Version: \n\t{app.version}"]
+ [f"Running inside: \n\t{app.app_path}"]
+ ["Ajiaco Variables:"]
+ lines
+ [""]
)

banner = "\n".join(banner_parts)

return banner


@CLI_BUILTINS.register()
def shell(app, plain: bool = False):
"""Run the Python shell inside Ajiaco environment"""
slocals = {"app": app}

with app.storage.transaction() as conn:
slocals["conn"] = conn
banner = _create_banner(app, slocals)
shell = _run_plain if plain else _run_ipython
return shell(slocals, banner)
43 changes: 25 additions & 18 deletions ajiaco/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import sqlalchemy as sa
import sqlalchemy.orm as orm

from ..utils import sysinfo


# =============================================================================
# MODELS CONTAINER
Expand All @@ -18,10 +20,11 @@
@attrs.define(frozen=True, repr=False)
class AjcModelsContainer(Mapping):
BaseModel: ...
Stamp: ...
models: frozenset

def items(self):
for model in it.chain([self.BaseModel], self.models):
for model in it.chain([self.BaseModel, self.Stamp], self.models):
name = model.__name__
yield name, model

Expand All @@ -44,7 +47,7 @@ def __iter__(self):
return (name for name, _ in self.items())

def __len__(self):
return len(self.models) + 1 # + BaseModel
return len(self.models) + 2 # + BaseModel + StampModel

def __repr__(self):
models = set(self.keys())
Expand All @@ -56,10 +59,6 @@ def __repr__(self):
# =============================================================================


def _utcnow():
return dt.datetime.now(dt.timezone.utc)


def create_models(metadata: sa.MetaData):

# First the Base Model ====================================================
Expand All @@ -74,7 +73,7 @@ class IDAndCreatedAtModelABC(BaseModel):
id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
created_at: orm.Mapped[dt.datetime] = orm.mapped_column(
sa.DateTime(timezone=True),
default=_utcnow,
default=sysinfo.utcnow,
nullable=False,
)

Expand All @@ -91,6 +90,14 @@ class CodeAndExtraModelABC(IDAndCreatedAtModelABC):
sa.String, nullable=True
)

# Third the stamp =========================================================

class Stamp(IDAndCreatedAtModelABC):

__tablename__ = "ajc_stamp"

data: orm.Mapped[dict] = orm.mapped_column(sa.JSON)

# Store all models for easy return ========================================
the_models = set()

Expand Down Expand Up @@ -128,7 +135,7 @@ class SubjectModel(CodeAndExtraModelABC):
sa.ForeignKey("ajc_sessions.id"), nullable=False
)
session: orm.Mapped[ExperimentSessionModel] = orm.relationship(
back_populates="subjects", lazy=False
backref="subjects", lazy=False
)

the_models.add(SubjectModel)
Expand Down Expand Up @@ -156,7 +163,7 @@ class RoundModel(CodeAndExtraModelABC):
sa.ForeignKey("ajc_sessions.id"), nullable=False
)
session: orm.Mapped[ExperimentSessionModel] = orm.relationship(
back_populates="rounds", lazy=False
backref="rounds", lazy=False
)

the_models.add(RoundModel)
Expand All @@ -172,7 +179,7 @@ class GroupModel(CodeAndExtraModelABC):
sa.ForeignKey("ajc_rounds.id"), nullable=False
)
round: orm.Mapped[RoundModel] = orm.relationship(
back_populates="groups", lazy=False
backref="groups", lazy=False
)

the_models.add(GroupModel)
Expand All @@ -191,15 +198,13 @@ class Role(CodeAndExtraModelABC):
group_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("ajc_groups.id"), nullable=False
)
group: orm.Mapped[GroupModel] = orm.relationship(
back_populates="roles"
)
group: orm.Mapped[GroupModel] = orm.relationship(backref="roles")

subject_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("ajc_subjects.id"), nullable=False
)
subject: orm.Mapped[SubjectModel] = orm.relationship(
back_populates="roles", lazy=False
backref="roles", lazy=False
)

the_models.add(Role)
Expand Down Expand Up @@ -235,13 +240,13 @@ class StageHistory(IDAndCreatedAtModelABC):
role_id: orm.Mapped[id] = orm.mapped_column(
sa.ForeignKey("ajc_roles.id"), nullable=False
)
role: orm.Mapped[Role] = orm.relationship(back_populates="stages")
role: orm.Mapped[Role] = orm.relationship(backref="stages")

subject_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("ajc_subjects.id"), nullable=False
)
subject: orm.Mapped[SubjectModel] = orm.relationship(
back_populates="stages", lazy=False
backref="stages", lazy=False
)

@property
Expand All @@ -250,13 +255,15 @@ def expired(self):
self.timeout
and self.enter_dt
and self.expire_dt
and _utcnow() >= self.expire_dt
and sysinfo.utcnow() >= self.expire_dt
)

the_models.add(StageHistory)

models_container = AjcModelsContainer(
BaseModel=BaseModel, models=frozenset(the_models)
BaseModel=BaseModel,
Stamp=Stamp,
models=frozenset(the_models),
)

return models_container
40 changes: 36 additions & 4 deletions ajiaco/storage/storage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import contextlib

import attrs

import sqlalchemy as sa
Expand All @@ -13,6 +15,39 @@ def __init__(self, storage, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage = storage

@property
def models(self):
return self.storage.models

def get_stamp(self):
stamp = self.query(self.models.Stamp).first()
return stamp.data if stamp else None

def stamp(self, data):
Stamp = self.models.Stamp
if self.get_stamp() is not None:
raise ValueError("Database already stamped")
stamp = Stamp(data=data)
self.add(stamp)


class AjcTransactionConectManager(contextlib.AbstractContextManager):
"""Provide a transactional scope around a series of operations."""

def __init__(self, session_maker):
self._session_maker = session_maker

def __enter__(self):
self._session = self._session_maker()
return self._session

def __exit__(self, exc_type, exc_value, traceback):
if exc_type:
self._session.rollback()
else:
self._session.commit()
self._session.close()


@attrs.define(frozen=True)
class AjcStorage:
Expand Down Expand Up @@ -63,7 +98,4 @@ def create_schema(self):
return BaseModel.metadata.create_all(self.engine)

def transaction(self):
return self.session_maker()

def stamp(self):
print("'stamp()' implemente me")
return AjcTransactionConectManager(self.session_maker)
Loading

0 comments on commit cf1ba3f

Please sign in to comment.