Skip to content

Commit

Permalink
ADD: logger loaded from file and stored per model
Browse files Browse the repository at this point in the history
Signed-off-by: naweedkhan <naweed.khan@ibm.com>
  • Loading branch information
NaweedAghmad committed Nov 30, 2023
1 parent aec7b50 commit b303d26
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 43 deletions.
20 changes: 11 additions & 9 deletions lnn/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ def average_time(multi_checkpoint: List[List[Tuple]]):
result[name] += point


def logger_setup(flush=False):
for level in ["INFO"]:
filename = f"LNN_{level}.log"
if flush:
with open(filename, "w"):
pass
logging.basicConfig(
filename=filename, encoding="utf-8", level=eval(f"logging.{level}")
)
def get_logger(flush: bool = False):
ref = importlib.resources.files("lnn").joinpath("../config.yaml")
config = yaml.safe_load(ref.read_text())
logging.config.dictConfig(config)
logger = logging.getLogger(
"".join(config["handlers"]["file"]["filename"].split(".")[:-1])
)
if flush:
with open(logger.handlers[0].baseFilename, "w"):
pass
return logger
29 changes: 15 additions & 14 deletions lnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch import nn
import matplotlib.pyplot as plt

_utils.logger_setup(flush=True)
_utils.get_logger(flush=True)


