Skip to content

Commit

Permalink
Merge pull request #30 from ScottWales/malloc_trim
Browse files Browse the repository at this point in the history
Allow setting malloc_threshold, add OOD Client
  • Loading branch information
Scott Wales authored Aug 27, 2021
2 parents b227403 + 17fb26c commit 3511e80
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 7 deletions.
12 changes: 7 additions & 5 deletions src/climtas/blocked.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def dask_approx_percentile(

def approx_percentile(
da: T.Union[xarray.DataArray, dask.array.Array, numpy.ndarray],
q: float,
q: T.Union[numbers.Real, T.List[numbers.Real]],
dim: str = None,
axis: int = None,
skipna: bool = True,
Expand All @@ -803,8 +803,10 @@ def approx_percentile(
Array of the same type as da, otherwise as :func:`numpy.percentile`
"""

if isinstance(q, numbers.Number):
q = [q]
if isinstance(q, numbers.Real):
qlist = [q]
else:
qlist = q

if skipna:
pctile = numpy.nanpercentile
Expand All @@ -818,7 +820,7 @@ def approx_percentile(
data = dask_approx_percentile(da.data, pcts=q, axis=axis, skipna=skipna)
dims = ["percentile", *[d for i, d in enumerate(da.dims) if i != axis]]
coords = {k: v for k, v in da.coords.items() if k in dims}
coords["percentile"] = q
coords["percentile"] = xarray.DataArray(qlist, dims="percentile")
return xarray.DataArray(
data,
name=da.name,
Expand All @@ -828,7 +830,7 @@ def approx_percentile(

if isinstance(da, xarray.DataArray):
# Xarray+Numpy
return da.quantile([p / 100 for p in q], dim=dim, skipna=skipna)
return da.quantile([p / 100 for p in qlist], dim=dim, skipna=skipna)

assert dim is None
assert axis is not None
Expand Down
2 changes: 1 addition & 1 deletion src/climtas/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def to_cdo_grid(self, outfile):
def to_netcdf(self, outfile):
ds = xarray.DataArray(
data=numpy.zeros((len(self.lats), len(self.lons))),
coords=[("lat", self.lats), ("lon", self.lons)],
coords=[("lat", self.lats.data), ("lon", self.lons.data)],
)
ds.lat.attrs["units"] = "degrees_north"
ds.lon.attrs["units"] = "degrees_east"
Expand Down
87 changes: 86 additions & 1 deletion src/climtas/nci/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,100 @@
_tmpdir = None


def GadiClient(threads=1):
def Client(threads=1, malloc_trim_threshold=None):
"""Start a Dask client at NCI
An appropriate client will be started for the current system
Args:
threads: Number of threads per worker process. The total number of
workers will be ncpus/threads, so that each thread gets its own
CPU
malloc_trim_threshold: Threshold for automatic memory trimming. Can be
either a string e.g. '64kib' or a number of bytes e.g. 65536.
Smaller values may reduce out of memory errors at the cost of
running slower
https://distributed.dask.org/en/latest/worker.html?highlight=worker#automatically-trim-memory
"""

if os.environ["HOSTNAME"].startswith("ood"):
return OODClient(threads, malloc_trim_threshold)
else:
return GadiClient(threads, malloc_trim_threshold)


def OODClient(threads=1, malloc_trim_threshold=None):
"""Start a Dask client on OOD
This function is mostly to be consistent with the Gadi version
Args:
threads: Number of threads per worker process. The total number of
workers will be ncpus/threads, so that each thread gets its own
CPU
malloc_trim_threshold: Threshold for automatic memory trimming. Can be
either a string e.g. '64kib' or a number of bytes e.g. 65536.
Smaller values may reduce out of memory errors at the cost of
running slower
https://distributed.dask.org/en/latest/worker.html?highlight=worker#automatically-trim-memory
"""
global _dask_client, _tmpdir

env = {}

if malloc_trim_threshold is not None:
env["MALLOC_TRIM_THRESHOLD_"] = str(
dask.utils.parse_bytes(malloc_trim_threshold)
)

if _dask_client is None:
try:
# Works in sidebar and can follow the link
dask.config.set(
{
"distributed.dashboard.link": f'/node/{os.environ["host"]}/{os.environ["port"]}/proxy/{{port}}/status'
}
)
except KeyError:
# Works in sidebar, but can't follow the link
dask.config.set({"distributed.dashboard.link": "/proxy/{port}/status"})

_dask_client = dask.distributed.Client(threads_per_worker=threads, env=env)

return _dask_client


def GadiClient(threads=1, malloc_trim_threshold=None):
"""Start a Dask client on Gadi
If run on a compute node it will check the PBS resources to know how many
CPUs and the amount of memory that is available.
If run on a login node it will ask for 2 workers each with a 1GB memory
limit
Args:
threads: Number of threads per worker process. The total number of
workers will be $PBS_NCPUS/threads, so that each thread gets its own
CPU
malloc_trim_threshold: Threshold for automatic memory trimming. Can be
either a string e.g. '64kib' or a number of bytes e.g. 65536.
Smaller values may reduce out of memory errors at the cost of
running slower
https://distributed.dask.org/en/latest/worker.html?highlight=worker#automatically-trim-memory
"""
global _dask_client, _tmpdir

env = {}

if malloc_trim_threshold is not None:
env["MALLOC_TRIM_THRESHOLD_"] = str(
dask.utils.parse_bytes(malloc_trim_threshold)
)

if _dask_client is None:
_tmpdir = tempfile.TemporaryDirectory("dask-worker-space")

Expand All @@ -45,6 +128,7 @@ def GadiClient(threads=1):
threads_per_worker=threads,
memory_limit="1000mb",
local_directory=_tmpdir.name,
env=env,
)
else:
workers = int(os.environ["PBS_NCPUS"]) // threads
Expand All @@ -53,5 +137,6 @@ def GadiClient(threads=1):
threads_per_worker=threads,
memory_limit=int(os.environ["PBS_VMEM"]) / workers,
local_directory=_tmpdir.name,
env=env,
)
return _dask_client

0 comments on commit 3511e80

Please sign in to comment.