From 6c27969359585c8f245ef96599d4f8f8b449a70c Mon Sep 17 00:00:00 2001 From: Roman223 Date: Fri, 3 Nov 2023 16:40:00 +0300 Subject: [PATCH] bug fix --- bamt/networks/base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/bamt/networks/base.py b/bamt/networks/base.py index 09ddc7d..98e1361 100644 --- a/bamt/networks/base.py +++ b/bamt/networks/base.py @@ -2,6 +2,7 @@ import os.path as path import random import re +from copy import deepcopy from typing import Dict, Tuple, List, Callable, Optional, Type, Union, Any, Sequence import numpy as np @@ -457,21 +458,23 @@ def save(self, bn_name, models_dir: str = "models_dir"): :return: saving status. """ - distributions = self.distributions.copy() + distributions = deepcopy(self.distributions) new_weights = {str(key): self.weights[key] for key in self.weights} to_serialize = {} # separate logit and gaussian nodes from distributions to serialize bn's models - for node_name in self.distributions.keys(): + for node_name in distributions.keys(): + if self[node_name].type.startswith("Gaussian"): + if not distributions[node_name]["regressor"]: + continue if ( "Gaussian" in self[node_name].type or "Logit" in self[node_name].type or "ConditionalLogit" in self[node_name].type - or "ConditionalGaussian" in self[node_name].type ): to_serialize[node_name] = [ self[node_name].type, - self.distributions[node_name], + distributions[node_name], ] serializer = serialization_utils.ModelsSerializer(