2.4. Network-PyTorch Integration

While the methods run, train, and test of the rectipy.Network class provide wrappers to RNN simulation, training, and testing via torch, the Network class and each of its layers can also be integrated into custom torch code.

We will demonstrate this below for the Network class, using a simple optimization problem. Specifically, we will perform online parameter optimization in a model of rate-coupled leaky integrator neurons of the form

\[ \begin{align}\begin{aligned}\dot v_i &= -\frac{v_i}{\tau} + I_i(t) + k r_i^{in},\\r_i &= tanh(v_i).\end{aligned}\end{align} \]

This rate neuron model is described in detail in this use example. For our optimization problem, we will focus on the global leakage time constant \(\tau\) and the global coupling constant \(k\). We will set up two separate rectipy.Network instances, a target network with the target values of \(\tau\) and \(k\), and a learner network with a different, randomly sampled set of values for \(\tau\) and \(k\). We will then simulate the dynamic response of both network to a periodic extrinsic driving signal and optimize \(\tau\) and \(k\) of the learner network such that its dynamics resembles the dynamics of the target network.

2.4.1. Step 1: Network initialization

First, lets set up both networks with different parametrizations for \(\tau\) and \(k\).

import numpy as np
from rectipy import Network

# network parameters
node = "neuron_model_templates.rate_neurons.leaky_integrator.tanh"
N = 5
dt = 1e-3
J = np.random.randn(N, N)
k_t = np.random.uniform(0.25, 4.0)
tau_t = np.random.uniform(0.25, 4.0)
k_0 = np.random.uniform(0.25, 4.0)
tau_0 = np.random.uniform(0.25, 4.0)

# target model initialization
target = Network(dt=dt)
target.add_diffeq_node("tanh", node=node, weights=J, source_var="tanh_op/r", target_var="li_op/r_in",
                       input_var="li_op/I_ext", output_var="li_op/v", clear=True,
                       node_vars={"all/li_op/k": k_t, "all/li_op/tau": tau_t})

# test model initialization
learner = Network(dt=dt)
learner.add_diffeq_node("tanh", node=node, weights=J, source_var="tanh_op/r", target_var="li_op/r_in",
                        input_var="li_op/I_ext", output_var="li_op/v", clear=True,
                        node_vars={"all/li_op/k": k_0, "all/li_op/tau": tau_0},
                        train_params=["li_op/k", "li_op/tau"])

print("Target network parameters: " + r"$k_t$ = " + f"{k_t}" + r", $\tau_t$ = " + f"{tau_t}.")
print("Learner network parameters: " + r"$k_0$ = " + f"{k_0}" + r", $\tau_0$ = " + f"{tau_0}.")

As can be seen, we drew two different sets of parameters for our networks.

2.4.2. Step 2: Perform online optimization

Now, we would like to optimize the parameters of our learner network in an online optimization algorithm. We will do this via a custom torch optimization procedure. As a first step, we need to compile both networks to be able to use them as torch modules:

target.compile()
learner.compile()

Next, we are going to choose an optimization algorithm:

import torch

opt = torch.optim.Rprop(learner.parameters(), lr=0.01, etas=(0.5, 1.1), step_sizes=(1e-5, 1e-1))

We chose the resilient backpropagation algorithm and tweaked some of its default parameters to control the automated learning rate adjustments. In addition, we need to specify a loss function:

loss = torch.nn.MSELoss()

Here, we just chose the vanilla mean-squared error. Finally, lets initialize a figure in which we are going to plot the progress of the online optimization:

# matplotlib settings
import matplotlib
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
plt.style.use('dark_background')
plt.ion()

# figure layout
fig, ax = plt.subplots(ncols=2, figsize=(12, 6))
ax[0].set_xlabel("training steps")
ax[0].set_ylabel("MSE")
ax[1].set_xlabel("training steps")
ax[1].set_ylabel(r"$v$")

Now, we are ready to perform the optimization. The code below implements a torch optimization procedure that adjusts the parameters \(\tau\) and \(k\) of the learner network every update_steps steps, based on the automatically backpropagated mean-squared error between the outputs of the target and learner network in response to a sinusoidal extrinsic input. The optimization will run until convergence, or until a maximum number of optimization steps has been reached.

# model fitting
error, tol, step, update_steps, plot_steps, max_step = 10.0, 1e-5, 0, 1000, 100, 1000000
mse_col, target_col, prediction_col = [], [], []
l = torch.zeros(1)
while error > tol and step < max_step:

    # calculate network outputs
    I_ext = np.sin(np.pi * step * dt) * 0.5
    targ = target.forward(I_ext)
    pred = learner.forward(I_ext)
    step += 1

    # calculate loss
    l += loss(targ, pred)

    # make optimization step
    if step % update_steps == 0:
        l.backward()
        opt.step()
        opt.zero_grad()
        l = torch.zeros(1)
        learner.detach()

    # update average error
    error = 0.95 * error + 0.05 * l.item()

    # collect data
    mse_col.append(error)
    target_col.append(targ.detach().numpy()[0])
    prediction_col.append(pred.detach().numpy()[0])

    # update the figure for online plotting
    if step % plot_steps == 0:
        ax[0].plot(mse_col, "red")
        ax[1].plot(target_col, "blue")
        ax[1].plot(prediction_col, "orange")
        fig.canvas.draw()
        fig.canvas.flush_events()
        plt.show()

# retrieve optimized parameters
params = list(learner.parameters())
k = params[0].clone().detach().numpy()
tau = params[1].clone().detach().numpy()
print("Optimized parameters: " + r"$k_*$ = " + f"{k[0]}" + r", $\tau_*$ = " + f"{tau[0]}.")

The code above demonstrates how any rectipy.Network instance can be integrated into custom torch code. After calling Network.compile, the Network instance provides the standard torch.nn.Module.forward and torch.nn.Module.parameters methods that you can use to calculate the network output and access the trainable parameters, respectively. The same holds for each node and edge in the Network. This allows to implement more complex optimization procedures that go beyond the functions that Network.fit_bptt provides. One final note: As a final part of the optimization step in the code above, we use the Network.detach() method to implement truncated backpropagation through time. In scenarios where parameter optimization steps are performed online (as we do above), this method is crucial to ensure proper gradient calculation in a compute graph with changing parameter values.

Gallery generated by Sphinx-Gallery