Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Animate polygon -- create animations using a polygon plotting tool #312

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 256 additions & 0 deletions xbout/boutdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import animatplot as amp
from matplotlib import pyplot as plt
from matplotlib.animation import PillowWriter
from matplotlib.animation import FuncAnimation

from mpl_toolkits.axes_grid1 import make_axes_locatable

Expand All @@ -22,6 +23,7 @@
animate_poloidal,
animate_pcolormesh,
animate_line,
animate_polygon,
_add_controls,
_normalise_time_coord,
_parse_coord_option,
Expand Down Expand Up @@ -1345,6 +1347,260 @@ def is_list(variable):

return anim

def animate_polygon_list(
self,
variables,
animate_over=None,
save_as=None,
show=False,
fps=10,
nrows=None,
ncols=None,
poloidal_plot=False,
axis_coords=None,
subplots_adjust=None,
vmin=None,
vmax=None,
logscale=None,
titles=None,
aspect=None,
extend=None,
# controls="both",
tight_layout=True,
**kwargs,
):
"""
Parameters
----------
variables : list of str or BoutDataArray
The variables to plot. For any string passed, the corresponding
variable in this DataSet is used - then the calling DataSet must
have only 3 dimensions. It is possible to pass BoutDataArrays to
allow more flexible plots, e.g. with different variables being
plotted against different axes.
animate_over : str, optional
Dimension over which to animate, defaults to the time dimension
save_as : str, optional
If passed, a gif is created with this filename
show : bool, optional
Call pyplot.show() to display the animation
fps : float, optional
Indicates the number of frames per second to play
nrows : int, optional
Specify the number of rows of plots
ncols : int, optional
Specify the number of columns of plots
poloidal_plot : bool or sequence of bool, optional
If set to True, make all 2D animations in the poloidal plane instead of using
grid coordinates, per variable if sequence is given
axis_coords : None, str, dict or list of None, str or dict
Coordinates to use for axis labelling.

- None: Use the dimension coordinate for each axis, if it exists.
- "index": Use the integer index values.
- dict: keys are dimension names, values set axis_coords for each axis
separately. Values can be: None, "index", the name of a 1d variable or
coordinate (which must have the dimension given by 'key'), or a 1d
numpy array, dask array or DataArray whose length matches the length of
the dimension given by 'key'.

Only affects time coordinate for plots with poloidal_plot=True.
If a list is passed, it must have the same length as 'variables' and gives
the axis_coords setting for each plot individually.
The setting to use for the 'animate_over' coordinate can be passed in one or
more dict values, but must be the same in all dicts if given more than once.
subplots_adjust : dict, optional
Arguments passed to fig.subplots_adjust()()
vmin : float or sequence of floats
Minimum value for color scale, per variable if a sequence is given
vmax : float or sequence of floats
Maximum value for color scale, per variable if a sequence is given
logscale : bool or float, sequence of bool or float, optional
If True, default to a logarithmic color scale instead of a linear one.
If a non-bool type is passed it is treated as a float used to set the linear
threshold of a symmetric logarithmic scale as
linthresh=min(abs(vmin),abs(vmax))*logscale, defaults to 1e-5 if True is
passed.
Per variable if sequence is given.
titles : sequence of str or None, optional
Custom titles for each plot. Pass None in the sequence to use the default for
a certain variable
aspect : str or None, or sequence of str or None, optional
Argument to set_aspect() for each plot. Defaults to "equal" for poloidal
plots and "auto" for others.
extend : str or None, optional
Passed to fig.colorbar()
controls : string or None, default "both"
By default, add both the timeline and play/pause toggle to the animation. If
"timeline" is passed add only the timeline, if "toggle" is passed add only
the play/pause toggle. If None or an empty string is passed, add neither.
tight_layout : bool or dict, optional
If set to False, don't call tight_layout() on the figure.
If a dict is passed, the dict entries are passed as arguments to
tight_layout()
**kwargs : dict, optional
Additional keyword arguments are passed on to each animation function, per
variable if a sequence is given.

Returns
-------
animation
An animatplot.Animation object.
"""

