Skip to content

Commit

Permalink
Refactored plotting parallel coordinates into classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
joseph-hellerstein committed May 29, 2024
1 parent e45dc51 commit a255da1
Show file tree
Hide file tree
Showing 7 changed files with 1,967 additions and 622 deletions.
1,993 changes: 1,580 additions & 413 deletions examples/Controlling-Glucose-Metabolis-in-Yeast.ipynb

Large diffs are not rendered by default.

45 changes: 38 additions & 7 deletions src/controlSBML/control_sbml.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import controlSBML.constants as cn
import controlSBML.msgs as msgs
from controlSBML.grid import Grid
from controlSBML.parallel_coordinates import ParallelCoordinates

from collections import namedtuple
import os
Expand Down Expand Up @@ -237,6 +238,7 @@ def __init__(self, model_reference:str,
is_fixed_input_species=is_fixed_input_species, is_steady_state=is_steady_state) # type: ignore
self._fitter_result = cn.FitterResult()
self.antimony_builder:Optional[AntimonyBuilder] = None # Last antimony builder used
self.design_result:Optional[DesignResult] = None # Result of last design

def copy(self):
ctlsb = ControlSBML(self.model_reference,
Expand Down Expand Up @@ -653,7 +655,8 @@ def _plotClosedLoop(self,
msg = "System is unstable for kP={}, kI={}, kD={}, kF={}".format(kP, kI, kD, kF)
msgs.warn(msg)
return response_ts, builder
self._sbml_system.plotSISOClosedLoop(response_ts, setpoint, times=times, **plot_dct)
if plot_dct[cn.O_IS_PLOT]:
self._sbml_system.plotSISOClosedLoop(response_ts, setpoint, times=times, **plot_dct)
return response_ts, builder

def plotGridDesign(self,
Expand Down Expand Up @@ -703,11 +706,13 @@ def plotGridDesign(self,
input_name=self.input_name,
output_name=self.output_name)
# Translate axis names
designs = designer.designAlongGrid(grid, is_greedy=self.is_greedy, num_process=num_process, # type: ignore
self.design_result = designer.designAlongGrid(grid,
num_process=num_process, # type: ignore
num_restart=num_restart) # type: ignore
if designer.residual_mse is None:
msgs.warn("No design found!")
return GridDesignResult(timeseries=None, antimony_builder=None, designs=designs)
return GridDesignResult(timeseries=None, antimony_builder=None,
designs=self.design_result)
# Persist the design parameters
self.setOption(cn.CP_KP, designer.kP)
self.setOption(cn.CP_KI, designer.kI)
Expand All @@ -733,7 +738,31 @@ def plotGridDesign(self,
selections=selections,
**plot_dct)
return GridDesignResult(timeseries=response_ts, antimony_builder=antimony_builder,
designs=designs)
designs=self.design_result)

def plotAllDesignResults(self, design_result:Optional[DesignResult]=None,
title=None, figsize=None, columns=None,
round_digit:int=4, is_plot=True):
"""
Does a parallel coordinate plot of the design results.
Args:
design_result: DesignResult
title: str
figsize: tuple
columns: list-str (order in whch columns appear on plot)
round_digit: int (number of digits to round to)
"""
if design_result is None:
design_result = self.design_result
df = design_result.dataframe.copy() # type: ignore
del df[cn.REASON]
ParallelCoordinates.plotParallelCoordinates(df, # type: ignore
title=title, figsize=figsize,
columns=columns,
value_column=cn.SCORE,
round_digit=round_digit,
is_plot=is_plot)

@staticmethod
def setSpec(val):
Expand Down Expand Up @@ -836,7 +865,7 @@ def setSpec(key, val):
setpoint=setpoint,
input_name=self.input_name,
output_name=self.output_name)
designs = designer.design(
self.design_result = designer.design(
kP_spec=kP_spec,
kI_spec=kI_spec,
kD_spec=kD_spec,
Expand All @@ -850,7 +879,8 @@ def setSpec(key, val):
num_process=num_process) # type: ignore
if designer.residual_mse is None:
msgs.warn("No design found!")
return DesignResult(timeseries=None, antimony_builder=None, designs=designs)
return DesignResult(timeseries=None, antimony_builder=None,
designs=self.design_result)
# Persist the design parameters
self.setOption(cn.CP_KP, designer.kP)
self.setOption(cn.CP_KI, designer.kI)
Expand Down Expand Up @@ -878,7 +908,8 @@ def setSpec(key, val):
kF=designer.kF,
selections=selections,
**plot_dct)
return DesignResult(timeseries=response_ts, antimony_builder=antimony_builder, designs=designs)
return DesignResult(timeseries=response_ts, antimony_builder=antimony_builder,
designs=self.design_result)

def _plotDesignResult(self, **kwargs):
"""
Expand Down
249 changes: 249 additions & 0 deletions src/controlSBML/parallel_coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
"""Does parallel coordinate plots for a dataframe."""

from colour import Color # type: ignore
from matplotlib import ticker # type: ignore
import numpy as np
import matplotlib.pyplot as plt # type: ignore
import pandas as pd # type: ignore
from typing import Tuple, Dict, List

C_DUMMY = "dummy"
C_INDEX = "index"
C_UNNAMED = "Unnamed: 0"
LABEL_NAN = "nan"

class ParallelCoordinates(object):

def __init__(self, df, value_column=None, num_category=10, columns=None,
figsize=(15,5), round_digit=None, title=None,
noise_std=0.01, is_plot=True):
"""
Does a parallel coordinate plots for with lines colored based on categorizing
the value_column. One category is for nans.
Args:
df (pd.DataFrame): contains parameters to plot and the value_column
value_column (object, optional): Column of dataframe that has the values to
categorize. Defaults to first column.
columns (str, optional): Columns in order
num_category (int, optional): Number of categories to split the value_column,
including the nan category
title (str, optional): Title of the plot
figsize (tuple, optional): Size of the plot
noise_std (float, optional): Standard deviation of noise to add to the data
"""
# Save inputs
self.unscaled_df = df.copy()
if value_column is None:
value_column = columns[0]
self.value_column = value_column
self.num_category = num_category
self.noise_std = noise_std
#
self.figsize = figsize
self.round_digit = round_digit
self.title = title
self.is_plot = is_plot
self.values = self.unscaled_df[self.value_column]
# Handle columns
if columns is None:
columns = list(self.unscaled_df.columns)
self.columns = columns
# Ensure that value_column is the last column
if self.value_column in columns:
self.columns.remove(value_column)
self.columns.append(value_column)
self.unscaled_df = self.unscaled_df[self.columns]
# Manipulate data
self.scaled_df = self._makeScaledDF()
self.unscaled_df = self._fixupDF(self.unscaled_df)
self.scaled_df = self._fixupDF(self.scaled_df)
self.columns = list(self.scaled_df.columns)

def _makeScaledDF(self)->pd.DataFrame:
"""
Creates a dataframe that in which each column is scaled between 0 and 1.
Returns:
pd.DataFrame
"""
scaled_df = self.unscaled_df.copy()
scaled_df = self._fixupDF(scaled_df)
# Adjust the value ranges
for column in self.columns:
ser = self.unscaled_df[column]
min_val, range_val = ser.min(), np.ptp(ser)
if range_val > 0:
values = (ser - min_val)/range_val
else:
values = 0
scaled_df[column] = values + np.random.normal(0, self.noise_std, len(ser))
#
return scaled_df

def _fixupDF(self, df)->pd.DataFrame:
"""Cleans_ up the dataframe"""
df = df.reset_index(drop=False)
for column in [C_INDEX, C_UNNAMED]:
if column in df.columns:
del df[column]
return df

def _setAxisTicks(self, column:str, ax, num_tick:int=6):
"""
Set the tick positions and labels on y axis for each plot
Tick positions based on normalised data
Tick labels are based on original data
Args:
column: str
ax (matplotlib.Axis): _description_
num_tick (int): Number of ticks
"""
def min_step(ser):
min_val = ser.min()
range_val = np.ptp(ser)
step = range_val/float(num_tick-1)
return min_val, step
#
scaled_min, scaled_step = min_step(self.scaled_df[column])
unscaled_min, unscaled_step = min_step(self.unscaled_df[column])
if self.round_digit is None:
tick_labels = [unscaled_min + unscaled_step * i for i in range(num_tick)]
ticks = [scaled_min + scaled_step * i for i in range(num_tick)]
else:
tick_labels = [np.round(unscaled_min + unscaled_step * i, self.round_digit) for i in range(num_tick)]
ticks = [np.round(scaled_min + scaled_step * i, self.round_digit) for i in range(num_tick)]
is_ticks = [l != tick_labels[0] for l in tick_labels]
is_ticks[0] = True
ticks = [t for i, t in enumerate(ticks) if is_ticks[i]]
tick_labels = [t for i, t in enumerate(tick_labels) if is_ticks[i]]
ax.yaxis.set_ticks(ticks)
ax.set_yticklabels(tick_labels)

def _makeCategoriesAndLabels(self)->Tuple[np.array, np.array, np.array]:
"""
Create the categories for each row.
Returns:
np.array - categories for rows
np.array - untrimmed labels for each categories
np.array - trimmed labels for each categories
"""
NAN_CATEGORY = self.num_category - 1
# Construct the categories for each row
classify_vals = self.unscaled_df[self.value_column]
nan_idx = np.array([np.isnan(v) or np.isinf(v) for v in classify_vals])
classify_arr = np.array(classify_vals)
classify_arr[nan_idx] = np.mean(classify_arr[~nan_idx])
# Make the categories accounting for the nans
categories = pd.cut(classify_arr, self.num_category-1,
labels=range(self.num_category-1)).tolist()
categories = np.array(categories)
categories[nan_idx] = NAN_CATEGORY
# Create the categories and labels
category_intervals = pd.cut(classify_arr, self.num_category)
trimmed_category_labels = list(np.repeat("", self.num_category))
untrimmed_category_labels = list(np.repeat("", self.num_category))
for idx, category in enumerate(categories):
left_value = category_intervals[idx].left
right_value = category_intervals[idx].right
label = f"{left_value} to {right_value}"
untrimmed_category_labels[category] = label
#
if self.round_digit is not None:
left_value = np.round(left_value, self.round_digit)
right_value = np.round(right_value, self.round_digit)
label = f"{left_value} to {right_value}"
trimmed_category_labels[category] = label
else:
trimmed_category_labels[category] = untrimmed_category_labels[category]
untrimmed_category_labels[NAN_CATEGORY] = LABEL_NAN
trimmed_category_labels[NAN_CATEGORY] = LABEL_NAN
#
return categories, untrimmed_category_labels, trimmed_category_labels

def plot(self):
"""
Plot the parallel coordinates
"""
categories, untrimmed_labels, trimmed_labels = self._makeCategoriesAndLabels()

# Create the colors for the categories
red = Color("red")
colors = list(red.range_to(Color("blue"), self.num_category-1))
colors.reverse()
colors.append("grey")

# Create sublots along x axis
_, axes = plt.subplots(1, len(self.columns), sharey=False, figsize=self.figsize)

# Plot each row
x_vals = list(range(len(self.columns)))
for i, ax in enumerate(axes):
for idx in self.scaled_df.index:
category = categories[idx]
ax.plot(x_vals, self.scaled_df.loc[idx, self.columns], c=str(colors[category]))
ax.set_xlim([x_vals[i], x_vals[i] + 1])

# Set the ticks for each axis
for idx, ax in enumerate(axes):
column = self.columns[idx]
ax.xaxis.set_major_locator(ticker.FixedLocator([idx]))
self._setAxisTicks(column, ax)
ax.set_xticklabels([column])

# Move the final axis' ticks to the right-hand side
axes[-1].remove() # Remove the dummy
axes = axes[:-1]
x_vals = x_vals[:-1]
columns = self.columns[:-1]
ax = plt.twinx(axes[-2])
idx = len(axes) - 1
ax.xaxis.set_major_locator(ticker.FixedLocator([x_vals[-2], x_vals[-1]]))
self._setAxisTicks(columns[idx], ax)
ax.set_xticklabels([columns[-2], columns[-1]])

# Remove space between subplots
plt.subplots_adjust(wspace=0)
axes[-1].remove()

# Add legend to plot
valid_idxs = np.array([v for v in range(self.num_category) if v in categories])
trimmed_category_arr = np.array(untrimmed_labels)[valid_idxs]
plt.legend(
[plt.Line2D((0,1),(0,0), color=str(colors[v])) for v in range(self.num_category)
if v in categories],
trimmed_category_arr,
bbox_to_anchor=(1.75, 1), loc=2, borderaxespad=0.)

# Add title
if self.title is None:
value_column_str = str(self.value_column)
title = f"Values by {value_column_str} category"
else:
title = f"{self.title}"
plt.title(title)
if self.is_plot:
plt.show()

@classmethod
def plotParallelCoordinates(cls, df, value_column=None, num_category=10, columns=None,
figsize=(15,5), round_digit=None, title=None, is_plot=True):
"""
Does a parallel coordinate plots for with lines colored based on categorizing
the value_column. One category is for nans.
Args:
df (pd.DataFrame): contains parameters to plot and the value_column
value_column (object, optional): Column of dataframe that has the values to
categorize. Defaults to first column.
columns (str, optional): Columns in order
num_category (int, optional): Number of categories to split the value_column,
including the nan category
title (str, optional): Title of the plot
figsize (tuple, optional): Size of the plot
"""
plotter = cls(df, value_column=value_column, num_category=num_category, columns=columns,
figsize=figsize, round_digit=round_digit, title=title, is_plot=is_plot)
plotter.plot()
Loading

0 comments on commit a255da1

Please sign in to comment.