diff --git a/cfgrib/dataset.py b/cfgrib/dataset.py index c8ea895d..7bc262b3 100644 --- a/cfgrib/dataset.py +++ b/cfgrib/dataset.py @@ -345,11 +345,11 @@ class OnDiskArray: ) missing_value: float geo_ndim: int = attr.attrib(default=1, repr=False) - dtype = np.dtype("float32") + dtype: np.dtype = attr.attrib(default=messages.DEFAULT_VALUES_DTYPE, repr=False) def build_array(self) -> np.ndarray: """Helper method used to test __getitem__""" - array = np.full(self.shape, fill_value=np.nan, dtype="float32") + array = np.full(self.shape, fill_value=np.nan, dtype=self.dtype) for header_indexes, message_ids in self.field_id_index.items(): # NOTE: fill a single field as found in the message message = self.index.get_field(message_ids[0]) # type: ignore @@ -363,7 +363,7 @@ def __getitem__(self, item): header_item_list = expand_item(item[: -self.geo_ndim], self.shape) header_item = [{ix: i for i, ix in enumerate(it)} for it in header_item_list] array_field_shape = tuple(len(i) for i in header_item_list) + self.shape[-self.geo_ndim :] - array_field = np.full(array_field_shape, fill_value=np.nan, dtype="float32") + array_field = np.full(array_field_shape, fill_value=np.nan, dtype=self.dtype) for header_indexes, message_ids in self.field_id_index.items(): try: array_field_indexes = [it[ix] for it, ix in zip(header_item, header_indexes)] @@ -497,6 +497,7 @@ def build_variable_components( extra_coords: T.Dict[str, str] = {}, coords_as_attributes: T.Dict[str, str] = {}, cache_geo_coords: bool = True, + values_dtype: np.dtype = messages.DEFAULT_VALUES_DTYPE, ) -> T.Tuple[T.Dict[str, int], Variable, T.Dict[str, Variable]]: data_var_attrs = enforce_unique_attributes(index, DATA_ATTRIBUTES_KEYS, filter_by_keys) grid_type_keys = GRID_TYPE_MAP.get(index.getone("gridType"), []) @@ -601,6 +602,7 @@ def build_variable_components( field_id_index=offsets, missing_value=missing_value, geo_ndim=len(geo_dims), + dtype=values_dtype, ) if "time" in coord_vars and "step" in coord_vars: @@ -673,6 +675,7 @@ def build_dataset_components( extra_coords: T.Dict[str, str] = {}, coords_as_attributes: T.Dict[str, str] = {}, cache_geo_coords: bool = True, + values_dtype: np.dtype = messages.DEFAULT_VALUES_DTYPE, ) -> T.Tuple[T.Dict[str, int], T.Dict[str, Variable], T.Dict[str, T.Any], T.Dict[str, T.Any]]: dimensions = {} # type: T.Dict[str, int] variables = {} # type: T.Dict[str, Variable] @@ -700,6 +703,7 @@ def build_dataset_components( extra_coords=extra_coords, coords_as_attributes=coords_as_attributes, cache_geo_coords=cache_geo_coords, + values_dtype=values_dtype, ) except DatasetBuildError as ex: # NOTE: When a variable has more than one value for an attribute we need to raise all diff --git a/cfgrib/messages.py b/cfgrib/messages.py index d2d2000a..17a4e752 100644 --- a/cfgrib/messages.py +++ b/cfgrib/messages.py @@ -69,6 +69,7 @@ def multi_enabled(file: T.IO[bytes]) -> T.Iterator[None]: } DEFAULT_INDEXPATH = "{path}.{short_hash}.idx" +DEFAULT_VALUES_DTYPE = np.dtype("float32") OffsetType = T.Union[int, T.Tuple[int, int]] diff --git a/cfgrib/xarray_plugin.py b/cfgrib/xarray_plugin.py index b5eea32c..2e6d0a2d 100644 --- a/cfgrib/xarray_plugin.py +++ b/cfgrib/xarray_plugin.py @@ -107,6 +107,7 @@ def open_dataset( extra_coords: T.Dict[str, str] = {}, coords_as_attributes: T.Dict[str, str] = {}, cache_geo_coords: bool = True, + values_dtype: np.dtype = messages.DEFAULT_VALUES_DTYPE, ) -> xr.Dataset: store = CfGribDataStore( filename_or_obj, @@ -122,6 +123,7 @@ def open_dataset( extra_coords=extra_coords, coords_as_attributes=coords_as_attributes, cache_geo_coords=cache_geo_coords, + values_dtype=values_dtype, ) with xr.core.utils.close_on_error(store): vars, attrs = store.load() # type: ignore diff --git a/tests/test_30_dataset.py b/tests/test_30_dataset.py index 9914e482..ae9d22f9 100644 --- a/tests/test_30_dataset.py +++ b/tests/test_30_dataset.py @@ -380,3 +380,17 @@ def test_missing_field_values() -> None: t2 = res.variables["t2m"] assert np.isclose(np.nanmean(t2.data[0, :, :]), 268.375) assert np.isclose(np.nanmean(t2.data[1, :, :]), 270.716) + + +def test_default_values_dtype() -> None: + res = dataset.open_file(TEST_DATA_MISSING_VALS) + assert res.variables["t2m"].data.dtype == np.dtype("float32") + assert res.variables["latitude"].data.dtype == np.dtype("float64") + assert res.variables["longitude"].data.dtype == np.dtype("float64") + + +def test_float64_values_dtype() -> None: + res = dataset.open_file(TEST_DATA_MISSING_VALS, values_dtype=np.dtype("float64")) + assert res.variables["t2m"].data.dtype == np.dtype("float64") + assert res.variables["latitude"].data.dtype == np.dtype("float64") + assert res.variables["longitude"].data.dtype == np.dtype("float64") diff --git a/tests/test_50_xarray_plugin.py b/tests/test_50_xarray_plugin.py index 2009198a..3a9ecc6f 100644 --- a/tests/test_50_xarray_plugin.py +++ b/tests/test_50_xarray_plugin.py @@ -175,4 +175,18 @@ def test_xr_open_dataset_coords_to_attributes() -> None: assert "depthBelowLandLayer" not in ds.coords assert "GRIB_surface" in ds["t2m"].attrs - assert "GRIB_depthBelowLandLayer" in ds["stl1"].attrs \ No newline at end of file + assert "GRIB_depthBelowLandLayer" in ds["stl1"].attrs + + +def test_xr_open_dataset_default_values_dtype() -> None: + ds = xr.open_dataset(TEST_DATA_MISSING_VALS, engine="cfgrib") + assert ds["t2m"].data.dtype == np.dtype("float32") + assert ds["latitude"].data.dtype == np.dtype("float64") + assert ds["longitude"].data.dtype == np.dtype("float64") + + +def test_xr_open_dataset_float64_values_dtype() -> None: + ds = xr.open_dataset(TEST_DATA_MISSING_VALS, engine="cfgrib", values_dtype=np.dtype("float64")) + assert ds["t2m"].data.dtype == np.dtype("float64") + assert ds["latitude"].data.dtype == np.dtype("float64") + assert ds["longitude"].data.dtype == np.dtype("float64")