diff --git a/causy/cli.py b/causy/cli.py index f1ade70..caa2cc8 100644 --- a/causy/cli.py +++ b/causy/cli.py @@ -30,6 +30,10 @@ @app.command() def eject(algorithm: str, output_file: str): + logging.warning( + f"Ejecting pipelines outside of workspace context is deprecated. Please use workspaces instead." + ) + typer.echo(f"💾 Loading algorithm {algorithm}") model = AVAILABLE_ALGORITHMS[algorithm]() result = serialize_algorithm(model, algorithm_name=algorithm) @@ -46,6 +50,9 @@ def execute( output_file: str = None, log_level: str = "ERROR", ): + logging.warning( + f"Executing outside of workspaces is deprecated and will be removed in future versions. Please use workspaces instead." + ) logging.basicConfig(level=log_level) if pipeline: typer.echo(f"💾 Loading pipeline from {pipeline}") diff --git a/causy/graph.py b/causy/graph.py index 41a1c1f..c8242ae 100644 --- a/causy/graph.py +++ b/causy/graph.py @@ -13,8 +13,10 @@ NodeInterface, EdgeInterface, EdgeTypeInterface, + MetadataType, ) from causy.models import TestResultAction, TestResult, ActionHistoryStep +from causy.variables import VariableType logger = logging.getLogger(__name__) @@ -28,7 +30,7 @@ class Node(NodeInterface): name: str id: str values: Optional[torch.Tensor] = None - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, MetadataType]] = None def __hash__(self): return hash(self.id) @@ -46,7 +48,7 @@ class Edge(EdgeInterface): u: NodeInterface v: NodeInterface edge_type: EdgeTypeInterface - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, MetadataType]] = None deleted: Optional[bool] = False def __init__(self, *args, **kwargs): diff --git a/causy/graph_model.py b/causy/graph_model.py index 6d389fa..09d3368 100644 --- a/causy/graph_model.py +++ b/causy/graph_model.py @@ -3,7 +3,7 @@ from abc import ABC from copy import deepcopy import time -from typing import Optional, List, Dict, Callable, Union, Any +from typing import Optional, List, Dict, Callable, Union, Any, Generator import torch.multiprocessing as mp @@ -21,6 +21,8 @@ from causy.variables import ( resolve_variables_to_algorithm_for_pipeline_steps, resolve_variables, + VariableType, + VariableTypes, ) logger = logging.getLogger(__name__) @@ -188,19 +190,26 @@ def create_all_possible_edges(self): continue self.graph.add_edge(u, v, {}) - def execute_pipeline_steps(self): + def execute_pipeline_steps(self) -> List[ActionHistoryStep]: """ Execute all pipeline_steps :return: the steps taken during the step execution """ + all(self.execute_pipeline_step_with_progress()) + return self.graph.action_history + def execute_pipeline_step_with_progress(self) -> Generator: + started = time.time() for filter in self.pipeline_steps: logger.info(f"Executing pipeline step {filter.__class__.__name__}") if isinstance(filter, LogicStepInterface): actions_taken = filter.execute(self.graph.graph, self) self.graph.graph.action_history.append(actions_taken) continue - + yield { + "step": filter.__class__.__name__, + "previous_duration": time.time() - started, + } started = time.time() actions_taken = self.execute_pipeline_step(filter) self.graph.graph.action_history.append( @@ -213,8 +222,6 @@ def execute_pipeline_steps(self): self.graph.purge_soft_deleted_edges() - return self.graph.action_history - def _format_yield(self, test_fn, graph, generator): """ Format the yield for the parallel processing @@ -386,7 +393,7 @@ def execute_pipeline_step( def graph_model_factory( algorithm: Algorithm = None, - variables: Dict[str, Any] = None, + variables: Dict[str, VariableTypes] = None, ) -> type[AbstractGraphModel]: """ Create a graph model based on a List of pipeline_steps diff --git a/causy/interfaces.py b/causy/interfaces.py index 543b14e..9c4483a 100644 --- a/causy/interfaces.py +++ b/causy/interfaces.py @@ -24,6 +24,11 @@ AS_MANY_AS_FIELDS = 0 +MetadataBaseType = Union[str, int, float, bool] +MetadataType = Union[ + str, int, float, bool, List[MetadataBaseType], Dict[str, MetadataBaseType] +] + class ComparisonSettingsInterface(BaseModel, ABC): min: IntegerParameter @@ -85,7 +90,7 @@ class EdgeInterface(BaseModel, ABC): u: NodeInterface v: NodeInterface edge_type: EdgeTypeInterface - metadata: Dict[str, any] = None + metadata: Dict[str, MetadataType] = None class Config: arbitrary_types_allowed = True diff --git a/causy/serialization.py b/causy/serialization.py index a47dd6d..d907434 100644 --- a/causy/serialization.py +++ b/causy/serialization.py @@ -3,12 +3,13 @@ import importlib import json from json import JSONEncoder -from typing import Dict, Any +from typing import Dict, Any, List import os import torch import yaml -from pydantic import parse_obj_as +from pydantic import TypeAdapter + from causy.edge_types import EDGE_TYPES from causy.graph_utils import load_pipeline_steps_by_definition @@ -38,7 +39,7 @@ def load_algorithm_from_specification(algorithm_dict: Dict[str, Any]): ] from causy.models import Algorithm - return parse_obj_as(Algorithm, algorithm_dict) + return TypeAdapter(Algorithm).validate_python(algorithm_dict) def load_algorithm_by_reference(reference_type: str, algorithm: str): @@ -102,4 +103,5 @@ def deserialize_result(result: Dict[str, Any], klass=Result): result["edges"][i]["edge_type"] = EDGE_TYPES[edge["edge_type"]["name"]]( **edge["edge_type"] ) - return parse_obj_as(klass, result) + + return TypeAdapter(klass).validate_python(result) diff --git a/causy/ui/server.py b/causy/ui/server.py index f75ac7e..202f3bb 100644 --- a/causy/ui/server.py +++ b/causy/ui/server.py @@ -182,6 +182,13 @@ def _set_workspace(workspace: Workspace): WORKSPACE = workspace +def is_port_in_use(host: str, port: int) -> bool: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex((host, port)) == 0 + + def server(result: Dict[str, Any] = None, workspace: Workspace = None): """Create the FastAPI server.""" app = _create_ui_app() @@ -196,7 +203,16 @@ def server(result: Dict[str, Any] = None, workspace: Workspace = None): raise ValueError("No model or workspace provided") host = os.getenv("HOST", "localhost") - port = int(os.getenv("PORT", "8000")) + is_port_from_env = os.getenv("PORT") + if is_port_from_env: + port = int(is_port_from_env) + else: + port = int(os.getenv("PORT", "8000")) + while is_port_in_use(host, port): + port += 1 + if port > 65535: + raise ValueError("No free port available") + cors_enabled = os.getenv("CORS_ENABLED", "false").lower() == "true" # cors e.g. for development of separate frontend diff --git a/causy/variables.py b/causy/variables.py index 2d3236f..1b74d31 100644 --- a/causy/variables.py +++ b/causy/variables.py @@ -8,6 +8,8 @@ VariableInterfaceType = TypeVar("VariableInterfaceType") +VariableType = Union[str, int, float, bool] + class VariableTypes(enum.Enum): String = "string" @@ -149,7 +151,7 @@ def type(self) -> str: CausyParameter = Union[BoolParameter, IntegerParameter, FloatParameter, StringParameter] -def validate_variable_values(algorithm, variable_values: Dict[str, Any]): +def validate_variable_values(algorithm, variable_values: Dict[str, VariableType]): """ Validate the variable values for the algorithm. :param algorithm: @@ -169,8 +171,8 @@ def validate_variable_values(algorithm, variable_values: Dict[str, Any]): def resolve_variables( - variables: List[BaseVariable], variable_values: Dict[str, Any] -) -> Dict[str, Any]: + variables: List[BaseVariable], variable_values: Dict[str, VariableType] +) -> Dict[str, VariableType]: """ Resolve the variables from the list of variables and the variable values coming from the user. :param variables: @@ -225,7 +227,7 @@ def resolve_variables_to_algorithm_for_pipeline_steps(pipeline_steps, variables) return pipeline_steps -def deserialize_variable(variable_dict: Dict[str, Any]) -> BaseVariable: +def deserialize_variable(variable_dict: Dict[str, VariableType]) -> BaseVariable: """ Deserialize the variable from the dictionary. :param variable_dict: diff --git a/causy/workspaces/cli.py b/causy/workspaces/cli.py index 01206e7..f486b7b 100644 --- a/causy/workspaces/cli.py +++ b/causy/workspaces/cli.py @@ -21,6 +21,7 @@ PackageLoader, ) from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn from rich.table import Table from causy.graph_model import graph_model_factory @@ -280,7 +281,26 @@ def _execute_experiment(workspace: Workspace, experiment: Experiment) -> Result: model = graph_model_factory(pipeline, experiment.variables)() model.create_graph_from_data(data_loader) model.create_all_possible_edges() - model.execute_pipeline_steps() + task_count = len(model.pipeline_steps) + current = 0 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + transient=True, + ) as progress: + prev_task = None + prev_task_data = None + for task in model.execute_pipeline_step_with_progress(): + current += 1 + if prev_task is not None: + progress.update( + prev_task, + completed=True, + current=1, + description=f"✅ {prev_task_data['step']} ({round(task['previous_duration'])}s)", + ) + prev_task = progress.add_task(description=task["step"], total=1) + prev_task_data = task return Result( algorithm=workspace.pipelines[experiment.pipeline], @@ -650,10 +670,12 @@ def init(): "Enter the name of the data loader" ).ask() data_loader_slug = slugify(data_loader_name, "_") - workspace.dataloaders[data_loader_slug] = { - "type": data_loader_type, - "reference": data_loader_path, - } + workspace.dataloaders[data_loader_slug] = DataLoaderReference( + **{ + "type": data_loader_type, + "reference": data_loader_path, + } + ) elif data_loader_type == "dynamic": data_loader_name = questionary.text( "Enter the name of the data loader" diff --git a/causy/workspaces/models.py b/causy/workspaces/models.py index 18d881f..f404ab3 100644 --- a/causy/workspaces/models.py +++ b/causy/workspaces/models.py @@ -1,9 +1,10 @@ -from typing import Optional, Dict, Any +from typing import Optional, Dict from pydantic import BaseModel from causy.data_loader import DataLoaderReference from causy.models import AlgorithmReference, Algorithm +from causy.variables import VariableType class Experiment(BaseModel): @@ -15,7 +16,7 @@ class Experiment(BaseModel): pipeline: str dataloader: str - variables: Optional[Dict[str, Any]] = None + variables: Optional[Dict[str, VariableType]] class Workspace(BaseModel): diff --git a/poetry.lock b/poetry.lock index 88ec134..b8b354a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -863,13 +863,13 @@ files = [ [[package]] name = "pdoc" -version = "14.4.0" +version = "14.5.1" description = "API Documentation for Python Projects" optional = false python-versions = ">=3.8" files = [ - {file = "pdoc-14.4.0-py3-none-any.whl", hash = "sha256:6ea4fe07620b1f7601e2708a307a257636ec206e20b5611640b30f2e3cab47d6"}, - {file = "pdoc-14.4.0.tar.gz", hash = "sha256:c92edc425429ccbe287ace2a027953c24f13de53eab484c1a6d31ca72dd2fda9"}, + {file = "pdoc-14.5.1-py3-none-any.whl", hash = "sha256:fda6365a06e438b43ca72235b58a2e2ecd66445fcc444313f6ebbde4b0abd94b"}, + {file = "pdoc-14.5.1.tar.gz", hash = "sha256:4ddd9c5123a79f511cedffd7231bf91a6e0bd0968610f768342ec5d00b5eefee"}, ] [package.dependencies] diff --git a/tests/utils.py b/tests/utils.py index 39bde51..b473334 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,11 +1,12 @@ import json import random +from typing import List from unittest import TestCase from unittest.util import safe_repr import numpy as np import torch -from pydantic import parse_obj_as +from pydantic import TypeAdapter from causy.graph import Graph from causy.serialization import CausyJSONEncoder @@ -25,7 +26,7 @@ def dump_fixture_graph(graph, file_path): def load_fixture_graph(file_path): with open(file_path, "r") as f: data = json.loads(f.read()) - return parse_obj_as(Graph, data) + return TypeAdapter(Graph).validate_python(data) class CausyTestCase(TestCase):