Skip to content

Commit

Permalink
add more logging info
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Jul 21, 2024
1 parent 0d5e4b6 commit 8b5ed96
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/scripts/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def run(self):
setup_data_loaders(
self.args.dataset,
self.args.data_path,
self.args.batch_size,
logger=self.logger,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
num_samples=self.args.num_samples,
standardize=self.args.standardize,
Expand All @@ -297,6 +298,7 @@ def run(self):
self.model = setup_model(
self.args.model,
self.metadata,
logger=self.logger,
region_graph=self.args.region_graph,
structured_decomposable=self.args.region_graph_sd,
num_components=self.args.num_components,
Expand Down
14 changes: 10 additions & 4 deletions src/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
plot_bivariate_discrete_samples_hmap,
)
from models import PC, MPC, SOS
from scripts.logger import Logger
from utilities import (
retrieve_default_dtype,
REGION_GRAPHS,
Expand Down Expand Up @@ -279,7 +280,8 @@ def setup_experiment_path(
def setup_data_loaders(
dataset: str,
path: str,
batch_size: int,
logger: Logger,
batch_size: int = 128,
num_workers: int = 0,
num_samples: int = 1000,
standardize: bool = False,
Expand All @@ -289,7 +291,8 @@ def setup_data_loaders(
discretize_bins: int = 32,
shuffle_bins: bool = False,
) -> Tuple[dict, Tuple[DataLoader, DataLoader, DataLoader]]:
seed = 123
logger.info(f"Loading dataset '{dataset}' ...")

numpy_dtype = retrieve_default_dtype(numpy=True)
metadata = dict()
# Load the dataset
Expand All @@ -300,7 +303,7 @@ def setup_data_loaders(
language_dataset = dataset in LANGUAGE_DATASETS
if small_uci_dataset:
train_data, valid_data, test_data = load_small_uci_dataset(
dataset, path=path, seed=seed
dataset, path=path, seed=123
)
metadata["image_shape"] = None
metadata["num_variables"] = train_data.shape[1]
Expand Down Expand Up @@ -371,7 +374,7 @@ def setup_data_loaders(
test_data = TensorDataset(torch.tensor(test_data))
elif language_dataset:
train_data, valid_data, test_data = load_language_dataset(
dataset, path=path, seed=seed
dataset, path=path, seed=123
)
seq_length = train_data.shape[1]
metadata["image_shape"] = None
Expand Down Expand Up @@ -436,6 +439,7 @@ def setup_data_loaders(
def setup_model(
model_name: str,
dataset_metadata: dict,
logger: Logger,
region_graph: str = "rnd",
structured_decomposable: bool = False,
num_components: int = 1,
Expand All @@ -446,6 +450,8 @@ def setup_model(
spline_order: int = 2,
seed: int = 123,
) -> Union[PC, Flow]:
logger.info(f"Building model '{model_name}' ...")

if complex and model_name != "SOS":
raise ValueError("--complex can only be used with SOS circuits")
assert model_name in MODELS
Expand Down

0 comments on commit 8b5ed96

Please sign in to comment.