Skip to content

Commit

Permalink
Change over axis to vertical_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
jthielen committed Sep 3, 2020
1 parent 7dcccce commit 4d50ffa
Showing 1 changed file with 85 additions and 55 deletions.
140 changes: 85 additions & 55 deletions src/metpy/calc/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Distributed under the terms of the BSD 3-Clause License.
# SPDX-License-Identifier: BSD-3-Clause
"""Contains a collection of thermodynamic calculations."""
import functools
from inspect import signature
import warnings

import numpy as np
Expand All @@ -23,6 +25,39 @@
sat_pressure_0c = 6.112 * units.millibar


def add_vertical_dim_from_xarray(func):
"""Fill in optional vertical_dim from DataArray argument."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
bound_args = signature(func).bind(*args, **kwargs)
bound_args.apply_defaults()

# Search for DataArray with valid latitude and longitude coordinates to find grid
# deltas and any other needed parameter
dataarray_arguments = [
value for value in bound_args.arguments.values()
if isinstance(value, xr.DataArray)
]

# Fill in vertical_dim
if (
len(dataarray_arguments) > 0
and 'vertical_dim' in bound_args.arguments
):
try:
bound_args.arguments['vertical_dim'] = (
dataarray_arguments[0].metpy.find_axis_number('vertical')
)
except AttributeError:
# If axis number not found, fall back to default but warn.
warnings.warn(
'Vertical dimension number not found. Defaulting to initial dimension.'
)

return func(*bound_args.args, **bound_args.kwargs)
return wrapper


@exporter.export
@preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'dewpoint'))
@check_units('[temperature]', '[temperature]')
Expand Down Expand Up @@ -1790,8 +1825,9 @@ def most_unstable_parcel(pressure, temperature, dewpoint, height=None,

@exporter.export
@preprocess_and_wrap()
@add_vertical_dim_from_xarray
@check_units('[temperature]', '[pressure]', '[temperature]')
def isentropic_interpolation(levels, pressure, temperature, *args, axis=0,
def isentropic_interpolation(levels, pressure, temperature, *args, vertical_dim=0,
temperature_out=False, max_iters=50, eps=1e-6,
bottom_up_search=True, **kwargs):
r"""Interpolate data in isobaric coordinates to isentropic coordinates.
Expand All @@ -1804,7 +1840,7 @@ def isentropic_interpolation(levels, pressure, temperature, *args, axis=0,
One-dimensional array of pressure levels
temperature : array
Array of temperature
axis : int, optional
vertical_dim : int, optional
The axis corresponding to the vertical in the temperature array, defaults to 0.
temperature_out : bool, optional
If true, will calculate temperature and output as the last item in the output list.
Expand Down Expand Up @@ -1860,14 +1896,14 @@ def _isen_iter(iter_log_p, isentlevs_nd, ka, a, b, pok):
temperature = temperature.to('kelvin')

slices = [np.newaxis] * ndim
slices[axis] = slice(None)
slices[vertical_dim] = slice(None)
slices = tuple(slices)
pres = np.broadcast_to(pres[slices].magnitude, temperature.shape) * pres.units

# Sort input data
sort_pres = np.argsort(pres.m, axis=axis)
sort_pres = np.swapaxes(np.swapaxes(sort_pres, 0, axis)[::-1], 0, axis)
sorter = broadcast_indices(pres, sort_pres, ndim, axis)
sort_pres = np.argsort(pres.m, axis=vertical_dim)
sort_pres = np.swapaxes(np.swapaxes(sort_pres, 0, vertical_dim)[::-1], 0, vertical_dim)
sorter = broadcast_indices(pres, sort_pres, ndim, vertical_dim)
levs = pres[sorter]
tmpk = temperature[sorter]

Expand All @@ -1876,7 +1912,7 @@ def _isen_iter(iter_log_p, isentlevs_nd, ka, a, b, pok):

# Make the desired isentropic levels the same shape as temperature
shape = list(temperature.shape)
shape[axis] = isentlevels.size
shape[vertical_dim] = isentlevels.size
isentlevs_nd = np.broadcast_to(isentlevels[slices], shape)

# exponent to Poisson's Equation, which is imported above
Expand All @@ -1897,7 +1933,7 @@ def _isen_iter(iter_log_p, isentlevs_nd, ka, a, b, pok):
pok = mpconsts.P0 ** ka

# index values for each point for the pressure level nearest to the desired theta level
above, below, good = find_bounding_indices(pres_theta.m, levels, axis,
above, below, good = find_bounding_indices(pres_theta.m, levels, vertical_dim,
from_below=bottom_up_search)

# calculate constants for the interpolation
Expand Down Expand Up @@ -1932,7 +1968,7 @@ def _isen_iter(iter_log_p, isentlevs_nd, ka, a, b, pok):
# do an interpolation for each additional argument
if args:
others = interpolate_1d(isentlevels, pres_theta.m, *(arr[sorter] for arr in args),
axis=axis, return_list_always=True)
axis=vertical_dim, return_list_always=True)
ret.extend(others)

return ret
Expand Down Expand Up @@ -1998,7 +2034,7 @@ def isentropic_interpolation_as_dataset(
all_args[0].metpy.vertical,
all_args[0].metpy.unit_array,
*(arg.metpy.unit_array for arg in all_args[1:]),
axis=all_args[0].metpy.find_axis_number('vertical'),
vertical_dim=all_args[0].metpy.find_axis_number('vertical'),
temperature_out=True,
max_iters=max_iters,
eps=eps,
Expand Down Expand Up @@ -2535,8 +2571,10 @@ def thickness_hydrostatic_from_relative_humidity(pressure, temperature, relative


@exporter.export
@add_vertical_dim_from_xarray
@preprocess_and_wrap(wrap_like='height', broadcast=('height', 'potential_temperature'))
@check_units('[length]', '[temperature]')
def brunt_vaisala_frequency_squared(height, potential_temperature, axis=0):
def brunt_vaisala_frequency_squared(height, potential_temperature, vertical_dim=0):
r"""Calculate the square of the Brunt-Vaisala frequency.
Brunt-Vaisala frequency squared (a measure of atmospheric stability) is given by the
Expand All @@ -2552,7 +2590,7 @@ def brunt_vaisala_frequency_squared(height, potential_temperature, axis=0):
Atmospheric (geopotential) height
potential_temperature : `xarray.DataArray` or `pint.Quantity`
Atmospheric potential temperature
axis : int, optional
vertical_dim : int, optional
The axis corresponding to vertical in the potential temperature array, defaults to 0,
unless `height` and `potential_temperature` given as `xarray.DataArray`, in which case
it is automatically determined from the coordinate metadata.
Expand All @@ -2569,32 +2607,22 @@ def brunt_vaisala_frequency_squared(height, potential_temperature, axis=0):
brunt_vaisala_frequency, brunt_vaisala_period, potential_temperature
"""
if isinstance(height, xr.DataArray) and isinstance(potential_temperature, xr.DataArray):
# Prepare arguments when DataArrays
potential_temperature = potential_temperature.metpy.convert_units('K')
height, potential_temperature = xr.broadcast(
height.metpy.quantify(),
potential_temperature.metpy.quantify()
)
dtheta_dz = first_derivative(
potential_temperature,
x=height,
axis=potential_temperature.metpy.find_axis_number('vertical')
)
else:
if isinstance(potential_temperature, xr.DataArray):
potential_temperature = potential_temperature.metpy.unit_array.to('K')
else:
potential_temperature = potential_temperature.to('K')
dtheta_dz = first_derivative(potential_temperature, x=height, axis=axis)
# Ensure validity of temperature units
potential_temperature = potential_temperature.to('K')

# Calculate and return the square of Brunt-Vaisala frequency
return mpconsts.g / potential_temperature * dtheta_dz
return mpconsts.g / potential_temperature * first_derivative(
potential_temperature,
x=height,
axis=vertical_dim
)


@exporter.export
@add_vertical_dim_from_xarray
@preprocess_and_wrap(wrap_like='height', broadcast=('height', 'potential_temperature'))
@check_units('[length]', '[temperature]')
def brunt_vaisala_frequency(height, potential_temperature, axis=0):
def brunt_vaisala_frequency(height, potential_temperature, vertical_dim=0):
r"""Calculate the Brunt-Vaisala frequency.
This function will calculate the Brunt-Vaisala frequency as follows:
Expand All @@ -2612,7 +2640,7 @@ def brunt_vaisala_frequency(height, potential_temperature, axis=0):
Atmospheric (geopotential) height
potential_temperature : `xarray.DataArray` or `pint.Quantity`
Atmospheric potential temperature
axis : int, optional
vertical_dim : int, optional
The axis corresponding to vertical in the potential temperature array, defaults to 0,
unless `height` and `potential_temperature` given as `xarray.DataArray`, in which case
it is automatically determined from the coordinate metadata.
Expand All @@ -2630,18 +2658,17 @@ def brunt_vaisala_frequency(height, potential_temperature, axis=0):
"""
bv_freq_squared = brunt_vaisala_frequency_squared(height, potential_temperature,
axis=axis)
if isinstance(bv_freq_squared, xr.DataArray):
bv_freq_squared.data[bv_freq_squared.data.magnitude < 0] = np.nan
else:
bv_freq_squared[bv_freq_squared.magnitude < 0] = np.nan
axis=vertical_dim)
bv_freq_squared[bv_freq_squared.magnitude < 0] = np.nan

return np.sqrt(bv_freq_squared)


@exporter.export
@add_vertical_dim_from_xarray
@preprocess_and_wrap(wrap_like='height', broadcast=('height', 'potential_temperature'))
@check_units('[length]', '[temperature]')
def brunt_vaisala_period(height, potential_temperature, axis=0):
def brunt_vaisala_period(height, potential_temperature, vertical_dim=0):
r"""Calculate the Brunt-Vaisala period.
This function is a helper function for `brunt_vaisala_frequency` that calculates the
Expand All @@ -2657,7 +2684,7 @@ def brunt_vaisala_period(height, potential_temperature, axis=0):
Atmospheric (geopotential) height
potential_temperature : `xarray.DataArray` or `pint.Quantity`
Atmospheric potential temperature
axis : int, optional
vertical_dim : int, optional
The axis corresponding to vertical in the potential temperature array, defaults to 0,
unless `height` and `potential_temperature` given as `xarray.DataArray`, in which case
it is automatically determined from the coordinate metadata.
Expand All @@ -2676,11 +2703,8 @@ def brunt_vaisala_period(height, potential_temperature, axis=0):
"""
bv_freq_squared = brunt_vaisala_frequency_squared(height, potential_temperature,
axis=axis)
if isinstance(bv_freq_squared, xr.DataArray):
bv_freq_squared.data[bv_freq_squared.data.magnitude <= 0] = np.nan
else:
bv_freq_squared[bv_freq_squared.magnitude <= 0] = np.nan
axis=vertical_dim)
bv_freq_squared[bv_freq_squared.magnitude <= 0] = np.nan

return 2 * np.pi / np.sqrt(bv_freq_squared)

Expand Down Expand Up @@ -2747,9 +2771,10 @@ def wet_bulb_temperature(pressure, temperature, dewpoint):


@exporter.export
@add_vertical_dim_from_xarray
@preprocess_and_wrap(wrap_like='temperature', broadcast=('pressure', 'temperature'))
@check_units('[pressure]', '[temperature]')
def static_stability(pressure, temperature, axis=0):
def static_stability(pressure, temperature, vertical_dim=0):
r"""Calculate the static stability within a vertical profile.
.. math:: \sigma = -\frac{RT}{p} \frac{\partial \ln \theta}{\partial p}
Expand All @@ -2762,7 +2787,7 @@ def static_stability(pressure, temperature, axis=0):
Profile of atmospheric pressure
temperature : `pint.Quantity`
Profile of temperature
axis : int, optional
vertical_dim : int, optional
The axis corresponding to vertical in the pressure and temperature arrays, defaults
to 0.
Expand All @@ -2774,8 +2799,11 @@ def static_stability(pressure, temperature, axis=0):
"""
theta = potential_temperature(pressure, temperature)

return - mpconsts.Rd * temperature / pressure * first_derivative(np.log(theta.m_as('K')),
x=pressure, axis=axis)
return - mpconsts.Rd * temperature / pressure * first_derivative(
np.log(theta.m_as('K')),
x=pressure,
axis=vertical_dim
)


@exporter.export
Expand Down Expand Up @@ -2975,12 +3003,13 @@ def lifted_index(pressure, temperature, parcel_profile):


@exporter.export
@add_vertical_dim_from_xarray
@preprocess_and_wrap(
wrap_like='potential_temperature',
broadcast=('height', 'potential_temperature', 'u', 'v')
)
@check_units('[length]', '[temperature]', '[speed]', '[speed]')
def gradient_richardson_number(height, potential_temperature, u, v, axis=0):
def gradient_richardson_number(height, potential_temperature, u, v, vertical_dim=0):
r"""Calculate the gradient (or flux) Richardson number.
.. math:: Ri = (g/\theta) * \frac{\left(\partial \theta/\partial z\)}
Expand All @@ -2999,16 +3028,17 @@ def gradient_richardson_number(height, potential_temperature, u, v, axis=0):
x component of the wind
v : `pint.Quantity`
y component of the wind
axis : int, optional
The axis corresponding to vertical, defaults to 0.
vertical_dim : int, optional
The axis corresponding to vertical, defaults to 0. Automatically determined from
xarray DataArray arguments.
Returns
-------
`pint.Quantity`
Gradient Richardson number
"""
dthetadz = first_derivative(potential_temperature, x=height, axis=axis)
dudz = first_derivative(u, x=height, axis=axis)
dvdz = first_derivative(v, x=height, axis=axis)
dthetadz = first_derivative(potential_temperature, x=height, axis=vertical_dim)
dudz = first_derivative(u, x=height, axis=vertical_dim)
dvdz = first_derivative(v, x=height, axis=vertical_dim)

return (mpconsts.g / potential_temperature) * (dthetadz / (dudz ** 2 + dvdz ** 2))

0 comments on commit 4d50ffa

Please sign in to comment.