import torch
from networkx.classes.reportviews import NodeView
from torch.nn import Module
from typing import Union, Iterator, Callable, Tuple, Optional
from .nodes import RateNet, SpikeNet, InstantNode, SpikeResetNet
from .edges import RLS, Linear, LinearMasked, LinearMemory, LinearFilter, LinearMemoryFilter
from .utility import retrieve_from_dict, add_op_name
from .observer import Observer
from pyrates import NodeTemplate, CircuitTemplate
import numpy as np
from time import perf_counter
from networkx import DiGraph
from multipledispatch import dispatch
[docs]
class Network(Module):
"""Main user interface for initializing, training, testing, and running networks consisting of rnn, input, and
output layers.
"""
def __init__(self, dt: float, device: str = "cpu"):
"""Instantiates network with a single RNN layer.
Parameters
----------
dt
Time-step used for all simulations and rnn layers.
device
Device on which to deploy the `Network` instance.
"""
super().__init__()
self.graph = DiGraph()
self.device = device
self.dt = dt
self._record = {}
self._var_map = {}
self._in_node = None
self._out_node = None
self._bwd_graph = {}
self._train_edge = ()
@dispatch(str)
def __getitem__(self, item: str):
return self.graph.nodes[item]
@dispatch(tuple)
def __getitem__(self, nodes: tuple):
return self.graph[nodes[0]][nodes[1]]
def __iter__(self):
for n in self.graph.nodes:
yield self[n]
def __len__(self):
return len(self.graph.nodes)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
@property
def n_out(self) -> int:
"""Current output dimensionality.
"""
try:
return self[self._out_node]["n_out"]
except AttributeError:
return 0
@property
def n_in(self) -> int:
"""Current input dimensionality of the network.
"""
try:
return self[self._in_node]["n_in"]
except AttributeError:
return 0
@property
def nodes(self) -> NodeView:
"""Network nodes
"""
return self.graph.nodes
@property
def state(self) -> dict:
"""Dictionary containing the state vectors of each differential equation node in the network.
"""
states = {}
for n in self.nodes:
try:
states[n] = self.get_node(n).y
except AttributeError:
pass
return states
[docs]
def get_node(self, node: str) -> Union[InstantNode, RateNet]:
"""Returns node instance from the network.
Parameters
----------
node
Name of the node.
Returns
-------
Union[InstantNode, RateNet]
Instance of a node class.
"""
return self[node]["node"]
[docs]
def get_edge(self, source: str, target: str) -> Linear:
"""Returns edge instance from the network.
Parameters
----------
source
Name of the source node.
target
Name of the target node.
Returns
-------
Linear
Instance of the edge class.
"""
return self[source, target]["edge"]
[docs]
def get_var(self, node: str, var: str) -> Union[torch.Tensor, float]:
"""Returns variable from network node.
Parameters
----------
node
Name of the network node.
var
Name of the node variable.
Returns
-------
Union[torch.Tensor, float]
"""
try:
return self.get_node(node)[self._relabel_var(var)]
except KeyError:
return self[node][var]
[docs]
def set_var(self, node: str, var: str, val: Union[torch.Tensor, float]):
"""Set the value of a network node variable.
Parameters
----------
node
Name of the network node.
var
Name of the node variable.
val
New variable value.
Returns
-------
None
"""
try:
n = self.get_node(node)
try:
n.set_param(var, val)
except KeyError:
v = n[var]
v[:] = val
except KeyError:
raise KeyError(f"Variable {var} was not found on node {node}.")
[docs]
def add_node(self, label: str, node: Union[InstantNode, RateNet], node_type: str, op: str = None,
**node_attrs) -> None:
"""Add node to the network, based on an instance from `rectipy.nodes`.
Parameters
----------
label
Name of the node in the network graph.
node
Instance of a class from `rectipy.nodes`.
node_type
Type of the node. Should be set to "diff_eq" for nodes that contain differential equations.
op
For differential equation-based nodes, an operator name can be passed that is used to identify variables on
the node.
node_attrs
Additional keyword arguments passed to `networkx.DiGraph.add_node`.
Returns
-------
None
"""
# remember operator mapping for each RNN node parameter and state variable
if op:
for p in node.parameter_names:
add_op_name(op, p, self._var_map)
for v in node.variable_names:
add_op_name(op, v, self._var_map)
# add node to graph
self.graph.add_node(label, node=node, node_type=node_type, n_out=node.n_out, n_in=node.n_in, eval=True,
out=torch.zeros(node.n_out, device=self.device), **node_attrs)
[docs]
def add_diffeq_node(self, label: str, node: Union[str, NodeTemplate, CircuitTemplate], input_var: str,
output_var: str, weights: np.ndarray = None, source_var: str = None, target_var: str = None,
spike_var: Union[str, list] = None, reset_var: Union[str, list] = None, reset: bool = True,
op: str = None, train_params: list = None, **kwargs) -> RateNet:
"""Adds a differential equation-based RNN node to the `Network` instance.
Parameters
----------
label
The label of the node in the network graph.
node
Path to the YAML template or an instance of a `pyrates.NodeTemplate`.
input_var
Name of the parameter in the node equations that input should be projected to.
output_var
Name of the variable in the node equations that should be used as output of the RNN node.
weights
Determines the number of neurons in the network as well as their connectivity. Given an `N x N` weights
matrix, `N` neurons will be added to the RNN node, each of which is governed by the equations defined in the
`NodeTemplate` (see argument `node`). Neurons will be labeled `n0` to `n<N>` and every non-zero entry in the
matrix will be realized by an edge between the corresponding neurons in the network.
source_var
Source variable that will be used for each connection in the network.
target_var
Target variable that will be used for each connection in the network.
spike_var
Name of the input variable in the node equations that recurrent input from the RNN should be projected to.
reset_var
Either the name of the input variable in the node equations that is used for membrane potential resetting
after spiking within the node equations (if `reset=False`) or the name of the state variable that should be
reset after a spike (if `reset=True`).
reset
If true, an additional spike resetting mechanism is added that will reset the variable defined by
`reset_var` after a spike occurred.
op
Name of the operator in which all the above variables can be found. If not provided, it is assumed that
the operator name is provided together with the variable names, e.g. `source_var = <op>/<var>`.
train_params
Names of all RNN parameters that should be made available for optimization.
kwargs
Additional keyword arguments provided to the `RNNLayer` (or `SRNNLayer` in case of spiking neurons).
Returns
-------
RateNet
Instance of the RNN node that was added to the network.
"""
# add operator key to variable names
var_dict = {'svar': source_var, 'tvar': target_var, 'in_ext': input_var, 'out': output_var, 'spike': spike_var,
'reset': reset_var}
if "record_vars" in kwargs:
var_dict["record_vars"] = kwargs.pop("record_vars")
self._var_map = {}
if op is not None:
for key, var in var_dict.copy().items():
if key == "record_vars":
kwargs["var_mapping"] = {}
for v in var:
v_new = add_op_name(op, v, self._var_map)
kwargs["var_mapping"][v_new] = v_new
if type(var) is list:
var_dict[key] = [add_op_name(op, v, self._var_map) for v in var]
else:
var_dict[key] = add_op_name(op, var, self._var_map)
if train_params:
train_params = [add_op_name(op, p, self._var_map) for p in train_params]
if "node_vars" in kwargs:
for key in kwargs["node_vars"].copy():
if "/" not in key:
val = kwargs["node_vars"].pop(key)
kwargs["node_vars"][f"all/{op}/{key}"] = val
# initialize rnn layer
args = (node, var_dict['in_ext'], var_dict['out'])
kwargs_tmp = {"weights": weights, "source_var": var_dict["svar"], "target_var": var_dict["tvar"],
"train_params": train_params, "device": self.device, "dt": self.dt}
if spike_var is None:
NodeClass = RateNet
elif reset_var is None:
raise ValueError('To define a reservoir with a spiking neural network layer, please provide the '
'name of the variable that should be reset after a spike occurred (`reset_var`).')
else:
kwargs_tmp["spike_var"] = var_dict["spike"]
kwargs_tmp["reset_var"] = var_dict["reset"]
NodeClass = SpikeResetNet if reset else SpikeNet
kwargs.update(kwargs_tmp)
node = NodeClass.from_pyrates(*args, **kwargs)
# add node to the network graph
self.add_node(label, node=node, node_type="diff_eq", op=op)
return node
[docs]
def add_func_node(self, label: str, n: int, activation_function: str, **kwargs) -> InstantNode:
"""Add an activation function as a node to the network (no intrinsic dynamics, just an input-output mapping).
Parameters
----------
label
The label of the node in the network graph.
n
Dimensionality of the node.
activation_function
Activation function applied to the output of the last layer. Valid options are:
- 'tanh' for `torch.nn.Tanh()`
- 'sigmoid' for `torch.nn.Sigmoid()`
- 'softmax' for `torch.nn.Softmax(dim=0)`
- 'softmin' for `torch.nn.Softmin(dim=0)`
- 'log_softmax' for `torch.nn.LogSoftmax(dim=0)`
- 'identity' for `torch.nn.Identity`
Returns
-------
ActivationFunc
The node of the network graph.
"""
# create node instance
node = InstantNode(n, activation_function, **kwargs)
# add node to the network graph
self.add_node(label, node=node, node_type="func_instant")
return node
[docs]
def add_edge(self, source: str, target: str, weights: Union[torch.Tensor, np.ndarray] = None,
train: Optional[str] = None, dtype: torch.dtype = torch.float64, edge_attrs: dict = None,
**kwargs) -> Linear:
"""Add a feed-forward layer to the network.
Parameters
----------
source
Label of the source node.
target
Label of the target node.
weights
`k x n` weight matrix that realizes the linear projection of the `n` source outputs to
the `k` target inputs.
train
Can be used to make the edge weights trainable. The following options are available:
- `None` for a static edge
- 'gd' for training of the edge weights via standard pytorch gradient descent
- 'rls' for recursive least squares training of the edge weights
dtype
Data type of the edge weights.
edge_attrs
Additional edge attributes passed to `networkx.DiGraph.add_edge`.
kwargs
Additional keyword arguments to be passed to the edge class initialization method.
Returns
-------
Linear
Instance of the edge class.
"""
if not edge_attrs:
edge_attrs = {}
# choose edge class
if "mask" in kwargs:
LinEdge = LinearMasked
elif "delays" in kwargs:
LinEdge = LinearMemoryFilter if "filter_weights" in kwargs else LinearMemory
elif "filter_weights" in kwargs:
LinEdge = LinearFilter
else:
LinEdge = Linear
# initialize edge
kwargs.update({"n_in": self[source]["n_out"], "n_out": self[target]["n_in"],
"weights": weights, "dtype": dtype})
trainable = True
if train is None:
trainable = False
edge = LinEdge(**kwargs, detach=True)
elif train == "gd":
edge = LinEdge(**kwargs, detach=False)
elif train == "rls":
edge = RLS(**kwargs)
self._train_edge = (source, target)
else:
raise ValueError("Invalid option for keyword argument `train`. Please see the docstring of "
"`Network.add_output_layer` for valid options.")
# add connecting edge to graph
self.graph.add_edge(source, target, edge=edge.to(self.device), trainable=trainable, n_in=edge.n_in,
n_out=edge.n_out, **edge_attrs)
return edge
[docs]
def pop_node(self, node: str) -> Union[InstantNode, RateNet]:
"""Removes (and returns) a node from the network.
Parameters
----------
node
Name of the node to remove.
Returns
-------
Union[InstantNode, RateNet]
Removed node.
"""
node_data = self.get_node(node)
self.graph.remove_node(node)
return node_data
[docs]
def pop_edge(self, source: str, target: str) -> Linear:
"""Removes (and returns) an edge from the network.
Parameters
----------
source
Name of the source node.
target
Name of the target node.
Returns
-------
Linear
Removed edge.
"""
edge = self.get_edge(source, target)
self.graph.remove_edge(source, target)
return edge
[docs]
def compile(self):
"""Automatically detects a forward pass through the network based on the nodes and edges in the network.
"""
# make sure that only a single input node exists
in_nodes = [n for n in self.graph.nodes if self.graph.in_degree(n) == 0]
if len(in_nodes) != 1:
raise ValueError(f"Unable to identify the input node of the Network. "
f"Nodes that have no input edges: {in_nodes}."
f"Make sure that exactly one such node without input edges exists in the network.")
self._in_node = in_nodes.pop()
# make sure that only a single output node exists
out_nodes = [n for n in self.graph.nodes if self.graph.out_degree(n) == 0]
if len(out_nodes) != 1:
raise ValueError(f"Unable to identify the output node of the Network. "
f"Nodes that have no outgoing edges: {out_nodes}."
f"Make sure that exactly one such node without outgoing edges exists in the network.")
self._out_node = out_nodes.pop()
# create backward pass through network starting from output node
self._bwd_graph = self._compile_bwd_graph(self._out_node, dict())
[docs]
def forward(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
"""Forward method as implemented for any `torch.Module`.
Parameters
----------
x
Input tensor.
Returns
-------
torch.Tensor
Output tensor.
"""
node = self._out_node
x = self._backward(x, node)
self._reset_node_eval()
return x
[docs]
def parameters(self, recurse: bool = True) -> Iterator:
"""Yields the trainable parameters of the network model.
Parameters
----------
recurse
If true, yields parameters of all submodules.
Yields
------
Iterator
Trainable model parameters.
"""
for p in self._get_parameters(self.graph, recurse=recurse):
yield p
[docs]
def detach(self, requires_grad: bool = True, detach_params: bool = False) -> None:
"""Goes through all DE-based nodes and detaches their state variables from the current graph for gradient
calculation.
Parameters
----------
requires_grad
If true, all tensors that will be detached will be set to require gradient calculation after detachment.
detach_params
If true, parameters that require gradient calculation will be detached as well.
Returns
-------
None
"""
for node in self.nodes:
n = self.get_node(node)
if hasattr(n, "y"):
n.detach(requires_grad=requires_grad, detach_params=detach_params)
[docs]
def reset(self, state: dict = None):
"""Reset the network state.
Parameters
----------
state
Optional dictionary, that contains state-vectors (values) for nodes of the network (keys).
Returns
-------
None
"""
for node in self.nodes:
n = self.get_node(node)
if hasattr(n, "y"):
if state and node in state:
n.reset(y=state[node])
else:
n.reset()
[docs]
def clear(self):
"""Removes all nodes and edges from the network
"""
for node in list(self.nodes):
self.pop_node(node)
[docs]
def run(self, inputs: Union[np.ndarray, torch.Tensor], sampling_steps: int = 1, cutoff: int = 0,
verbose: bool = True, enable_grad: bool = True, **kwargs) -> Observer:
"""Perform numerical integration of the input-driven network equations.
Parameters
----------
inputs
`T x m` array of inputs fed to the model, where`T` is the number of integration steps and `m` is the number
of input dimensions of the network.
sampling_steps
Number of integration steps at which to record observables.
cutoff
Initial number of simulation steps to disregard.
verbose
If true, the progress of the integration will be displayed.
enable_grad
If true, the simulation will be performed with gradient calculation.
kwargs
Additional keyword arguments used for the observation.
Returns
-------
Observer
Instance of the `Observer`.
"""
# preparations on input arguments
steps = inputs.shape[0]
if type(inputs) is np.ndarray:
inputs = torch.tensor(inputs, device=self.device)
truncate_steps = kwargs.pop("truncate_steps", steps)
# compile network
self.compile()
# initialize observer
if "obs" in kwargs:
obs = kwargs.pop("obs")
else:
obs = Observer(dt=self.dt, record_loss=kwargs.pop("record_loss", False), **kwargs)
rec_vars = [v for v in obs.recorded_state_variables]
buffer = []
# forward input through static network
grad = torch.enable_grad if enable_grad else torch.no_grad
with grad():
for step in range(steps):
output = self.forward(inputs[step, :])
if step >= cutoff:
buffer.append(output)
if step % sampling_steps == 0:
if verbose:
print(f'Progress: {step}/{steps} integration steps finished.')
obs.record(step, torch.mean(torch.stack(buffer, dim=0), dim=0), 0.0,
[self.get_var(v[0], v[1]) for v in rec_vars])
buffer = []
if truncate_steps < steps and step % truncate_steps == truncate_steps-1:
self.detach()
return obs
[docs]
def fit_bptt(self, inputs: Union[np.ndarray, list], targets: [np.ndarray, list], optimizer: str = 'sgd',
optimizer_kwargs: dict = None, loss: str = 'mse', loss_kwargs: dict = None, lr: float = 1e-3,
sampling_steps: int = 1, update_steps: int = 100, verbose: bool = True, **kwargs) -> Observer:
"""Optimize model parameters via backpropagation through time.
Parameters
----------
inputs
`T x m` array of inputs fed to the model, where`T` is the number of training steps and `m` is the number of
input dimensions of the network.
targets
`T x k` array of targets, where `T` is the number of training steps and `k` is the number of outputs of the
network.
optimizer
Name of the optimization algorithm to use. Available options are:
- 'sgd' for `torch.optim.SGD`
- 'adam' for `torch.optim.Adam`
- 'adamw' for torch.optim.AdamW
- 'adagrad' for `torch.optim.Adagrad`
- 'adadelta' for `torch.optim.Adadelta`
- 'rmsprop' for `torch.optim.RMSprop`
- 'rprop' for `torch.optim.Rprop`
optimizer_kwargs
Additional keyword arguments provided to the initialization of the optimizer.
loss
Name of the loss function that should be used for optimization. Available options are:
- 'mse' for `torch.nn.MSELoss`
- 'l1' for `torch.nn.L1Loss`
- 'nll' for `torch.nn.NLLLoss`
- 'ce' for `torch.nn.CrossEntropyLoss`
- 'kld' for `torch.nn.KLDivLoss`
- 'hinge' for `torch.nn.HingeEmbeddingLoss`
loss_kwargs
Additional keyword arguments provided to the initialization of the loss.
lr
Learning rate.
sampling_steps
Number of training steps at which to record observables.
update_steps
Number of training steps after which to perform an update of the trainable parameters based on the
accumulated gradients.
verbose
If true, the training progress will be displayed.
kwargs
Additional keyword arguments used for the optimization, loss calculation and observation.
Returns
-------
Observer
Instance of the `observer`.
"""
# preparations
##############
# compile network
self.compile()
# initialize loss function
loss = self._get_loss_function(loss, loss_kwargs=loss_kwargs)
# initialize optimizer
optimizer = self._get_optimizer(optimizer, lr, self.parameters(), optimizer_kwargs=optimizer_kwargs)
# retrieve keyword arguments for optimization
step_kwargs = retrieve_from_dict(['closure'], kwargs)
error_kwargs = retrieve_from_dict(['retain_graph'], kwargs)
# initialize observer
obs_kwargs = retrieve_from_dict(['record_output', 'record_loss', 'record_vars'], kwargs)
obs = Observer(dt=self.dt, **obs_kwargs)
# optimization
##############
t0 = perf_counter()
if type(inputs) is list:
# transform inputs and targets into tensors
if len(inputs) != len(targets):
raise ValueError('Wrong dimensions of input and target output. Please make sure that `inputs` and '
'`targets` agree in the first dimension (epochs).')
# perform optimization
obs = self._bptt_epochs(inputs, targets, loss=loss, optimizer=optimizer,
obs=obs, error_kwargs=error_kwargs, step_kwargs=step_kwargs,
sampling_steps=sampling_steps, verbose=verbose)
else:
# transform inputs into tensors
inp_tensor = torch.tensor(inputs, device=self.device)
target_tensor = torch.tensor(targets, device=self.device)
if inp_tensor.shape[0] != target_tensor.shape[0]:
raise ValueError('Wrong dimensions of input and target output. Please make sure that `inputs` and '
'`targets` agree in the first dimension.')
# perform optimization
obs = self._bptt(inp_tensor, target_tensor, loss, optimizer, obs, error_kwargs, step_kwargs,
sampling_steps=sampling_steps, optim_steps=update_steps, verbose=verbose)
t1 = perf_counter()
print(f'Finished optimization after {t1-t0} s.')
return obs
[docs]
def fit_ridge(self, inputs: np.ndarray, targets: np.ndarray, sampling_steps: int = 100, alpha: float = 1e-4,
verbose: bool = True, add_readout_node: bool = True, **kwargs) -> Observer:
"""Train readout weights on top of the input-driven model dynamics via ridge regression.
Parameters
----------
inputs
`T x m` array of inputs fed to the model, where`T` is the number of training steps and `m` is the number of
input dimensions of the network.
targets
`T x k` array of targets, where `T` is the number of training steps and `k` is the number of outputs of the
network.
sampling_steps
Number of training steps at which to record observables.
alpha
Ridge regression regularization constant.
verbose
If true, the training progress will be displayed.
add_readout_node
If true, a readout node is added to the network, which will be connected to the current output node of the
network via the trained readout weights.
kwargs
Additional keyword arguments used for the observation and network simulations.
Returns
-------
Observer
Instance of the `observer`.
"""
# preparations
##############
# transform inputs into tensors
target_tensor = torch.tensor(targets, device=self.device)
if inputs.shape[0] != target_tensor.shape[0]:
raise ValueError('Wrong dimensions of input and target output. Please make sure that `inputs` and '
'`targets` agree in the first dimension.')
# compile network
self.compile()
# collect network states
########################
t0 = perf_counter()
obs = self.run(inputs=inputs, sampling_steps=sampling_steps, verbose=verbose, **kwargs)
t1 = perf_counter()
print(f'Finished network state collection after {t1-t0} s.')
# train read-out classifier
###########################
t0 = perf_counter()
# ridge regression formula
X = torch.stack(obs["out"])
X_t = X.T
w_out = torch.inverse(X_t @ X + alpha*torch.eye(X.shape[1])) @ X_t @ target_tensor
y = X @ w_out
# progress report
t1 = perf_counter()
print(f'Finished fitting of read-out weights after {t1 - t0} s.')
# add read-out layer
####################
if add_readout_node:
self.add_func_node("readout", node_type="function", n=w_out.shape[1], activation_function="identity")
self.add_edge(self._out_node, target="readout", weights=w_out.T)
obs.save("y", y)
obs.save("w_out", w_out)
return obs
[docs]
def fit_rls(self, inputs: Union[list, np.ndarray], targets: Union[list, np.ndarray], update_steps: int = 1,
sampling_steps: int = 100, verbose: bool = True, **kwargs) -> Observer:
r"""Finds model parameters $w$ such that $||Xw - y||_2$ is minimized, where $X$ contains the neural activity and
$y$ contains the targets.
Parameters
----------
inputs
`T x m` array of inputs fed to the model, where`T` is the number of training steps and `m` is the number of
input dimensions of the network.
targets
`T x k` array of targets, where `T` is the number of training steps and `k` is the number of outputs of the
network.
update_steps
Each `update_steps` an update of the trainable parameters will be performed.
sampling_steps
Number of training steps at which to record observables.
verbose
If true, the training progress will be displayed.
kwargs
Additional keyword arguments used for the optimization, loss calculation and observation.
Returns
-------
Observer
Instance of the `observer`.
"""
# preparations
##############
# compile network
self.compile()
# initialize observer
obs_kwargs = retrieve_from_dict(['record_output', 'record_loss', 'record_vars'], kwargs)
obs = Observer(dt=self.dt, **obs_kwargs)
rec_vars = [self._relabel_var(v) for v in obs.recorded_state_variables]
# optimization
##############
t0 = perf_counter()
if type(inputs) is list:
# check input and target dimensions
if len(inputs) != len(targets):
raise ValueError('Wrong dimensions of input and target output. Please make sure that `inputs` and '
'`targets` agree in the first dimension (epochs).')
# fit weights
obs = self._rls_epoch(inputs, targets, obs, optim_steps=update_steps, verbose=verbose)
else:
# test correct dimensionality of inputs
if inputs.shape[0] != targets.shape[0]:
raise ValueError('Wrong dimensions of input and target output. Please make sure that `inputs` and '
'`targets` agree in the first dimension.')
# transform inputs into tensors
inp_tensor = torch.tensor(inputs, device=self.device)
target_tensor = torch.tensor(targets, device=self.device)
# fit weights
obs = self._rls(inp_tensor, target_tensor, obs, optim_steps=update_steps, sampling_steps=sampling_steps,
verbose=verbose)
t1 = perf_counter()
print(f'Finished optimization after {t1 - t0} s.')
return obs
[docs]
def fit_eprop(self, inputs: np.ndarray, targets: np.ndarray, feedback_weights: np.ndarray = None,
epsilon: float = 0.99, delta: float = 0.9, update_steps: int = 1, sampling_steps: int = 100,
verbose: bool = True, **kwargs) -> Observer:
r"""Reinforcement learning algorithm that implements slow adjustment of the feedback weights to the RNN layer
based on a running average of the residuals.
Parameters
----------
inputs
`T x m` array of inputs fed to the model, where`T` is the number of training steps and `m` is the number of
input dimensions of the network.
targets
`T x k` array of targets, where `T` is the number of training steps and `k` is the number of outputs of the
network.
feedback_weights
`m x k` array of synaptic weights. If provided, a feedback connections is established with these weights,
that projects the network output back to the RNN layer.
epsilon
Scalar in (0, 1] that controls how quickly the loss used for reinforcement learning can change.
delta
Scalar in (0, 1] that controls how quickly the feedback weights can change.
update_steps
Each `update_steps` an update of the trainable parameters will be performed.
sampling_steps
Number of training steps at which to record observables.
verbose
If true, the training progress will be displayed.
kwargs
Additional keyword arguments used for the optimization, loss calculation and observation.
Returns
-------
Observer
Instance of the `observer`.
"""
# TODO: Implement e-prop as defined in Bellec et al. (2020) Nature Communications
# TODO: Make sure that this fitting method allows for reinforcement learning schemes
raise NotImplementedError("Method is currently not implemented")
[docs]
def test(self, inputs: np.ndarray, targets: np.ndarray, loss: str = 'mse',
loss_kwargs: dict = None, sampling_steps: int = 100, verbose: bool = True, **kwargs) -> tuple:
"""Test the model performance on a set of inputs and target outputs, with frozen model parameters.
Parameters
----------
inputs
`T x m` array of inputs fed to the model, where`T` is the number of testing steps and `m` is the number of
input dimensions of the network.
targets
`T x k` array of targets, where `T` is the number of testing steps and `k` is the number of outputs of the
network.
loss
Name of the loss function that should be used to calculate the loss on the test data. See `Network.train`
for available options.
loss_kwargs
Additional keyword arguments provided to the initialization of the loss.
sampling_steps
Number of testing steps at which to record observables.
verbose
If true, the progress of the test run will be displayed.
kwargs
Additional keyword arguments used for the loss calculation and observation.
Returns
-------
Tuple[Observer,float]
The `Observer` instance and the total loss on the test data.
"""
# preparations
##############
# transform inputs into tensors
target_tensor = torch.tensor(targets, device=self.device)
# initialize loss function
loss = self._get_loss_function(loss, loss_kwargs=loss_kwargs)
# simulate network dynamics
obs = self.run(inputs=inputs, sampling_steps=sampling_steps, verbose=verbose, **kwargs)
# calculate loss
output = torch.stack(obs["out"])
loss_val = loss(output, target_tensor)
return obs, loss_val.item()
def _get_parameters(self, g: DiGraph, recurse: bool = True) -> Iterator:
for node in g:
for p in self.get_node(node).parameters(recurse=recurse):
yield p
for s, t in g.edges:
for p in g[s][t]["edge"].parameters():
yield p
def _compile_bwd_graph(self, n: str, graph: dict) -> dict:
sources = list(self.graph.predecessors(n))
if len(sources) > 0:
graph[n] = sources
for s in sources:
graph = self._compile_bwd_graph(s, graph)
return graph
def _backward(self, x: Union[torch.Tensor, np.ndarray], n: str) -> torch.Tensor:
if n in self._bwd_graph:
inp = self._bwd_graph[n]
if len(inp) == 1:
x = self._edge_forward(x, inp[0], n)
else:
x = torch.sum(torch.tensor([self._edge_forward(x, i, n) for i in inp]), dim=0)
node = self[n]
if node["eval"]:
node["out"] = node["node"].forward(x)
node["eval"] = False
return node["out"]
def _edge_forward(self, x: Union[torch.Tensor, np.ndarray], u: str, v: str) -> torch.Tensor:
x = self._backward(x, u)
return self.get_edge(u, v).forward(x)
def _reset_node_eval(self):
for n in self:
n["eval"] = True
def _bptt_epochs(self, inp: list, target: list, loss: Callable,
optimizer: torch.optim.Optimizer, obs: Observer, error_kwargs: dict, step_kwargs: dict,
sampling_steps: int = 1, verbose: bool = False, **kwargs) -> Observer:
y0 = self.state
epochs = len(inp)
epoch_losses = []
for epoch in range(epochs):
# simulate network dynamics
obs = self.run(torch.tensor(inp[epoch], device=self.device), verbose=False, sampling_steps=sampling_steps,
enable_grad=True, **kwargs)
# perform gradient descent step
epoch_loss = self._bptt_step(torch.stack(obs["out"]), torch.tensor(target[epoch], device=self.device),
optimizer=optimizer, loss=loss, error_kwargs=error_kwargs,
step_kwargs=step_kwargs)
epoch_losses.append(epoch_loss)
# reset network
self.reset(y0)
torch.cuda.empty_cache()
# display progress
if verbose:
print(f'Progress: {epoch+1}/{epochs} training epochs finished.')
print(f'Epoch loss: {epoch_loss}.')
print('')
obs.save("epoch_loss", epoch_losses)
obs.save("epochs", np.arange(epochs))
return obs
def _bptt(self, inp: torch.Tensor, target: torch.Tensor, loss: Callable, optimizer: torch.optim.Optimizer,
obs: Observer, error_kwargs: dict, step_kwargs: dict, sampling_steps: int = 100,
optim_steps: int = 1000, verbose: bool = False) -> Observer:
# preparations
rec_vars = [self._relabel_var(v) for v in obs.recorded_state_variables]
steps = inp.shape[0]
error = 0.0
predictions = []
old_step = 0
# optimization loop
for step in range(steps):
# forward pass
pred = self.forward(inp[step, :])
predictions.append(pred)
# gradient descent optimization step
if step % optim_steps == optim_steps-1:
error = self._bptt_step(torch.stack(predictions), target[old_step:step+1], optimizer=optimizer,
loss=loss, error_kwargs=error_kwargs, step_kwargs=step_kwargs)
self.detach()
old_step = step+1
predictions.clear()
# results storage
if step % sampling_steps == 0:
if verbose:
print(f'Progress: {step}/{steps} training steps finished. Current loss: {error}.')
obs.record(step, pred, error, [self[v] for v in rec_vars])
return obs
def _rls_epoch(self, inp: list, target: list, obs: Observer, optim_steps: int = 1, verbose: bool = False
) -> Observer:
# preparations
rls_edge = self.get_edge(self._train_edge[0], self._train_edge[1])
rls_source = self[self._train_edge[0]]
rls_target = self[self._train_edge[1]]
y0 = self.state
epochs = len(inp)
epoch_losses = []
# fitting
for epoch in range(epochs):
# turn input and target into tensors
inp_tmp = torch.tensor(inp[epoch], device=self.device)
target_tmp = torch.tensor(target[epoch], device=self.device)
# optimization loop
for step in range(inp_tmp.shape[0]):
# forward pass
self.forward(inp_tmp[step, :])
# RLS update
if step % optim_steps == 0:
rls_edge.update(rls_source["out"], target_tmp[step, :], rls_target["out"])
loss = rls_edge.loss
# reset network
self.reset(y0)
torch.cuda.empty_cache()
# display progress
if verbose:
print(f'Progress: {epoch + 1}/{epochs} training epochs finished.')
print(f'Epoch loss: {epoch_losses[-1]}.')
print('')
obs.save("epoch_loss", epoch_losses)
obs.save("epochs", np.arange(epochs))
return obs
def _rls(self, inp: torch.Tensor, target: torch.Tensor, obs: Observer, sampling_steps: int = 100,
optim_steps: int = 1, verbose: bool = False) -> Observer:
# preparations
rec_vars = [self._relabel_var(v) for v in obs.recorded_state_variables]
steps = inp.shape[0]
rls_edge = self.get_edge(self._train_edge[0], self._train_edge[1])
rls_source = self[self._train_edge[0]]
rls_target = self[self._train_edge[1]]
loss = 0.0
# optimization loop
for step in range(steps):
# forward pass
pred = self.forward(inp[step, :])
# update
if step % optim_steps == 0:
rls_edge.update(rls_source["out"], target[step, :], rls_target["out"])
loss = rls_edge.loss
# recording
if step % sampling_steps == 0:
if verbose:
print(f'Progress: {step}/{steps} training steps finished. Current loss: {loss}.')
obs.record(step, pred, loss, [self[v] for v in rec_vars])
return obs
@staticmethod
def _bptt_step(predictions: torch.Tensor, targets: torch.Tensor, optimizer: torch.optim.Optimizer,
loss: Callable, error_kwargs: dict, step_kwargs: dict) -> float:
error = loss(predictions, targets)
optimizer.zero_grad()
error.backward(**error_kwargs)
optimizer.step(**step_kwargs)
return error.item()
def _relabel_var(self, var: str) -> str:
try:
return self._var_map[var]
except KeyError:
return var
@staticmethod
def _get_optimizer(optimizer: str, lr: float, model_params: Iterator, optimizer_kwargs: dict = None
) -> torch.optim.Optimizer:
if optimizer_kwargs is None:
optimizer_kwargs = {}
if optimizer == 'sgd':
opt = torch.optim.SGD
elif optimizer == 'adam':
opt = torch.optim.Adam
elif optimizer == 'adamw':
opt = torch.optim.AdamW
elif optimizer == 'adagrad':
opt = torch.optim.Adagrad
elif optimizer == 'adadelta':
opt = torch.optim.Adadelta
elif optimizer == 'adamax':
opt = torch.optim.Adamax
elif optimizer == 'rmsprop':
opt = torch.optim.RMSprop
elif optimizer == 'rprop':
opt = torch.optim.Rprop
else:
raise ValueError('Invalid optimizer choice. Please see the documentation of the `Network.train()` '
'method for valid options.')
return opt(model_params, lr=lr, **optimizer_kwargs)
@staticmethod
def _get_loss_function(loss: str, loss_kwargs: dict = None) -> Callable:
if loss_kwargs is None:
loss_kwargs = {}
if loss == 'mse':
from torch.nn import MSELoss
l = MSELoss
elif loss == 'l1':
from torch.nn import L1Loss
l = L1Loss
elif loss == 'nll':
from torch.nn import NLLLoss
l = NLLLoss
elif loss == 'ce':
from torch.nn import CrossEntropyLoss
l = CrossEntropyLoss
elif loss == 'kld':
from torch.nn import KLDivLoss
l = KLDivLoss
elif loss == 'hinge':
from torch.nn import HingeEmbeddingLoss
l = HingeEmbeddingLoss
else:
raise ValueError('Invalid loss function choice. Please see the documentation of the `Network.train()` '
'method for valid options.')
return l(**loss_kwargs)
[docs]
class FeedbackNetwork(Network):
def __init__(self, dt: float, device: str = "cpu"):
super().__init__(dt, device)
self._bwd_graph = None
self._fb_graph = None
[docs]
def compile(self):
if self._fb_graph is not None:
# add feedback edges to original graph again
for edge in self._fb_graph.edges:
self.graph.add_edge(edge[0], edge[1], **self._fb_graph[edge[0]][edge[1]])
self._fb_graph = None
# sort edges into feedback and feedforward edges
ffwd_edges, fb_edges = [], []
for edge in self.graph.edges:
fb = self.graph[edge[0]][edge[1]].get("feedback")
if fb:
fb_edges.append(edge)
else:
ffwd_edges.append(edge)
# reduce graph to view that contains only feedforward edges
g_fwd = DiGraph(self.graph.edge_subgraph(ffwd_edges))
self._fb_graph = self.graph.edge_subgraph(fb_edges)
self.graph = g_fwd
# call super method
super().compile()
[docs]
def add_edge(self, source: str, target: str, weights: Union[torch.Tensor, np.ndarray] = None,
train: Optional[str] = None, feedback: bool = False, dtype: torch.dtype = torch.float64,
edge_attrs: dict = None, **kwargs) -> Linear:
"""Add a feed-forward layer to the network.
Parameters
----------
source
Label of the source node.
target
Label of the target node.
weights
`k x n` weight matrix that realizes the linear projection of the `n` source outputs to
the `k` target inputs.
train
Can be used to make the edge weights trainable. The following options are available:
- `None` for a static edge
- 'gd' for training of the edge weights via standard pytorch gradient descent
- 'rls' for recursive least squares training of the edge weights
feedback
If true, this edge is treated as a feedback edge, meaning that it does not affect the feedforward path that
connects the network input to its output.
dtype
Data type of the edge weights.
edge_attrs
Additional edge attributes passed to `networkx.DiGraph.add_edge`.
kwargs
Additional keyword arguments to be passed to the edge class initialization method.
Returns
-------
Linear
Instance of the edge class.
"""
if not edge_attrs:
edge_attrs = {}
edge_attrs["feedback"] = feedback
return super().add_edge(source, target, weights=weights, train=train, dtype=dtype, edge_attrs=edge_attrs,
**kwargs)
[docs]
def get_edge(self, source: str, target: str) -> Linear:
"""Returns edge instance from the network.
Parameters
----------
source
Name of the source node.
target
Name of the target node.
Returns
-------
Linear
Instance of the edge class.
"""
try:
return super().get_edge(source, target)
except KeyError:
return self._fb_graph[source][target]["edge"]
[docs]
def get_node(self, node: str) -> Union[InstantNode, RateNet]:
"""Returns node instance from the network.
Parameters
----------
node
Name of the node.
Returns
-------
Union[InstantNode, RateNet]
Instance of a node class.
"""
try:
return super().get_node(node)
except KeyError:
return self._fb_graph.nodes[node]["node"]
[docs]
def parameters(self, recurse: bool = True) -> Iterator:
"""Yields the trainable parameters of the network model.
Parameters
----------
recurse
If true, yields parameters of all submodules.
Yields
------
Iterator
Trainable model parameters.
"""
for g in [self.graph, self._fb_graph]:
for p in self._get_parameters(g, recurse=recurse):
yield p
def _backward(self, x: Union[torch.Tensor, np.ndarray], n: str) -> torch.Tensor:
# get feedforward input
if n in self._bwd_graph:
inp = self._bwd_graph[n]
if len(inp) == 1:
x = self._edge_forward(x, inp[0], n)
else:
x = torch.sum(torch.tensor([self._edge_forward(x, i, n) for i in inp]), dim=0)
# get feedback input
if n in self._fb_graph:
inputs = list(self._fb_graph.predecessors(n))
n_in = len(inputs)
if n_in == 0:
pass
elif n_in == 1:
x = x + self._edge_bwd(inputs[0], n)
else:
x = x + torch.sum(torch.tensor([self._edge_bwd(i, n) for i in inputs]), dim=0)
# calculate node output
node = self[n]
if node["eval"]:
node["out"] = node["node"].forward(x)
node["eval"] = False
return node["out"]
def _edge_bwd(self, source: str, target: str):
x = self.get_node(source)["out"]
edge = self._fb_graph[source][target]["edge"]
return edge.forward(x)