Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Hopkins <thomashopkins000@gmail.com>
  • Loading branch information
maffettone and thomashopkins32 authored Feb 6, 2025
1 parent 8aeffb7 commit 1c31f2e
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,6 @@ def train_inputs(self, index=None, **subset_kwargs):
dof = self.dofs[index]
raw_inputs = self.raw_inputs(index=index, **subset_kwargs)

# check that inputs values are inside acceptable values
valid = (raw_inputs >= dof._trust_domain[0]) & (raw_inputs <= dof._trust_domain[1])
raw_inputs = torch.where(valid, raw_inputs, np.nan)

return dof._transform(raw_inputs)

Expand All @@ -196,9 +193,6 @@ def train_targets(self, concatenate=False, **subset_kwargs):
for obj in self.objectives(**subset_kwargs):
y = raw_targets_dict[obj.name]

# check that targets values are inside acceptable values
valid = (y >= obj._trust_domain[0]) & (y <= obj._trust_domain[1])
y = torch.where(valid, y, np.nan)

targets_dict[obj.name] = obj._transform(y)

Expand Down Expand Up @@ -587,8 +581,7 @@ def ask(self, batch_size) -> Tuple[Sequence[Dict[str, ArrayLike]], Sequence[Arra
points: Dict = default_result.pop("points")
acqf_obj: List[ArrayLike] = default_result.pop("acqf_obj")
# Turn dict of list of points into list of consistently sized points
keys = list(points.keys())
points: List[Tuple[ArrayLike]] = list(zip(*[points[key] for key in keys]))
points: List[Tuple[ArrayLike]] = list(zip(*[value for _, value in points.items()]))
dicts = []
for point, obj in zip(points, acqf_obj):
d = default_result.copy()
Expand Down Expand Up @@ -620,7 +613,7 @@ def unpack_run(self, run):
dependent_var :
The measured data, processed for relevance
"""
if self.digestion == default_digestion_function:
if not self.digestion or self.digestion == default_digestion_function:
# Assume all raw data is available in primary stream as keys
return (
[run.primary.data[key].read() for key in self.dofs.names],
Expand Down

0 comments on commit 1c31f2e

Please sign in to comment.