Skip to content

Commit

Permalink
Comment correlation algorithm, remove flattening of Velocity observab…
Browse files Browse the repository at this point in the history
…le (#369)

* Comment correlator.py

* Don't flatten Velocity values

* Add Attributes section to correlator
  • Loading branch information
harveydevereux authored Jan 13, 2025
1 parent 1d79227 commit b727ff3
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 6 deletions.
62 changes: 57 additions & 5 deletions janus_core/processing/correlator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ class Correlator:
"""
Correlate scalar real values, <ab>.
Implements the algorithm detailed in https://doi.org/10.1063/1.3491098.
Data pairs are observed iteratively and stored in a set of rolling hierarchical data
blocks.
Once a block is filled, coarse graining may be applied to update coarser block
levels by updating the coarser block with the average of values accumulated up to
that point in the filled block.
The correlation is continuously updated when any block is updated with new data.
Parameters
----------
blocks : int
Expand All @@ -22,6 +33,27 @@ class Correlator:
Number of points per block.
averaging : int
Averaging window per block level.
Attributes
----------
_max_block_used : int
Which levels have been updated with data.
_min_dist : int
First point in coarse-grained block relevant for correlation updates.
_accumulator : NDArray[float64]
Sum of data seen for calculating the average between blocks.
_count_accumulated : NDArray[int]
Data points accumulated at this block.
_shift_index : NDArray[int]
Current position in each block's data store.
_shift : NDArray[float64]
Rolling data store for each block.
_shift_not_null : NDArray[bool]
If data is stored in this block's rolling data store, at a given index.
_correlation : NDArray[float64]
Running correlation values.
_count_correlated : NDArray[int]
Count of correlation updates for each block.
"""

def __init__(self, *, blocks: int, points: int, averaging: int) -> None:
Expand All @@ -31,11 +63,11 @@ def __init__(self, *, blocks: int, points: int, averaging: int) -> None:
Parameters
----------
blocks : int
Number of correlation blocks.
Number of resolution levels.
points : int
Number of points per block.
Data points at each resolution.
averaging : int
Averaging window per block level.
Coarse-graining between resolution levels.
"""
self._blocks = blocks
self._points = points
Expand Down Expand Up @@ -78,46 +110,59 @@ def _propagate(self, a: float, b: float, block: int) -> None:
Block in the hierachy being updated.
"""
if block == self._blocks:
# Hit the end of the data structure.
return

shift = self._shift_index[block]
self._max_block_used = max(self._max_block_used, block)

# Update the rolling data store, and accumulate
self._shift[block, shift, :] = a, b
self._accumulator[block, :] += a, b
self._shift_not_null[block, shift] = True
self._count_accumulated[block] += 1

if self._count_accumulated[block] == self._averaging:
# Hit the coarse graining threshold, the next block can be updated.
self._propagate(
self._accumulator[block, 0] / self._averaging,
self._accumulator[block, 1] / self._averaging,
block + 1,
)
# Reset the accumulator at this block level.
self._accumulator[block, :] = 0.0
self._count_accumulated[block] = 0

# Update the correlation.
i = self._shift_index[block]
if block == 0:
# Need to multiply by all in this block (full resolution).
j = i
for point in range(self._points):
if self._shifts_valid(block, i, j):
# Correlate at this lag.
self._correlation[block, point] += (
self._shift[block, i, 0] * self._shift[block, j, 1]
)
self._count_correlated[block, point] += 1
j -= 1
if j < 0:
# Wrap to start of rolling data store.
j += self._points
else:
# Only need to update after points/averaging.
# The previous block already accounts for those points.
for point in range(self._min_dist, self._points):
if j < 0:
j = j + self._points
if self._shifts_valid(block, i, j):
# Correlate at this lag.
self._correlation[block, point] += (
self._shift[block, i, 0] * self._shift[block, j, 1]
)
self._count_correlated[block, point] += 1
j = j - 1
# Update the rolling data store.
self._shift_index[block] = (self._shift_index[block] + 1) % self._points

def _shifts_valid(self, block: int, p_i: int, p_j: int) -> bool:
Expand Down Expand Up @@ -154,11 +199,13 @@ def get_lags(self) -> Iterable[float]:
lag = 0
for i in range(self._points):
if self._count_correlated[0, i] > 0:
# Data has been correlated, at full resolution.
lags[lag] = i
lag += 1
for k in range(1, self._max_block_used):
for i in range(self._min_dist, self._points):
if self._count_correlated[k, i] > 0:
# Data has been correlated at a coarse-grained level.
lags[lag] = float(i) * float(self._averaging) ** k
lag += 1
return lags[0:lag]
Expand All @@ -177,13 +224,16 @@ def get_value(self) -> Iterable[float]:
lag = 0
for i in range(self._points):
if self._count_correlated[0, i] > 0:
# Data has been correlated at full resolution.
correlation[lag] = (
self._correlation[0, i] / self._count_correlated[0, i]
)
lag += 1
for k in range(1, self._max_block_used):
for i in range(self._min_dist, self._points):
# Indices less than points/averaging accounted in the previous block.
if self._count_correlated[k, i] > 0:
# Data has been correlated at a coarse-grained level.
correlation[lag] = (
self._correlation[k, i] / self._count_correlated[k, i]
)
Expand Down Expand Up @@ -275,13 +325,15 @@ def update(self, atoms: Atoms) -> None:
atoms : Atoms
Atoms object to observe values from.
"""
value_pairs = zip(self._get_a(atoms), self._get_b(atoms))
# All pairs of data to be correlated.
value_pairs = zip(self._get_a(atoms).flatten(), self._get_b(atoms).flatten())
if self._correlators is None:
# Initialise correlators automatically.
self._correlators = [
Correlator(
blocks=self.blocks, points=self.points, averaging=self.averaging
)
for _ in range(len(self._get_a(atoms)))
for _ in range(len(self._get_a(atoms).flatten()))
]
for corr, values in zip(self._correlators, value_pairs):
corr.update(*values)
Expand Down
2 changes: 1 addition & 1 deletion janus_core/processing/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,4 @@ def __call__(self, atoms: Atoms) -> list[float]:
list[float]
The velocity values.
"""
return atoms.get_velocities()[self.atoms_slice, :][:, self._indices].flatten()
return atoms.get_velocities()[self.atoms_slice, :][:, self._indices]

0 comments on commit b727ff3

Please sign in to comment.