if animate_over is None:
animate_over = self.metadata.get("bout_tdim", "t")

nvars = len(variables)

if nrows is None and ncols is None:
ncols = int(np.ceil(np.sqrt(nvars)))
nrows = int(np.ceil(nvars / ncols))
elif nrows is None:
nrows = int(np.ceil(nvars / ncols))
elif ncols is None:
ncols = int(np.ceil(nvars / nrows))
else:
if nrows * ncols < nvars:
raise ValueError("Not enough rows*columns to fit all variables")

fig, axes = plt.subplots(nrows, ncols, squeeze=False)
axes = axes.flatten()

ncells = nrows * ncols

if nvars < ncells:
for index in range(ncells - nvars):
fig.delaxes(axes[ncells - index - 1])

if subplots_adjust is not None:
fig.subplots_adjust(**subplots_adjust)

def _expand_list_arg(arg, arg_name):
if isinstance(arg, collections.abc.Sequence) and not isinstance(arg, str):
if len(arg) != len(variables):
raise ValueError(
"if %s is a sequence, it must have the same "
'number of elements as "variables"' % arg_name
)
else:
arg = [arg] * len(variables)
return arg

poloidal_plot = _expand_list_arg(poloidal_plot, "poloidal_plot")
vmin = _expand_list_arg(vmin, "vmin")
vmax = _expand_list_arg(vmax, "vmax")
logscale = _expand_list_arg(logscale, "logscale")
titles = _expand_list_arg(titles, "titles")
aspect = _expand_list_arg(aspect, "aspect")
extend = _expand_list_arg(extend, "extend")
axis_coords = _expand_list_arg(axis_coords, "axis_coords")
for k in kwargs:
kwargs[k] = _expand_list_arg(kwargs[k], k)

animate_data = []

def is_list(variable):
return (
isinstance(variable, list)
or isinstance(variable, tuple)
or isinstance(variable, set)
)

for i, subplot_args in enumerate(
zip(
variables,
axes,
poloidal_plot,
vmin,
vmax,
logscale,
titles,
)
):
(
v,
ax,
this_poloidal_plot,
this_vmin,
this_vmax,
this_logscale,
this_title,
) = subplot_args

this_kwargs = {k: v[i] for k, v in kwargs.items()}

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)

if isinstance(v, str):
v = self.data[v]

data = v.bout.data
ndims = len(data.dims)
ax.set_title(data.name)

if ndims == 3:
if this_poloidal_plot:
update_func = animate_polygon(
data,
ax=ax,
cax=cax,
vmin=this_vmin,
vmax=this_vmax,
logscale=this_logscale,
animate=False,
**this_kwargs,
)
animate_data.append(update_func)
else:
raise ValueError(
"Unsupported option "
+ ". this_poloidal_plot "
+ str(this_poloidal_plot)
)
else:
raise ValueError(
"Unsupported number of dimensions "
+ str(ndims)
+ ". Dims are "
+ str(v.dims)
)

if this_title is not None:
# Replace default title with user-specified one
ax.set_title(this_title)

def update(frame):
for update_func in animate_data:
# call update function for each axes
update_func(frame)

# make the animation for all the subplots simultaneously
# use time data array "t" to choose the number of frames
# assumes time dimension same length for all variables
anim = FuncAnimation(
fig=fig, func=update, frames=self.data["t"].data.size, interval=30
)
if tight_layout:
if subplots_adjust is not None:
warnings.warn(
"tight_layout argument to animate_list() is True, but "
"subplots_adjust argument is not None. subplots_adjust "
"is being ignored."
)
if not isinstance(tight_layout, dict):
tight_layout = {}
fig.tight_layout(**tight_layout)

if save_as is not None:
anim.save(save_as + ".gif", writer=PillowWriter(fps=fps))

if show:
plt.show()

return anim

def with_cherab_grid(self):
"""
Returns a new DataSet with a 'cherab_grid' attribute.
Expand Down
Loading