-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
made initialization more aligned to previous ones
- Loading branch information
Showing
5 changed files
with
111 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import functools | ||
from typing import Dict, Any | ||
|
||
from cirkit.backend.torch.compiler import TorchCompiler | ||
from cirkit.backend.torch.initializers import InitializerFunc | ||
from cirkit.symbolic.initializers import Initializer | ||
from torch import nn | ||
from torch import Tensor | ||
|
||
|
||
class ExpUniformInitializer(Initializer): | ||
def __init__(self, a: float = 0.0, b: float = 1.0) -> None: | ||
if a >= b: | ||
raise ValueError("The minimum should be strictly less than the maximum") | ||
self.a = a | ||
self.b = b | ||
|
||
@property | ||
def config(self) -> Dict[str, Any]: | ||
return dict(a=self.a, b=self.b) | ||
|
||
|
||
def exp_uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> Tensor: | ||
nn.init.uniform_(tensor, a=a, b=b) | ||
tensor.log_() | ||
return tensor | ||
|
||
|
||
def compile_exp_uniform_initializer( | ||
compiler: TorchCompiler, init: ExpUniformInitializer | ||
) -> InitializerFunc: | ||
return functools.partial(exp_uniform_, a=init.a, b=init.b) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import Tuple, Dict, Any | ||
|
||
import torch | ||
from cirkit.backend.torch.compiler import TorchCompiler | ||
from cirkit.backend.torch.parameters.nodes import TorchEntrywiseParameterOp | ||
from cirkit.symbolic.parameters import EntrywiseParameterOp | ||
from torch import Tensor | ||
|
||
|
||
class ScaledSigmoidParameter(EntrywiseParameterOp): | ||
def __init__(self, in_shape: Tuple[int, ...], vmin: float, vmax: float, scale: float = 1.0): | ||
super().__init__(in_shape) | ||
self.vmin = vmin | ||
self.vmax = vmax | ||
self.scale = scale | ||
|
||
@property | ||
def config(self) -> Dict[str, Any]: | ||
return dict(vmin=self.vmin, vmax=self.vmax, scale=self.scale) | ||
|
||
|
||
class TorchScaledSigmoidParameter(TorchEntrywiseParameterOp): | ||
def __init__( | ||
self, in_shape: Tuple[int, ...], *, vmin: float, vmax: float, scale: float, num_folds: int = 1 | ||
) -> None: | ||
super().__init__(in_shape, num_folds=num_folds) | ||
assert 0 <= vmin < vmax, "Must provide 0 <= vmin < vmax." | ||
assert scale > 0.0 | ||
self.vmin = vmin | ||
self.vmax = vmax | ||
self.scale = scale | ||
|
||
@property | ||
def config(self) -> Dict[str, Any]: | ||
config = super().config | ||
config.update(vmin=self.vmin, vmax=self.vmax, scale=self.scale) | ||
return config | ||
|
||
@torch.compile() | ||
def forward(self, x: Tensor) -> Tensor: | ||
return torch.sigmoid(x * self.scale) * (self.vmax - self.vmin) + self.vmin | ||
|
||
|
||
def compile_scaled_sigmoid_parameter( | ||
compiler: TorchCompiler, p: ScaledSigmoidParameter | ||
) -> TorchScaledSigmoidParameter: | ||
(in_shape,) = p.in_shapes | ||
return TorchScaledSigmoidParameter(in_shape, vmin=p.vmin, vmax=p.vmax, scale=p.scale) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from cirkit.pipeline import PipelineContext | ||
|
||
from initializers import compile_exp_uniform_initializer | ||
from parameters import compile_scaled_sigmoid_parameter | ||
|
||
|
||
def setup_pipeline_context( | ||
*, | ||
backend: str = 'torch', | ||
semiring: str = 'lse-sum', | ||
fold: bool = True, | ||
optimize: bool = True, | ||
) -> PipelineContext: | ||
ctx = PipelineContext( | ||
backend=backend, semiring=semiring, fold=fold, optimize=optimize | ||
) | ||
ctx.add_parameter_compilation_rule(compile_scaled_sigmoid_parameter) | ||
ctx.add_initializer_compilation_rule(compile_exp_uniform_initializer) | ||
return ctx |