Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change IQmod to dataclass #1001

Draft
wants to merge 2 commits into
base: next
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 203 additions & 83 deletions src/drtsans/dataobjects.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import namedtuple
from collections.abc import Iterable
from dataclasses import asdict, dataclass, fields

import h5py
from enum import Enum
import numpy as np
Expand Down Expand Up @@ -80,16 +82,21 @@ def _nary_operation(iq_objects, operation, unpack=True, **kwargs):
"""
reference_object = iq_objects[0]
assert len(set([type(iq_object) for iq_object in iq_objects])) == 1 # check all objects of same type
new_components = list()
for i in range(len(reference_object)): # iterate over the IQ object components
i_components = [iq_object[i] for iq_object in iq_objects] # collect the ith components of each object
if True in [i_component is None for i_component in i_components]: # is any of these None?
new_components.append(None)
elif unpack is True:
new_components.append(operation(*i_components, **kwargs))
new_components = {}
for f in fields(reference_object): # iterate over the IQ object components
component_name = f.name
values = [
getattr(iq_object, component_name) for iq_object in iq_objects
] # collect this component of each object

if any(value is None for value in values): # is any of these None?
new_components[component_name] = None
elif unpack:
new_components[component_name] = operation(*values, **kwargs)
else:
new_components.append(operation(i_components, **kwargs))
return reference_object.__class__(*new_components)
new_components[component_name] = operation(values, **kwargs)

return reference_object.__class__(**new_components)


def _extract(iq_object, selection):
Expand All @@ -104,15 +111,16 @@ def _extract(iq_object, selection):
Parameters
----------
iq_object: ~drtsans.dataobjects.IQmod, ~drtsans.dataobjects.IQazimuthal, ~drtsans.dataobjects.IQcrystal
selection: int, slice, :ref:`~numpy.ndarray`
selection: int, slice, np.ndarray, :ref:`~numpy.ndarray`
Any selection that can be passed onto a :ref:`~numpy.ndarray`

Returns
-------
~drtsans.dataobjects.IQmod, ~drtsans.dataobjects.IQazimuthal, ~drtsans.dataobjects.IQcrystal
"""
component_fragments = list()
for component in iq_object:
for _field in fields(iq_object):
component = getattr(iq_object, _field.name)
if component is None:
component_fragments.append(None)
else:
Expand All @@ -122,7 +130,10 @@ def _extract(iq_object, selection):
fragment,
]
component_fragments.append(fragment)
return iq_object.__class__(*component_fragments)
# Rebuild the dataclass with the modified fragments
return iq_object.__class__(
**{_field.name: fragment for _field, fragment in zip(fields(iq_object), component_fragments)}
)


def scale_intensity(iq_object, scaling):
Expand All @@ -138,9 +149,10 @@ def scale_intensity(iq_object, scaling):
-------
~drtsans.dataobjects.IQmod, ~drtsans.dataobjects.IQazimuthal, ~drtsans.dataobjects.IQcrystal
"""
intensity = scaling * iq_object.intensity
error = scaling * iq_object.error
return iq_object.__class__(intensity, error, *[iq_object[i] for i in range(2, len(iq_object))])
iq_fields = {field.name: getattr(iq_object, field.name) for field in fields(iq_object)}
iq_fields["intensity"] = scaling * iq_fields["intensity"]
iq_fields["error"] = scaling * iq_fields["error"]
return iq_object.__class__(**iq_fields)


def verify_same_q_bins(iq0, iq1, raise_exception_if_diffrent=False, tolerance=None):
Expand Down Expand Up @@ -245,9 +257,125 @@ def concatenate(iq_objects):
return _nary_operation(iq_objects, np.concatenate, unpack=False)


class IQmod(namedtuple("IQmod", "intensity error mod_q delta_mod_q wavelength")):
# separate arguments with and without defaults into separate base classes to allow
# child classes to have non-default arguments
# see https://stackoverflow.com/a/53085935/23095774


@dataclass(frozen=True)
class _I1DBase:
intensity: np.ndarray | list
error: np.ndarray | list


@dataclass(frozen=True)
class _I1DDefaultsBase:
wavelength: np.ndarray | list = None


@dataclass(frozen=True)
class BaseI1D(_I1DDefaultsBase, _I1DBase):
"""Base class for 1D intensity profiles"""

def __post_init__(self):
# these conversions do nothing if the supplied information is already a numpy.ndarray
# use setattr to subvert 'frozen'
object.__setattr__(self, "intensity", np.array(self.intensity))
object.__setattr__(self, "error", np.array(self.error))

# if intensity is 1d, then everything else will be if they are parallel
if len(self.intensity.shape) != 1:
raise TypeError('"intensity" must be a 1-dimensional array, found shape={}'.format(self.intensity.shape))

# check that the mandatory fields are parallel
_check_parallel(self.intensity, self.error)

# check optional field
if self.wavelength is not None:
# use setattr to subvert 'frozen'
object.__setattr__(self, "wavelength", np.array(self.wavelength))
_check_parallel(self.intensity, self.wavelength)


@dataclass(frozen=True)
class _I1DannularBase(_I1DBase):
phi: np.ndarray | list


@dataclass(frozen=True)
class I1Dannular(BaseI1D, _I1DannularBase):
"""This class holds the information for I(phi) scalar.

