Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman223 committed Nov 3, 2023
1 parent a1c7c1e commit 6c27969
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions bamt/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os.path as path
import random
import re
from copy import deepcopy
from typing import Dict, Tuple, List, Callable, Optional, Type, Union, Any, Sequence

import numpy as np
Expand Down Expand Up @@ -457,21 +458,23 @@ def save(self, bn_name, models_dir: str = "models_dir"):
:return: saving status.
"""
distributions = self.distributions.copy()
distributions = deepcopy(self.distributions)
new_weights = {str(key): self.weights[key] for key in self.weights}

to_serialize = {}
# separate logit and gaussian nodes from distributions to serialize bn's models
for node_name in self.distributions.keys():
for node_name in distributions.keys():
if self[node_name].type.startswith("Gaussian"):
if not distributions[node_name]["regressor"]:
continue
if (
"Gaussian" in self[node_name].type
or "Logit" in self[node_name].type
or "ConditionalLogit" in self[node_name].type
or "ConditionalGaussian" in self[node_name].type
):
to_serialize[node_name] = [
self[node_name].type,
self.distributions[node_name],
distributions[node_name],
]

serializer = serialization_utils.ModelsSerializer(
Expand Down

0 comments on commit 6c27969

Please sign in to comment.