Skip to content

Commit

Permalink
made initialization more aligned to previous ones
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Jul 31, 2024
1 parent 71def17 commit 1cbbd6d
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 115 deletions.
32 changes: 32 additions & 0 deletions src/initializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import functools
from typing import Dict, Any

from cirkit.backend.torch.compiler import TorchCompiler
from cirkit.backend.torch.initializers import InitializerFunc
from cirkit.symbolic.initializers import Initializer
from torch import nn
from torch import Tensor


class ExpUniformInitializer(Initializer):
def __init__(self, a: float = 0.0, b: float = 1.0) -> None:
if a >= b:
raise ValueError("The minimum should be strictly less than the maximum")
self.a = a
self.b = b

@property
def config(self) -> Dict[str, Any]:
return dict(a=self.a, b=self.b)


def exp_uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> Tensor:
nn.init.uniform_(tensor, a=a, b=b)
tensor.log_()
return tensor


def compile_exp_uniform_initializer(
compiler: TorchCompiler, init: ExpUniformInitializer
) -> InitializerFunc:
return functools.partial(exp_uniform_, a=init.a, b=init.b)
91 changes: 0 additions & 91 deletions src/layers.py

This file was deleted.

36 changes: 12 additions & 24 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import torch
from cirkit.backend.torch.circuits import TorchCircuit, TorchConstantCircuit
from cirkit.backend.torch.layers import TorchLayer, TorchSumLayer
from cirkit.backend.torch.optimization.layers import DenseKroneckerPattern
from cirkit.pipeline import PipelineContext, compile
from cirkit.pipeline import compile
from cirkit.symbolic.circuit import Circuit
from cirkit.symbolic.dtypes import DataType
from cirkit.symbolic.initializers import NormalInitializer, UniformInitializer
Expand All @@ -21,7 +20,6 @@
from cirkit.symbolic.parameters import (
ExpParameter,
Parameter,
ScaledSigmoidParameter,
TensorParameter,
)
from cirkit.templates.region_graph import (
Expand All @@ -32,9 +30,9 @@
from cirkit.utils.scope import Scope
from torch import Tensor, nn

from layers import (
apply_dense_product,
)
from initializers import ExpUniformInitializer
from parameters import ScaledSigmoidParameter
from pipeline import setup_pipeline_context


class PC(nn.Module, ABC):
Expand Down Expand Up @@ -91,9 +89,7 @@ def __init__(
) -> None:
assert num_components > 0
super().__init__(num_variables)
self._pipeline = PipelineContext(
backend="torch", semiring="lse-sum", fold=True, optimize=True
)
self._pipeline = setup_pipeline_context(semiring='lse-sum')
self._circuit, self._int_circuit = self._build_circuits(
num_input_units,
num_sum_units,
Expand Down Expand Up @@ -182,13 +178,7 @@ def __init__(
) -> None:
assert num_squares > 0
super().__init__(num_variables)
self._pipeline = PipelineContext(
backend="torch", semiring="complex-lse-sum", fold=True, optimize=True
)
# Use a different optimization rule for the dense-kronecker pattern
self._pipeline._compiler._optimization_registry["layer_shatter"].add_rule(
apply_dense_product, signature=DenseKroneckerPattern
)
self._pipeline = setup_pipeline_context(semiring='complex-lse-sum')
self._circuit, self._int_sq_circuit = self._build_circuits(
num_input_units,
num_sum_units,
Expand Down Expand Up @@ -288,9 +278,7 @@ def __init__(
seed: int = 42,
) -> None:
super().__init__(num_variables)
self._pipeline = PipelineContext(
backend="torch", semiring="complex-lse-sum", fold=True, optimize=True
)
self._pipeline = setup_pipeline_context(semiring='complex-lse-sum')
# Introduce optimization rules
self._circuit, self._mono_circuit, self._int_circuit = self._build_circuits(
num_input_units,
Expand Down Expand Up @@ -465,8 +453,8 @@ def gaussian_layer_factory(
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0))
),
stddev_factory=lambda shape: Parameter.from_sequence(
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1e-1)),
ScaledSigmoidParameter(shape, vmin=1e-5, vmax=1.0),
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0)),
ScaledSigmoidParameter(shape, vmin=1e-5, vmax=2.0, scale=2.0),
),
)

Expand All @@ -484,7 +472,7 @@ def dense_layer_factory(
num_output_units,
weight_factory=lambda shape: Parameter.from_unary(
ExpParameter(shape),
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1e-1)),
TensorParameter(*shape, initializer=ExpUniformInitializer(0.0, 1.0)),
),
)

Expand Down Expand Up @@ -545,8 +533,8 @@ def gaussian_layer_factory(
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0))
),
stddev_factory=lambda shape: Parameter.from_sequence(
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1e-1)),
ScaledSigmoidParameter(shape, vmin=1e-5, vmax=1.0),
TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0)),
ScaledSigmoidParameter(shape, vmin=1e-5, vmax=2.0, scale=2.0),
),
)

Expand Down
48 changes: 48 additions & 0 deletions src/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Tuple, Dict, Any

import torch
from cirkit.backend.torch.compiler import TorchCompiler
from cirkit.backend.torch.parameters.nodes import TorchEntrywiseParameterOp
from cirkit.symbolic.parameters import EntrywiseParameterOp
from torch import Tensor


class ScaledSigmoidParameter(EntrywiseParameterOp):
def __init__(self, in_shape: Tuple[int, ...], vmin: float, vmax: float, scale: float = 1.0):
super().__init__(in_shape)
self.vmin = vmin
self.vmax = vmax
self.scale = scale

@property
def config(self) -> Dict[str, Any]:
return dict(vmin=self.vmin, vmax=self.vmax, scale=self.scale)


class TorchScaledSigmoidParameter(TorchEntrywiseParameterOp):
def __init__(
self, in_shape: Tuple[int, ...], *, vmin: float, vmax: float, scale: float, num_folds: int = 1
) -> None:
super().__init__(in_shape, num_folds=num_folds)
assert 0 <= vmin < vmax, "Must provide 0 <= vmin < vmax."
assert scale > 0.0
self.vmin = vmin
self.vmax = vmax
self.scale = scale

@property
def config(self) -> Dict[str, Any]:
config = super().config
config.update(vmin=self.vmin, vmax=self.vmax, scale=self.scale)
return config

@torch.compile()
def forward(self, x: Tensor) -> Tensor:
return torch.sigmoid(x * self.scale) * (self.vmax - self.vmin) + self.vmin


def compile_scaled_sigmoid_parameter(
compiler: TorchCompiler, p: ScaledSigmoidParameter
) -> TorchScaledSigmoidParameter:
(in_shape,) = p.in_shapes
return TorchScaledSigmoidParameter(in_shape, vmin=p.vmin, vmax=p.vmax, scale=p.scale)
19 changes: 19 additions & 0 deletions src/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from cirkit.pipeline import PipelineContext

from initializers import compile_exp_uniform_initializer
from parameters import compile_scaled_sigmoid_parameter


def setup_pipeline_context(
*,
backend: str = 'torch',
semiring: str = 'lse-sum',
fold: bool = True,
optimize: bool = True,
) -> PipelineContext:
ctx = PipelineContext(
backend=backend, semiring=semiring, fold=fold, optimize=optimize
)
ctx.add_parameter_compilation_rule(compile_scaled_sigmoid_parameter)
ctx.add_initializer_compilation_rule(compile_exp_uniform_initializer)
return ctx

0 comments on commit 1cbbd6d

Please sign in to comment.