diff --git a/janus_core/processing/correlator.py b/janus_core/processing/correlator.py index 10f18c66..4328d20d 100644 --- a/janus_core/processing/correlator.py +++ b/janus_core/processing/correlator.py @@ -14,6 +14,17 @@ class Correlator: """ Correlate scalar real values, . + Implements the algorithm detailed in https://doi.org/10.1063/1.3491098. + + Data pairs are observed iteratively and stored in a set of rolling hierarchical data + blocks. + + Once a block is filled, coarse graining may be applied to update coarser block + levels by updating the coarser block with the average of values accumulated up to + that point in the filled block. + + The correlation is continuously updated when any block is updated with new data. + Parameters ---------- blocks : int @@ -22,6 +33,27 @@ class Correlator: Number of points per block. averaging : int Averaging window per block level. + + Attributes + ---------- + _max_block_used : int + Which levels have been updated with data. + _min_dist : int + First point in coarse-grained block relevant for correlation updates. + _accumulator : NDArray[float64] + Sum of data seen for calculating the average between blocks. + _count_accumulated : NDArray[int] + Data points accumulated at this block. + _shift_index : NDArray[int] + Current position in each block's data store. + _shift : NDArray[float64] + Rolling data store for each block. + _shift_not_null : NDArray[bool] + If data is stored in this block's rolling data store, at a given index. + _correlation : NDArray[float64] + Running correlation values. + _count_correlated : NDArray[int] + Count of correlation updates for each block. """ def __init__(self, *, blocks: int, points: int, averaging: int) -> None: @@ -31,11 +63,11 @@ def __init__(self, *, blocks: int, points: int, averaging: int) -> None: Parameters ---------- blocks : int - Number of correlation blocks. + Number of resolution levels. points : int - Number of points per block. + Data points at each resolution. averaging : int - Averaging window per block level. + Coarse-graining between resolution levels. """ self._blocks = blocks self._points = points @@ -78,46 +110,59 @@ def _propagate(self, a: float, b: float, block: int) -> None: Block in the hierachy being updated. """ if block == self._blocks: + # Hit the end of the data structure. return shift = self._shift_index[block] self._max_block_used = max(self._max_block_used, block) + + # Update the rolling data store, and accumulate self._shift[block, shift, :] = a, b self._accumulator[block, :] += a, b self._shift_not_null[block, shift] = True self._count_accumulated[block] += 1 if self._count_accumulated[block] == self._averaging: + # Hit the coarse graining threshold, the next block can be updated. self._propagate( self._accumulator[block, 0] / self._averaging, self._accumulator[block, 1] / self._averaging, block + 1, ) + # Reset the accumulator at this block level. self._accumulator[block, :] = 0.0 self._count_accumulated[block] = 0 + # Update the correlation. i = self._shift_index[block] if block == 0: + # Need to multiply by all in this block (full resolution). j = i for point in range(self._points): if self._shifts_valid(block, i, j): + # Correlate at this lag. self._correlation[block, point] += ( self._shift[block, i, 0] * self._shift[block, j, 1] ) self._count_correlated[block, point] += 1 j -= 1 if j < 0: + # Wrap to start of rolling data store. j += self._points else: + # Only need to update after points/averaging. + # The previous block already accounts for those points. for point in range(self._min_dist, self._points): if j < 0: j = j + self._points if self._shifts_valid(block, i, j): + # Correlate at this lag. self._correlation[block, point] += ( self._shift[block, i, 0] * self._shift[block, j, 1] ) self._count_correlated[block, point] += 1 j = j - 1 + # Update the rolling data store. self._shift_index[block] = (self._shift_index[block] + 1) % self._points def _shifts_valid(self, block: int, p_i: int, p_j: int) -> bool: @@ -154,11 +199,13 @@ def get_lags(self) -> Iterable[float]: lag = 0 for i in range(self._points): if self._count_correlated[0, i] > 0: + # Data has been correlated, at full resolution. lags[lag] = i lag += 1 for k in range(1, self._max_block_used): for i in range(self._min_dist, self._points): if self._count_correlated[k, i] > 0: + # Data has been correlated at a coarse-grained level. lags[lag] = float(i) * float(self._averaging) ** k lag += 1 return lags[0:lag] @@ -177,13 +224,16 @@ def get_value(self) -> Iterable[float]: lag = 0 for i in range(self._points): if self._count_correlated[0, i] > 0: + # Data has been correlated at full resolution. correlation[lag] = ( self._correlation[0, i] / self._count_correlated[0, i] ) lag += 1 for k in range(1, self._max_block_used): for i in range(self._min_dist, self._points): + # Indices less than points/averaging accounted in the previous block. if self._count_correlated[k, i] > 0: + # Data has been correlated at a coarse-grained level. correlation[lag] = ( self._correlation[k, i] / self._count_correlated[k, i] ) @@ -275,13 +325,15 @@ def update(self, atoms: Atoms) -> None: atoms : Atoms Atoms object to observe values from. """ - value_pairs = zip(self._get_a(atoms), self._get_b(atoms)) + # All pairs of data to be correlated. + value_pairs = zip(self._get_a(atoms).flatten(), self._get_b(atoms).flatten()) if self._correlators is None: + # Initialise correlators automatically. self._correlators = [ Correlator( blocks=self.blocks, points=self.points, averaging=self.averaging ) - for _ in range(len(self._get_a(atoms))) + for _ in range(len(self._get_a(atoms).flatten())) ] for corr, values in zip(self._correlators, value_pairs): corr.update(*values) diff --git a/janus_core/processing/observables.py b/janus_core/processing/observables.py index 52fe0b91..188abe45 100644 --- a/janus_core/processing/observables.py +++ b/janus_core/processing/observables.py @@ -266,4 +266,4 @@ def __call__(self, atoms: Atoms) -> list[float]: list[float] The velocity values. """ - return atoms.get_velocities()[self.atoms_slice, :][:, self._indices].flatten() + return atoms.get_velocities()[self.atoms_slice, :][:, self._indices]