Skip to content

Commit

Permalink
Various fixes; basic plotting for histories
Browse files Browse the repository at this point in the history
  • Loading branch information
alcrene committed Aug 6, 2019
1 parent eb12da7 commit f154c08
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 45 deletions.
8 changes: 6 additions & 2 deletions sinn/analyze/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,11 @@ def subsample(series, amount=None, target_dt=None, aggregation='mean',
as the series' `dt`.
Returns
-------
Series instance
Series instance or array
The result will be `amount` times shorter than `history`. The result
of each new bin is the average over `amount` bins of the original
series. Bins are identified by the time at which they begin.
If an array is passed, a plain array is returned instead of a series.
"""
if isinstance(series, np.ndarray):
if amount is None:
Expand All @@ -408,7 +409,7 @@ def subsample(series, amount=None, target_dt=None, aggregation='mean',
# About 10x faster than passing a callable which computes mean|sum
normalizer = (amount if aggregation is 'mean' else 1)
return sum(data[i : (i+nbins)*amount : amount]
for i in range(amount))/normalizer,
for i in range(amount))/normalizer
# Can't use np.mean on a generator
else:
resdata = np.zeros((nbins,) + data.shape[1:])
Expand All @@ -417,6 +418,9 @@ def subsample(series, amount=None, target_dt=None, aggregation='mean',
return resdata

# else: Everyting below is the `else` branch
if not isinstance(series, histories.HistoryBase):
raise ValueError("`series` argument should be a History, but is of "
"type `{}`".format(type(series)))
series = histories.DataView(series)

if (not isinstance(aggregation, Callable)
Expand Down
73 changes: 50 additions & 23 deletions sinn/analyze/axisdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ class MarginalCollection:
'desc', 'longdesc',
'idx', 'flatidx',
'to_desc', 'back_desc'])
Marker = namedtuple('Marker', ['pos', 'color'])
Marker = namedtuple('Marker', ['pos', 'color', 'α', 'size'])
AxisFormat = namedtuple('AxesFormat', ['key', 'scale', 'visible', 'apply'])

# TODO: Remove flat_idx ?
Expand Down Expand Up @@ -1305,7 +1305,11 @@ def set_colorscheme(self, scheme="viridis"):
else:
self.colorscheme = scheme

def set_markers(self, markers=None, colors=None):
def set_markers(self, markers=None, colors=None, alphas=1., size=None):
"""
All parameters can be passed as either iterable (same length) or scalar.
size=None => compute marker size from axis size (0.2% plot area)
"""
# Standardize `markers` format
if isinstance(markers, (dict, ParameterSet)):
# `markers` is treated as a list of dicts, one dict per marker
Expand All @@ -1328,9 +1332,31 @@ def set_markers(self, markers=None, colors=None):
if len(colors) != 1:
raise ValueError("`colors` argument must either be of length 1 or "
"of the same length as `markers`.")
colors = tuple(colors)*len(markers)
colors = list(colors)*len(markers)
if not isinstance(alphas, Iterable):
alphas = [alphas]*len(markers)
else:
alphas = list(alphas) # In case alphas is consummable
if len(alphas) != len(markers):
if len(alphas) != 1:
raise ValueError(
"`alphas` argument must either be of length 1 or "
"of the same length as `markers`.")
alphas = alphas*len(markers)
if not isinstance(size, Iterable):
sizes = [size]*len(markers)
else:
sizes = list(size) # In case alphas is consummable
if len(sizes) != len(markers):
if len(sizes) != 1:
raise ValueError(
"`size` argument must either be of length 1 or "
"of the same length as `markers`.")
sizes = sizes*len(markers)

self.markers = [self.Marker(pos, color) for pos, color in zip(markers, colors)]
self.markers = [self.Marker(pos, color, α, s)
for pos, color, α, s
in zip(markers, colors, alphas, sizes)]

def set_transformed(self, transformed_axes):
"""
Expand Down Expand Up @@ -1397,7 +1423,7 @@ def _format_axis(self, key, axes, axis):
"""
Internal function that applies the parameters specified in `set_axis()`.
"""
if key in self._axes_format:
if self._axes_format.sanitize(key) in self._axes_format:
# Set which spines are visible
format = self._axes_format[key]
if format.visible is not None:
Expand Down Expand Up @@ -1475,11 +1501,8 @@ def plot_marginal2D(self, keyi, keyj, stddevs=None, marker_size=None,
of a dictionary.
If there is only a single ellipse to draw, it does not need to be wrapped
in a list.
marker_size: float
Value is passed on to the 's' keyword of `scatter()` which plots the markers.
If not specified, calculated from the axes witdh.
**kwargs:
Keyword arguments passed to `ScalarAxisData.histogram_plot()`.
Keyword arguments passed to `ScalarAxisData.plot()`.
"""
if ax is None: ax = plt.gca()

Expand All @@ -1501,8 +1524,16 @@ def plot_marginal2D(self, keyi, keyj, stddevs=None, marker_size=None,
if to_desc is not None:
transform = ml.parameters.Transform(to_desc)
paramj_markers = [transform(marker) for marker in paramj_markers]
# TODO: Don't recreate these lists every time we plot
colors = [marker.color for marker in self.markers]
# TODO: Don't recreate these lists every time we plot
αs = [marker.α for marker in self.markers]
if marker_size is None:
axsize = np.prod(ax.get_window_extent().bounds[2:])
marker_size = axsize / 500
# Base the marker size on the display size
# `s` argument specifies marker _area_, so we use display area
sizes = [marker.size if marker.size is not None else marker_size
for marker in self.markers]
maxexpand = self.maxexpand

if self.transformed[keyi] != self.transformed[keyj]:
Expand All @@ -1519,7 +1550,7 @@ def plot_marginal2D(self, keyi, keyj, stddevs=None, marker_size=None,
# Draw stddev ellipses
if stddevs is None:
stddevs = ()
elif not isinstance(stddevs, Iterable):
elif not isinstance(stddevs, Iterable) or isinstance(stddevs, dict):
stddevs = (stddevs,)
for stddev in stddevs:
# Check that required parameters are there
Expand All @@ -1540,14 +1571,10 @@ def plot_marginal2D(self, keyi, keyj, stddevs=None, marker_size=None,
# Possibly expand the plot region if it doesn't include some of the markers
xlim, ylim = np.array(plt.xlim()), np.array(plt.ylim())
xwidth, yheight = xlim[1]-xlim[0], ylim[1]-ylim[0]
if marker_size is None:
# Base the marker size on the smallest display size (either in x or y)
# TODO: does s give linear or area size ? if latter we should base
# it on width * height rather than min(width, height)
axsize = min(ax.get_window_extent().bounds[2:])
marker_size = axsize / 4
for marki, markj, color in zip(parami_markers, paramj_markers, colors):
ax.scatter(markj, marki, s=marker_size, c=color, zorder=2)
for marki, markj, color, α, s in zip(
parami_markers, paramj_markers, colors, αs, sizes):
ax.scatter(markj, marki, s=s, c=color, zorder=2, alpha=α,
edgecolors='none')
# Recall that histograms set their x-axis to j-parameter
newxlim, newylim = np.array(plt.xlim()), np.array(plt.ylim())
if maxexpand*(xlim[1] - xlim[0]) < (newxlim[1] - newxlim[0]):
Expand Down Expand Up @@ -1615,10 +1642,10 @@ def plot_grid(self, names_to_display, kwargs1D=None, kwargs2D=None, **kwargs):
kwargs1D = {}
if kwargs2D is None:
kwargs2D = {}
def plot_marginal1D(*args):
return self.plot_marginal1D(*args, **kwargs1D)
def plot_marginal2D(*args):
return self.plot_marginal2D(*args, **kwargs2D)
def plot_marginal1D(*args, **kwargs):
return self.plot_marginal1D(*args, **kwargs, **kwargs1D)
def plot_marginal2D(*args, **kwargs):
return self.plot_marginal2D(*args, **kwargs, **kwargs2D)

self._plot_grid_layout(gridkeys=params,
plot_diagonal=plot_marginal1D,
Expand Down
67 changes: 54 additions & 13 deletions sinn/histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,8 @@ def set_update_function(self, func, inputs=None, cast=True, _return_dtype=None,
logger.warning(
"You should specify the list of inputs with the function.")
if inputs is None: inputs = set()
self.inputs = inputs
elif not isinstance(inputs, Iterable): inputs = [inputs]
self.inputs = set(inputs)

def _update_function(self, t):
# TODO: If setting update function in __init__ is deprecated,
Expand Down Expand Up @@ -2423,6 +2424,37 @@ def _is_batch_computable(self, up_to='end'):

return retval

def plot(self, tslc=None, idx=None, ax=None, **plot_kwargs):
"""
tslc: time slice. Can use `np.slc[t0:tn]`
idx: Indices of components to plot. Can be tuple to allow dims > 1.
Specified as either integer or slice.
ax: matplotlib axes. If not specified, use current axes.
plot_kwargs: keywords forwarded to ax.plot()
"""
if ax is None:
from matplotlib.pyplot import gca
ax = gca()
if tslc is None:
start = self.t0idx
stop = self.tnidx + 1
step = 1
else:
start, stop, step = tslc.start, tslc.stop, tslc.step
if start is not None: start = self.get_tidx(start)
if stop is not None: stop = self.get_tidx(stop)
if step is not None: step = self.index_interval(step)
time = self.timeaxis.stops[start:stop:step]
if idx is None:
idx = (np.s_[:],) * self.ndim
elif not isinstance(idx, tuple):
idx = (idx,) * self.ndim
# FIXME: Make an indexing interface to Spiketrain, so that this works
data = self._data[(np.s_[start:stop:step],) + idx]

ax.plot(time, data, **plot_kwargs)


class PopulationHistory(PopulationHistoryBase, History):
"""
History where traces are organized into populations.
Expand Down Expand Up @@ -2553,11 +2585,13 @@ def clear(self, init_data=-np.inf):
self.initialize(init_data)
super().clear()

def set_update_function(self, func, _return_dtype=None):
def set_update_function(self, func, *args, _return_dtype=None, **kwargs):
# FIXME: I don't know when this was written, but shouldn't we return a time type, not index type ?
# Could have been copy-pasted from Spiketrain...
if _return_dtype is None:
super().set_update_functin(func, self.idx_dtype)
super().set_update_functin(func, *args, self.idx_dtype, **kwargs)
else:
super().set_update_function(func, _return_dtype)
super().set_update_function(func, *args, _return_dtype, **kwargs)

def retrieve(self, key):
'''A function taking either an index or a splice and returning respectively
Expand Down Expand Up @@ -3052,11 +3086,11 @@ def clear(self, init_data=None):
self.initialize(None)
super().clear()

def set_update_function(self, func, _return_dtype=None):
def set_update_function(self, func, *args, _return_dtype=None, **kwargs):
if _return_dtype is None:
super().set_update_function(func, _return_dtype=self.idx_dtype)
super().set_update_function(func, *args, _return_dtype=self.idx_dtype, **kwargs)
else:
super().set_update_function(func, _return_dtype=_return_dtype)
super().set_update_function(func, *args, _return_dtype=_return_dtype, **kwargs)


def get_trace(self, pop=None, neuron=None, include_padding='none', time_slice=None):
Expand Down Expand Up @@ -4209,20 +4243,27 @@ def _convolve_op_batch(self, discretized_kernel, kernel_slice):
def _apply_op(self, op, b=None):
if b is None:
new_series = Series(self)
new_series.set_update_function(lambda t: op(self[t]))
new_series.set_range_update_function(lambda tarr: op(self[self.time_array_to_slice(tarr)]))
new_series.add_input(self)
new_series.set_update_function(
lambda t: op(self[t]),
inputs = self)
new_series.set_range_update_function(
lambda tarr: op(self[self.time_array_to_slice(tarr)]),
inputs = self)
# new_series.add_input(self)
elif isinstance(b, HistoryBase):
# HACK Should write function that doesn't create empty arrays
shape = np.broadcast(np.empty(self.shape), np.empty(b.shape)).shape
tnidx = min(self.tnidx, b.get_tidx_for(b.tnidx, self))
new_series = Series(self, shape=shape,
time_array=self._tarr[:tnidx+1])
new_series.set_update_function(lambda t: op(self[t], b[t]))
new_series.set_update_function(
lambda t: op(self[t], b[t]),
inputs = [self, b])
new_series.set_range_update_function(
lambda tarr: op(self[self.time_array_to_slice(tarr)],
b[b.time_array_to_slice(tarr)]))
new_series.add_input(self)
b[b.time_array_to_slice(tarr)]),
inputs = [self, b])
#new_series.add_input(self)
computable_tidx = min(
self.get_tidx_for(min(self.cur_tidx, self.tnidx), new_series),
b.get_tidx_for(min(b.cur_tidx, b.tnidx), new_series))
Expand Down
29 changes: 24 additions & 5 deletions sinn/models/driftdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import theano_shim as shim
import sinn
from sinn.histories import Series
from sinn.histories import Series, HistoryBase
from sinn.models.common import Model, register_model

from mackelab.utils import StringFunction
Expand All @@ -31,6 +31,7 @@ class DriftDiffusion(Model):
State = namedtuple('State', ['x'])

def __init__(self, params, x_history, random_stream=None,
t0=None, tn=None, dt=None,
drift=None, diffusion=None, namespace=None):
"""
Default `drift` and `diffusion` implement
Expand Down Expand Up @@ -59,14 +60,32 @@ def __init__(self, params, x_history, random_stream=None,
raise TypeError("`x_history` argument must be a sinn `Series`.")
Model.output_rng(x_history, self.rndstream)

super().__init__(params, public_histories=(x_history,))
if t0 is None: t0 = x_history.t0
if tn is None: tn = x_history.tn
if dt is None: dt = x_history.dt

super().__init__(params, t0=t0, tn=tn, dt=dt,
public_histories=(x_history,))
# NOTE: Do not use `params` beyond here. Always use self.params.

self.x = x_history
inputs = set([self.x])
# TODO: I haven't tested dependence on other histories.
# Also, how should we deal with dependence with a time lag ?
if namespace is not None:
for I in namespace.values():
if isinstance(I, HistoryBase):
inputs.add(I)
if len(inputs) > 1:
logger.warning(
"Support for dependence on other histories is experimental."
" You have specified update dependencies on the histories "
+ ', '.join([str(I) for I in inputs[1:]]) + "."
)

self.add_history(self.x)
self.x.set_update_function(self.x_fn)
self.x.add_input(self.x)
self.x.set_update_function(self.x_fn, inputs=inputs)
#self.x.add_input(self.x)

if isinstance(drift, str):
# Evaluate string to get function
Expand Down Expand Up @@ -119,7 +138,7 @@ def dW(self):
# here: `sampler()`
return self.rndstream.normal(size=self.x.shape, avg=0, std=np.sqrt(self.x.dt))

def drift(self, t):
def drift(self, t, x):
tidx = x.get_t_idx(t)
return -x[tidx-1]

Expand Down
18 changes: 16 additions & 2 deletions sinn/optimize/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,7 @@ def _plot(self, ax, stops, traces,
else:
logLs = [fit.data.cost_trace[-1].logL for fit in self.fits]
maxlogL = max(logLs)
σlogLs = np.argsort(logLs)

def get_color(logL):
# Do the interpolation in hsv space ensures intermediate colors are OK
Expand Down Expand Up @@ -1820,7 +1821,11 @@ def get_color(logL):
plot_traces = [trace[idcs] for trace, idcs in zip(traces, trace_idcs)]

# Loop over the traces
for trace, stops, logL in zip(plot_traces, trace_stops, logLs):
ntraces = len(plot_traces)
#transform = lambda rank: 0.5 - np.tanh(5 * (rank/ntraces - 0.4)) / 2
transform = lambda rank: 1 - rank/ntraces
for trace, stops, logL, rank in zip(
plot_traces, trace_stops, logLs, σlogLs):
# Set plotting parameters
if logL > maxlogL - keep_range:
if keep_color is None:
Expand All @@ -1832,8 +1837,17 @@ def get_color(logL):
if discard_color is None:
continue
kwargs = {'color': discard_color,
'zorder': -1,
'zorder': -5,
'linewidth': linewidth[1]}
# If there are a lot of discarded traces, vary their opacity
# according to logL, so that we don't just see a blob
if ntraces > 100: # HACK: Hard-coded arbitrary threshold
c = mpl.colors.to_rgb(kwargs['color'])
α = transform(rank)
# Using alpha doesn't always work, so merge with white
# background is hard-coded
c = 1 + (np.array(c)-1)*α
kwargs['color'] = tuple(c)

# Draw plot
plot_kwargs.update(kwargs)
Expand Down

0 comments on commit f154c08

Please sign in to comment.