Skip to content

Commit

Permalink
feature: expand objects returned by LoadScan.load_topostats()
Browse files Browse the repository at this point in the history
Closes #1067

Modifies `LoadScan.load_topostats()` to take an argument `extract: str = "all"` so that by default the cleaned
image (post filter) that is stored at `image`, `px_to_nm_scaling` and `data` that are stored in `.topostats` HDF5 are
returned.

To assist with #517 though it is also possible to specify other data to extract such as `raw` to get the original
image array and `pixel_to_nm_scaling` should the user want to re-run the `Filter` stage and `filter` should the user
wish to re-run the grain detection on the cleaned (post-Filter) array

The user options are mapped to the keys used in the HDF5 structure by means of a dictionary (which is local to the
`.load_topostats()` function) and will be extended as required in subsequent work.

Tests are expanded.
  • Loading branch information
ns-rse committed Jan 6, 2025
1 parent cf46824 commit b353a4a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
33 changes: 30 additions & 3 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,11 +582,11 @@ def test_load_scan_asd(load_scan_asd: LoadScans) -> None:
assert px_to_nm_scaling == 2.0


def test_load_scan_topostats(load_scan_topostats: LoadScans) -> None:
"""Test loading of a .topostats file."""
def test_load_scan_topostats_all(load_scan_topostats: LoadScans) -> None:
"""Test loading all data from a .topostats file."""
load_scan_topostats.img_path = load_scan_topostats.img_paths[0]
load_scan_topostats.filename = load_scan_topostats.img_paths[0].stem
image, px_to_nm_scaling, data = load_scan_topostats.load_topostats()
image, px_to_nm_scaling, data = load_scan_topostats.load_topostats(extract="all")
above_grain_mask = data["grain_masks"]["above"]
grain_trace_data = data["grain_trace_data"]
assert isinstance(image, np.ndarray)
Expand All @@ -601,6 +601,33 @@ def test_load_scan_topostats(load_scan_topostats: LoadScans) -> None:
assert grain_trace_data.keys() == {"above"}


@pytest.mark.parametrize(
("extract", "array_sum"),
[
pytest.param("raw", 30695369.188316286, id="loading raw data"),
pytest.param("filter", 184140.8593819073, id="loading filtered data"),
],
)
def test_load_scan_topostats_components(load_scan_topostats: LoadScans, extract: str, array_sum: float) -> None:
"""Test loading different components from a .topostats file."""
load_scan_topostats.img_path = load_scan_topostats.img_paths[0]
load_scan_topostats.filename = load_scan_topostats.img_paths[0].stem
image, px_to_nm_scaling, _ = load_scan_topostats.load_topostats(extract)
assert isinstance(image, np.ndarray)
assert image.shape == (1024, 1024)
assert image.sum() == array_sum
assert isinstance(px_to_nm_scaling, float)
assert px_to_nm_scaling == 0.4940029296875


def test_load_scan_topostats_keyerror(load_scan_topostats: LoadScans):
"""Test KeyError is raised when invalid extract is provided."""
load_scan_topostats.img_path = load_scan_topostats.img_paths[0]
load_scan_topostats.filename = load_scan_topostats.img_paths[0].stem
with pytest.raises(KeyError): # noqa: PT011
load_scan_topostats.load_topostats(extract="nothing")


@pytest.mark.parametrize(
("load_scan_object", "length", "image_shape", "image_sum", "filename", "pixel_to_nm_scaling"),
[
Expand Down
22 changes: 19 additions & 3 deletions topostats/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def load_spm(self) -> tuple[npt.NDArray, float]:
LOGGER.error(f"File Not Found : {self.img_path}")
raise

def load_topostats(self) -> tuple[npt.NDArray, float]:
def load_topostats(self, extract: str = "all") -> tuple[npt.NDArray, float, Any]:
"""
Load a .topostats file (hdf5 format).
Expand All @@ -650,17 +650,33 @@ def load_topostats(self) -> tuple[npt.NDArray, float]:
Note that grain masks are stored via self.grain_masks rather than returned due to how we extract information for
all other file loading functions.
Parameters
----------
extract : str
String of which image (Numpy array) and data to extract, default is 'all' which returns the cleaned
(post-Filter) image, `pixel_to_nm_scaling` and all `data`. It is possible to extract image arrays for other
stages of processing such as `raw` or 'filter'.
Returns
-------
tuple[npt.NDArray, float]
tuple[npt.NDArray, float, Any]
A tuple containing the image and its pixel to nanometre scaling value.
"""
map_stage_to_image = {"raw": "image_original"}
try:
LOGGER.debug(f"Loading image from : {self.img_path}")
return topostats.load_topostats(self.img_path)
image, px_to_nm_scaling, data = topostats.load_topostats(self.img_path)
except FileNotFoundError:
LOGGER.error(f"File Not Found : {self.img_path}")
raise
try:
if extract == "all":
return (image, px_to_nm_scaling, data)
if extract == "filter":
return (image, px_to_nm_scaling, None)
return (data[map_stage_to_image[extract]], px_to_nm_scaling, None)
except KeyError as ke:
raise KeyError(f"Can not extract array of type '{extract}' from .topostats objects.") from ke

def load_asd(self) -> tuple[npt.NDArray, float]:
"""
Expand Down

0 comments on commit b353a4a

Please sign in to comment.