Skip to content

Commit

Permalink
Improve docstrings and update template instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
gb119 committed Dec 21, 2024
1 parent ef5d131 commit 776f748
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Docstrings should be written in ReStructured Text format.
1. **Examples**

- Include real-world examples of how to use the functionality, both basic and advanced.
- This section may be omitted for functions, methods and classes that are not intended
- This section should be omitted for functions, methods and classes that are not intended
for third parties to use - for example, base classes, internal or private functions or methods.

### Example Template for a Class
Expand Down
176 changes: 163 additions & 13 deletions src/stonerplots/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,59 @@


class _RavelList(list):
"""A list with additional flattening and fake 2D indexing capabilities."""
"""A list with additional flattening and fake 2D indexing capabilities.
This class extends the standard list to provide additional functionalities for
flattening nested lists and supporting 2D-style indexing.
Methods:
flatten(): Flattens a nested list into a single-level list.
_flatten_recursive(items): Recursively flattens a list.
__getitem__(index): Supports 2D-style indexing using tuples.
"""

def flatten(self) -> List[Any]:
"""Flattens a nested list into a single-level list.
Returns:
List[Any]: A flattened list.
Examples:
>>> lst = _RavelList([[1, 2], [3, 4]])
>>> lst.flatten()
[1, 2, 3, 4]
"""
return self._flatten_recursive(self)

@staticmethod
def _flatten_recursive(items: Union[list, Any]) -> List[Any]:
"""Help to recursively flatten a list."""
"""Help to recursively flatten a list.
Args:
items (Union[list, Any]): The list (or single item) to flatten.
Returns:
List[Any]: A flattened list.
Examples:
>>> _RavelList._flatten_recursive([[1, 2], [3, [4, 5]]])
[1, 2, 3, 4, 5]
"""
return [element for sublist in items for element in (sublist if isinstance(sublist, list) else [sublist])]

