Skip to content

Commit

Permalink
Merge pull request #43 from causy-dev/progress-preview
Browse files Browse the repository at this point in the history
Progress preview
  • Loading branch information
LilithWittmann authored Jun 29, 2024
2 parents c3a6068 + 9b9e8a6 commit 118ed30
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 31 deletions.
7 changes: 7 additions & 0 deletions causy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@

@app.command()
def eject(algorithm: str, output_file: str):
logging.warning(
f"Ejecting pipelines outside of workspace context is deprecated. Please use workspaces instead."
)

typer.echo(f"💾 Loading algorithm {algorithm}")
model = AVAILABLE_ALGORITHMS[algorithm]()
result = serialize_algorithm(model, algorithm_name=algorithm)
Expand All @@ -46,6 +50,9 @@ def execute(
output_file: str = None,
log_level: str = "ERROR",
):
logging.warning(
f"Executing outside of workspaces is deprecated and will be removed in future versions. Please use workspaces instead."
)
logging.basicConfig(level=log_level)
if pipeline:
typer.echo(f"💾 Loading pipeline from {pipeline}")
Expand Down
6 changes: 4 additions & 2 deletions causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
NodeInterface,
EdgeInterface,
EdgeTypeInterface,
MetadataType,
)
from causy.models import TestResultAction, TestResult, ActionHistoryStep
from causy.variables import VariableType

logger = logging.getLogger(__name__)

Expand All @@ -28,7 +30,7 @@ class Node(NodeInterface):
name: str
id: str
values: Optional[torch.Tensor] = None
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, MetadataType]] = None

def __hash__(self):
return hash(self.id)
Expand All @@ -46,7 +48,7 @@ class Edge(EdgeInterface):
u: NodeInterface
v: NodeInterface
edge_type: EdgeTypeInterface
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, MetadataType]] = None
deleted: Optional[bool] = False

def __init__(self, *args, **kwargs):
Expand Down
19 changes: 13 additions & 6 deletions causy/graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from copy import deepcopy
import time
from typing import Optional, List, Dict, Callable, Union, Any
from typing import Optional, List, Dict, Callable, Union, Any, Generator

import torch.multiprocessing as mp