class Model(nn.Module):
Expand Down Expand Up @@ -117,7 +117,8 @@ def __init__(
self.add_knowledge(knowledge)
if data:
self.add_data(data)
logging.info(f" {name} {datetime.datetime.now()} ".join(["*" * 22] * 2))
self.logger = _utils.get_logger(flush=True)
self.logger.info(f" {name} {datetime.datetime.now()} ".join(["*" * 22] * 2))

def __getitem__(
self, formula: Union[Formula, int]
Expand Down Expand Up @@ -438,7 +439,7 @@ def _traverse_execute(
val = getattr(node, func)(**kwds) if hasattr(node, func) else None
coalesce = coalesce + val if val is not None else coalesce
if coalesce and func in [d.value.lower() for d in Direction]:
logging.info(f"{direction.value} INFERENCE RESULT:{coalesce}")
self.logger.info(f"{direction.value} INFERENCE RESULT:{coalesce}")
return coalesce

def infer(
Expand Down Expand Up @@ -488,14 +489,14 @@ def _infer(

while not converged:
if self.query and self.query.is_classically_resolved and not self._converge:
logging.info("=" * 22)
logging.info(
self.logger.info("=" * 22)
self.logger.info(
f"QUERY PROVED AS {self.query.world_state(True)} for "
f"'{self.query.name}'"
)
break
logging.info("-" * 22)
logging.info(f"REASONING STEP:{steps}")
self.logger.info("-" * 22)
self.logger.info(f"REASONING STEP:{steps}")
bounds_diff = 0.0
for d in direction:
bounds_diff += self._traverse_execute(
Expand All @@ -508,17 +509,17 @@ def _infer(
)
if converged_bounds:
converged = True
logging.info("NO UPDATES AVAILABLE, TRYING A NEW AXIOM")
self.logger.info("NO UPDATES AVAILABLE, TRYING A NEW AXIOM")
facts_inferred += bounds_diff
steps += 1
if max_steps and steps >= max_steps:
break
logging.info("=" * 22)
logging.info(
self.logger.info("=" * 22)
self.logger.info(
f"INFERENCE CONVERGED WITH {facts_inferred} BOUNDS "
f"UPDATES IN {steps} REASONING STEPS "
)
logging.info("*" * 78)
self.logger.info("*" * 78)
return steps, facts_inferred

def upward(self, **kwds):
Expand Down Expand Up @@ -613,7 +614,7 @@ def train(self, losses: Union[Loss, List[Loss], Dict[List[Loss], float]], **kwds
):
optimizer.zero_grad()
if epoch > 0:
logging.info(" PARAMETER STEP ".join(["#" * 31] * 2))
self.logger.info(" PARAMETER STEP ".join(["#" * 31] * 2))
self.reset_bounds()
self.increment_param_history(kwds.get("parameter_history"))
_, facts_inferred = self.infer(**kwds)
Expand All @@ -622,7 +623,7 @@ def train(self, losses: Union[Loss, List[Loss], Dict[List[Loss], float]], **kwds
if not loss.grad_fn:
break
if loss and len(loss_fn) > 1:
logging.info(f"TOTAL LOSS: {loss}")
self.logger.info(f"TOTAL LOSS: {loss}")
loss.backward()
optimizer.step()
self._project_params()
Expand Down Expand Up @@ -698,7 +699,7 @@ def loss_fn(self, losses):
self._traverse_execute(f"_{loss.value.lower()}_loss", **kwds)
)
if result[-1]:
logging.info(f"{loss.value.upper()} LOSS {result[-1]}")
self.logger.info(f"{loss.value.upper()} LOSS {result[-1]}")
return result

def print(
Expand Down
2 changes: 1 addition & 1 deletion lnn/symbolic/logic/binary_neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ... import _utils
from ...constants import Direction, NeuralActivation

_utils.logger_setup()
_utils.get_logger()


class _BinaryNeuron(_ConnectiveNeuron):
Expand Down
2 changes: 1 addition & 1 deletion lnn/symbolic/logic/connective_formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .formula import Formula
from ... import _utils

_utils.logger_setup()
_utils.get_logger()


class _ConnectiveFormula(Formula):
Expand Down
10 changes: 5 additions & 5 deletions lnn/symbolic/logic/connective_neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ... import _utils
from ...constants import Direction

_utils.logger_setup()
_utils.get_logger()
subclasses = {}


Expand Down Expand Up @@ -79,14 +79,14 @@ def upward(
)
result = self.neuron.aggregate_bounds(grounding_rows, self.func(input_bounds))
if result:
logging.info(
self.logger.info(
"↑ BOUNDS UPDATED "
f"TIGHTENED:{result} "
f"FOR:'{self.name}' "
f"FORMULA:{self.formula_number} "
)
if self.is_contradiction():
logging.info(
self.logger.info(
"↑ CONTRADICTION "
f"FOR:'{self.name}' "
f"FORMULA:{self.formula_number} "
Expand Down Expand Up @@ -148,7 +148,7 @@ def downward(
op_grounding_rows, new_bounds[..., op_index], duplicates=duplicates
)
if op_aggregate:
logging.info(
self.logger.info(
"↓ BOUNDS UPDATED "
f"TIGHTENED:{op_aggregate} "
f"FOR:'{op.name}' "
Expand All @@ -157,7 +157,7 @@ def downward(
f"PARENT:{self.formula_number} "
)
if op.is_contradiction():
logging.info(
self.logger.info(
"↓ CONTRADICTION "
f"FOR:'{op.name}' "
f"FROM:'{self.name}' "
Expand Down
4 changes: 2 additions & 2 deletions lnn/symbolic/logic/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
import numpy as np

_utils.logger_setup()
_utils.get_logger()
subclasses: typing.Dict[str, object] = {}


Expand Down Expand Up @@ -353,7 +353,7 @@ def recurse(formula):
recurse(self)
if store:
if edge_replace:
logging.info(f"ABSORBED NEGATIONS INTO WEIGHTS FOR: '{self.name}'")
self.logger.info(f"ABSORBED NEGATIONS INTO WEIGHTS FOR: '{self.name}'")
return edge_replace, n_negations
else:
return operands, edge_replace, n_negations
Expand Down
2 changes: 1 addition & 1 deletion lnn/symbolic/logic/n_ary_neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ... import _utils
from ...constants import Direction, NeuralActivation

_utils.logger_setup()
_utils.get_logger()


class _NAryNeuron(_ConnectiveNeuron):
Expand Down
8 changes: 4 additions & 4 deletions lnn/symbolic/logic/n_ary_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ... import _utils
from ...constants import Fact

_utils.logger_setup()
_utils.get_logger()


class _NAryOperator(_ConnectiveFormula):
Expand Down Expand Up @@ -90,7 +90,7 @@ def upward(
)
result = self.neuron.aggregate_bounds(grounding_rows, input_bounds)
if result:
logging.info(
self.logger.info(
"↑ BOUNDS UPDATED "
f"TIGHTENED:{result} "
f"FOR:'{self.name}' "
Expand Down Expand Up @@ -146,7 +146,7 @@ def downward(
op_grounding_rows[g_i] = op.grounding_table.get(op_g)
op_aggregate = op.neuron.aggregate_bounds(op_grounding_rows, parent)
if op_aggregate:
logging.info(
self.logger.info(
"↓ BOUNDS UPDATED "
f"TIGHTENED:{op_aggregate} "
f"FOR:'{op.name}' "
Expand Down Expand Up @@ -192,7 +192,7 @@ def upward(
)
result = self.neuron.aggregate_bounds(grounding_rows, input_bounds)
if result:
logging.info(
self.logger.info(
"↑ BOUNDS UPDATED "
f"TIGHTENED:{result} "
f"FOR:'{self.name}' "
Expand Down
2 changes: 1 addition & 1 deletion lnn/symbolic/logic/neural_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ... import _utils, _exceptions
from ...constants import NeuralActivation

_utils.logger_setup()
_utils.get_logger()


class _NeuralActivation:
Expand Down
2 changes: 1 addition & 1 deletion lnn/symbolic/logic/node_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ... import _utils

_utils.logger_setup()
_utils.get_logger()


class _NodeActivation:
Expand Down
6 changes: 3 additions & 3 deletions lnn/symbolic/logic/unary_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ...constants import Fact, Direction, Bound
from torch.nn.parameter import Parameter

_utils.logger_setup()
_utils.get_logger()


class _UnaryOperator(_ConnectiveFormula):
Expand Down Expand Up @@ -388,7 +388,7 @@ def upward(self, **kwds) -> float:
None, _utils.negate_bounds(self.operands[0].get_data(*groundings))
)
if self.is_contradiction():
logging.info(
self.logger.info(
"↑ CONTRADICTION "
f"FOR:'{self.name}' "
f"FORMULA:{self.formula_number} "
Expand Down Expand Up @@ -419,7 +419,7 @@ def downward(self, **kwds) -> torch.Tensor:
None, _utils.negate_bounds(self.get_data(*groundings))
)
if self.operands[0].is_contradiction():
logging.info(
self.logger.info(
"↓ CONTRADICTION "
f"FOR:'{self.operands[0].name}' "
f"FROM:'{self.name}' "
Expand Down
2 changes: 1 addition & 1 deletion lnn/symbolic/logic/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ... import _utils, utils

_utils.logger_setup()
_utils.get_logger()


class Variable:
Expand Down

0 comments on commit b303d26

Please sign in to comment.