Skip to content

Commit

Permalink
freezing initializers
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Jul 31, 2024
1 parent 7c66cc7 commit c23ff28
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 55 deletions.
9 changes: 4 additions & 5 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from torch import Tensor, nn

from initializers import ExpUniformInitializer
from parameters import ScaledSigmoidParameter
from pipeline import setup_pipeline_context


Expand Down Expand Up @@ -438,7 +437,7 @@ def categorical_layer_factory(
num_channels,
num_categories=input_layer_kwargs["num_categories"],
logits_factory=lambda shape: Parameter.from_leaf(
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1e-1))
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0))
),
)

Expand All @@ -454,7 +453,7 @@ def gaussian_layer_factory(
),
stddev_factory=lambda shape: Parameter.from_sequence(
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0)),
ScaledSigmoidParameter(shape, vmin=1e-5, vmax=1.0, scale=1.0),
ScaledSigmoidParameter(shape, vmin=1e-5, vmax=1.0),
),
)

Expand Down Expand Up @@ -518,7 +517,7 @@ def categorical_layer_factory(
num_channels,
num_categories=input_layer_kwargs["num_categories"],
logits_factory=lambda shape: Parameter.from_leaf(
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1e-1))
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0))
),
)

Expand All @@ -534,7 +533,7 @@ def gaussian_layer_factory(
),
stddev_factory=lambda shape: Parameter.from_sequence(
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0)),
ScaledSigmoidParameter(shape, vmin=1e-5, vmax=1.0, scale=1.0),
ScaledSigmoidParameter(shape, vmin=1e-5, vmax=1.0),
),
)

Expand Down
48 changes: 0 additions & 48 deletions src/parameters.py

This file was deleted.

2 changes: 0 additions & 2 deletions src/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from cirkit.pipeline import PipelineContext

from initializers import compile_exp_uniform_initializer
from parameters import compile_scaled_sigmoid_parameter


def setup_pipeline_context(
Expand All @@ -14,6 +13,5 @@ def setup_pipeline_context(
ctx = PipelineContext(
backend=backend, semiring=semiring, fold=fold, optimize=optimize
)
ctx.add_parameter_compilation_rule(compile_scaled_sigmoid_parameter)
ctx.add_initializer_compilation_rule(compile_exp_uniform_initializer)
return ctx

0 comments on commit c23ff28

Please sign in to comment.