Skip to content

Commit

Permalink
Correct bug in linear function
Browse files Browse the repository at this point in the history
  • Loading branch information
achiefa committed Feb 17, 2025
1 parent 9f286c6 commit 1b40c4c
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,17 @@ def linear_bin_function(
a: npt.ArrayLike, y_shift: npt.ArrayLike, bin_edges: npt.ArrayLike
) -> np.ndarray:
"""
This function defines the linear bin function used to construct the prior. The bins of the
function are constructed using pairs of consecutive points. For instance, given the set of
points [0.0, 0.1, 0.3, 0.5], there will be three bins with edges [[0.0, 0.1], [0.1, 0.3],
0.3, 0.5]]. Each bin is coupled with a shift, which correspond to the y-value of the bin.
This function defines the linear bin function used to construct the prior. Specifically,
the prior is constructed using a triangular function whose value at the peak of the node
is linked to the right and left nodes using a straight line.
Parameters
----------
a: ArrayLike of float
A one-dimensional array of points at which the function is evaluated.
y_shift: ArrayLike of float
A one-dimensional array whose elements represent the y-value of each bin
bin_edges: ArrayLike of float
bin_nodes: ArrayLike of float
A one-dimensional array containing the edges of the bins. The bins are
constructed using pairs of consecutive points.
Expand Down Expand Up @@ -153,7 +152,9 @@ def linear_bin_function(
bin_high = bin_mid
m1 = shift / (bin_mid - bin_low)
m2 = 0.0
cond_low = np.multiply(a >= bin_low, a < bin_mid)
cond_low = np.multiply(
a >= bin_low, a < bin_mid if shift_pos != len(y_shift) - 1 else a <= bin_mid
)
cond_high = np.multiply(
a >= bin_mid, a < bin_high if shift_pos != len(y_shift) - 1 else a <= bin_high
)
Expand Down Expand Up @@ -1052,7 +1053,6 @@ def average(y_values_pc2_p, y_values_pcL_p, y_values_pc3_p):
# When this happens, this part must be updated.
eta = cd_table['kin1'].to_numpy()
pT = cd_table['kin2'].to_numpy()
q2 = pT * pT

pc_func = JET_pc(pc_jet_nodes, pT, eta, pc_func_type)
for pars_pc in pars_combs:
Expand Down

0 comments on commit 1b40c4c

Please sign in to comment.