import matplotlib.pyplot as plt
import numpy as np
import torch
from typing import Iterable, Union, Any, Tuple
from pandas import DataFrame
from .utility import retrieve_from_dict
[docs]
class Observer:
"""Class that is used to record state variables, outputs, and losses during calls of `Network.train`,
`Network.test`, or `Network.run`.
"""
def __init__(self, dt: float, record_output: bool = True, record_loss: bool = True, record_vars: list = None):
"""Instantiates observer.
Parameters
----------
dt
Step-size of training/testing/integration steps.
record_output
If true, the output of the `Network` instance is recorded.
record_loss
If true, the loss calculated during training/testing by the `Network` is recorded.
record_vars
Additional variables of the RNN layer that should be recorded.
"""
if not record_vars:
record_vars = []
self._dt = dt
self._state_vars = [v[:2] for v in record_vars]
self._reduce_vars = [v[2] for v in record_vars]
self._recordings = {v: [] for v in self._state_vars}
self._record_loss = record_loss
self._record_out = record_output
if record_loss:
self._recordings['loss'] = []
if record_output:
self._recordings['out'] = []
self._recordings["steps"] = []
self._additional_storage = {}
def __getitem__(self, item: Union[str, Tuple[str, str]]):
try:
return self._recordings[item]
except KeyError:
return self._additional_storage[item]
@property
def recorded_state_variables(self) -> list:
"""RNN state variables that are recorded by this `Observer` instance.
"""
return self._state_vars
@property
def recorded_variables(self) -> list:
"""RNN state variables that are recorded by this `Observer` instance.
"""
return list(self._recordings.keys())
@property
def recordings(self):
columns = self._state_vars
if self._record_out:
columns.append("out")
if self._record_loss:
columns.append("loss")
data = np.asarray([self[v] for v in columns]).T
return DataFrame(index=np.asarray(self._recordings["steps"])*self._dt, data=data, columns=columns)
[docs]
def to_dataframe(self, item: Union[str, Tuple[str, str]]):
try:
data = self.to_numpy(item)
return DataFrame(index=np.asarray(self._recordings["steps"])*self._dt, data=data)
except KeyError:
return self[item]
[docs]
def record(self, step: int, output: torch.Tensor, loss: Union[float, torch.Tensor],
record_vars: Iterable[torch.Tensor]) -> None:
"""Performs a single recording steps.
Parameters
----------
step
Integration step.
output
Output of the `Network` model.
loss
Current loss of the `Network` model.
record_vars
Additional variables of the RNN layer that should be recorded.
Returns
-------
None
"""
recs = self._recordings
recs["steps"].append(step)
for key, val, reduce in zip(self._state_vars, record_vars, self._reduce_vars):
recs[key].append(torch.mean(val) if reduce else val)
if self._record_out:
recs['out'].append(output)
if self._record_loss:
recs['loss'].append(loss)
[docs]
def save(self, key: str, val: Any):
"""Saves object on observer. Can be retrieved via `key`.
Parameters
----------
key
Used for storage/retrieval.
val
Object to be stored.
"""
self._additional_storage[key] = val
[docs]
def to_numpy(self, item: Union[str, Tuple[str, str]]) -> np.ndarray:
try:
val = self._recordings[item]
except KeyError:
val = self._additional_storage[item]
try:
val_numpy = np.asarray([v.detach().cpu().numpy() for v in val])
except AttributeError as e:
raise e
return val_numpy
[docs]
def plot(self, y: Union[str, Tuple[str, str]], x: Union[str, Tuple[str, str]] = None, ax: plt.Axes = None,
**kwargs) -> plt.Axes:
"""Create a line plot with variable `y` on the y-axis and `x` on the x-axis.
Parameters
----------
y
Tuple that contains the names of the node and the node variable to be plotted on the y-axis.
x
Tuple that contains the names of the node and the node variable to be plotted on the x-axis.
If not provided, `y` will be plotted against time steps.
ax
`matplotlib.pyplot.Axes` instance in which to plot.
kwargs
Additional keyword arguments for the `matplotlib.pyplot.plot` call.
Returns
-------
plt.Axes
Instance of `matplotlib.pyplot.Axes` that contains the line plot.
"""
if ax is None:
subplot_kwargs = retrieve_from_dict(['figsize'], kwargs)
_, ax = plt.subplots(**subplot_kwargs)
if x is None:
ax.plot(self.to_dataframe(y), **kwargs)
else:
ax.plot(self.to_numpy(x), self.to_numpy(y), **kwargs)
ax.set_xlabel('time' if x is None else f"Node: {x[0]}, variable: {x[-1]}" if type(x) is tuple else x)
ax.set_ylabel(f"Node: {y[0]}, variable: {y[-1]}" if type(y) is tuple else y)
return ax
[docs]
def matshow(self, v: Union[str, Tuple[str, str]], ax: plt.Axes = None, **kwargs) -> plt.Axes:
"""Create a 2D color plot of variable `v`.
Parameters
----------
v
Tuple that contains the names of the node and the node variable to be plotted.
ax
`matplotlib.pyplot.Axes` instance in which to plot.
kwargs
Additional keyword arguments for the `matplotlib.pyplot.imshow` call.
Returns
-------
plt.Axes
Instance of `matplotlib.pyplot.Axes` that contains the line plot.
"""
if ax is None:
subplot_kwargs = retrieve_from_dict(['figsize'], kwargs)
_, ax = plt.subplots(**subplot_kwargs)
sig = self.to_dataframe(v)
if type(sig) is not np.ndarray:
sig = np.asarray(sig)
shrink = kwargs.pop("shrink", 0.6)
im = ax.imshow(sig.T, **kwargs)
plt.colorbar(im, ax=ax, shrink=shrink)
ax.set_xlabel('time')
ax.set_ylabel(f"Node: {v[0]}, variable: {v[1]}" if type(v) is tuple else v)
return ax