Skip to content

Commit

Permalink
add blox plot about squared complex circuits + logged checkpoints bas…
Browse files Browse the repository at this point in the history
…ed on trial id
  • Loading branch information
loreloc committed Jul 21, 2024
1 parent 11ec3fe commit 0d5e4b6
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 257 deletions.
19 changes: 13 additions & 6 deletions src/graphics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,26 @@ def setup_tueplots(
ncols: int,
rel_width: float = 1.0,
hw_ratio: Optional[float] = None,
inc_font_size: int = 0,
default_smaller: int = -1,
use_tex: bool = True,
**kwargs
):
if use_tex:
font_config = fonts.iclr2023_tex(family="serif")
font_config = fonts.neurips2024_tex(family="serif")
else:
font_config = fonts.iclr2023(family="serif")
font_config = fonts.neurips2024(family="serif")
if hw_ratio is not None:
kwargs["height_to_width_ratio"] = hw_ratio
size = figsizes.iclr2023(rel_width=rel_width, nrows=nrows, ncols=ncols, **kwargs)
fontsize_config = fontsizes.iclr2023(default_smaller=-inc_font_size)
rc_params = {**font_config, **size, **fontsize_config}
size = figsizes.neurips2024(rel_width=rel_width, nrows=nrows, ncols=ncols, **kwargs)
fontsize_config = fontsizes.neurips2024(default_smaller=default_smaller)
rc_params = {
**font_config,
**size,
**fontsize_config,
}
rc_params.update({
'text.latex.preamble': r'\usepackage{amsfonts}'
})
plt.rcParams.update(rc_params)
# plt.rcParams.update({
# "axes.prop_cycle": plt.cycler(
Expand Down
4 changes: 2 additions & 2 deletions src/scripts/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, args: Namespace):
}
os.makedirs(kwargs["wandb_path"], exist_ok=True)

self.logger = Logger(args.verbose, **kwargs)
self.logger = Logger(self._trial_unique_id, args.verbose, **kwargs)
self.metadata: Dict[str, Any] = dict()

self.dataloaders: Dict[str, Optional[DataLoader]] = {
Expand Down Expand Up @@ -263,7 +263,7 @@ def _eval_step(
"ppl": metrics["test_ppl"],
},
},
"checkpoint.pt",
f"checkpoint-{self._trial_unique_id}.pt",
)
return metrics

Expand Down
10 changes: 6 additions & 4 deletions src/scripts/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
class Logger:
def __init__(
self,
trail_id: str,
verbose: bool,
*,
checkpoint_path: Optional[str] = None,
tboard_path: Optional[str] = None,
wandb_path: Optional[str] = None,
wandb_kwargs: Optional[Dict[str, Any]] = None
):
self.trial_id = trail_id
self.verbose = verbose
self.checkpoint_path = checkpoint_path
self._tboard_writer: Optional[SummaryWriter] = None
Expand Down Expand Up @@ -150,17 +152,17 @@ def log_step_distribution(

def close(self):
if self._logged_distributions:
self.save_array(self._best_distribution, "distbest.npy")
self.save_array(self._best_distribution, f"distbest-{self.trial_id}.npy")
self.save_array(
np.stack(self._logged_distributions, axis=0), "diststeps.npy"
np.stack(self._logged_distributions, axis=0), f"diststeps-{self.trial_id}.npy"
)
if self._logged_wcoords:
self.save_array(np.stack(self._logged_wcoords, axis=0), "wcoords.npy")
self.save_array(np.stack(self._logged_wcoords, axis=0), f"wcoords-{self.trial_id}.npy")
if self._tboard_writer is not None:
self._tboard_writer.close()
if wandb.run:
wandb.finish(quiet=True)
self.save_dict(self._logged_scalars, "scalars.json")
self.save_dict(self._logged_scalars, f"scalars-{self.trial_id}.json")

def save_checkpoint(self, data: Dict[str, Any], filepath: str):
if self.checkpoint_path:
Expand Down
Empty file.
245 changes: 0 additions & 245 deletions src/scripts/plots/gpt2dist/lines.py

This file was deleted.

Loading

0 comments on commit 0d5e4b6

Please sign in to comment.