Parameters
----------
intensity: list | np.ndarray
Intensity
error: list | np.ndarray
Error in intensity
phi: list | np.ndarray
Annular angle
wavelength: list | np.ndarray, default None
Wavelength
"""

phi: np.ndarray | list

def __post_init__(self):
super().__post_init__()

# use setattr to subvert 'frozen'
object.__setattr__(self, "phi", np.array(self.phi))

# check that the mandatory fields are parallel
_check_parallel(self.intensity, self.phi)


@dataclass(frozen=True)
class _IQmodBase(_I1DBase):
mod_q: np.ndarray | list


@dataclass(frozen=True)
class _IQmodDefaultsBase:
delta_mod_q: np.ndarray | list = None


@dataclass(frozen=True)
class IQmod(BaseI1D, _IQmodDefaultsBase, _IQmodBase):
r"""This class holds the information for I(Q) scalar. All of the arrays must be 1-dimensional
and parallel (same length). The ``delta_mod_q`` and ``wavelength`` fields are optional."""
and parallel (same length). The ``delta_mod_q`` and ``wavelength`` fields are optional.

Parameters
----------
intensity: list | np.ndarray
Intensity
error: list | np.ndarray
Error in intensity
mod_q: list | np.ndarray
Q modulus
delta_mod_q: list | np.ndarray, default None
Error in Q modulus
wavelength: list | np.ndarray, default None
Wavelength
"""

mod_q: np.ndarray | list
delta_mod_q: np.ndarray | list | None = None

def __post_init__(self):
super().__post_init__()

# use setattr to subvert 'frozen'
object.__setattr__(self, "mod_q", np.array(self.mod_q))

# check that the mandatory fields are parallel
_check_parallel(self.intensity, self.mod_q)

# check optional field
if self.delta_mod_q is not None:
# use setattr to subvert 'frozen'
object.__setattr__(self, "delta_mod_q", np.array(self.delta_mod_q))
_check_parallel(self.intensity, self.delta_mod_q)

@staticmethod
def read_csv(file, sep=" "):
Expand Down Expand Up @@ -295,30 +423,6 @@ def read_csv(file, sep=" "):
}
return IQmod(*args, **kwargs)

def __new__(cls, intensity, error, mod_q, delta_mod_q=None, wavelength=None):
# these conversions do nothing if the supplied information is already a numpy.ndarray
intensity = np.array(intensity)
error = np.array(error)
mod_q = np.array(mod_q)

# if intensity is 1d, then everything else will be if they are parallel
if len(intensity.shape) != 1:
raise TypeError('"intensity" must be a 1-dimensional array, found shape={}'.format(intensity.shape))

# check that the mandatory fields are parallel
_check_parallel(intensity, error, mod_q)