Expand All @@ -21,6 +21,8 @@
from causy.variables import (
resolve_variables_to_algorithm_for_pipeline_steps,
resolve_variables,
VariableType,
VariableTypes,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -188,19 +190,26 @@ def create_all_possible_edges(self):
continue
self.graph.add_edge(u, v, {})

def execute_pipeline_steps(self):
def execute_pipeline_steps(self) -> List[ActionHistoryStep]:
"""
Execute all pipeline_steps
:return: the steps taken during the step execution
"""
all(self.execute_pipeline_step_with_progress())
return self.graph.action_history

def execute_pipeline_step_with_progress(self) -> Generator:
started = time.time()
for filter in self.pipeline_steps:
logger.info(f"Executing pipeline step {filter.__class__.__name__}")
if isinstance(filter, LogicStepInterface):
actions_taken = filter.execute(self.graph.graph, self)
self.graph.graph.action_history.append(actions_taken)
continue

yield {
"step": filter.__class__.__name__,
"previous_duration": time.time() - started,
}
started = time.time()
actions_taken = self.execute_pipeline_step(filter)
self.graph.graph.action_history.append(
Expand All @@ -213,8 +222,6 @@ def execute_pipeline_steps(self):

self.graph.purge_soft_deleted_edges()

return self.graph.action_history

def _format_yield(self, test_fn, graph, generator):
"""
Format the yield for the parallel processing
Expand Down Expand Up @@ -386,7 +393,7 @@ def execute_pipeline_step(

def graph_model_factory(
algorithm: Algorithm = None,
variables: Dict[str, Any] = None,
variables: Dict[str, VariableTypes] = None,
) -> type[AbstractGraphModel]:
"""
Create a graph model based on a List of pipeline_steps
Expand Down
7 changes: 6 additions & 1 deletion causy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@

AS_MANY_AS_FIELDS = 0

MetadataBaseType = Union[str, int, float, bool]
MetadataType = Union[
str, int, float, bool, List[MetadataBaseType], Dict[str, MetadataBaseType]
]


class ComparisonSettingsInterface(BaseModel, ABC):
min: IntegerParameter
Expand Down Expand Up @@ -85,7 +90,7 @@ class EdgeInterface(BaseModel, ABC):
u: NodeInterface
v: NodeInterface
edge_type: EdgeTypeInterface
metadata: Dict[str, any] = None
metadata: Dict[str, MetadataType] = None

class Config:
arbitrary_types_allowed = True
Expand Down
10 changes: 6 additions & 4 deletions causy/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import importlib
import json
from json import JSONEncoder
from typing import Dict, Any
from typing import Dict, Any, List
import os

import torch
import yaml
from pydantic import parse_obj_as
from pydantic import TypeAdapter


from causy.edge_types import EDGE_TYPES
from causy.graph_utils import load_pipeline_steps_by_definition
Expand Down Expand Up @@ -38,7 +39,7 @@ def load_algorithm_from_specification(algorithm_dict: Dict[str, Any]):
]
from causy.models import Algorithm

return parse_obj_as(Algorithm, algorithm_dict)
return TypeAdapter(Algorithm).validate_python(algorithm_dict)


def load_algorithm_by_reference(reference_type: str, algorithm: str):
Expand Down Expand Up @@ -102,4 +103,5 @@ def deserialize_result(result: Dict[str, Any], klass=Result):
result["edges"][i]["edge_type"] = EDGE_TYPES[edge["edge_type"]["name"]](
**edge["edge_type"]
)
return parse_obj_as(klass, result)

return TypeAdapter(klass).validate_python(result)
18 changes: 17 additions & 1 deletion causy/ui/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ def _set_workspace(workspace: Workspace):
WORKSPACE = workspace


def is_port_in_use(host: str, port: int) -> bool:
import socket

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex((host, port)) == 0


def server(result: Dict[str, Any] = None, workspace: Workspace = None):
"""Create the FastAPI server."""
app = _create_ui_app()
Expand All @@ -196,7 +203,16 @@ def server(result: Dict[str, Any] = None, workspace: Workspace = None):
raise ValueError("No model or workspace provided")

host = os.getenv("HOST", "localhost")
port = int(os.getenv("PORT", "8000"))
is_port_from_env = os.getenv("PORT")
if is_port_from_env:
port = int(is_port_from_env)
else:
port = int(os.getenv("PORT", "8000"))
while is_port_in_use(host, port):
port += 1
if port > 65535:
raise ValueError("No free port available")

cors_enabled = os.getenv("CORS_ENABLED", "false").lower() == "true"

# cors e.g. for development of separate frontend
Expand Down
10 changes: 6 additions & 4 deletions causy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

VariableInterfaceType = TypeVar("VariableInterfaceType")

VariableType = Union[str, int, float, bool]


class VariableTypes(enum.Enum):
String = "string"
Expand Down Expand Up @@ -149,7 +151,7 @@ def type(self) -> str:
CausyParameter = Union[BoolParameter, IntegerParameter, FloatParameter, StringParameter]


def validate_variable_values(algorithm, variable_values: Dict[str, Any]):
def validate_variable_values(algorithm, variable_values: Dict[str, VariableType]):
"""
Validate the variable values for the algorithm.
:param algorithm:
Expand All @@ -169,8 +171,8 @@ def validate_variable_values(algorithm, variable_values: Dict[str, Any]):


def resolve_variables(
variables: List[BaseVariable], variable_values: Dict[str, Any]
) -> Dict[str, Any]:
variables: List[BaseVariable], variable_values: Dict[str, VariableType]
) -> Dict[str, VariableType]:
"""
Resolve the variables from the list of variables and the variable values coming from the user.
:param variables:
Expand Down Expand Up @@ -225,7 +227,7 @@ def resolve_variables_to_algorithm_for_pipeline_steps(pipeline_steps, variables)
return pipeline_steps


def deserialize_variable(variable_dict: Dict[str, Any]) -> BaseVariable:
def deserialize_variable(variable_dict: Dict[str, VariableType]) -> BaseVariable:
"""
Deserialize the variable from the dictionary.
:param variable_dict:
Expand Down
32 changes: 27 additions & 5 deletions causy/workspaces/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PackageLoader,
)
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table

from causy.graph_model import graph_model_factory
Expand Down Expand Up @@ -280,7 +281,26 @@ def _execute_experiment(workspace: Workspace, experiment: Experiment) -> Result:
model = graph_model_factory(pipeline, experiment.variables)()
model.create_graph_from_data(data_loader)
model.create_all_possible_edges()
model.execute_pipeline_steps()
task_count = len(model.pipeline_steps)
current = 0
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=True,
) as progress:
prev_task = None
prev_task_data = None
for task in model.execute_pipeline_step_with_progress():
current += 1
if prev_task is not None:
progress.update(
prev_task,
completed=True,
current=1,
description=f"✅ {prev_task_data['step']} ({round(task['previous_duration'])}s)",
)
prev_task = progress.add_task(description=task["step"], total=1)
prev_task_data = task

return Result(
algorithm=workspace.pipelines[experiment.pipeline],
Expand Down Expand Up @@ -650,10 +670,12 @@ def init():
"Enter the name of the data loader"
).ask()
data_loader_slug = slugify(data_loader_name, "_")
workspace.dataloaders[data_loader_slug] = {
"type": data_loader_type,
"reference": data_loader_path,
}
workspace.dataloaders[data_loader_slug] = DataLoaderReference(
**{
"type": data_loader_type,
"reference": data_loader_path,
}
)
elif data_loader_type == "dynamic":
data_loader_name = questionary.text(
"Enter the name of the data loader"
Expand Down
5 changes: 3 additions & 2 deletions causy/workspaces/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional, Dict, Any
from typing import Optional, Dict

from pydantic import BaseModel

from causy.data_loader import DataLoaderReference
from causy.models import AlgorithmReference, Algorithm
from causy.variables import VariableType


class Experiment(BaseModel):
Expand All @@ -15,7 +16,7 @@ class Experiment(BaseModel):

pipeline: str
dataloader: str
variables: Optional[Dict[str, Any]] = None
variables: Optional[Dict[str, VariableType]]


class Workspace(BaseModel):
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import random
from typing import List
from unittest import TestCase
from unittest.util import safe_repr

import numpy as np
import torch
from pydantic import parse_obj_as
from pydantic import TypeAdapter

from causy.graph import Graph
from causy.serialization import CausyJSONEncoder
Expand All @@ -25,7 +26,7 @@ def dump_fixture_graph(graph, file_path):
def load_fixture_graph(file_path):
with open(file_path, "r") as f:
data = json.loads(f.read())
return parse_obj_as(Graph, data)
return TypeAdapter(Graph).validate_python(data)


class CausyTestCase(TestCase):
Expand Down

0 comments on commit 118ed30

Please sign in to comment.