diff --git a/ilamb3/__init__.py b/ilamb3/__init__.py index 9643736..88f815c 100644 --- a/ilamb3/__init__.py +++ b/ilamb3/__init__.py @@ -43,5 +43,13 @@ def ilamb_catalog() -> pooch.Pooch: return registry -__all__ = ["dataset", "compare", "analysis", "regions", "ilamb_catalog,", "conf"] +__all__ = [ + "dataset", + "compare", + "analysis", + "regions", + "output", + "ilamb_catalog,", + "conf", +] xr.set_options(keep_attrs=True) diff --git a/ilamb3/analysis/bias.py b/ilamb3/analysis/bias.py index d653b42..ee13771 100644 --- a/ilamb3/analysis/bias.py +++ b/ilamb3/analysis/bias.py @@ -14,6 +14,7 @@ import xarray as xr import ilamb3.plot as plt +import ilamb3.regions as ilr from ilamb3 import compare as cmp from ilamb3 import dataset as dset from ilamb3.analysis.base import ILAMBAnalysis @@ -226,9 +227,8 @@ def _scalar( weight=ref_ if (mass_weighting and weight) else None, ) elif dset.is_site(da): - site_dim = dset.get_dim_name(da, "site") - da = da.pint.dequantify() - da = da.mean(dim=site_dim) + da = ilr.Regions().restrict_to_region(da, region) + da = da.mean(dim=dset.get_dim_name(da, "site")) else: raise ValueError(f"Input is neither spatial nor site: {da}") da = da.pint.quantify() @@ -305,7 +305,6 @@ def plots( ref: xr.Dataset, com: dict[str, xr.Dataset], ) -> pd.DataFrame: - # Some initialization regions = [None if r == "None" else r for r in df["region"].unique()] com["Reference"] = ref diff --git a/ilamb3/compare.py b/ilamb3/compare.py index e6d48a5..3a1dd07 100644 --- a/ilamb3/compare.py +++ b/ilamb3/compare.py @@ -257,6 +257,10 @@ def extract_sites( + (ds_site[lon_site] - ds_spatial[lon_spatial]) ** 2 ) assert (dist < model_res).all() + + # Set these are coordinates though so that we can tell this is site data + ds_spatial = ds_spatial.set_coords([lat_spatial, lon_spatial]) + return ds_spatial diff --git a/ilamb3/regions.py b/ilamb3/regions.py index 5115162..9b4738b 100644 --- a/ilamb3/regions.py +++ b/ilamb3/regions.py @@ -18,21 +18,33 @@ def restrict_to_bbox( work well on unsorted indices. """ assert isinstance(da, xr.DataArray) - lat_name = dset.get_dim_name(da, "lat") - lon_name = dset.get_dim_name(da, "lon") - da = da.sortby(list(da.dims)) - da = da.sel( - { - lat_name: slice( - da[lat_name].sel({lat_name: lat0}, method="nearest"), - da[lat_name].sel({lat_name: latf}, method="nearest"), - ), - lon_name: slice( - da[lon_name].sel({lon_name: lon0}, method="nearest"), - da[lon_name].sel({lon_name: lonf}, method="nearest"), - ), - } - ) + lat_name = dset.get_coord_name(da, "lat") + lon_name = dset.get_coord_name(da, "lon") + if dset.is_site(da): + site_name = dset.get_dim_name(da, "site") + da = da.sel( + { + site_name: ( + (da[lat_name] >= lat0) + & (da[lat_name] <= latf) + & (da[lon_name] >= lon0) + & (da[lon_name] <= lonf) + ) + } + ) + else: + da = da.sel( + { + lat_name: slice( + da[lat_name].sel({lat_name: lat0}, method="nearest"), + da[lat_name].sel({lat_name: latf}, method="nearest"), + ), + lon_name: slice( + da[lon_name].sel({lon_name: lon0}, method="nearest"), + da[lon_name].sel({lon_name: lonf}, method="nearest"), + ), + } + ) return da diff --git a/ilamb3/run.py b/ilamb3/run.py new file mode 100644 index 0000000..18789aa --- /dev/null +++ b/ilamb3/run.py @@ -0,0 +1,103 @@ +"""Functions for rendering ilamb3 output.""" + +import importlib + +import pandas as pd +import xarray as xr +from jinja2 import Template + +import ilamb3 +import ilamb3.regions as ilr +from ilamb3.analysis.base import ILAMBAnalysis + + +def run_analyses( + ref: xr.Dataset, com: xr.Dataset, analyses: dict[str, ILAMBAnalysis] +) -> None: + dfs = [] + ds_refs = [] + ds_coms = [] + for _, a in analyses.items(): + df, ds_ref, ds_com = a(ref, com, regions=ilamb3.conf["regions"]) + dfs.append(df) + ds_refs.append(ds_ref) + ds_coms.append(ds_com) + dfs = pd.concat(dfs) + dfs["name"] = dfs["name"] + " [" + df["units"] + "]" + ds_ref = xr.merge(ds_refs) + ds_com = xr.merge(ds_coms) + return dfs, ds_ref, ds_com + + +def plot_analyses( + df: pd.DataFrame, + ref: xr.Dataset, + com: dict[str, xr.Dataset], + analyses: dict[str, ILAMBAnalysis], +) -> pd.DataFrame: + df_plots = [] + for name, a in analyses.items(): + dfp = a.plots(df, ref, com) + dfp["analysis"] = name + df_plots.append(dfp) + df_plots = pd.concat(df_plots) + for _, row in df_plots.iterrows(): + row["axis"].get_figure().savefig( + f"{row['source']}_{row['region']}_{row['name']}.png" + ) + return df_plots + + +def generate_html_page( + df: pd.DataFrame, + ref: xr.Dataset, + com: dict[str, xr.Dataset], + df_plots: pd.DataFrame, +) -> str: + """.""" + ilamb_regions = ilr.Regions() + + # Setup template analyses and plots + analyses = {analysis: {} for analysis in df["analysis"].unique()} + for (aname, pname), df_grp in df_plots.groupby(["analysis", "name"], sort=False): + analyses[aname][pname] = [] + if "Reference" in df_grp["source"].unique(): + analyses[aname][pname] += [{"Reference": f"Reference_RNAME_{pname}.png"}] + analyses[aname][pname] += [{"Model": f"MNAME_RNAME_{pname}.png"}] + ref_plots = list(df_plots[df_plots["source"] == "Reference"]["name"].unique()) + mod_plots = list(df_plots[df_plots["source"] != "Reference"]["name"].unique()) + all_plots = list(set(ref_plots) | set(mod_plots)) + + # Setup template dictionary + df = df.reset_index(drop=True) # ? + df["id"] = df.index + data = { + "page_header": ref.attrs["header"] if "header" in ref.attrs else "", + "analysis_list": list(analyses.keys()), + "model_names": [m for m in df["source"].unique() if m != "Reference"], + "ref_plots": ref_plots, + "mod_plots": mod_plots, + "all_plots": all_plots, + "regions": { + (None if key == "None" else key): ( + "All Data" if key == "None" else ilamb_regions.get_name(key) + ) + for key in df["region"].unique() + }, + "analyses": analyses, + "data_information": { + key.capitalize(): ref.attrs[key] + for key in ["title", "institutions", "version"] + if key in ref.attrs + }, + "table_data": str( + [row.to_dict() for _, row in df.drop(columns="units").iterrows()] + ), + } + + # Generate the html from the template + template = importlib.resources.open_text( + "ilamb3.templates", "dataset_page.html" + ).read() + html = Template(template).render(data) + return html diff --git a/ilamb3/templates/dataset_page.html b/ilamb3/templates/dataset_page.html index 6cea4a0..bb7658b 100644 --- a/ilamb3/templates/dataset_page.html +++ b/ilamb3/templates/dataset_page.html @@ -306,6 +306,15 @@

All Models

+ + +{% for mname in model_names %} + +{% endfor %}