# work with optional fields
if delta_mod_q is not None:
delta_mod_q = np.array(delta_mod_q)
_check_parallel(intensity, delta_mod_q)
if wavelength is not None:
wavelength = np.array(wavelength)
_check_parallel(intensity, wavelength)

# pass everything to namedtuple
return super(IQmod, cls).__new__(cls, intensity, error, mod_q, delta_mod_q, wavelength)

def __mul__(self, scaling):
r"""Scale intensities and their uncertainties by a number"""
return scale_intensity(self, scaling)
Expand Down Expand Up @@ -434,7 +538,7 @@ def to_csv(self, file_name, sep=" ", float_format="%.6E", skip_nan=True):
# Convert to dictionary to construct a pandas DataFrame instance
from pandas import DataFrame

frame = DataFrame({label: value for label, value in self._asdict().items() if value is not None})
frame = DataFrame({label: value for label, value in asdict(self).items() if value is not None})

# Create the order of the columns
i_q_mod_cols = ["mod_q", "intensity", "error"] # 3 mandatory columns
Expand Down Expand Up @@ -552,7 +656,8 @@ def save_iqmod(
iq.to_csv(file, sep=sep, float_format=float_format, skip_nan=skip_nan)


class IQazimuthal(namedtuple("IQazimuthal", "intensity error qx qy delta_qx delta_qy wavelength")):
@dataclass(frozen=True)
class IQazimuthal:
r"""
This class holds the information for the azimuthal projection, I(Qx, Qy). The resolution terms,
(``delta_qx``, ``delta_qy``) and ``wavelength`` fields are optional.
Expand All @@ -570,64 +675,76 @@ class IQazimuthal(namedtuple("IQazimuthal", "intensity error qx qy delta_qx delt
because qx and qy will be created in such style.
"""

def __new__(cls, intensity, error, qx, qy, delta_qx=None, delta_qy=None, wavelength=None): # noqa: C901
intensity: np.ndarray | list
error: np.ndarray | list
qx: np.ndarray | list
qy: np.ndarray | list
delta_qx: np.ndarray | list = None
delta_qy: np.ndarray | list = None
wavelength: np.ndarray | list = None

def __post_init__(self):
# these conversions do nothing if the supplied information is already a numpy.ndarray
intensity = np.array(intensity)
error = np.array(error)
qx = np.array(qx)
qy = np.array(qy)
# use setattr to subvert 'frozen'
object.__setattr__(self, "intensity", np.array(self.intensity))
object.__setattr__(self, "error", np.array(self.error))
object.__setattr__(self, "qx", np.array(self.qx))
object.__setattr__(self, "qy", np.array(self.qy))

# check that the mandatory fields are parallel
if len(intensity.shape) == 1:
_check_parallel(intensity, error, qx, qy)
elif len(intensity.shape) == 2:
if len(qx.shape) == 1:
if len(self.intensity.shape) == 1:
_check_parallel(self.intensity, self.error, self.qx, self.qy)
elif len(self.intensity.shape) == 2:
if len(self.qx.shape) == 1:
# Qx and Qy are given in 1D array (not meshed)
_check_parallel(intensity, error)
if intensity.shape[0] != qx.shape[0]:
_check_parallel(self.intensity, self.error)
if self.intensity.shape[0] != self.qx.shape[0]:
raise TypeError(
"Incompatible dimensions intensity[{}] and qx[{}]".format(intensity.shape, qx.shape[0])
"Incompatible dimensions intensity[{}] and qx[{}]".format(
self.intensity.shape, self.qx.shape[0]
)
)
if intensity.shape[1] != qy.shape[0]:
if self.intensity.shape[1] != self.qy.shape[0]:
raise TypeError(
"Incompatible dimensions intensity[{}] and qy[{}]".format(intensity.shape, qy.shape[0])
"Incompatible dimensions intensity[{}] and qy[{}]".format(
self.intensity.shape, self.qy.shape[0]
)
)
elif len(qx.shape) == 2:
elif len(self.qx.shape) == 2:
# Qx and Qy are given in meshed 2D
_check_parallel(intensity, error, qx, qy)
_check_parallel(self.intensity, self.error, self.qx, self.qy)
else:
raise TypeError("Qx can only be of dimension 1 or 2, found {}".format(len(qx.shape)))
raise TypeError("Qx can only be of dimension 1 or 2, found {}".format(len(self.qx.shape)))
else:
raise TypeError("intensity can only be of dimension 1 or 2, found {}".format(len(intensity.shape)))
raise TypeError("intensity can only be of dimension 1 or 2, found {}".format(len(self.intensity.shape)))

# work with optional fields
if np.logical_xor(delta_qx is None, delta_qy is None):
if np.logical_xor(self.delta_qx is None, self.delta_qy is None):
raise TypeError("Must specify either both or neither of delta_qx and delta_qy")
if delta_qx is not None:
delta_qx = np.array(delta_qx)
delta_qy = np.array(delta_qy)
_check_parallel(intensity, delta_qx, delta_qy)
if wavelength is not None:
wavelength = np.array(wavelength)
_check_parallel(intensity, wavelength)
if self.delta_qx is not None:
object.__setattr__(self, "delta_qx", np.array(self.delta_qx))
object.__setattr__(self, "delta_qy", np.array(self.delta_qy))
_check_parallel(self.intensity, self.delta_qx, self.delta_qy)
if self.wavelength is not None:
object.__setattr__(self, "wavelength", np.array(self.wavelength))
_check_parallel(self.intensity, self.wavelength)

# make the qx and qy have the same shape as the data
if len(intensity.shape) == 2 and len(qx.shape) == 1 and len(qy.shape) == 1:
if len(self.intensity.shape) == 2 and len(self.qx.shape) == 1 and len(self.qy.shape) == 1:
# Using meshgrid to construct the Qx and Qy 2D arrays. This is consistent with the algorithm
# that is used in bin_iq_2d()
qx, qy = np.meshgrid(qx, qy, indexing="ij")
_qx, _qy = np.meshgrid(self.qx, self.qy, indexing="ij")
object.__setattr__(self, "qx", _qx)
object.__setattr__(self, "qy", _qy)

# Sanity check
assert qx.shape == intensity.shape, (
f"qx and intensity must have same shapes. It is not now: {qx.shape} vs {intensity.shape}"
assert self.qx.shape == self.intensity.shape, (
f"qx and intensity must have same shapes. It is not now: {self.qx.shape} vs {self.intensity.shape}"
)
assert qy.shape == intensity.shape, (
f"qy and intensity must have same shapes. It is not now: {qy.shape} vs {intensity.shape}"
assert self.qy.shape == self.intensity.shape, (
f"qy and intensity must have same shapes. It is not now: {self.qy.shape} vs {self.intensity.shape}"
)

# pass everything to namedtuple
return super(IQazimuthal, cls).__new__(cls, intensity, error, qx, qy, delta_qx, delta_qy, wavelength)

def be_finite(self):
"""Remove NaN by flattening first

Expand Down Expand Up @@ -908,12 +1025,15 @@ def _nary_assertion(iq_objects, assertion_function, unpack=True, **kwargs):
"""
reference_object = iq_objects[0] # pick the first of the list as reference object
assert len(set([type(iq_object) for iq_object in iq_objects])) == 1 # check all objects of same type
for i in range(len(reference_object)): # iterate over the IQ object components
component_name = reference_object._fields[i]
# Iterate over I object fields/components
for f in fields(reference_object):
component_name = f.name
print(f"all_close on {component_name}")
i_components = [iq_object[i] for iq_object in iq_objects] # collect the ith components of each object
if True in [i_component is None for i_component in i_components]: # is any of these None?
if set(i_components) == set([None]):
i_components = [
getattr(iq_object, component_name) for iq_object in iq_objects
] # collect the component of each object
if any(value is None for value in i_components): # is any of these None?
if all(value is None for value in i_components):
continue # all arrays are actually None, so they are identical
else:
raise AssertionError(f"field {component_name} is None for some of the iQ objects")
Expand Down
Loading
Loading