diff --git a/lnn/_utils.py b/lnn/_utils.py index 1a28e3a..7709ce8 100644 --- a/lnn/_utils.py +++ b/lnn/_utils.py @@ -153,12 +153,14 @@ def average_time(multi_checkpoint: List[List[Tuple]]): result[name] += point -def logger_setup(flush=False): - for level in ["INFO"]: - filename = f"LNN_{level}.log" - if flush: - with open(filename, "w"): - pass - logging.basicConfig( - filename=filename, encoding="utf-8", level=eval(f"logging.{level}") - ) +def get_logger(flush: bool = False): + ref = importlib.resources.files("lnn").joinpath("../config.yaml") + config = yaml.safe_load(ref.read_text()) + logging.config.dictConfig(config) + logger = logging.getLogger( + "".join(config["handlers"]["file"]["filename"].split(".")[:-1]) + ) + if flush: + with open(logger.handlers[0].baseFilename, "w"): + pass + return logger diff --git a/lnn/model.py b/lnn/model.py index 5816c27..453a5e4 100644 --- a/lnn/model.py +++ b/lnn/model.py @@ -24,7 +24,7 @@ from torch import nn import matplotlib.pyplot as plt -_utils.logger_setup(flush=True) +_utils.get_logger(flush=True) class Model(nn.Module): @@ -117,7 +117,8 @@ def __init__( self.add_knowledge(knowledge) if data: self.add_data(data) - logging.info(f" {name} {datetime.datetime.now()} ".join(["*" * 22] * 2)) + self.logger = _utils.get_logger(flush=True) + self.logger.info(f" {name} {datetime.datetime.now()} ".join(["*" * 22] * 2)) def __getitem__( self, formula: Union[Formula, int] @@ -438,7 +439,7 @@ def _traverse_execute( val = getattr(node, func)(**kwds) if hasattr(node, func) else None coalesce = coalesce + val if val is not None else coalesce if coalesce and func in [d.value.lower() for d in Direction]: - logging.info(f"{direction.value} INFERENCE RESULT:{coalesce}") + self.logger.info(f"{direction.value} INFERENCE RESULT:{coalesce}") return coalesce def infer( @@ -488,14 +489,14 @@ def _infer( while not converged: if self.query and self.query.is_classically_resolved and not self._converge: - logging.info("=" * 22) - logging.info( + self.logger.info("=" * 22) + self.logger.info( f"QUERY PROVED AS {self.query.world_state(True)} for " f"'{self.query.name}'" ) break - logging.info("-" * 22) - logging.info(f"REASONING STEP:{steps}") + self.logger.info("-" * 22) + self.logger.info(f"REASONING STEP:{steps}") bounds_diff = 0.0 for d in direction: bounds_diff += self._traverse_execute( @@ -508,17 +509,17 @@ def _infer( ) if converged_bounds: converged = True - logging.info("NO UPDATES AVAILABLE, TRYING A NEW AXIOM") + self.logger.info("NO UPDATES AVAILABLE, TRYING A NEW AXIOM") facts_inferred += bounds_diff steps += 1 if max_steps and steps >= max_steps: break - logging.info("=" * 22) - logging.info( + self.logger.info("=" * 22) + self.logger.info( f"INFERENCE CONVERGED WITH {facts_inferred} BOUNDS " f"UPDATES IN {steps} REASONING STEPS " ) - logging.info("*" * 78) + self.logger.info("*" * 78) return steps, facts_inferred def upward(self, **kwds): @@ -613,7 +614,7 @@ def train(self, losses: Union[Loss, List[Loss], Dict[List[Loss], float]], **kwds ): optimizer.zero_grad() if epoch > 0: - logging.info(" PARAMETER STEP ".join(["#" * 31] * 2)) + self.logger.info(" PARAMETER STEP ".join(["#" * 31] * 2)) self.reset_bounds() self.increment_param_history(kwds.get("parameter_history")) _, facts_inferred = self.infer(**kwds) @@ -622,7 +623,7 @@ def train(self, losses: Union[Loss, List[Loss], Dict[List[Loss], float]], **kwds if not loss.grad_fn: break if loss and len(loss_fn) > 1: - logging.info(f"TOTAL LOSS: {loss}") + self.logger.info(f"TOTAL LOSS: {loss}") loss.backward() optimizer.step() self._project_params() @@ -698,7 +699,7 @@ def loss_fn(self, losses): self._traverse_execute(f"_{loss.value.lower()}_loss", **kwds) ) if result[-1]: - logging.info(f"{loss.value.upper()} LOSS {result[-1]}") + self.logger.info(f"{loss.value.upper()} LOSS {result[-1]}") return result def print( diff --git a/lnn/symbolic/logic/binary_neuron.py b/lnn/symbolic/logic/binary_neuron.py index f5ace0d..b4488cd 100644 --- a/lnn/symbolic/logic/binary_neuron.py +++ b/lnn/symbolic/logic/binary_neuron.py @@ -13,7 +13,7 @@ from ... import _utils from ...constants import Direction, NeuralActivation -_utils.logger_setup() +_utils.get_logger() class _BinaryNeuron(_ConnectiveNeuron): diff --git a/lnn/symbolic/logic/connective_formula.py b/lnn/symbolic/logic/connective_formula.py index bcaaa7f..d3ff8de 100644 --- a/lnn/symbolic/logic/connective_formula.py +++ b/lnn/symbolic/logic/connective_formula.py @@ -9,7 +9,7 @@ from .formula import Formula from ... import _utils -_utils.logger_setup() +_utils.get_logger() class _ConnectiveFormula(Formula): diff --git a/lnn/symbolic/logic/connective_neuron.py b/lnn/symbolic/logic/connective_neuron.py index 05f8eac..34f92b6 100644 --- a/lnn/symbolic/logic/connective_neuron.py +++ b/lnn/symbolic/logic/connective_neuron.py @@ -17,7 +17,7 @@ from ... import _utils from ...constants import Direction -_utils.logger_setup() +_utils.get_logger() subclasses = {} @@ -79,14 +79,14 @@ def upward( ) result = self.neuron.aggregate_bounds(grounding_rows, self.func(input_bounds)) if result: - logging.info( + self.logger.info( "↑ BOUNDS UPDATED " f"TIGHTENED:{result} " f"FOR:'{self.name}' " f"FORMULA:{self.formula_number} " ) if self.is_contradiction(): - logging.info( + self.logger.info( "↑ CONTRADICTION " f"FOR:'{self.name}' " f"FORMULA:{self.formula_number} " @@ -148,7 +148,7 @@ def downward( op_grounding_rows, new_bounds[..., op_index], duplicates=duplicates ) if op_aggregate: - logging.info( + self.logger.info( "↓ BOUNDS UPDATED " f"TIGHTENED:{op_aggregate} " f"FOR:'{op.name}' " @@ -157,7 +157,7 @@ def downward( f"PARENT:{self.formula_number} " ) if op.is_contradiction(): - logging.info( + self.logger.info( "↓ CONTRADICTION " f"FOR:'{op.name}' " f"FROM:'{self.name}' " diff --git a/lnn/symbolic/logic/formula.py b/lnn/symbolic/logic/formula.py index 91f40ec..4b631d5 100644 --- a/lnn/symbolic/logic/formula.py +++ b/lnn/symbolic/logic/formula.py @@ -22,7 +22,7 @@ import torch import numpy as np -_utils.logger_setup() +_utils.get_logger() subclasses: typing.Dict[str, object] = {} @@ -353,7 +353,7 @@ def recurse(formula): recurse(self) if store: if edge_replace: - logging.info(f"ABSORBED NEGATIONS INTO WEIGHTS FOR: '{self.name}'") + self.logger.info(f"ABSORBED NEGATIONS INTO WEIGHTS FOR: '{self.name}'") return edge_replace, n_negations else: return operands, edge_replace, n_negations diff --git a/lnn/symbolic/logic/n_ary_neuron.py b/lnn/symbolic/logic/n_ary_neuron.py index b3ff91c..80ccc15 100644 --- a/lnn/symbolic/logic/n_ary_neuron.py +++ b/lnn/symbolic/logic/n_ary_neuron.py @@ -16,7 +16,7 @@ from ... import _utils from ...constants import Direction, NeuralActivation -_utils.logger_setup() +_utils.get_logger() class _NAryNeuron(_ConnectiveNeuron): diff --git a/lnn/symbolic/logic/n_ary_operator.py b/lnn/symbolic/logic/n_ary_operator.py index 10e200f..b538079 100644 --- a/lnn/symbolic/logic/n_ary_operator.py +++ b/lnn/symbolic/logic/n_ary_operator.py @@ -18,7 +18,7 @@ from ... import _utils from ...constants import Fact -_utils.logger_setup() +_utils.get_logger() class _NAryOperator(_ConnectiveFormula): @@ -90,7 +90,7 @@ def upward( ) result = self.neuron.aggregate_bounds(grounding_rows, input_bounds) if result: - logging.info( + self.logger.info( "↑ BOUNDS UPDATED " f"TIGHTENED:{result} " f"FOR:'{self.name}' " @@ -146,7 +146,7 @@ def downward( op_grounding_rows[g_i] = op.grounding_table.get(op_g) op_aggregate = op.neuron.aggregate_bounds(op_grounding_rows, parent) if op_aggregate: - logging.info( + self.logger.info( "↓ BOUNDS UPDATED " f"TIGHTENED:{op_aggregate} " f"FOR:'{op.name}' " @@ -192,7 +192,7 @@ def upward( ) result = self.neuron.aggregate_bounds(grounding_rows, input_bounds) if result: - logging.info( + self.logger.info( "↑ BOUNDS UPDATED " f"TIGHTENED:{result} " f"FOR:'{self.name}' " diff --git a/lnn/symbolic/logic/neural_activation.py b/lnn/symbolic/logic/neural_activation.py index 18305cc..9228ef5 100644 --- a/lnn/symbolic/logic/neural_activation.py +++ b/lnn/symbolic/logic/neural_activation.py @@ -11,7 +11,7 @@ from ... import _utils, _exceptions from ...constants import NeuralActivation -_utils.logger_setup() +_utils.get_logger() class _NeuralActivation: diff --git a/lnn/symbolic/logic/node_activation.py b/lnn/symbolic/logic/node_activation.py index 3e170d6..39914c1 100644 --- a/lnn/symbolic/logic/node_activation.py +++ b/lnn/symbolic/logic/node_activation.py @@ -10,7 +10,7 @@ from ... import _utils -_utils.logger_setup() +_utils.get_logger() class _NodeActivation: diff --git a/lnn/symbolic/logic/unary_operator.py b/lnn/symbolic/logic/unary_operator.py index 5e80b2c..da9a754 100644 --- a/lnn/symbolic/logic/unary_operator.py +++ b/lnn/symbolic/logic/unary_operator.py @@ -22,7 +22,7 @@ from ...constants import Fact, Direction, Bound from torch.nn.parameter import Parameter -_utils.logger_setup() +_utils.get_logger() class _UnaryOperator(_ConnectiveFormula): @@ -388,7 +388,7 @@ def upward(self, **kwds) -> float: None, _utils.negate_bounds(self.operands[0].get_data(*groundings)) ) if self.is_contradiction(): - logging.info( + self.logger.info( "↑ CONTRADICTION " f"FOR:'{self.name}' " f"FORMULA:{self.formula_number} " @@ -419,7 +419,7 @@ def downward(self, **kwds) -> torch.Tensor: None, _utils.negate_bounds(self.get_data(*groundings)) ) if self.operands[0].is_contradiction(): - logging.info( + self.logger.info( "↓ CONTRADICTION " f"FOR:'{self.operands[0].name}' " f"FROM:'{self.name}' " diff --git a/lnn/symbolic/logic/variable.py b/lnn/symbolic/logic/variable.py index 66db376..4cc684d 100644 --- a/lnn/symbolic/logic/variable.py +++ b/lnn/symbolic/logic/variable.py @@ -10,7 +10,7 @@ from ... import _utils, utils -_utils.logger_setup() +_utils.get_logger() class Variable: