.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_interfaces/torch_integration.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_interfaces_torch_integration.py: Network-PyTorch Integration =========================== While the methods `run`, `train`, and `test` of the `rectipy.Network` class provide wrappers to RNN simulation, training, and testing via `torch`, the `Network` class and each of its layers can also be integrated into custom `torch` code. We will demonstrate this below for the `Network` class, using a simple optimization problem. Specifically, we will perform online parameter optimization in a model of rate-coupled leaky integrator neurons of the form .. math:: \dot v_i &= -\frac{v_i}{\tau} + I_i(t) + k r_i^{in}, r_i &= tanh(v_i). This rate neuron model is described in detail in `this use example `_. For our optimization problem, we will focus on the global leakage time constant :math:`\tau` and the global coupling constant :math:`k`. We will set up two separate :code:`rectipy.Network` instances, a target network with the target values of :math:`\tau` and :math:`k`, and a learner network with a different, randomly sampled set of values for :math:`\tau` and :math:`k`. We will then simulate the dynamic response of both network to a periodic extrinsic driving signal and optimize :math:`\tau` and :math:`k` of the learner network such that its dynamics resembles the dynamics of the target network. Step 1: Network initialization ------------------------------ First, lets set up both networks with different parametrizations for :math:`\tau` and :math:`k`. .. GENERATED FROM PYTHON SOURCE LINES 31-61 .. code-block:: Python import numpy as np from rectipy import Network # network parameters node = "neuron_model_templates.rate_neurons.leaky_integrator.tanh" N = 5 dt = 1e-3 J = np.random.randn(N, N) k_t = np.random.uniform(0.25, 4.0) tau_t = np.random.uniform(0.25, 4.0) k_0 = np.random.uniform(0.25, 4.0) tau_0 = np.random.uniform(0.25, 4.0) # target model initialization target = Network(dt=dt) target.add_diffeq_node("tanh", node=node, weights=J, source_var="tanh_op/r", target_var="li_op/r_in", input_var="li_op/I_ext", output_var="li_op/v", clear=True, node_vars={"all/li_op/k": k_t, "all/li_op/tau": tau_t}) # test model initialization learner = Network(dt=dt) learner.add_diffeq_node("tanh", node=node, weights=J, source_var="tanh_op/r", target_var="li_op/r_in", input_var="li_op/I_ext", output_var="li_op/v", clear=True, node_vars={"all/li_op/k": k_0, "all/li_op/tau": tau_0}, train_params=["li_op/k", "li_op/tau"]) print("Target network parameters: " + r"$k_t$ = " + f"{k_t}" + r", $\tau_t$ = " + f"{tau_t}.") print("Learner network parameters: " + r"$k_0$ = " + f"{k_0}" + r", $\tau_0$ = " + f"{tau_0}.") .. GENERATED FROM PYTHON SOURCE LINES 62-63 As can be seen, we drew two different sets of parameters for our networks. .. GENERATED FROM PYTHON SOURCE LINES 65-71 Step 2: Perform online optimization ----------------------------------- Now, we would like to optimize the parameters of our learner network in an online optimization algorithm. We will do this via a custom `torch` optimization procedure. As a first step, we need to compile both networks to be able to use them as `torch` modules: .. GENERATED FROM PYTHON SOURCE LINES 71-75 .. code-block:: Python target.compile() learner.compile() .. GENERATED FROM PYTHON SOURCE LINES 76-77 Next, we are going to choose an optimization algorithm: .. GENERATED FROM PYTHON SOURCE LINES 77-82 .. code-block:: Python import torch opt = torch.optim.Rprop(learner.parameters(), lr=0.01, etas=(0.5, 1.1), step_sizes=(1e-5, 1e-1)) .. GENERATED FROM PYTHON SOURCE LINES 83-86 We chose the `resilient backpropagation algorithm `_ and tweaked some of its default parameters to control the automated learning rate adjustments. In addition, we need to specify a loss function: .. GENERATED FROM PYTHON SOURCE LINES 86-89 .. code-block:: Python loss = torch.nn.MSELoss() .. GENERATED FROM PYTHON SOURCE LINES 90-92 Here, we just chose the vanilla `mean-squared error `_. Finally, lets initialize a figure in which we are going to plot the progress of the online optimization: .. GENERATED FROM PYTHON SOURCE LINES 92-104 .. code-block:: Python # matplotlib settings import matplotlib import matplotlib.pyplot as plt # figure layout fig, ax = plt.subplots(ncols=2, figsize=(12, 6)) ax[0].set_xlabel("training steps") ax[0].set_ylabel("MSE") ax[1].set_xlabel("training steps") ax[1].set_ylabel(r"$v$") .. GENERATED FROM PYTHON SOURCE LINES 105-110 Now, we are ready to perform the optimization. The code below implements a torch optimization procedure that adjusts the parameters :math:`\tau` and :math:`k` of the learner network every :code:`update_steps` steps, based on the automatically backpropagated mean-squared error between the outputs of the target and learner network in response to a sinusoidal extrinsic input. The optimization will run until convergence, or until a maximum number of optimization steps has been reached. .. GENERATED FROM PYTHON SOURCE LINES 110-158 .. code-block:: Python # model fitting error, tol, step, update_steps, plot_steps, max_step = 10.0, 1e-5, 0, 1000, 100, 20000 mse_col, target_col, prediction_col = [], [], [] l = torch.zeros(1) while error > tol and step < max_step: # calculate network outputs I_ext = np.sin(np.pi * step * dt) * 0.5 targ = target.forward(I_ext) pred = learner.forward(I_ext) step += 1 # calculate loss l += loss(targ, pred) # make optimization step if step % update_steps == 0: l.backward() opt.step() opt.zero_grad() l = torch.zeros(1) learner.detach() # update average error error = 0.95 * error + 0.05 * l.item() # collect data mse_col.append(error) target_col.append(targ.detach().numpy()[0]) prediction_col.append(pred.detach().numpy()[0]) # update the figure for online plotting if step % plot_steps == 0: ax[0].plot(mse_col, "red") ax[1].plot(target_col, "blue") ax[1].plot(prediction_col, "orange") fig.canvas.draw() print(f"Training error at step {step}: {error}") plt.show() # retrieve optimized parameters params = list(learner.parameters()) k = params[0].clone().detach().numpy() tau = params[1].clone().detach().numpy() print("Optimized parameters: " + r"$k_*$ = " + f"{k[0]}" + r", $\tau_*$ = " + f"{tau[0]}.") .. GENERATED FROM PYTHON SOURCE LINES 159-169 The code above demonstrates how any :code:`rectipy.Network` instance can be integrated into custom `torch` code. After calling :code:`Network.compile`, the :code:`Network` instance provides the standard :code:`torch.nn.Module.forward` and :code:`torch.nn.Module.parameters` methods that you can use to calculate the network output and access the trainable parameters, respectively. The same holds for each node and edge in the :code:`Network`. This allows to implement more complex optimization procedures that go beyond the functions that :Code:`Network.fit_bptt` provides. One final note: As a final part of the optimization step in the code above, we use the `Network.detach()` method to implement truncated backpropagation through time. In scenarios where parameter optimization steps are performed online (as we do above), this method is crucial to ensure proper gradient calculation in a compute graph with changing parameter values. .. _sphx_glr_download_auto_interfaces_torch_integration.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_integration.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_integration.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: torch_integration.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_