Skip to content

Commit

Permalink
Merge pull request #49 from petebunting/main
Browse files Browse the repository at this point in the history
Updates to plotting functions
  • Loading branch information
petebunting authored May 19, 2022
2 parents 235eaff + b135c15 commit 7f4c054
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 21 deletions.
1 change: 1 addition & 0 deletions doc/python/source/rsgislib_tools_plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ RSGISLib Plotting Tools
Statistical Plots
-------------------
.. autofunction:: rsgislib.tools.plotting.residual_plot
.. autofunction:: rsgislib.tools.plotting.residual_density_plot
.. autofunction:: rsgislib.tools.plotting.quantile_plot


Expand Down
1 change: 1 addition & 0 deletions doc/python/source/rsgislib_tools_utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Numeric
Colours
---------
.. autofunction:: rsgislib.tools.utils.hex_to_rgb
.. autofunction:: rsgislib.tools.utils.rgb_to_hex

Dates
-------
Expand Down
200 changes: 182 additions & 18 deletions python/rsgislib/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
except ImportError:
have_matplotlib = False

have_mpl_scatter_density = True
try:
import mpl_scatter_density
except ImportError:
have_mpl_scatter_density = False


def plot_image_spectra(
input_img,
Expand Down Expand Up @@ -474,6 +480,119 @@ def residual_plot(y_true, residuals, out_file, out_format="PNG", title=None):
plt.close()


def residual_density_plot(
y_true: numpy.array,
residuals: numpy.array,
out_file: str,
out_format: str = "PNG",
out_dpi: int = 800,
title: str = None,
cmap_name: str = "viridis",
use_log_norm: bool = False,
density_norm_vmin: float = 1,
density_norm_vmax: float = None,
freq_nbins: int = 50,
val_plt_range: List[float] = None,
resid_plt_range: List[float] = None,
):
"""
A function to create a residual plot where the scatter plot will be represented
as a density plot. This plot allows the investigatation of the
normality and homoscedasticity of model residuals.
:param y_true: A numpy 1D array containing true/observed values.
:param residuals: A numpy 1D array containing model residuals.
:param out_file: Path to the output file.
:param out_format: Output format supported by matplotlib (e.g. "PNG" or "PDF").
Default: PNG
:param out_dpi: the output DPI of the save raster plot (default: 800)
:param title: A title for the plot. Optional, if None then ignored. (Default: None)
:param cmap_name: The name of the colour bar to use for the density plot
Default: viridis
:param use_log_norm: Specify whether to use log normalisation for the density plot
instead of linear. (Default: False)
:param density_norm_vmin: the minimum density value for the normalisation
(default: 1)
:param density_norm_vmax: the maximum density value for the normalisation
(default: None)
:param freq_nbins: the number of bins used for the frequency histogram (Default: 50)
:param val_plt_range: A user specified x-axis range of values (Default: None). If
specified then must be a list of 2 values.
:param resid_plt_range: A user specified y-axis range of values (Default: None) If
specified then must be a list of 2 values.
"""
if not have_matplotlib:
raise rsgislib.RSGISPyException(
"The matplotlib module is required and could not be imported."
)
if not have_mpl_scatter_density:
raise rsgislib.RSGISPyException(
"The mpl_scatter_density module is required and could not be imported."
)

if not isinstance(residuals, numpy.ndarray):
residuals = numpy.array(residuals)
if not isinstance(y_true, numpy.ndarray):
y_true = numpy.array(y_true)
if y_true.ndim != 1:
raise rsgislib.RSGISPyException("y_true has more than 1 dimension.")
if residuals.ndim != 1:
raise rsgislib.RSGISPyException("Residuals has more than 1 dimension.")
if residuals.size != y_true.size:
raise rsgislib.RSGISPyException("y_true.size != residuals.size.")
if val_plt_range is not None:
if len(val_plt_range) != 2:
raise rsgislib.RSGISPyException("val_plt_range must have len of 2")
if resid_plt_range is not None:
if len(resid_plt_range) != 2:
raise rsgislib.RSGISPyException("resid_plt_range must have len of 2")

c_cmap = plt.get_cmap(cmap_name)
mClrs.Colormap.set_under(c_cmap, color="white")
if use_log_norm:
c_norm = mClrs.LogNorm(vmin=density_norm_vmin, vmax=density_norm_vmax)
else:
c_norm = mClrs.Normalize(vmin=density_norm_vmin, vmax=density_norm_vmax)

# setup plot:
# rcParams.update({'font.family': 'cmr10'}) # use latex fonts.
# rcParams['axes.unicode_minus'] = False
rcParams.update({"font.size": 8.5})
rcParams["axes.linewidth"] = 0.5
rcParams["xtick.major.pad"] = "2"
rcParams["ytick.major.pad"] = "2"
fig = plt.figure(figsize=(10, 5))
gs = gridspec.GridSpec(nrows=1, ncols=2, width_ratios=[3.5, 1])
ax1 = plt.subplot(gs[0], projection="scatter_density")
ax2 = plt.subplot(gs[1])
plt.tight_layout(w_pad=-1, h_pad=0)

# draw scatterplot:
ax1.axhline(y=0.0, c="k", ls=":", lw=0.5, zorder=2)
ax1.scatter_density(y_true, residuals, norm=c_norm, cmap=c_cmap, zorder=1)
ax1.set_xlabel("Observed value", fontsize=9)
ax1.set_ylabel("Residuals", fontsize=9)
if val_plt_range is not None:
ax1.set_xlim(val_plt_range[0], val_plt_range[1])
if resid_plt_range is not None:
ax1.set_ylim(resid_plt_range[0], resid_plt_range[1])
if title is not None:
ax1.set_title(title)

# draw histogram:
ax2.get_xaxis().tick_bottom()
ax2.get_yaxis().tick_right()
ax2.get_yaxis().set_visible(False)
ax2.hist(residuals, bins=freq_nbins, orientation="horizontal", color="C0")
ax2.axhline(y=0.0, c="k", ls=":", lw=0.5, zorder=2)
ax2.set_xlabel("Frequency", fontsize=9)
if resid_plt_range is not None:
ax2.set_ylim(resid_plt_range[0], resid_plt_range[1])
plt.savefig(out_file, format=out_format, dpi=out_dpi, bbox_inches="tight")
plt.close()


def quantile_plot(residuals, ylabel, out_file, out_format="PNG", title=None):
"""
A function to create a Quantile-Quantile plot to investigate the
Expand Down Expand Up @@ -661,8 +780,9 @@ def get_gdal_thematic_raster_mpl_imshow(
input_img: str,
band: int = 1,
bbox: List[float] = None,
out_patches=False,
cls_names_lut=None,
out_patches: bool = False,
cls_names_lut: Dict = None,
alpha_lyr: bool = False,
) -> Tuple[numpy.array, List[float], list]:
"""
A function which retrieves thematic image data with a colour table as an
Expand All @@ -681,7 +801,11 @@ def get_gdal_thematic_raster_mpl_imshow(
create a legend.
:param cls_names_lut: A dictionary LUT with labels for the classes. The dict
key is the pixel value for the class and
:return: numpy.array either [n,m,3], a bbox (xmin, xmax, ymin, ymax)
:param alpha_lyr: a boolean specifying whether an alpha channel should be
created and therefore the returned array will have 4
rather than 3 dims. If an alpha channel is created then
then background will be transparent.
:return: numpy.array either [n,m,3 or 4], a bbox (xmin, xmax, ymin, ymax)
specifying the extent of the image data and list of matplotlib patches,
if out_patches=False then None is returned.
Expand Down Expand Up @@ -752,6 +876,8 @@ def get_gdal_thematic_raster_mpl_imshow(
red_arr = numpy.zeros_like(img_data_arr, dtype=numpy.uint8)
grn_arr = numpy.zeros_like(img_data_arr, dtype=numpy.uint8)
blu_arr = numpy.zeros_like(img_data_arr, dtype=numpy.uint8)
if alpha_lyr:
alp_arr = numpy.zeros_like(img_data_arr, dtype=numpy.uint8)

lgd_out_patches = None
if out_patches:
Expand All @@ -762,6 +888,8 @@ def get_gdal_thematic_raster_mpl_imshow(
red_arr[img_data_arr == i] = clr_tab_entry[0]
grn_arr[img_data_arr == i] = clr_tab_entry[1]
blu_arr[img_data_arr == i] = clr_tab_entry[2]
if alpha_lyr and (i > 0):
alp_arr[img_data_arr == i] = 255

if out_patches and (i > 0):
cls_name = f"{i}"
Expand All @@ -776,7 +904,10 @@ def get_gdal_thematic_raster_mpl_imshow(
Patch(facecolor=rgb_clr, edgecolor=rgb_clr, label=cls_name)
)

img_clr_data_arr = numpy.stack([red_arr, grn_arr, blu_arr], axis=-1)
if alpha_lyr:
img_clr_data_arr = numpy.stack([red_arr, grn_arr, blu_arr, alp_arr], axis=-1)
else:
img_clr_data_arr = numpy.stack([red_arr, grn_arr, blu_arr], axis=-1)

image_ds = None
img_data_arr = None
Expand Down Expand Up @@ -1231,20 +1362,20 @@ def manual_stretch_np_arr(


def create_legend_img(
legend_info,
out_img_file,
n_cols=1,
box_size=(10, 20),
title_str=None,
font_size=12,
font=None,
font_clr=(0, 0, 0, 255),
col_width=None,
img_height=None,
char_width=6,
bkgd_clr=(255, 255, 255, 255),
title_height=16,
margin=2,
legend_info: Dict,
out_img_file: str,
n_cols: int = 1,
box_size: Tuple[int] = (10, 20),
title_str: str = None,
font_size: int = 12,
font: str = None,
font_clr: Tuple[int] = (0, 0, 0, 255),
col_width: int = None,
img_height: int = None,
char_width: int = 6,
bkgd_clr: Tuple[int] = (255, 255, 255, 255),
title_height: int = 16,
margin: int = 2,
):
"""
A function which can generate a legend image file using the PIL module.
Expand Down Expand Up @@ -1346,3 +1477,36 @@ def create_legend_img(
break

img_obj.save(out_img_file)


def gen_colour_lst(cmap_name: str, n_clrs: int, reverse: bool = False) -> List[str]:
"""
A function which gets a list of colours as hex strings from a matplotlib colour
bar.
For available colour bars see:
https://matplotlib.org/stable/tutorials/colors/colormaps.html
:param cmap_name: The name of a matplotlib colour bar
:param n_clrs: The number of colours to be returned
:param reverse: Option to reverse the order of the colours
:return: List of hex colour presentations
"""
if not have_matplotlib:
raise rsgislib.RSGISPyException(
"The matplotlib module is required and could not be imported."
)

c_map = plt.cm.get_cmap(cmap_name)
vals_arr = numpy.linspace(0, 1, n_clrs)
clr_lst = list()
for c in vals_arr:
rgba = c_map(c)
clr = mClrs.rgb2hex(rgba) # convert to hex
clr_lst.append(str(clr)) # create a list of these colors

if reverse == True:
clr_lst.reverse()
return clr_lst
41 changes: 40 additions & 1 deletion python/rsgislib/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import string
import json
from typing import List
from typing import List, Union

import numpy
import rsgislib
Expand Down Expand Up @@ -450,6 +450,45 @@ def hex_to_rgb(hex_str: str):
return int(r_hex, 16), int(g_hex, 16), int(b_hex, 16)


def rgb_to_hex(
r: Union[int, float],
g: Union[int, float],
b: Union[int, float],
normalised: bool = False,
) -> str:
"""
A function which converts red, green, blue values to a hexadecimal colour
representation.
For example: 180, 50, 190 is equal to: #b432be
:param r: number with range either 0-255 or 0-1 if normalised
:param g: number with range either 0-255 or 0-1 if normalised
:param b: number with range either 0-255 or 0-1 if normalised
:param normalised: a boolean specifying the inputs are in range 0-1
:return: string with hexadecimal colour representation
"""
if normalised:
r = int(r * 255)
g = int(g * 255)
b = int(b * 255)

if (r < 0) or (r > 255):
raise rsgislib.RSGISPyException(
"Red value must be between 0-255 or 0-1 if normalised"
)
if (g < 0) or (g > 255):
raise rsgislib.RSGISPyException(
"Green value must be between 0-255 or 0-1 if normalised"
)
if (b < 0) or (b > 255):
raise rsgislib.RSGISPyException(
"Blue value must be between 0-255 or 0-1 if normalised"
)

return "#{:02x}{:02x}{:02x}".format(r, g, b)


def remove_repeated_chars(str_val: str, repeat_char: str):
"""
A function which removes repeated characters within a string for the
Expand Down
3 changes: 1 addition & 2 deletions python/rsgislib/vectorutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,11 +2269,10 @@ def vector_translate(
if os.path.exists(out_vec_file) and del_exist_vec:
delete_vector_file(out_vec_file)

n_feats = get_vec_feat_count(vec_file, vec_lyr)
try:
import tqdm

pbar = tqdm.tqdm(total=n_feats)
pbar = tqdm.tqdm(total=100)
callback = lambda *args, **kw: pbar.update()
except:
callback = gdal.TermProgress
Expand Down

0 comments on commit 7f4c054

Please sign in to comment.