def __getitem__(self, index: Union[int, tuple]) -> Any:
"""2D-style indexing using tuples. For simple indices, default list behavior is used.
"""2D-style indexing using tuples.
Args:
index (Union[int, tuple]): Index or tuple of indices.
Returns:
Any: Element at the specified index.
Examples:
>>> lst = _RavelList([[1, 2], [3, 4]])
>>> lst[0, 1]
2
"""
if not isinstance(index, tuple):
return super().__getitem__(index)
Expand Down Expand Up @@ -138,10 +168,33 @@ def counter(value, pattern="({alpha})", **kwargs):


class _TrackNewFiguresAndAxes:
"""A simple context manager to handle identiifying new figures or axes."""
"""A simple context manager to handle identifying new figures or axes.
This context manager tracks figures and axes created within its context, allowing
for operations on only the new figures and axes.
Args:
include_open (bool): If `True`, includes already open figures and axes. Defaults to `False`.
Attributes:
_existing_open_figs (list): List of weak references to existing open figures.
_existing_open_axes (dict): Dictionary of weak references to existing open axes.
Methods:
new_figures: Returns an iterator over figures created since the context was entered.
new_axes: Returns an iterator over axes created since the context was entered.
"""

def __init__(self, *args, **kwargs):
"""Set storage of figures and axes."""
"""Set storage of figures and axes.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Keyword Args:
include_open (bool): If `True`, includes already open figures and axes. Defaults to `False`.
"""
super().__init__()
self._existing_open_figs = []
self._existing_open_axes = {}
Expand All @@ -156,7 +209,18 @@ def __enter__(self):

@property
def new_figures(self):
"""Return an iterator over figures created since the context manager was entered."""
"""Return an iterator over figures created since the context manager was entered.
Yields:
matplotlib.figure.Figure: New figures created within the context.
Examples:
>>> with _TrackNewFiguresAndAxes() as tracker:
... plt.figure()
... # New figure created here
>>> list(tracker.new_figures)
[<Figure size ...>]
"""
for num in plt.get_fignums():
fig = plt.figure(num)
if fig in self._existing_open_figs: # Skip figures opened before context
Expand All @@ -165,7 +229,18 @@ def new_figures(self):

@property
def new_axes(self):
"""Return an iterator over all new axes created since the context manager was entered."""
"""Return an iterator over all new axes created since the context manager was entered.
Yields:
matplotlib.axes.Axes: New axes created within the context.
Examples:
>>> with _TrackNewFiguresAndAxes() as tracker:
... fig, ax = plt.subplots()
... # New axes created here
>>> list(tracker.new_axes)
[<AxesSubplot:...>]
"""
for num in plt.get_fignums():
fig = plt.figure(num)
for ax in fig.axes:
Expand Down Expand Up @@ -309,12 +384,31 @@ def __init__(self, filename=None, style=None, autoclose=False, formats=None, inc

@property
def filename(self):
"""Return filename as a Path object without extension."""
"""Return filename as a Path object without extension.
Returns:
Path: The filename or directory path.
Examples:
>>> sf = SavedFigure(filename="plot.png")
>>> sf.filename
PosixPath('plot')
"""
return self._filename

@filename.setter
def filename(self, value):
"""Set filename and extract its extension if valid."""
"""Set filename and extract its extension if valid.
Args:
value (Union[str, Path]): The filename or directory path.
Examples:
>>> sf = SavedFigure()
>>> sf.filename = "plot.png"
>>> sf.filename
PosixPath('plot')
"""
if value is not None:
value = Path(value)
ext = value.suffix[1:]
Expand All @@ -325,12 +419,34 @@ def filename(self, value):

@property
def formats(self):
"""Return the output formats as a list of strings."""
"""Return the output formats as a list of strings.
Returns:
list[str]: The list of output formats.
Examples:
>>> sf = SavedFigure(formats="png,pdf")
>>> sf.formats
['png', 'pdf']
"""
return self._formats

@formats.setter
def formats(self, value):
"""Ensure formats are stored as a list of strings."""
"""Ensure formats are stored as a list of strings.
Args:
value (Union[str, Iterable[str], None]): The formats to store.
Raises:
TypeError: If the value is not str, iterable, or None.
Examples:
>>> sf = SavedFigure()
>>> sf.formats = "png,pdf"
>>> sf.formats
['png', 'pdf']
"""
if isinstance(value, str):
self._formats = [x.strip() for x in value.split(",") if x.strip()]
elif isinstance(value, Iterable):
Expand All @@ -343,12 +459,34 @@ def formats(self, value):

@property
def style(self):
"""Return the stylesheets as a list of strings."""
"""Return the stylesheets as a list of strings.
Returns:
list[str]: The list of stylesheets.
Examples:
>>> sf = SavedFigure(style="default")
>>> sf.style
['default']
"""
return self._style

@style.setter
def style(self, value):
"""Ensure style is stored as a list of strings."""
"""Ensure style is stored as a list of strings.
Args:
value (Union[str, Iterable[str], None]): The styles to store.
Raises:
TypeError: If the value is not str, iterable, or None.
Examples:
>>> sf = SavedFigure()
>>> sf.style = "default,ggplot"
>>> sf.style
['default', 'ggplot']
"""
if isinstance(value, str):
self._style = [x.strip() for x in value.split(",") if x.strip()]
elif isinstance(value, Iterable):
Expand Down Expand Up @@ -402,6 +540,18 @@ def generate_filename(self, label, counter):
Supports placeholders like {label}, {number}, and appends
a counter if multiple new figures are detected.
Args:
label (str): The figure label.
counter (int): The figure counter.
Returns:
str: The generated filename.
Examples:
>>> sf = SavedFigure(filename="plot_{label}.png")
>>> sf.generate_filename("test", 1)
'plot_test.png'
"""
if self.filename.is_dir():
filename: Path = self.filename / "{label}"
Expand Down

0 comments on commit 776f748

Please sign in to comment.