diff --git a/src/metpy/calc/thermo.py b/src/metpy/calc/thermo.py index e94de93451e..893757aff50 100644 --- a/src/metpy/calc/thermo.py +++ b/src/metpy/calc/thermo.py @@ -23,6 +23,39 @@ sat_pressure_0c = 6.112 * units.millibar +def add_vertical_dim_from_xarray(func): + """Fill in optional vertical_dim from DataArray argument.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + bound_args = signature(func).bind(*args, **kwargs) + bound_args.apply_defaults() + + # Search for DataArray with valid latitude and longitude coordinates to find grid + # deltas and any other needed parameter + dataarray_arguments = [ + value for value in bound_args.arguments.values() + if isinstance(value, xr.DataArray) + ] + + # Fill in vertical_dim + if ( + len(dataarray_arguments) > 0 + and 'vertical_dim' in bound_args.arguments + ): + try: + bound_args.arguments['vertical_dim'] = ( + dataarray_arguments[0].metpy.find_axis_number('vertical') + ) + except AttributeError: + # If axis number not found, fall back to default but warn. + warnings.warn( + 'Vertical dimension number not found. Defaulting to initial dimension.' + ) + + return func(*bound_args.args, **bound_args.kwargs) + return wrapper + + @exporter.export @preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'dewpoint')) @check_units('[temperature]', '[temperature]') @@ -1790,8 +1823,9 @@ def most_unstable_parcel(pressure, temperature, dewpoint, height=None, @exporter.export @preprocess_and_wrap() +@add_vertical_dim_from_xarray @check_units('[temperature]', '[pressure]', '[temperature]') -def isentropic_interpolation(levels, pressure, temperature, *args, axis=0, +def isentropic_interpolation(levels, pressure, temperature, *args, vertical_dim=0, temperature_out=False, max_iters=50, eps=1e-6, bottom_up_search=True, **kwargs): r"""Interpolate data in isobaric coordinates to isentropic coordinates. @@ -1804,7 +1838,7 @@ def isentropic_interpolation(levels, pressure, temperature, *args, axis=0, One-dimensional array of pressure levels temperature : array Array of temperature - axis : int, optional + vertical_dim : int, optional The axis corresponding to the vertical in the temperature array, defaults to 0. temperature_out : bool, optional If true, will calculate temperature and output as the last item in the output list. @@ -1860,14 +1894,14 @@ def _isen_iter(iter_log_p, isentlevs_nd, ka, a, b, pok): temperature = temperature.to('kelvin') slices = [np.newaxis] * ndim - slices[axis] = slice(None) + slices[vertical_dim] = slice(None) slices = tuple(slices) pres = np.broadcast_to(pres[slices].magnitude, temperature.shape) * pres.units # Sort input data - sort_pres = np.argsort(pres.m, axis=axis) - sort_pres = np.swapaxes(np.swapaxes(sort_pres, 0, axis)[::-1], 0, axis) - sorter = broadcast_indices(pres, sort_pres, ndim, axis) + sort_pres = np.argsort(pres.m, axis=vertical_dim) + sort_pres = np.swapaxes(np.swapaxes(sort_pres, 0, vertical_dim)[::-1], 0, vertical_dim) + sorter = broadcast_indices(pres, sort_pres, ndim, vertical_dim) levs = pres[sorter] tmpk = temperature[sorter] @@ -1876,7 +1910,7 @@ def _isen_iter(iter_log_p, isentlevs_nd, ka, a, b, pok): # Make the desired isentropic levels the same shape as temperature shape = list(temperature.shape) - shape[axis] = isentlevels.size + shape[vertical_dim] = isentlevels.size isentlevs_nd = np.broadcast_to(isentlevels[slices], shape) # exponent to Poisson's Equation, which is imported above @@ -1897,7 +1931,7 @@ def _isen_iter(iter_log_p, isentlevs_nd, ka, a, b, pok): pok = mpconsts.P0 ** ka # index values for each point for the pressure level nearest to the desired theta level - above, below, good = find_bounding_indices(pres_theta.m, levels, axis, + above, below, good = find_bounding_indices(pres_theta.m, levels, vertical_dim, from_below=bottom_up_search) # calculate constants for the interpolation @@ -1932,7 +1966,7 @@ def _isen_iter(iter_log_p, isentlevs_nd, ka, a, b, pok): # do an interpolation for each additional argument if args: others = interpolate_1d(isentlevels, pres_theta.m, *(arr[sorter] for arr in args), - axis=axis, return_list_always=True) + axis=vertical_dim, return_list_always=True) ret.extend(others) return ret @@ -1998,7 +2032,7 @@ def isentropic_interpolation_as_dataset( all_args[0].metpy.vertical, all_args[0].metpy.unit_array, *(arg.metpy.unit_array for arg in all_args[1:]), - axis=all_args[0].metpy.find_axis_number('vertical'), + vertical_dim=all_args[0].metpy.find_axis_number('vertical'), temperature_out=True, max_iters=max_iters, eps=eps, @@ -2535,8 +2569,10 @@ def thickness_hydrostatic_from_relative_humidity(pressure, temperature, relative @exporter.export +@add_vertical_dim_from_xarray +@preprocess_and_wrap(wrap_like='height', broadcast=('height', 'potential_temperature')) @check_units('[length]', '[temperature]') -def brunt_vaisala_frequency_squared(height, potential_temperature, axis=0): +def brunt_vaisala_frequency_squared(height, potential_temperature, vertical_dim=0): r"""Calculate the square of the Brunt-Vaisala frequency. Brunt-Vaisala frequency squared (a measure of atmospheric stability) is given by the @@ -2552,7 +2588,7 @@ def brunt_vaisala_frequency_squared(height, potential_temperature, axis=0): Atmospheric (geopotential) height potential_temperature : `xarray.DataArray` or `pint.Quantity` Atmospheric potential temperature - axis : int, optional + vertical_dim : int, optional The axis corresponding to vertical in the potential temperature array, defaults to 0, unless `height` and `potential_temperature` given as `xarray.DataArray`, in which case it is automatically determined from the coordinate metadata. @@ -2569,32 +2605,21 @@ def brunt_vaisala_frequency_squared(height, potential_temperature, axis=0): brunt_vaisala_frequency, brunt_vaisala_period, potential_temperature """ - if isinstance(height, xr.DataArray) and isinstance(potential_temperature, xr.DataArray): - # Prepare arguments when DataArrays - potential_temperature = potential_temperature.metpy.convert_units('K') - height, potential_temperature = xr.broadcast( - height.metpy.quantify(), - potential_temperature.metpy.quantify() - ) - dtheta_dz = first_derivative( - potential_temperature, - x=height, - axis=potential_temperature.metpy.find_axis_number('vertical') - ) - else: - if isinstance(potential_temperature, xr.DataArray): - potential_temperature = potential_temperature.metpy.unit_array.to('K') - else: - potential_temperature = potential_temperature.to('K') - dtheta_dz = first_derivative(potential_temperature, x=height, axis=axis) + potential_temperature = potential_temperature.to('K') # Calculate and return the square of Brunt-Vaisala frequency - return mpconsts.g / potential_temperature * dtheta_dz + return mpconsts.g / potential_temperature * first_derivative( + potential_temperature, + x=height, + axis=vertical_dim + ) @exporter.export +@add_vertical_dim_from_xarray +@preprocess_and_wrap(wrap_like='height', broadcast=('height', 'potential_temperature')) @check_units('[length]', '[temperature]') -def brunt_vaisala_frequency(height, potential_temperature, axis=0): +def brunt_vaisala_frequency(height, potential_temperature, vertical_dim=0): r"""Calculate the Brunt-Vaisala frequency. This function will calculate the Brunt-Vaisala frequency as follows: @@ -2612,7 +2637,7 @@ def brunt_vaisala_frequency(height, potential_temperature, axis=0): Atmospheric (geopotential) height potential_temperature : `xarray.DataArray` or `pint.Quantity` Atmospheric potential temperature - axis : int, optional + vertical_dim : int, optional The axis corresponding to vertical in the potential temperature array, defaults to 0, unless `height` and `potential_temperature` given as `xarray.DataArray`, in which case it is automatically determined from the coordinate metadata. @@ -2630,18 +2655,17 @@ def brunt_vaisala_frequency(height, potential_temperature, axis=0): """ bv_freq_squared = brunt_vaisala_frequency_squared(height, potential_temperature, - axis=axis) - if isinstance(bv_freq_squared, xr.DataArray): - bv_freq_squared.data[bv_freq_squared.data.magnitude < 0] = np.nan - else: - bv_freq_squared[bv_freq_squared.magnitude < 0] = np.nan + axis=vertical_dim) + bv_freq_squared[bv_freq_squared.magnitude < 0] = np.nan return np.sqrt(bv_freq_squared) @exporter.export +@add_vertical_dim_from_xarray +@preprocess_and_wrap(wrap_like='height', broadcast=('height', 'potential_temperature')) @check_units('[length]', '[temperature]') -def brunt_vaisala_period(height, potential_temperature, axis=0): +def brunt_vaisala_period(height, potential_temperature, vertical_dim=0): r"""Calculate the Brunt-Vaisala period. This function is a helper function for `brunt_vaisala_frequency` that calculates the @@ -2657,7 +2681,7 @@ def brunt_vaisala_period(height, potential_temperature, axis=0): Atmospheric (geopotential) height potential_temperature : `xarray.DataArray` or `pint.Quantity` Atmospheric potential temperature - axis : int, optional + vertical_dim : int, optional The axis corresponding to vertical in the potential temperature array, defaults to 0, unless `height` and `potential_temperature` given as `xarray.DataArray`, in which case it is automatically determined from the coordinate metadata. @@ -2676,11 +2700,8 @@ def brunt_vaisala_period(height, potential_temperature, axis=0): """ bv_freq_squared = brunt_vaisala_frequency_squared(height, potential_temperature, - axis=axis) - if isinstance(bv_freq_squared, xr.DataArray): - bv_freq_squared.data[bv_freq_squared.data.magnitude <= 0] = np.nan - else: - bv_freq_squared[bv_freq_squared.magnitude <= 0] = np.nan + axis=vertical_dim) + bv_freq_squared[bv_freq_squared.magnitude <= 0] = np.nan return 2 * np.pi / np.sqrt(bv_freq_squared) @@ -2747,9 +2768,10 @@ def wet_bulb_temperature(pressure, temperature, dewpoint): @exporter.export +@add_vertical_dim_from_xarray @preprocess_and_wrap(wrap_like='temperature', broadcast=('pressure', 'temperature')) @check_units('[pressure]', '[temperature]') -def static_stability(pressure, temperature, axis=0): +def static_stability(pressure, temperature, vertical_dim=0): r"""Calculate the static stability within a vertical profile. .. math:: \sigma = -\frac{RT}{p} \frac{\partial \ln \theta}{\partial p} @@ -2762,7 +2784,7 @@ def static_stability(pressure, temperature, axis=0): Profile of atmospheric pressure temperature : `pint.Quantity` Profile of temperature - axis : int, optional + vertical_dim : int, optional The axis corresponding to vertical in the pressure and temperature arrays, defaults to 0. @@ -2774,8 +2796,11 @@ def static_stability(pressure, temperature, axis=0): """ theta = potential_temperature(pressure, temperature) - return - mpconsts.Rd * temperature / pressure * first_derivative(np.log(theta.m_as('K')), - x=pressure, axis=axis) + return - mpconsts.Rd * temperature / pressure * first_derivative( + np.log(theta.m_as('K')), + x=pressure, + axis=vertical_dim + ) @exporter.export @@ -2975,12 +3000,13 @@ def lifted_index(pressure, temperature, parcel_profile): @exporter.export +@add_vertical_dim_from_xarray @preprocess_and_wrap( wrap_like='potential_temperature', broadcast=('height', 'potential_temperature', 'u', 'v') ) @check_units('[length]', '[temperature]', '[speed]', '[speed]') -def gradient_richardson_number(height, potential_temperature, u, v, axis=0): +def gradient_richardson_number(height, potential_temperature, u, v, vertical_dim=0): r"""Calculate the gradient (or flux) Richardson number. .. math:: Ri = (g/\theta) * \frac{\left(\partial \theta/\partial z\)} @@ -2999,16 +3025,17 @@ def gradient_richardson_number(height, potential_temperature, u, v, axis=0): x component of the wind v : `pint.Quantity` y component of the wind - axis : int, optional - The axis corresponding to vertical, defaults to 0. + vertical_dim : int, optional + The axis corresponding to vertical, defaults to 0. Automatically determined from + xarray DataArray arguments. Returns ------- `pint.Quantity` Gradient Richardson number """ - dthetadz = first_derivative(potential_temperature, x=height, axis=axis) - dudz = first_derivative(u, x=height, axis=axis) - dvdz = first_derivative(v, x=height, axis=axis) + dthetadz = first_derivative(potential_temperature, x=height, axis=vertical_dim) + dudz = first_derivative(u, x=height, axis=vertical_dim) + dvdz = first_derivative(v, x=height, axis=vertical_dim) return (mpconsts.g / potential_temperature) * (dthetadz / (dudz ** 2 + dvdz ** 2))