diff --git a/AUTHORS b/AUTHORS index f1e24665..9d16953e 100644 --- a/AUTHORS +++ b/AUTHORS @@ -7,6 +7,7 @@ Gregor Lenz Martino Sorbaro Martino Sorbaro Massimo Bortone +Mina Khoei MurphyWu Nogay Kuepelioglu Nogay Küpelioglu @@ -19,6 +20,8 @@ Sadique Sheik Vanessa Leite Vanessa Leite Vanessa Leite +Willian-Girao +WillianSG Yalun Hu Yalun_Hu allan @@ -35,6 +38,7 @@ qian.liu sadique.sheik sadique.sheik shynuie +unknown yalun.hu yannan xing yannan.xing diff --git a/examples/dynapcnn_network/snn_DVSLayer_given.ipynb b/examples/dynapcnn_network/snn_DVSLayer_given.ipynb new file mode 100644 index 00000000..6f4d88d9 --- /dev/null +++ b/examples/dynapcnn_network/snn_DVSLayer_given.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/samurai2077/anaconda3/envs/speck-rescnn/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from sinabs.backend.dynapcnn import DynapcnnNetwork\n", + "from sinabs.backend.dynapcnn import DVSLayer\n", + "from sinabs.layers import Merge, IAFSqueeze, SumPool2d\n", + "from sinabs.activation.surrogate_gradient_fn import PeriodicExponential\n", + "import sinabs.layers as sl\n", + "\n", + "from torch.nn import CrossEntropyLoss\n", + "from torch.optim import Adam\n", + "\n", + "from tonic.datasets.nmnist import NMNIST\n", + "from tonic.transforms import ToFrame\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from tqdm.notebook import tqdm\n", + "from statistics import mode\n", + "\n", + "device = torch.device('cpu')\n", + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "channels = 1\n", + "height = 34\n", + "width = 34\n", + "batch_size = 1\n", + "\n", + "input_shape = (channels, height, width)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SNN(\n", + " (dvs): DVSLayer(\n", + " (pool_layer): SumPool2d(norm_type=1, kernel_size=(1, 1), stride=None, ceil_mode=False)\n", + " (crop_layer): Crop2d((0, 34), (0, 34))\n", + " (flip_layer): FlipDims()\n", + " )\n", + " (conv1): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (iaf1): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " (conv2): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (iaf2): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (conv3): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (iaf3): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (fc1): Linear(in_features=144, out_features=200, bias=False)\n", + " (iaf4): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (fc2): Linear(in_features=200, out_features=10, bias=False)\n", + " (iaf5): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (flat): Flatten(start_dim=1, end_dim=-1)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class SNN(nn.Module):\n", + " def __init__(self, input_shape) -> None:\n", + " super().__init__()\n", + "\n", + " self.dvs = DVSLayer(input_shape=(input_shape[1], input_shape[2]))\n", + " # -- chip core A --\n", + " self.conv1 = nn.Conv2d(1, 10, 2, 1, bias=False)\n", + " self.iaf1 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " self.pool1 = nn.AvgPool2d(2,2)\n", + " # -- chip core B --\n", + " self.conv2 = nn.Conv2d(10, 10, 4, 1, bias=False)\n", + " self.iaf2 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core C --\n", + " self.conv3 = nn.Conv2d(10, 1, 2, 1, bias=False)\n", + " self.iaf3 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core D --\n", + " self.fc1 = nn.Linear(144, 200, bias=False)\n", + " self.iaf4 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core E --\n", + " self.fc2 = nn.Linear(200, 10, bias=False)\n", + " self.iaf5 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + "\n", + " # -- layers ignored during deployment --\n", + " self.flat = nn.Flatten()\n", + "\n", + " def init_weights(self):\n", + " for name, layer in self.named_modules():\n", + " if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n", + " nn.init.xavier_normal_(layer.weight.data)\n", + "\n", + " def detach_neuron_states(self):\n", + " for name, layer in self.named_modules():\n", + " if name != '':\n", + " if isinstance(layer, sl.StatefulLayer):\n", + " for name, buffer in layer.named_buffers():\n", + " buffer.detach_()\n", + "\n", + " def forward(self, x):\n", + " \n", + " dvs_out = self.dvs(x) # 0\n", + " \n", + " con1_out = self.conv1(dvs_out) # 4\n", + " iaf1_out = self.iaf1(con1_out) # 5\n", + " pool1_out = self.pool1(iaf1_out) # 6\n", + "\n", + " conv2_out = self.conv2(pool1_out) # 7\n", + " iaf2_out = self.iaf2(conv2_out) # 8\n", + "\n", + " conv3_out = self.conv3(iaf2_out) # 9\n", + " iaf3_out = self.iaf3(conv3_out) # 10\n", + "\n", + " flat_out = self.flat(iaf3_out) # 15\n", + " \n", + " fc1_out = self.fc1(flat_out) # 11\n", + " iaf4_out = self.iaf4(fc1_out) # 12\n", + " fc2_out = self.fc2(iaf4_out) # 13\n", + " iaf5_out = self.iaf5(fc2_out) # 14\n", + "\n", + " return iaf5_out\n", + " \n", + "snn = SNN(input_shape)\n", + "snn.init_weights()\n", + "snn.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "hw_model = DynapcnnNetwork(\n", + " snn=snn,\n", + " input_shape=input_shape,\n", + " batch_size=batch_size,\n", + " discretize=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "input_dummy = torch.randn((batch_size, *input_shape))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "out = hw_model(input_dummy)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]]]], grad_fn=)\n" + ] + } + ], + "source": [ + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Network is valid\n" + ] + }, + { + "data": { + "text/plain": [ + "DynapcnnNetwork(\n", + " (_dynapcnn_module): DynapcnnNetworkModule(\n", + " (_dynapcnn_layers): ModuleDict(\n", + " (1): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(174.), min_v_mem=Parameter containing:\n", + " tensor(-174.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (5): DynapcnnLayer(\n", + " (_conv): Conv2d(200, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(419.), min_v_mem=Parameter containing:\n", + " tensor(-419.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (3): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(303.), min_v_mem=Parameter containing:\n", + " tensor(-303.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (2): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(2084.), min_v_mem=Parameter containing:\n", + " tensor(-2084.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (4): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 200, kernel_size=(12, 12), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(365.), min_v_mem=Parameter containing:\n", + " tensor(-365.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " )\n", + " (_dvs_layer): DVSLayer(\n", + " (pool_layer): SumPool2d(norm_type=1, kernel_size=(1, 1), stride=None, ceil_mode=False)\n", + " (crop_layer): Crop2d((0, 34), (0, 34))\n", + " (flip_layer): FlipDims()\n", + " )\n", + " (merge_layer): Merge()\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hw_model.to(device=\"speck2fdevkit\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "speck-rescnn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/dynapcnn_network/snn_DVSLayer_given_followed_by_pool.ipynb b/examples/dynapcnn_network/snn_DVSLayer_given_followed_by_pool.ipynb new file mode 100644 index 00000000..3b261693 --- /dev/null +++ b/examples/dynapcnn_network/snn_DVSLayer_given_followed_by_pool.ipynb @@ -0,0 +1,341 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/samurai2077/anaconda3/envs/speck-rescnn/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from sinabs.backend.dynapcnn import DynapcnnNetwork\n", + "from sinabs.backend.dynapcnn import DVSLayer\n", + "from sinabs.layers import Merge, IAFSqueeze, SumPool2d\n", + "from sinabs.activation.surrogate_gradient_fn import PeriodicExponential\n", + "import sinabs.layers as sl\n", + "\n", + "from torch.nn import CrossEntropyLoss\n", + "from torch.optim import Adam\n", + "\n", + "from tonic.datasets.nmnist import NMNIST\n", + "from tonic.transforms import ToFrame\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from tqdm.notebook import tqdm\n", + "from statistics import mode\n", + "\n", + "device = torch.device('cpu')\n", + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "channels = 1\n", + "height = 34\n", + "width = 34\n", + "batch_size = 1\n", + "\n", + "input_shape = (channels, height, width)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SNN(\n", + " (dvs): DVSLayer(\n", + " (pool_layer): SumPool2d(norm_type=1, kernel_size=(1, 1), stride=None, ceil_mode=False)\n", + " (crop_layer): Crop2d((0, 34), (0, 34))\n", + " (flip_layer): FlipDims()\n", + " )\n", + " (dvs_pool): AvgPool2d(kernel_size=1, stride=1, padding=0)\n", + " (conv1): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (iaf1): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " (conv2): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (iaf2): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (conv3): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (iaf3): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (fc1): Linear(in_features=144, out_features=200, bias=False)\n", + " (iaf4): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (fc2): Linear(in_features=200, out_features=10, bias=False)\n", + " (iaf5): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (flat): Flatten(start_dim=1, end_dim=-1)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class SNN(nn.Module):\n", + " def __init__(self, input_shape) -> None:\n", + " super().__init__()\n", + "\n", + " self.dvs = DVSLayer(input_shape=(input_shape[1], input_shape[2]))\n", + " self.dvs_pool = nn.AvgPool2d(1,1)\n", + " # -- chip core A --\n", + " self.conv1 = nn.Conv2d(1, 10, 2, 1, bias=False)\n", + " self.iaf1 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " self.pool1 = nn.AvgPool2d(2,2)\n", + " # -- chip core B --\n", + " self.conv2 = nn.Conv2d(10, 10, 4, 1, bias=False)\n", + " self.iaf2 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core C --\n", + " self.conv3 = nn.Conv2d(10, 1, 2, 1, bias=False)\n", + " self.iaf3 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core D --\n", + " self.fc1 = nn.Linear(144, 200, bias=False)\n", + " self.iaf4 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core E --\n", + " self.fc2 = nn.Linear(200, 10, bias=False)\n", + " self.iaf5 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + "\n", + " # -- layers ignored during deployment --\n", + " self.flat = nn.Flatten()\n", + "\n", + " def init_weights(self):\n", + " for name, layer in self.named_modules():\n", + " if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n", + " nn.init.xavier_normal_(layer.weight.data)\n", + "\n", + " def detach_neuron_states(self):\n", + " for name, layer in self.named_modules():\n", + " if name != '':\n", + " if isinstance(layer, sl.StatefulLayer):\n", + " for name, buffer in layer.named_buffers():\n", + " buffer.detach_()\n", + "\n", + " def forward(self, x):\n", + " \n", + " dvs_out = self.dvs(x) # 0\n", + "\n", + " dvs_pool_out = self.dvs_pool(dvs_out)\n", + " \n", + " con1_out = self.conv1(dvs_pool_out) # 4\n", + " iaf1_out = self.iaf1(con1_out) # 5\n", + " pool1_out = self.pool1(iaf1_out) # 6\n", + "\n", + " conv2_out = self.conv2(pool1_out) # 7\n", + " iaf2_out = self.iaf2(conv2_out) # 8\n", + "\n", + " conv3_out = self.conv3(iaf2_out) # 9\n", + " iaf3_out = self.iaf3(conv3_out) # 10\n", + "\n", + " flat_out = self.flat(iaf3_out) # 15\n", + " \n", + " fc1_out = self.fc1(flat_out) # 11\n", + " iaf4_out = self.iaf4(fc1_out) # 12\n", + " fc2_out = self.fc2(iaf4_out) # 13\n", + " iaf5_out = self.iaf5(fc2_out) # 14\n", + "\n", + " return iaf5_out\n", + " \n", + "snn = SNN(input_shape)\n", + "snn.init_weights()\n", + "snn.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "hw_model = DynapcnnNetwork(\n", + " snn=snn,\n", + " input_shape=input_shape,\n", + " batch_size=batch_size,\n", + " discretize=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "input_dummy = torch.randn((batch_size, *input_shape))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "out = hw_model(input_dummy)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]]]], grad_fn=)\n" + ] + } + ], + "source": [ + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Network is valid\n" + ] + }, + { + "data": { + "text/plain": [ + "DynapcnnNetwork(\n", + " (_dynapcnn_module): DynapcnnNetworkModule(\n", + " (_dynapcnn_layers): ModuleDict(\n", + " (1): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(174.), min_v_mem=Parameter containing:\n", + " tensor(-174.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (5): DynapcnnLayer(\n", + " (_conv): Conv2d(200, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(419.), min_v_mem=Parameter containing:\n", + " tensor(-419.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (3): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(303.), min_v_mem=Parameter containing:\n", + " tensor(-303.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (2): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(2084.), min_v_mem=Parameter containing:\n", + " tensor(-2084.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (4): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 200, kernel_size=(12, 12), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(365.), min_v_mem=Parameter containing:\n", + " tensor(-365.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " )\n", + " (_dvs_layer): DVSLayer(\n", + " (pool_layer): SumPool2d(norm_type=1, kernel_size=(1, 1), stride=None, ceil_mode=False)\n", + " (crop_layer): Crop2d((0, 34), (0, 34))\n", + " (flip_layer): FlipDims()\n", + " )\n", + " (merge_layer): Merge()\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hw_model.to(device=\"speck2fdevkit\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "speck-rescnn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/dynapcnn_network/snn_deployment.ipynb b/examples/dynapcnn_network/snn_deployment.ipynb new file mode 100644 index 00000000..9c643724 --- /dev/null +++ b/examples/dynapcnn_network/snn_deployment.ipynb @@ -0,0 +1,882 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/samurai2077/anaconda3/envs/speck-rescnn/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from sinabs.backend.dynapcnn import DynapcnnNetwork\n", + "from sinabs.layers import Merge, IAFSqueeze, SumPool2d\n", + "from sinabs.activation.surrogate_gradient_fn import PeriodicExponential\n", + "import sinabs.layers as sl\n", + "\n", + "from torch.nn import CrossEntropyLoss\n", + "from torch.optim import Adam\n", + "\n", + "from tonic.datasets.nmnist import NMNIST\n", + "from tonic.transforms import ToFrame\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from tqdm.notebook import tqdm\n", + "from statistics import mode" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Network Module\n", + "\n", + "We need to define a `nn.Module` implementing the Spiking Neural Network (SNN) we want to deploy on chip. The configuration of the network on the chip needs to know in advance the shape of the input data and the batch size that will be used." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "channels = 2\n", + "height = 34\n", + "width = 34\n", + "batch_size = 8\n", + "\n", + "input_shape = (channels, height, width)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "class SNN(nn.Module):\n", + " def __init__(self) -> None:\n", + " super().__init__()\n", + " # -- chip core A --\n", + " self.conv1 = nn.Conv2d(2, 10, 2, 1, bias=False)\n", + " self.iaf1 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " self.pool1 = nn.AvgPool2d(2,2)\n", + " # -- chip core B --\n", + " self.conv2 = nn.Conv2d(10, 10, 4, 1, bias=False)\n", + " self.iaf2 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core C --\n", + " self.conv3 = nn.Conv2d(10, 1, 2, 1, bias=False)\n", + " self.iaf3 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core D --\n", + " self.fc1 = nn.Linear(144, 200, bias=False)\n", + " self.iaf4 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core E --\n", + " self.fc2 = nn.Linear(200, 10, bias=False)\n", + " self.iaf5 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + "\n", + " # -- layers ignored during deployment --\n", + " self.flat = nn.Flatten()\n", + "\n", + " def init_weights(self):\n", + " for name, layer in self.named_modules():\n", + " if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n", + " nn.init.xavier_normal_(layer.weight.data)\n", + "\n", + " def detach_neuron_states(self):\n", + " for name, layer in self.named_modules():\n", + " if name != '':\n", + " if isinstance(layer, sl.StatefulLayer):\n", + " for name, buffer in layer.named_buffers():\n", + " buffer.detach_()\n", + "\n", + " def forward(self, x):\n", + " \n", + " con1_out = self.conv1(x)\n", + " iaf1_out = self.iaf1(con1_out)\n", + " pool1_out = self.pool1(iaf1_out)\n", + "\n", + " conv2_out = self.conv2(pool1_out)\n", + " iaf2_out = self.iaf2(conv2_out)\n", + "\n", + " conv3_out = self.conv3(iaf2_out)\n", + " iaf3_out = self.iaf3(conv3_out)\n", + "\n", + " flat_out = self.flat(iaf3_out)\n", + " \n", + " fc1_out = self.fc1(flat_out)\n", + " iaf4_out = self.iaf4(fc1_out)\n", + " fc2_out = self.fc2(iaf4_out)\n", + " iaf5_out = self.iaf5(fc2_out)\n", + "\n", + " return iaf5_out" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "snn = SNN()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's train the model to see what kind of accuracy the software model gets:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/1afc103f-8799-464a-a214-81bb9b1f9337 to ./NMNIST/NMNIST/train.zip\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1011894272it [01:25, 11770103.44it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./NMNIST/NMNIST/train.zip to ./NMNIST/NMNIST\n", + "Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/a99d0fee-a95b-4231-ad22-988fdb0a2411 to ./NMNIST/NMNIST/test.zip\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "169675776it [00:18, 9035099.75it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./NMNIST/NMNIST/test.zip to ./NMNIST/NMNIST\n", + "The transformed array is in shape [Time-Step, Channel, Height, Width] --> (50, 2, 34, 34)\n" + ] + } + ], + "source": [ + "_ = NMNIST(save_to='./NMNIST', train=True)\n", + "_ = NMNIST(save_to='./NMNIST', train=False)\n", + "\n", + "nb_time_steps = 50\n", + "to_raster = ToFrame(sensor_size=NMNIST.sensor_size, n_time_bins=nb_time_steps)\n", + "\n", + "snn_train_dataset = NMNIST(save_to='./NMNIST', train=True, transform=to_raster)\n", + "snn_test_dataset = NMNIST(save_to='./NMNIST', train=False, transform=to_raster)\n", + "\n", + "sample_data, label = snn_train_dataset[0]\n", + "print(f\"The transformed array is in shape [Time-Step, Channel, Height, Width] --> {sample_data.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "train_indices = [i for i in range(1000)]\n", + "test_indices = [i for i in range(100)]\n", + "\n", + "snn_train_dataset_subset = torch.utils.data.Subset(snn_train_dataset, train_indices)\n", + "snn_test_subset = torch.utils.data.Subset(snn_train_dataset, test_indices)\n", + "\n", + "snn_train_dataloader = DataLoader(snn_train_dataset_subset, batch_size=batch_size, num_workers=4, drop_last=True, shuffle=True)\n", + "snn_test_dataloader = DataLoader(snn_test_subset, batch_size=batch_size, num_workers=4, drop_last=True, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "SNN(\n", + " (conv1): Conv2d(2, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (iaf1): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=8, num_timesteps=-1)\n", + " (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " (conv2): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (iaf2): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=8, num_timesteps=-1)\n", + " (conv3): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (iaf3): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=8, num_timesteps=-1)\n", + " (fc1): Linear(in_features=144, out_features=200, bias=False)\n", + " (iaf4): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=8, num_timesteps=-1)\n", + " (fc2): Linear(in_features=200, out_features=10, bias=False)\n", + " (iaf5): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=8, num_timesteps=-1)\n", + " (flat): Flatten(start_dim=1, end_dim=-1)\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = torch.device('cpu')\n", + "\n", + "snn.init_weights()\n", + "\n", + "snn.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = Adam(snn.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8)\n", + "loss_fn = CrossEntropyLoss()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "training the model..." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c5579aac6828434dbac67d04236a87c0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/125 [00:00 1\u001b[0m hw_model \u001b[38;5;241m=\u001b[39m \u001b[43mDynapcnnNetwork\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43msnn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msnn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_shape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_shape\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiscretize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 6\u001b[0m \u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Github/sinabs/sinabs/backend/dynapcnn/dynapcnn_network.py:70\u001b[0m, in \u001b[0;36mDynapcnnNetwork.__init__\u001b[0;34m(self, snn, input_shape, batch_size, dvs_input, discretize, weight_rescaling_fn)\u001b[0m\n\u001b[1;32m 67\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m sinabs\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mget_smallest_compatible_time_dimension(snn)\n\u001b[1;32m 68\u001b[0m \u001b[38;5;66;03m# computational graph from original PyTorch module.\u001b[39;00m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_graph_extractor \u001b[38;5;241m=\u001b[39m GraphExtractor(\n\u001b[0;32m---> 70\u001b[0m snn, \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_shape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdvs_input\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 71\u001b[0m ) \u001b[38;5;66;03m# needs the batch dimension.\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;66;03m# Remove nodes of ignored classes (including merge nodes)\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_graph_extractor\u001b[38;5;241m.\u001b[39mremove_nodes_by_class(DEFAULT_IGNORED_LAYER_TYPES)\n", + "\u001b[0;31mTypeError\u001b[0m: randn() received an invalid combination of arguments - got (tuple, bool), but expected one of:\n * (tuple of ints size, *, torch.Generator generator, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)\n * (tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)\n * (tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)\n * (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)\n" + ] + } + ], + "source": [ + "hw_model = DynapcnnNetwork(\n", + " snn=snn,\n", + " input_shape=input_shape,\n", + " batch_size=batch_size,\n", + " discretize=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice in the model bellow how the property DynapcnnLayer in the model has yet to be assigned to a core. This is only done once\n", + "DynapcnnNetworkGraph.to() is called." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------- [ DynapcnnLayer 0 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 0): Conv2d(2, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + "(node 1): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(241.), min_v_mem=Parameter containing:\n", + "tensor(-241.), batch_size=8, num_timesteps=-1)\n", + "(node 2): SumPool2d(norm_type=1, kernel_size=(2, 2), stride=None, ceil_mode=False)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: True\n", + "> convolution's weight re-scaling factor: None\n", + "> assigned core index: None\n", + "> destination DynapcnnLayers: [1]\n", + "> node 2 feeds input to nodes [3]\n", + "\n", + "----------------------- [ DynapcnnLayer 1 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 3): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + "(node 4): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(1041.), min_v_mem=Parameter containing:\n", + "tensor(-1041.), batch_size=8, num_timesteps=-1)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: False\n", + "> convolution's weight re-scaling factor: 2.0\n", + "> assigned core index: None\n", + "> destination DynapcnnLayers: [2]\n", + "> node 4 feeds input to nodes [5]\n", + "\n", + "----------------------- [ DynapcnnLayer 2 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 5): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + "(node 6): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(247.), min_v_mem=Parameter containing:\n", + "tensor(-247.), batch_size=8, num_timesteps=-1)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: False\n", + "> convolution's weight re-scaling factor: None\n", + "> assigned core index: None\n", + "> destination DynapcnnLayers: [3]\n", + "> node 6 feeds input to nodes [7]\n", + "\n", + "----------------------- [ DynapcnnLayer 3 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 7): Conv2d(1, 200, kernel_size=(12, 12), stride=(1, 1), bias=False)\n", + "(node 8): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(363.), min_v_mem=Parameter containing:\n", + "tensor(-363.), batch_size=8, num_timesteps=-1)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: False\n", + "> convolution's weight re-scaling factor: None\n", + "> assigned core index: None\n", + "> destination DynapcnnLayers: [4]\n", + "> node 8 feeds input to nodes [9]\n", + "\n", + "----------------------- [ DynapcnnLayer 4 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 9): Conv2d(200, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + "(node 10): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(422.), min_v_mem=Parameter containing:\n", + "tensor(-422.), batch_size=8, num_timesteps=-1)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: False\n", + "> convolution's weight re-scaling factor: None\n", + "> assigned core index: None\n", + "> destination DynapcnnLayers: []\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(hw_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `hw_model.to()` call will figure out into which core each `DynapcnnLayer` instance will be assigned to. Once this assingment is made the instance itself is used to configure the `CNNLayerConfig` instance representing the core's configuration assigned to it.\n", + "\n", + "If the call is sucessfull, the layers comprising the network and their associated metadata will be printed. To deploy the model, we need to provide the device string defining what Speck devkit is being used." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "speck_device = \"speck2fmodule:0\"" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Network is valid: \n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "DynapcnnNetwork()" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hw_model.to(device=speck_device)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------- [ DynapcnnLayer 0 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 0): Conv2d(2, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + "(node 1): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(241.), min_v_mem=Parameter containing:\n", + "tensor(-241.), batch_size=8, num_timesteps=-1)\n", + "(node 2): SumPool2d(norm_type=1, kernel_size=(2, 2), stride=None, ceil_mode=False)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: True\n", + "> convolution's weight re-scaling factor: None\n", + "> assigned core index: 0\n", + "> destination DynapcnnLayers: [1]\n", + "> node 2 feeds input to nodes [3]\n", + "\n", + "----------------------- [ DynapcnnLayer 1 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 3): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + "(node 4): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(1041.), min_v_mem=Parameter containing:\n", + "tensor(-1041.), batch_size=8, num_timesteps=-1)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: False\n", + "> convolution's weight re-scaling factor: 2.0\n", + "> assigned core index: 1\n", + "> destination DynapcnnLayers: [2]\n", + "> node 4 feeds input to nodes [5]\n", + "\n", + "----------------------- [ DynapcnnLayer 2 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 5): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + "(node 6): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(247.), min_v_mem=Parameter containing:\n", + "tensor(-247.), batch_size=8, num_timesteps=-1)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: False\n", + "> convolution's weight re-scaling factor: None\n", + "> assigned core index: 2\n", + "> destination DynapcnnLayers: [3]\n", + "> node 6 feeds input to nodes [7]\n", + "\n", + "----------------------- [ DynapcnnLayer 3 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 7): Conv2d(1, 200, kernel_size=(12, 12), stride=(1, 1), bias=False)\n", + "(node 8): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(363.), min_v_mem=Parameter containing:\n", + "tensor(-363.), batch_size=8, num_timesteps=-1)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: False\n", + "> convolution's weight re-scaling factor: None\n", + "> assigned core index: 5\n", + "> destination DynapcnnLayers: [4]\n", + "> node 8 feeds input to nodes [9]\n", + "\n", + "----------------------- [ DynapcnnLayer 4 ] -----------------------\n", + "\n", + "COMPUTATIONAL NODES:\n", + "\n", + "(node 9): Conv2d(200, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + "(node 10): IAFSqueeze(spike_threshold=Parameter containing:\n", + "tensor(422.), min_v_mem=Parameter containing:\n", + "tensor(-422.), batch_size=8, num_timesteps=-1)\n", + "\n", + "METADATA:\n", + "\n", + "> network's entry point: False\n", + "> convolution's weight re-scaling factor: None\n", + "> assigned core index: 3\n", + "> destination DynapcnnLayers: []\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(hw_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Spikes IN/Out of the Chip\n", + "\n", + "Let's try to use our network configured on the chip to forward some data. We'll get a sample from the NMNIST dataset to do that:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "event_dataset = NMNIST(save_to='./NMNIST', train=False)\n", + "event_subset = torch.utils.data.Subset(event_dataset, test_indices)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "targets = np.array(event_dataset.targets)\n", + "target_indices = {idx: np.where(targets == idx)[0] for idx in range(10)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you have a tensor with data and want to convert it to input_events, you would instantiate a ChipFactory object providing the device string (\"speck2fsomethingsomething\") as instantiation argument. For further details consult the [documentation](https://sinabs.readthedocs.io/en/v2.0.0/tutorials/nir_to_speck.html#prepare-dataset)." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from sinabs.backend.dynapcnn.chip_factory import ChipFactory\n", + "\n", + "chip_factory = ChipFactory(speck_device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This object has a method raster_to_events (see more [here](https://sinabs.readthedocs.io/en/v2.0.0/speck/api/dynapcnn/chip_factory.html#sinabs.backend.dynapcnn.chip_factory.ChipFactory.raster_to_events)) that can convert your data to an event list, which is what the chip expects. This method requires a 4 dimensional tensor of spike events with the dimensions [Time, Channel, Height, Width]." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "output core id: 3\n", + "input core id: 0\n" + ] + } + ], + "source": [ + "layer_out = hw_model.get_output_core_id() # core assigned to the output layer of the model\n", + "layer_in = hw_model.get_input_core_id()[-1] # core assigned to the input layyer of the model\n", + "\n", + "print(f'output core id: {layer_out}')\n", + "print(f'input core id: {layer_in}')" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output layer monitoring: True\n" + ] + } + ], + "source": [ + "print(f'Output layer monitoring: {hw_model.samna_config.cnn_layers[layer_out].monitor_enable}')" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "646e50af00044f1ab9625886fe3f36ed", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/100 [00:00 None:\n", + " super().__init__()\n", + "\n", + " # -- chip core A --\n", + " self.conv1 = nn.Conv2d(1, 10, 2, 1, bias=False)\n", + " self.iaf1 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " self.pool1 = nn.AvgPool2d(2,2)\n", + " # -- chip core B --\n", + " self.conv2 = nn.Conv2d(10, 10, 4, 1, bias=False)\n", + " self.iaf2 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core C --\n", + " self.conv3 = nn.Conv2d(10, 1, 2, 1, bias=False)\n", + " self.iaf3 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core D --\n", + " self.fc1 = nn.Linear(144, 200, bias=False)\n", + " self.iaf4 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core E --\n", + " self.fc2 = nn.Linear(200, 10, bias=False)\n", + " self.iaf5 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + "\n", + " # -- layers ignored during deployment --\n", + " self.flat = nn.Flatten()\n", + "\n", + " def init_weights(self):\n", + " for name, layer in self.named_modules():\n", + " if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n", + " nn.init.xavier_normal_(layer.weight.data)\n", + "\n", + " def detach_neuron_states(self):\n", + " for name, layer in self.named_modules():\n", + " if name != '':\n", + " if isinstance(layer, sl.StatefulLayer):\n", + " for name, buffer in layer.named_buffers():\n", + " buffer.detach_()\n", + "\n", + " def forward(self, x):\n", + " \n", + " con1_out = self.conv1(x) # 4\n", + " iaf1_out = self.iaf1(con1_out) # 5\n", + " pool1_out = self.pool1(iaf1_out) # 6\n", + "\n", + " conv2_out = self.conv2(pool1_out) # 7\n", + " iaf2_out = self.iaf2(conv2_out) # 8\n", + "\n", + " conv3_out = self.conv3(iaf2_out) # 9\n", + " iaf3_out = self.iaf3(conv3_out) # 10\n", + "\n", + " flat_out = self.flat(iaf3_out) # 15\n", + " \n", + " fc1_out = self.fc1(flat_out) # 11\n", + " iaf4_out = self.iaf4(fc1_out) # 12\n", + " fc2_out = self.fc2(iaf4_out) # 13\n", + " iaf5_out = self.iaf5(fc2_out) # 14\n", + "\n", + " return iaf5_out\n", + " \n", + "snn = SNN(input_shape)\n", + "snn.init_weights()\n", + "snn.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "hw_model = DynapcnnNetwork(\n", + " snn=snn,\n", + " input_shape=input_shape,\n", + " batch_size=batch_size,\n", + " dvs_input=True,\n", + " discretize=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------- [ DynapcnnLayer 1 ] -----------------------\n", + "DynapcnnLayer(\n", + " (_conv): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(174.), min_v_mem=Parameter containing:\n", + " tensor(-174.), batch_size=1, num_timesteps=-1)\n", + ")\n", + "\n", + "----------------------- [ DynapcnnLayer 5 ] -----------------------\n", + "DynapcnnLayer(\n", + " (_conv): Conv2d(200, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(419.), min_v_mem=Parameter containing:\n", + " tensor(-419.), batch_size=1, num_timesteps=-1)\n", + ")\n", + "\n", + "----------------------- [ DynapcnnLayer 2 ] -----------------------\n", + "DynapcnnLayer(\n", + " (_conv): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(2084.), min_v_mem=Parameter containing:\n", + " tensor(-2084.), batch_size=1, num_timesteps=-1)\n", + ")\n", + "\n", + "----------------------- [ DynapcnnLayer 3 ] -----------------------\n", + "DynapcnnLayer(\n", + " (_conv): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(303.), min_v_mem=Parameter containing:\n", + " tensor(-303.), batch_size=1, num_timesteps=-1)\n", + ")\n", + "\n", + "----------------------- [ DynapcnnLayer 4 ] -----------------------\n", + "DynapcnnLayer(\n", + " (_conv): Conv2d(1, 200, kernel_size=(12, 12), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(365.), min_v_mem=Parameter containing:\n", + " tensor(-365.), batch_size=1, num_timesteps=-1)\n", + ")\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(hw_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Network is valid\n" + ] + }, + { + "data": { + "text/plain": [ + "DynapcnnNetwork(\n", + " (_dynapcnn_module): DynapcnnNetworkModule(\n", + " (_dynapcnn_layers): ModuleDict(\n", + " (1): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(174.), min_v_mem=Parameter containing:\n", + " tensor(-174.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (5): DynapcnnLayer(\n", + " (_conv): Conv2d(200, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(419.), min_v_mem=Parameter containing:\n", + " tensor(-419.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (2): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(2084.), min_v_mem=Parameter containing:\n", + " tensor(-2084.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (3): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(303.), min_v_mem=Parameter containing:\n", + " tensor(-303.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (4): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 200, kernel_size=(12, 12), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(365.), min_v_mem=Parameter containing:\n", + " tensor(-365.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " )\n", + " (_dvs_layer): DVSLayer(\n", + " (pool_layer): SumPool2d(norm_type=1, kernel_size=(1, 1), stride=None, ceil_mode=False)\n", + " (crop_layer): Crop2d((0, 34), (0, 34))\n", + " (flip_layer): FlipDims()\n", + " )\n", + " (merge_layer): Merge()\n", + " )\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hw_model.to(device=\"speck2fdevkit\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "speck-rescnn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/dynapcnn_network/snn_no_DVSLayer.ipynb b/examples/dynapcnn_network/snn_no_DVSLayer.ipynb new file mode 100644 index 00000000..1869df3c --- /dev/null +++ b/examples/dynapcnn_network/snn_no_DVSLayer.ipynb @@ -0,0 +1,324 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/samurai2077/anaconda3/envs/speck-rescnn/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from sinabs.backend.dynapcnn import DynapcnnNetwork\n", + "from sinabs.backend.dynapcnn import DVSLayer\n", + "from sinabs.layers import Merge, IAFSqueeze, SumPool2d\n", + "from sinabs.activation.surrogate_gradient_fn import PeriodicExponential\n", + "import sinabs.layers as sl\n", + "\n", + "from torch.nn import CrossEntropyLoss\n", + "from torch.optim import Adam\n", + "\n", + "from tonic.datasets.nmnist import NMNIST\n", + "from tonic.transforms import ToFrame\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from tqdm.notebook import tqdm\n", + "from statistics import mode\n", + "\n", + "device = torch.device('cpu')\n", + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "channels = 1\n", + "height = 34\n", + "width = 34\n", + "batch_size = 1\n", + "\n", + "input_shape = (channels, height, width)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SNN(\n", + " (conv1): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (iaf1): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " (conv2): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (iaf2): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (conv3): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (iaf3): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (fc1): Linear(in_features=144, out_features=200, bias=False)\n", + " (iaf4): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (fc2): Linear(in_features=200, out_features=10, bias=False)\n", + " (iaf5): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (flat): Flatten(start_dim=1, end_dim=-1)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class SNN(nn.Module):\n", + " def __init__(self, input_shape) -> None:\n", + " super().__init__()\n", + "\n", + " # -- chip core A --\n", + " self.conv1 = nn.Conv2d(1, 10, 2, 1, bias=False)\n", + " self.iaf1 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " self.pool1 = nn.AvgPool2d(2,2)\n", + " # -- chip core B --\n", + " self.conv2 = nn.Conv2d(10, 10, 4, 1, bias=False)\n", + " self.iaf2 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core C --\n", + " self.conv3 = nn.Conv2d(10, 1, 2, 1, bias=False)\n", + " self.iaf3 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core D --\n", + " self.fc1 = nn.Linear(144, 200, bias=False)\n", + " self.iaf4 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core E --\n", + " self.fc2 = nn.Linear(200, 10, bias=False)\n", + " self.iaf5 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + "\n", + " # -- layers ignored during deployment --\n", + " self.flat = nn.Flatten()\n", + "\n", + " def init_weights(self):\n", + " for name, layer in self.named_modules():\n", + " if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n", + " nn.init.xavier_normal_(layer.weight.data)\n", + "\n", + " def detach_neuron_states(self):\n", + " for name, layer in self.named_modules():\n", + " if name != '':\n", + " if isinstance(layer, sl.StatefulLayer):\n", + " for name, buffer in layer.named_buffers():\n", + " buffer.detach_()\n", + "\n", + " def forward(self, x):\n", + " \n", + " con1_out = self.conv1(x) # 4\n", + " iaf1_out = self.iaf1(con1_out) # 5\n", + " pool1_out = self.pool1(iaf1_out) # 6\n", + "\n", + " conv2_out = self.conv2(pool1_out) # 7\n", + " iaf2_out = self.iaf2(conv2_out) # 8\n", + "\n", + " conv3_out = self.conv3(iaf2_out) # 9\n", + " iaf3_out = self.iaf3(conv3_out) # 10\n", + "\n", + " flat_out = self.flat(iaf3_out) # 15\n", + " \n", + " fc1_out = self.fc1(flat_out) # 11\n", + " iaf4_out = self.iaf4(fc1_out) # 12\n", + " fc2_out = self.fc2(iaf4_out) # 13\n", + " iaf5_out = self.iaf5(fc2_out) # 14\n", + "\n", + " return iaf5_out\n", + " \n", + "snn = SNN(input_shape)\n", + "snn.init_weights()\n", + "snn.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "hw_model = DynapcnnNetwork(\n", + " snn=snn,\n", + " input_shape=input_shape,\n", + " batch_size=batch_size,\n", + " discretize=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "input_dummy = torch.randn((batch_size, *input_shape))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "out = hw_model(input_dummy)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]]]], grad_fn=)\n" + ] + } + ], + "source": [ + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Network is valid\n" + ] + }, + { + "data": { + "text/plain": [ + "DynapcnnNetwork(\n", + " (_dynapcnn_module): DynapcnnNetworkModule(\n", + " (_dynapcnn_layers): ModuleDict(\n", + " (0): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(174.), min_v_mem=Parameter containing:\n", + " tensor(-174.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (4): DynapcnnLayer(\n", + " (_conv): Conv2d(200, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(419.), min_v_mem=Parameter containing:\n", + " tensor(-419.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (1): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(2084.), min_v_mem=Parameter containing:\n", + " tensor(-2084.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (2): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(303.), min_v_mem=Parameter containing:\n", + " tensor(-303.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " (3): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 200, kernel_size=(12, 12), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(365.), min_v_mem=Parameter containing:\n", + " tensor(-365.), batch_size=1, num_timesteps=-1)\n", + " )\n", + " )\n", + " (merge_layer): Merge()\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hw_model.to(device=\"speck2fdevkit\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "speck-rescnn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/dynapcnn_network/snn_with_batchnorm.ipynb b/examples/dynapcnn_network/snn_with_batchnorm.ipynb new file mode 100644 index 00000000..fdd6b337 --- /dev/null +++ b/examples/dynapcnn_network/snn_with_batchnorm.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/samurai2077/anaconda3/envs/speck-rescnn/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from sinabs.backend.dynapcnn import DynapcnnNetwork\n", + "from sinabs.backend.dynapcnn import DVSLayer\n", + "from sinabs.layers import Merge, IAFSqueeze, SumPool2d\n", + "from sinabs.activation.surrogate_gradient_fn import PeriodicExponential\n", + "import sinabs.layers as sl\n", + "\n", + "from torch.nn import CrossEntropyLoss\n", + "from torch.optim import Adam\n", + "\n", + "from tonic.datasets.nmnist import NMNIST\n", + "from tonic.transforms import ToFrame\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from tqdm.notebook import tqdm\n", + "from statistics import mode\n", + "\n", + "device = torch.device('cpu')\n", + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "channels = 1\n", + "height = 34\n", + "width = 34\n", + "batch_size = 1\n", + "\n", + "input_shape = (channels, height, width)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SNN(\n", + " (conv1): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (iaf1): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " (conv2): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (iaf2): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (conv3): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (bn3): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (iaf3): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (fc1): Linear(in_features=144, out_features=200, bias=False)\n", + " (iaf4): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (fc2): Linear(in_features=200, out_features=10, bias=False)\n", + " (iaf5): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=1, num_timesteps=-1)\n", + " (flat): Flatten(start_dim=1, end_dim=-1)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class SNN(nn.Module):\n", + " def __init__(self, input_shape) -> None:\n", + " super().__init__()\n", + "\n", + " # -- chip core A --\n", + " self.conv1 = nn.Conv2d(1, 10, 2, 1, bias=False)\n", + " self.bn1 = nn.BatchNorm2d(10)\n", + " self.iaf1 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " self.pool1 = nn.AvgPool2d(2,2)\n", + " # -- chip core B --\n", + " self.conv2 = nn.Conv2d(10, 10, 4, 1, bias=False)\n", + " self.bn2 = nn.BatchNorm2d(10)\n", + " self.iaf2 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core C --\n", + " self.conv3 = nn.Conv2d(10, 1, 2, 1, bias=False)\n", + " self.bn3 = nn.BatchNorm2d(1)\n", + " self.iaf3 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core D --\n", + " self.fc1 = nn.Linear(144, 200, bias=False)\n", + " self.iaf4 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core E --\n", + " self.fc2 = nn.Linear(200, 10, bias=False)\n", + " self.iaf5 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + "\n", + " # -- layers ignored during deployment --\n", + " self.flat = nn.Flatten()\n", + "\n", + " def init_weights(self):\n", + " for name, layer in self.named_modules():\n", + " if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n", + " nn.init.xavier_normal_(layer.weight.data)\n", + "\n", + " def detach_neuron_states(self):\n", + " for name, layer in self.named_modules():\n", + " if name != '':\n", + " if isinstance(layer, sl.StatefulLayer):\n", + " for name, buffer in layer.named_buffers():\n", + " buffer.detach_()\n", + "\n", + " def forward(self, x):\n", + " \n", + " con1_out = self.conv1(x) # 4\n", + " bn1_out = self.bn1(con1_out)\n", + " iaf1_out = self.iaf1(bn1_out) # 5\n", + " pool1_out = self.pool1(iaf1_out) # 6\n", + "\n", + " conv2_out = self.conv2(pool1_out) # 7\n", + " bn2_out = self.bn2(conv2_out)\n", + " iaf2_out = self.iaf2(bn2_out) # 8\n", + "\n", + " conv3_out = self.conv3(iaf2_out) # 9\n", + " bn3_out = self.bn3(conv3_out)\n", + " iaf3_out = self.iaf3(bn3_out) # 10\n", + "\n", + " flat_out = self.flat(iaf3_out) # 15\n", + " \n", + " fc1_out = self.fc1(flat_out) # 11\n", + " iaf4_out = self.iaf4(fc1_out) # 12\n", + " fc2_out = self.fc2(iaf4_out) # 13\n", + " iaf5_out = self.iaf5(fc2_out) # 14\n", + "\n", + " return iaf5_out\n", + " \n", + "snn = SNN(input_shape)\n", + "snn.init_weights()\n", + "snn.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "hw_model = DynapcnnNetwork(\n", + " snn=snn,\n", + " input_shape=input_shape,\n", + " batch_size=batch_size,\n", + " discretize=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "input_dummy = torch.randn((batch_size, *input_shape))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "out = hw_model(input_dummy)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]]]], grad_fn=)\n" + ] + } + ], + "source": [ + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Network is valid\n" + ] + }, + { + "ename": "KeyError", + "evalue": "'speck2fdevkit:0'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mhw_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mspeck2fdevkit\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Github/sinabs/sinabs/backend/dynapcnn/dynapcnn_network.py:339\u001b[0m, in \u001b[0;36mDynapcnnNetwork.to\u001b[0;34m(self, device, monitor_layers, config_modifier, slow_clk_frequency, layer2core_map, chip_layers_ordering)\u001b[0m\n\u001b[1;32m 330\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmake_config(\n\u001b[1;32m 331\u001b[0m layer2core_map\u001b[38;5;241m=\u001b[39mlayer2core_map,\n\u001b[1;32m 332\u001b[0m chip_layers_ordering\u001b[38;5;241m=\u001b[39mchip_layers_ordering,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 335\u001b[0m config_modifier\u001b[38;5;241m=\u001b[39mconfig_modifier,\n\u001b[1;32m 336\u001b[0m )\n\u001b[1;32m 338\u001b[0m \u001b[38;5;66;03m# apply configuration to device.\u001b[39;00m\n\u001b[0;32m--> 339\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msamna_device \u001b[38;5;241m=\u001b[39m \u001b[43mopen_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msamna_device\u001b[38;5;241m.\u001b[39mget_model()\u001b[38;5;241m.\u001b[39mapply_configuration(config)\n\u001b[1;32m 341\u001b[0m time\u001b[38;5;241m.\u001b[39msleep(\u001b[38;5;241m1\u001b[39m)\n", + "File \u001b[0;32m~/Github/sinabs/sinabs/backend/dynapcnn/io.py:255\u001b[0m, in \u001b[0;36mopen_device\u001b[0;34m(device_id)\u001b[0m\n\u001b[1;32m 253\u001b[0m device_id \u001b[38;5;241m=\u001b[39m standardize_device_id(device_id\u001b[38;5;241m=\u001b[39mdevice_id)\n\u001b[1;32m 254\u001b[0m device_map \u001b[38;5;241m=\u001b[39m get_device_map()\n\u001b[0;32m--> 255\u001b[0m device_info \u001b[38;5;241m=\u001b[39m \u001b[43mdevice_map\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdevice_id\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 256\u001b[0m device_handle \u001b[38;5;241m=\u001b[39m samna\u001b[38;5;241m.\u001b[39mdevice\u001b[38;5;241m.\u001b[39mopen_device(device_info)\n\u001b[1;32m 258\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m device_handle \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "\u001b[0;31mKeyError\u001b[0m: 'speck2fdevkit:0'" + ] + } + ], + "source": [ + "hw_model.to(device=\"speck2fdevkit\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "speck-rescnn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/dynapcnn_network/snn_with_multiple_batchnorm.ipynb b/examples/dynapcnn_network/snn_with_multiple_batchnorm.ipynb new file mode 100644 index 00000000..ab24493c --- /dev/null +++ b/examples/dynapcnn_network/snn_with_multiple_batchnorm.ipynb @@ -0,0 +1,380 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/samurai2077/anaconda3/envs/speck-rescnn/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from sinabs.backend.dynapcnn import DynapcnnNetwork\n", + "from sinabs.backend.dynapcnn import DVSLayer\n", + "from sinabs.layers import Merge, IAFSqueeze, SumPool2d\n", + "from sinabs.activation.surrogate_gradient_fn import PeriodicExponential\n", + "import sinabs.layers as sl\n", + "\n", + "from torch.nn import CrossEntropyLoss\n", + "from torch.optim import Adam\n", + "\n", + "from tonic.datasets.nmnist import NMNIST\n", + "from tonic.transforms import ToFrame\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from tqdm.notebook import tqdm\n", + "from statistics import mode\n", + "\n", + "device = torch.device('cpu')\n", + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "channels = 1\n", + "height = 34\n", + "width = 34\n", + "batch_size = 2\n", + "\n", + "input_shape = (channels, height, width)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SNN(\n", + " (conv1): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (iaf1): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=2, num_timesteps=-1)\n", + " (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " (conv2): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (iaf2): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=2, num_timesteps=-1)\n", + " (conv3): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (bn3): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (iaf3): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=2, num_timesteps=-1)\n", + " (fc1): Linear(in_features=144, out_features=200, bias=False)\n", + " (bn4): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (iaf4): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=2, num_timesteps=-1)\n", + " (fc2): Linear(in_features=200, out_features=10, bias=False)\n", + " (bn5): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (iaf5): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1.), min_v_mem=Parameter containing:\n", + " tensor(-1.), batch_size=2, num_timesteps=-1)\n", + " (flat): Flatten(start_dim=1, end_dim=-1)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class SNN(nn.Module):\n", + " def __init__(self, input_shape) -> None:\n", + " super().__init__()\n", + "\n", + " # -- chip core A --\n", + " self.conv1 = nn.Conv2d(1, 10, 2, 1, bias=False)\n", + " self.bn1 = nn.BatchNorm2d(10)\n", + " self.iaf1 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " self.pool1 = nn.AvgPool2d(2,2)\n", + " # -- chip core B --\n", + " self.conv2 = nn.Conv2d(10, 10, 4, 1, bias=False)\n", + " self.bn2 = nn.BatchNorm2d(10)\n", + " self.iaf2 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core C --\n", + " self.conv3 = nn.Conv2d(10, 1, 2, 1, bias=False)\n", + " self.bn3 = nn.BatchNorm2d(1)\n", + " self.iaf3 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core D --\n", + " self.fc1 = nn.Linear(144, 200, bias=False)\n", + " self.bn4 = nn.BatchNorm1d(200)\n", + " self.iaf4 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + " # -- chip core E --\n", + " self.fc2 = nn.Linear(200, 10, bias=False)\n", + " self.bn5 = nn.BatchNorm1d(10)\n", + " self.iaf5 = IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, spike_threshold=1.0, surrogate_grad_fn=PeriodicExponential())\n", + "\n", + " # -- layers ignored during deployment --\n", + " self.flat = nn.Flatten()\n", + "\n", + " def init_weights(self):\n", + " for name, layer in self.named_modules():\n", + " if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n", + " nn.init.xavier_normal_(layer.weight.data)\n", + "\n", + " def detach_neuron_states(self):\n", + " for name, layer in self.named_modules():\n", + " if name != '':\n", + " if isinstance(layer, sl.StatefulLayer):\n", + " for name, buffer in layer.named_buffers():\n", + " buffer.detach_()\n", + "\n", + " def forward(self, x):\n", + " \n", + " con1_out = self.conv1(x) # 4\n", + " bn1_out = self.bn1(con1_out)\n", + " iaf1_out = self.iaf1(bn1_out) # 5\n", + " pool1_out = self.pool1(iaf1_out) # 6\n", + "\n", + " conv2_out = self.conv2(pool1_out) # 7\n", + " bn2_out = self.bn2(conv2_out)\n", + " iaf2_out = self.iaf2(bn2_out) # 8\n", + "\n", + " conv3_out = self.conv3(iaf2_out) # 9\n", + " bn3_out = self.bn3(conv3_out)\n", + " iaf3_out = self.iaf3(bn3_out) # 10\n", + "\n", + " flat_out = self.flat(iaf3_out) # 15\n", + " \n", + " fc1_out = self.fc1(flat_out) # 11\n", + " bn4_out = self.bn4(fc1_out)\n", + " iaf4_out = self.iaf4(bn4_out) # 12\n", + " fc2_out = self.fc2(iaf4_out) # 13\n", + " bn5_out = self.bn5(fc2_out)\n", + " iaf5_out = self.iaf5(bn5_out) # 14\n", + "\n", + " return iaf5_out\n", + " \n", + "snn = SNN(input_shape)\n", + "snn.init_weights()\n", + "snn.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "metadata": {} + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "OpenBLAS Warning : Detect OpenMP Loop and this application may hang. Please rebuild the library with USE_OPENMP=1 option.\n", + "OpenBLAS Warning : Detect OpenMP Loop and this application may hang. Please rebuild the library with USE_OPENMP=1 option.\n", + "OpenBLAS Warning : Detect OpenMP Loop and this application may hang. Please rebuild the library with USE_OPENMP=1 option.\n", + "OpenBLAS Warning : Detect OpenMP Loop and this application may hang. Please rebuild the library with USE_OPENMP=1 option.\n" + ] + } + ], + "source": [ + "hw_model = DynapcnnNetwork(\n", + " snn=snn,\n", + " input_shape=input_shape,\n", + " batch_size=batch_size,\n", + " discretize=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "input_dummy = torch.randn((batch_size, *input_shape))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "OpenBLAS Warning : Detect OpenMP Loop and this application may hang. Please rebuild the library with USE_OPENMP=1 option.\n", + "OpenBLAS Warning : Detect OpenMP Loop and this application may hang. Please rebuild the library with USE_OPENMP=1 option.\n" + ] + } + ], + "source": [ + "out = hw_model(input_dummy)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[1.]],\n", + "\n", + " [[0.]]],\n", + "\n", + "\n", + " [[[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[1.]],\n", + "\n", + " [[0.]],\n", + "\n", + " [[1.]],\n", + "\n", + " [[0.]]]], grad_fn=)\n" + ] + } + ], + "source": [ + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Network is valid\n" + ] + }, + { + "data": { + "text/plain": [ + "DynapcnnNetwork(\n", + " (_dynapcnn_module): DynapcnnNetworkModule(\n", + " (_dynapcnn_layers): ModuleDict(\n", + " (0): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 10, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(171.), min_v_mem=Parameter containing:\n", + " tensor(-171.), batch_size=2, num_timesteps=-1)\n", + " )\n", + " (4): DynapcnnLayer(\n", + " (_conv): Conv2d(200, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(398.), min_v_mem=Parameter containing:\n", + " tensor(-398.), batch_size=2, num_timesteps=-1)\n", + " )\n", + " (1): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 10, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(1981.), min_v_mem=Parameter containing:\n", + " tensor(-1981.), batch_size=2, num_timesteps=-1)\n", + " )\n", + " (2): DynapcnnLayer(\n", + " (_conv): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(290.), min_v_mem=Parameter containing:\n", + " tensor(-290.), batch_size=2, num_timesteps=-1)\n", + " )\n", + " (3): DynapcnnLayer(\n", + " (_conv): Conv2d(1, 200, kernel_size=(12, 12), stride=(1, 1), bias=False)\n", + " (_spk): IAFSqueeze(spike_threshold=Parameter containing:\n", + " tensor(348.), min_v_mem=Parameter containing:\n", + " tensor(-348.), batch_size=2, num_timesteps=-1)\n", + " )\n", + " )\n", + " (merge_layer): Merge()\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hw_model.to(device=\"speck2fdevkit\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "speck-rescnn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/sinabs/backend/dynapcnn/__init__.py b/sinabs/backend/dynapcnn/__init__.py index 62e5e7ff..09f93756 100644 --- a/sinabs/backend/dynapcnn/__init__.py +++ b/sinabs/backend/dynapcnn/__init__.py @@ -1,5 +1,6 @@ -from .dynapcnn_network import ( # second one for compatibility purposes - DynapcnnCompatibleNetwork, - DynapcnnNetwork, -) +from .dvs_layer import DVSLayer +from .dynapcnn_layer import DynapcnnLayer +from .dynapcnn_network import DynapcnnCompatibleNetwork, DynapcnnNetwork from .dynapcnn_visualizer import DynapcnnVisualizer +from .dynapcnnnetwork_module import DynapcnnNetworkModule +from .nir_graph_extractor import GraphExtractor diff --git a/sinabs/backend/dynapcnn/chips/dynapcnn.py b/sinabs/backend/dynapcnn/chips/dynapcnn.py index 849932e9..dcc09663 100644 --- a/sinabs/backend/dynapcnn/chips/dynapcnn.py +++ b/sinabs/backend/dynapcnn/chips/dynapcnn.py @@ -1,14 +1,18 @@ import copy -from typing import List +from typing import Dict, List, Union from warnings import warn import samna import torch -from samna.dynapcnn.configuration import DynapcnnConfiguration +from samna.dynapcnn.configuration import ( + CNNLayerConfig, + DVSLayerConfig, + DynapcnnConfiguration, +) import sinabs from sinabs.backend.dynapcnn.config_builder import ConfigBuilder -from sinabs.backend.dynapcnn.dvs_layer import DVSLayer, expand_to_pair +from sinabs.backend.dynapcnn.dvs_layer import DVSLayer from sinabs.backend.dynapcnn.dynapcnn_layer import DynapcnnLayer from sinabs.backend.dynapcnn.mapping import LayerConstraints @@ -26,9 +30,38 @@ def get_default_config(cls) -> "DynapcnnConfiguration": def get_dvs_layer_config_dict(cls, layer: DVSLayer): ... @classmethod - def write_dvs_layer_config(cls, layer: DVSLayer, config: "DvsLayerConfig"): + def write_dvs_layer_config( + cls, + layer: DVSLayer, + layer2core_map: Dict[int, int], + destination_indices: List[int], + chip_layer: DVSLayerConfig, + ) -> None: + """Write a DVS layer configuration to the conf object. + + Uses the data in `layer` to configure a `DVSLayerConfig` to use the chip's DVS camera. + + Parameters + ---------- + - layer (DVSLayer): Layer instance from which to generate the config + - layer2core_map (Dict): Keys are layer indices, values are corresponding + cores on hardware. Needed to map the destinations. + - destination_indices (List): Indices of destination layers for `layer` + - chip_layer (DVSLayerConfig): Configuration object of the corrsesponding + on-chip core. Will be changed in-place based on `layer`. + """ for param, value in layer.get_config_dict().items(): - setattr(config, param, value) + setattr(chip_layer, param, value) + + # Set destinations. + for dest_idx, dest in enumerate(destination_indices): + chip_layer.destinations[dest_idx].layer = layer2core_map[dest] + chip_layer.destinations[dest_idx].enable = True + + chip_layer.pass_sensor_events = not layer.disable_pixel_array + + if layer.merge_polarities: + chip_layer.merge = True @classmethod def set_kill_bits(cls, layer: DynapcnnLayer, config_dict: dict) -> dict: @@ -73,12 +106,30 @@ def set_kill_bits(cls, layer: DynapcnnLayer, config_dict: dict) -> dict: return config_dict @classmethod - def get_dynapcnn_layer_config_dict(cls, layer: DynapcnnLayer): + def get_dynapcnn_layer_config_dict( + cls, + layer: DynapcnnLayer, + layer2core_map: Dict[int, int], + destination_indices: List[int], + ) -> dict: + """Generate config dict from DynapcnnLayer instance + + Parameters + ---------- + - layer (DynapcnnLayer): Layer instance from which to generate the config + - layer2core_map (Dict): Keys are layer indices, values are corresponding + cores on hardware. Needed to map the destinations.] + - destination_indices (List): Indices of destination layers for `layer` + + Returns + ------- + - Dict that holds the information to configure the on-chip core + """ config_dict = {} config_dict["destinations"] = [{}, {}] # Update the dimensions - channel_count, input_size_y, input_size_x = layer.input_shape + channel_count, input_size_y, input_size_x = layer.in_shape dimensions = {"input_shape": {}, "output_shape": {}} dimensions["input_shape"]["size"] = {"x": input_size_x, "y": input_size_y} dimensions["input_shape"]["feature_count"] = channel_count @@ -112,8 +163,6 @@ def get_dynapcnn_layer_config_dict(cls, layer: DynapcnnLayer): config_dict["weights"] = weights.int().tolist() config_dict["biases"] = biases.int().tolist() config_dict["leak_enable"] = biases.bool().any() - # config_dict["weights_kill_bit"] = torch.zeros_like(weights).bool().tolist() - # config_dict["biases_kill_bit"] = torch.zeros_like(biases).bool().tolist() # Update parameters from the spiking layer @@ -140,10 +189,6 @@ def get_dynapcnn_layer_config_dict(cls, layer: DynapcnnLayer): "Unknown reset mechanism. Only MembraneReset and MembraneSubtract are currently understood." ) - # if (not return_to_zero) and self.spk_layer.membrane_subtract != self.spk_layer.threshold: - # warn( - # "SpikingConv2dLayer: Subtraction of membrane potential is always by high threshold." - # ) if layer.spk_layer.min_v_mem is None: min_v_mem = -(2**15) else: @@ -155,17 +200,32 @@ def get_dynapcnn_layer_config_dict(cls, layer: DynapcnnLayer): "threshold_low": min_v_mem, "monitor_enable": False, "neurons_initial_value": neurons_state.int().tolist(), - # "neurons_value_kill_bit" : torch.zeros_like(neurons_state).bool().tolist() } ) - # Update parameters from pooling - if layer.pool_layer is not None: - config_dict["destinations"][0]["pooling"] = expand_to_pair( - layer.pool_layer.kernel_size - )[0] - config_dict["destinations"][0]["enable"] = True - else: - pass + + # Configure destinations + destinations = [] + pooling_sizes = layer.pool + for dest_layer_id, pool in zip(destination_indices, pooling_sizes): + # Ignore exit point destinations + if dest_layer_id >= 0: + + try: + # Use scalar value for pooling + pool = sinabs.utils.collapse_pair(pool) + except ValueError: + raise ValueError( + f"Can only do pooling with quadratic kernels. Received {pool}" + ) + + dest_data = { + "layer": layer2core_map[dest_layer_id], + "enable": True, + "pooling": pool, + } + destinations.append(dest_data) + + config_dict["destinations"] = destinations # Set kill bits config_dict = cls.set_kill_bits(layer=layer, config_dict=config_dict) @@ -174,28 +234,44 @@ def get_dynapcnn_layer_config_dict(cls, layer: DynapcnnLayer): @classmethod def write_dynapcnn_layer_config( - cls, layer: DynapcnnLayer, chip_layer: "CNNLayerConfig" - ): + cls, + layer: DynapcnnLayer, + layer2core_map: Dict[int, int], + destination_indices: List[int], + chip_layer: CNNLayerConfig, + ) -> None: """Write a single layer configuration to the dynapcnn conf object. + Uses the data in `layer` to configure a `CNNLayerConfig` to be + deployed on chip. + Parameters ---------- - layer: - The dynapcnn layer to write the configuration for - chip_layer: CNNLayerConfig - DYNAPCNN configuration object representing the layer to which - configuration is written. + - layer (DynapcnnLayer): Layer instance from which to generate the config + - layer2core_map (Dict): Keys are layer indices, values are corresponding + cores on hardware. Needed to map the destinations.] + - destination_indices (List): Indices of destination layers for `layer` + - chip_layer (CNNLayerConfig): Configuration object of the corrsesponding + on-chip core. Will be changed in-place based on `layer`. """ - config_dict = cls.get_dynapcnn_layer_config_dict(layer=layer) - # Update configuration of the DYNAPCNN layer - chip_layer.dimensions = config_dict["dimensions"] - config_dict.pop("dimensions") - for i in range(len(config_dict["destinations"])): - if "pooling" in config_dict["destinations"][i]: - chip_layer.destinations[i].pooling = config_dict["destinations"][i][ - "pooling" - ] - config_dict.pop("destinations") + + # extracting from a DynapcnnLayer the config. variables for its CNNLayerConfig. + config_dict = cls.get_dynapcnn_layer_config_dict( + layer=layer, + layer2core_map=layer2core_map, + destination_indices=destination_indices, + ) + + # update configuration of the DYNAPCNN layer. + chip_layer.dimensions = config_dict.pop("dimensions") + + # set the destinations configuration. + for dest_idx, destination in enumerate(config_dict.pop("destinations")): + chip_layer.destinations[dest_idx].layer = destination["layer"] + chip_layer.destinations[dest_idx].enable = destination["enable"] + chip_layer.destinations[dest_idx].pooling = destination["pooling"] + + # set remaining configuration. for param, value in config_dict.items(): try: setattr(chip_layer, param, value) @@ -203,36 +279,58 @@ def write_dynapcnn_layer_config( raise TypeError(f"Unexpected parameter {param} or value. {e}") @classmethod - def build_config(cls, model: "DynapcnnNetwork", chip_layers: List[int]): - layers = model.sequence - config = cls.get_default_config() + def build_config( + cls, + layers: Dict[int, DynapcnnLayer], + destination_map: Dict[int, List[int]], + layer2core_map: Dict[int, int], + ) -> DynapcnnConfiguration: + """Uses `DynapcnnLayer` objects to configure their equivalent chip cores - has_dvs_layer = False - i_cnn_layer = 0 # Instantiate an iterator for the cnn cores - for i, chip_equivalent_layer in enumerate(layers): - if isinstance(chip_equivalent_layer, DVSLayer): - chip_layer = config.dvs_layer - cls.write_dvs_layer_config(chip_equivalent_layer, chip_layer) - has_dvs_layer = True - elif isinstance(chip_equivalent_layer, DynapcnnLayer): - chip_layer = config.cnn_layers[chip_layers[i_cnn_layer]] - cls.write_dynapcnn_layer_config(chip_equivalent_layer, chip_layer) - i_cnn_layer += 1 - else: - # in our generated network there is a spurious layer... - # should never happen - raise TypeError("Unexpected layer in the model") + Parameters + ---------- + - layers (Dict): Keys are layer indices, values are DynapcnnLayer instances. + - layer2core_map (Dict): Keys are layer indices, values are corresponding + cores on hardware. Needed to map the destinations. + - destination_indices (List): Indices of destination layers for `layer` - if i == len(layers) - 1: - # last layer - chip_layer.destinations[0].enable = False + Returns + ------- + - DynapcnnConfiguration: Config object holding the information to configure + the chip based on the provided `layers`. + """ + config = cls.get_default_config() + config.dvs_layer.pass_sensor_events = False + + # Loop over layers in network and write corresponding configurations + for layer_index, ith_dcnnl in layers.items(): + if isinstance(ith_dcnnl, DynapcnnLayer): + # retrieve config dict for current layer + chip_layer = config.cnn_layers[layer2core_map[layer_index]] + # write core configuration. + cls.write_dynapcnn_layer_config( + layer=ith_dcnnl, + layer2core_map=layer2core_map, + chip_layer=chip_layer, + destination_indices=destination_map[layer_index], + ) + elif isinstance(ith_dcnnl, DVSLayer): + # Uses the DVS camera. + chip_layer = config.dvs_layer + sw_layer = ith_dcnnl + destination_indices = destination_map[layer_index] + # Write camera configuration. + cls.write_dvs_layer_config( + layer=sw_layer, + layer2core_map=layer2core_map, + destination_indices=destination_indices, + chip_layer=chip_layer, + ) else: - # Set destination layer - chip_layer.destinations[0].layer = chip_layers[i_cnn_layer] - chip_layer.destinations[0].enable = True - - if not has_dvs_layer: - config.dvs_layer.pass_sensor_events = False + # shouldn't happen since type checks are made previously. + raise TypeError( + f"Layer (index {layer_index}) is unexpected in the model: \n{ith_dcnnl}" + ) return config @@ -293,12 +391,13 @@ def monitor_layers(cls, config: "DynapcnnConfiguration", layers: List): config.dvs_layer.monitor_enable = True if config.dvs_layer.pooling.x != 1 or config.dvs_layer.pooling.y != 1: warn( - f"DVS layer has pooling and is being monitored. " + "DVS layer has pooling and is being monitored. " "Note that pooling will not be reflected in the monitored events." ) monitor_layers.remove("dvs") for lyr_indx in monitor_layers: config.cnn_layers[lyr_indx].monitor_enable = True + if any( dest.pooling != 1 for dest in config.cnn_layers[lyr_indx].destinations ): diff --git a/sinabs/backend/dynapcnn/chips/speck2cmini.py b/sinabs/backend/dynapcnn/chips/speck2cmini.py index b044e8da..66b5b4d1 100644 --- a/sinabs/backend/dynapcnn/chips/speck2cmini.py +++ b/sinabs/backend/dynapcnn/chips/speck2cmini.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List import samna from samna.speck2cMini.configuration import SpeckConfiguration @@ -29,8 +29,30 @@ def get_output_buffer(cls): return samna.BasicSinkNode_speck2c_mini_event_output_event() @classmethod - def get_dynapcnn_layer_config_dict(cls, layer: DynapcnnLayer): - config_dict = super().get_dynapcnn_layer_config_dict(layer=layer) + def get_dynapcnn_layer_config_dict( + cls, + layer: DynapcnnLayer, + layer2core_map: Dict[int, int], + destination_indices: List[int], + ) -> dict: + """Generate config dict from DynapcnnLayer instance + + Parameters + ---------- + - layer (DynapcnnLayer): Layer instance from which to generate the config + - layer2core_map (Dict): Keys are layer indices, values are corresponding + cores on hardware. Needed to map the destinations.] + - destination_indices (List): Indices of destination layers for `layer` + + Returns + ------- + - Dict that holds the information to configure the on-chip core + """ + config_dict = super().get_dynapcnn_layer_config_dict( + layer=layer, + layer2core_map=layer2core_map, + destination_indices=destination_indices, + ) config_dict.pop("weights_kill_bit") config_dict.pop("biases_kill_bit") config_dict.pop("neurons_value_kill_bit") diff --git a/sinabs/backend/dynapcnn/chips/speck2e.py b/sinabs/backend/dynapcnn/chips/speck2e.py index 1e170a9f..799b9f98 100644 --- a/sinabs/backend/dynapcnn/chips/speck2e.py +++ b/sinabs/backend/dynapcnn/chips/speck2e.py @@ -1,3 +1,5 @@ +from typing import Dict + import samna from samna.speck2e.configuration import SpeckConfiguration @@ -28,8 +30,3 @@ def get_output_buffer(cls): @classmethod def set_kill_bits(cls, layer: DynapcnnLayer, config_dict: dict) -> dict: return config_dict - - @classmethod - def get_dynapcnn_layer_config_dict(cls, layer: DynapcnnLayer): - config_dict = super().get_dynapcnn_layer_config_dict(layer=layer) - return config_dict diff --git a/sinabs/backend/dynapcnn/chips/speck2f.py b/sinabs/backend/dynapcnn/chips/speck2f.py index b43a5417..d34418f9 100644 --- a/sinabs/backend/dynapcnn/chips/speck2f.py +++ b/sinabs/backend/dynapcnn/chips/speck2f.py @@ -1,3 +1,5 @@ +from typing import Dict, List + import samna from samna.speck2f.configuration import SpeckConfiguration @@ -26,8 +28,30 @@ def get_output_buffer(cls): return samna.BasicSinkNode_speck2f_event_output_event() @classmethod - def get_dynapcnn_layer_config_dict(cls, layer: DynapcnnLayer): - config_dict = super().get_dynapcnn_layer_config_dict(layer=layer) + def get_dynapcnn_layer_config_dict( + cls, + layer: DynapcnnLayer, + layer2core_map: Dict[int, int], + destination_indices: List[int], + ) -> dict: + """Generate config dict from DynapcnnLayer instance + + Parameters + ---------- + - layer (DynapcnnLayer): Layer instance from which to generate the config + - layer2core_map (Dict): Keys are layer indices, values are corresponding + cores on hardware. Needed to map the destinations.] + - destination_indices (List): Indices of destination layers for `layer` + + Returns + ------- + - Dict that holds the information to configure the on-chip core + """ + config_dict = super().get_dynapcnn_layer_config_dict( + layer=layer, + layer2core_map=layer2core_map, + destination_indices=destination_indices, + ) config_dict.pop("weights_kill_bit") config_dict.pop("biases_kill_bit") config_dict.pop("neurons_value_kill_bit") diff --git a/sinabs/backend/dynapcnn/config_builder.py b/sinabs/backend/dynapcnn/config_builder.py index a96e04c6..7a360718 100644 --- a/sinabs/backend/dynapcnn/config_builder.py +++ b/sinabs/backend/dynapcnn/config_builder.py @@ -1,10 +1,16 @@ -import time from abc import ABC, abstractmethod -from typing import List +from typing import Dict, List import samna +from samna.dynapcnn.configuration import DynapcnnConfiguration + +import sinabs +import sinabs.backend +import sinabs.backend.dynapcnn from .dvs_layer import DVSLayer +from .dynapcnn_layer import DynapcnnLayer +from .exceptions import InvalidModel from .mapping import LayerConstraints, get_valid_mapping @@ -30,7 +36,12 @@ def get_default_config(cls): @classmethod @abstractmethod - def build_config(cls, model: "DynapcnnNetwork", chip_layers: List[int]): + def build_config( + cls, + layers: Dict[int, DynapcnnLayer], + destination_map: Dict[int, List[int]], + layer2core_map: Dict[int, int], + ) -> DynapcnnConfiguration: """Build the configuration given a model. Parameters @@ -61,30 +72,19 @@ def monitor_layers(cls, config, layers: List[int]): """Enable the monitor for a given set of layers in the config object.""" @classmethod - def get_valid_mapping(cls, model: "DynapcnnNetwork") -> List[int]: - """Find a valid set of layers for a given model. + def map_layers_to_cores(cls, layers: Dict[int, DynapcnnLayer]) -> Dict[int, int]: + """Find a mapping from DynapcnnLayers onto on-chip cores Parameters ---------- - model (DynapcnnNetwork): - A model + - layers: Dict with layer indices as keys and DynapcnnLayer instances as values Returns ------- - List of core indices corresponding to each layer of the model: - The index of the core on chip to which the i-th layer in the - model is mapped is the value of the i-th entry in the list. + - Dict mapping layer indices (keys) to assigned core IDs (values). """ - mapping = get_valid_mapping(model, cls.get_constraints()) - # turn the mapping into a dict - mapping = {m[0]: m[1] for m in mapping} - # Check if there is a dvs layer in the model - num_dynapcnn_cores = len(model.sequence) - if isinstance(model.sequence[0], DVSLayer): - num_dynapcnn_cores -= 1 - # apply the mapping - chip_layers_ordering = [mapping[i] for i in range(num_dynapcnn_cores)] - return chip_layers_ordering + + return get_valid_mapping(layers, cls.get_constraints()) @classmethod def validate_configuration(cls, config) -> bool: diff --git a/sinabs/backend/dynapcnn/connectivity_specs.py b/sinabs/backend/dynapcnn/connectivity_specs.py new file mode 100644 index 00000000..82c23d91 --- /dev/null +++ b/sinabs/backend/dynapcnn/connectivity_specs.py @@ -0,0 +1,48 @@ +""" +functionality : list device-independent supported connections between layers on chip +""" + +from typing import Union + +import torch.nn as nn + +import sinabs.layers as sl + +from .dvs_layer import DVSLayer + +Pooling = (sl.SumPool2d, nn.AvgPool2d) +Weight = (nn.Conv2d, nn.Linear) +Neuron = (sl.IAFSqueeze,) +DVS = (DVSLayer,) +SupportedNodeTypes = (*Pooling, *Weight, *Neuron, *DVS) + +VALID_SINABS_EDGE_TYPES_ABSTRACT = { + # convoluion is always followed by a neuron layer. + (Weight, Neuron): "weight-neuron", + # Neuron layer can be followed by pooling + (Neuron, Pooling): "neuron-pooling", + # Pooling can be followed by another pooling (will be consolidated) + (Pooling, Pooling): "pooling-pooling", + # Neuron layer can be followed by weight layer of next core + (Neuron, Weight): "neuron-weight", + # Pooling can be followed by weight layer of next core + (Pooling, Weight): "pooling-weight", + # Dvs can be followed by weight layer of next core + (DVS, Weight): "dvs-weight", + # Dvs can be followed by pooling layer + (DVS, Pooling): "dvs-pooling", +} + +# Unpack dict +VALID_SINABS_EDGE_TYPES = { + (source_type, target_type): name + for types, name in VALID_SINABS_EDGE_TYPES_ABSTRACT.items() + for source_type in types[0] + for target_type in types[1] +} + +# Only `Merge` layers are allowed to join multiple inputs +LAYER_TYPES_WITH_MULTIPLE_INPUTS = (sl.Merge,) + +# Neuron and pooling layers can have their output sent to multiple cores +LAYER_TYPES_WITH_MULTIPLE_OUTPUTS = (*Neuron, *Pooling, *DVS) diff --git a/sinabs/backend/dynapcnn/dvs_layer.py b/sinabs/backend/dynapcnn/dvs_layer.py index 0104dd8e..69ec7af6 100644 --- a/sinabs/backend/dynapcnn/dvs_layer.py +++ b/sinabs/backend/dynapcnn/dvs_layer.py @@ -3,27 +3,12 @@ import torch.nn as nn from sinabs.layers import SumPool2d +from sinabs.utils import expand_to_pair from .crop2d import Crop2d from .flipdims import FlipDims -def expand_to_pair(value) -> (int, int): - """Expand a given value to a pair (tuple) if an int is passed. - - Parameters - ---------- - value: - int - - Returns - ------- - pair: - (int, int) - """ - return (value, value) if isinstance(value, int) else value - - class DVSLayer(nn.Module): """DVSLayer representing the DVS pixel array on chip and/or the pre-processing. The order of processing is as follows MergePolarity -> Pool -> Cut -> Flip. @@ -202,14 +187,13 @@ def get_output_shape_dict(self) -> dict: ) = self.get_output_shape_after_pooling() # Compute dims after cropping - if self.crop_layer is not None: - ( - channel_count, - output_size_y, - output_size_x, - ) = self.crop_layer.get_output_shape( - (channel_count, output_size_y, output_size_x) - ) + ( + channel_count, + output_size_y, + output_size_x, + ) = self.crop_layer.get_output_shape( + (channel_count, output_size_y, output_size_x) + ) # Compute dims after pooling return { @@ -237,11 +221,13 @@ def forward(self, data): # Merge polarities if self.merge_polarities: data = data.sum(1, keepdim=True) + # Pool out = self.pool_layer(data) + # Crop - if self.crop_layer is not None: - out = self.crop_layer(out) + out = self.crop_layer(out) + # Flip stuff out = self.flip_layer(out) @@ -264,15 +250,11 @@ def get_roi(self) -> Tuple[Tuple[int, int], Tuple[int, int]]: ------- ((top, bottom), (left, right)) """ - if self.crop_layer is not None: - _, h, w = self.get_output_shape_after_pooling() - return ( - (self.crop_layer.top_crop, self.crop_layer.bottom_crop), - (self.crop_layer.left_crop, self.crop_layer.right_crop), - ) - else: - _, output_size_y, output_size_x = self.get_output_shape() - return (0, output_size_y), (0, output_size_x) + _, h, w = self.get_output_shape_after_pooling() + return ( + (self.crop_layer.top_crop, self.crop_layer.bottom_crop), + (self.crop_layer.left_crop, self.crop_layer.right_crop), + ) def get_output_shape(self) -> Tuple[int, int, int]: """Output shape of the layer. @@ -292,14 +274,13 @@ def get_output_shape(self) -> Tuple[int, int, int]: output_size_y = input_size_y // pooling[0] # Compute dims after cropping - if self.crop_layer is not None: - ( - channel_count, - output_size_y, - output_size_x, - ) = self.crop_layer.get_output_shape( - (channel_count, output_size_y, output_size_x) - ) + ( + channel_count, + output_size_y, + output_size_x, + ) = self.crop_layer.get_output_shape( + (channel_count, output_size_y, output_size_x) + ) return channel_count, output_size_y, output_size_x diff --git a/sinabs/backend/dynapcnn/dynapcnn_layer.py b/sinabs/backend/dynapcnn/dynapcnn_layer.py index a56454c8..76057854 100644 --- a/sinabs/backend/dynapcnn/dynapcnn_layer.py +++ b/sinabs/backend/dynapcnn/dynapcnn_layer.py @@ -1,40 +1,87 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + from copy import deepcopy -from typing import Dict, Optional, Tuple, Union -from warnings import warn +from functools import partial +from typing import List, Tuple import numpy as np import torch from torch import nn -import sinabs.activation import sinabs.layers as sl from .discretize import discretize_conv_spike_ -from .dvs_layer import expand_to_pair + +# Define sum pooling functional as power-average pooling with power 1 +sum_pool2d = partial(nn.functional.lp_pool2d, norm_type=1) + + +def convert_linear_to_conv( + lin: nn.Linear, input_shape: Tuple[int, int, int] +) -> nn.Conv2d: + """Convert Linear layer to Conv2d. + + Parameters + ---------- + - lin (nn.Linear): linear layer to be converted. + - input_shape (tuple): the tensor shape the layer expects. + + Returns + ------- + - nn.Conv2d: convolutional layer equivalent to `lin`. + """ + in_chan, in_h, in_w = input_shape + if lin.in_features != in_chan * in_h * in_w: + raise ValueError( + "Shape of linear layer weight does not match provided input shape" + ) + + layer = nn.Conv2d( + in_channels=in_chan, + kernel_size=(in_h, in_w), + out_channels=lin.out_features, + padding=0, + bias=lin.bias is not None, + ) + + if lin.bias is not None: + layer.bias.data = lin.bias.data.clone().detach() + + layer.weight.data = ( + lin.weight.data.clone() + .detach() + .reshape((lin.out_features, in_chan, in_h, in_w)) + ) + + return layer class DynapcnnLayer(nn.Module): - """Create a DynapcnnLayer object representing a dynapcnn layer. + """Create a DynapcnnLayer object representing a layer on DynapCNN or Speck. - Requires a convolutional layer, a sinabs spiking layer and an optional - pooling value. The layers are used in the order conv -> spike -> pool. + Requires a convolutional layer, a sinabs spiking layer and a list of + pooling values. The layers are used in the order conv -> spike -> pool. Parameters ---------- conv: torch.nn.Conv2d or torch.nn.Linear - Convolutional or linear layer (linear will be converted to convolutional) + Convolutional or linear layer + (linear will be converted to convolutional) spk: sinabs.layers.IAFSqueeze Sinabs IAF layer in_shape: tuple of int - The input shape, needed to create dynapcnn configs if the network does not - contain an input layer. Convention: (features, height, width) - pool: int or None - Integer representing the sum pooling kernel and stride. If `None`, no - pooling will be applied. + The input shape, needed to create dynapcnn configs if the network + does not contain an input layer. Convention: (features, height, width) + pool: List of integers + Each integer entry represents an output (destination on chip) and + whether pooling should be applied (values > 1) or not (values equal + to 1). The number of entries determines the number of tensors the + layer's forward method returns. discretize: bool Whether to discretize parameters. rescale_weights: int - Layer weights will be divided by this value. + Layer weights will be multiplied by this value. """ def __init__( @@ -42,126 +89,143 @@ def __init__( conv: nn.Conv2d, spk: sl.IAFSqueeze, in_shape: Tuple[int, int, int], - pool: Optional[sl.SumPool2d] = None, + pool: List[int], discretize: bool = True, rescale_weights: int = 1, ): super().__init__() - self.input_shape = in_shape + self.in_shape = in_shape + self.pool = pool + self._discretize = discretize + self._rescale_weights = rescale_weights + if not isinstance(spk, sl.IAFSqueeze): + raise TypeError( + f"Unsupported spiking layer type {type(spk)}. " + "Only `IAFSqueeze` layers are supported." + ) spk = deepcopy(spk) + + # Convert `nn.Linear` to `nn.Conv2d`. if isinstance(conv, nn.Linear): - conv = self._convert_linear_to_conv(conv) - if spk.is_state_initialised(): - # Expand dims - spk.v_mem = spk.v_mem.data.unsqueeze(-1).unsqueeze(-1) + conv = convert_linear_to_conv(conv, in_shape) + if spk.is_state_initialised() and (ndim := spk.v_mem.ndim) < 4: + for __ in range(4 - ndim): + # Expand spatial dimensions + spk.v_mem = spk.v_mem.data.unsqueeze(-1) else: conv = deepcopy(conv) - if rescale_weights != 1: + if self._rescale_weights != 1: # this has to be done after copying but before discretizing - conv.weight.data = (conv.weight / rescale_weights).clone().detach() + conv.weight.data = (conv.weight * self._rescale_weights).clone().detach() - self.discretize = discretize - if discretize: - # int conversion is done while writing the config. + # TODO: Does this really need to be enforced here or upon deployment? + # check if convolution kernel is a square. + if conv.kernel_size[0] != conv.kernel_size[1]: + raise ValueError( + "The kernel of a `nn.Conv2d` must have the same height and width." + ) + for pool_size in pool: + if pool_size[0] != pool_size[1]: + raise ValueError("Only square pooling kernels are supported") + + # int conversion is done while writing the config. + if self._discretize: conv, spk = discretize_conv_spike_(conv, spk, to_int=False) - self.conv_layer = conv - self.spk_layer = spk - if pool is not None: - if pool.kernel_size[0] != pool.kernel_size[1]: - raise ValueError("Only square kernels are supported") - self.pool_layer = deepcopy(pool) - else: - self.pool_layer = None + self.conv = conv + self.spk = spk - def _convert_linear_to_conv(self, lin: nn.Linear) -> nn.Conv2d: - """Convert Linear layer to Conv2d. + @property + def conv_layer(self): + return self.conv - Parameters - ---------- - lin: nn.Linear - Linear layer to be converted + @property + def spk_layer(self): + return self.spk - Returns - ------- - nn.Conv2d - Convolutional layer equivalent to `lin`. + @property + def discretize(self): + return self._discretize + + @property + def rescale_weights(self): + return self._rescale_weights + + @property + def conv_out_shape(self): + return self._get_conv_output_shape() + + ####################################################### Public Methods ####################################################### + + def forward(self, x) -> List[torch.Tensor]: + """Torch forward pass. + + ... """ - in_chan, in_h, in_w = self.input_shape + returns = [] - if lin.in_features != in_chan * in_h * in_w: - raise ValueError("Shapes don't match.") + x = self.conv_layer(x) + x = self.spk_layer(x) - layer = nn.Conv2d( - in_channels=in_chan, - kernel_size=(in_h, in_w), - out_channels=lin.out_features, - padding=0, - bias=lin.bias is not None, - ) + for pool in self.pool: - if lin.bias is not None: - layer.bias.data = lin.bias.data.clone().detach() + if pool == 1: + # no pooling is applied. + returns.append(x) + else: + # sum pooling of `(pool, pool)` is applied. + pool_out = sum_pool2d(x, kernel_size=pool) + returns.append(pool_out) - layer.weight.data = ( - lin.weight.data.clone() - .detach() - .reshape((lin.out_features, in_chan, in_h, in_w)) - ) + if len(returns) == 1: + return returns[0] + else: + return tuple(returns) - return layer + def zero_grad(self, set_to_none: bool = False) -> None: + """Call `zero_grad` method of spiking layer""" + return self.spk.zero_grad(set_to_none) def get_neuron_shape(self) -> Tuple[int, int, int]: """Return the output shape of the neuron layer. Returns ------- - features, height, width + - conv_out_shape (tuple): formatted as (features, height, width). """ + # same as the convolution's output. + return self._get_conv_output_shape() - def get_shape_after_conv(layer: nn.Conv2d, input_shape): - (ch_in, h_in, w_in) = input_shape - (kh, kw) = expand_to_pair(layer.kernel_size) - (pad_h, pad_w) = expand_to_pair(layer.padding) - (stride_h, stride_w) = expand_to_pair(layer.stride) - - def out_len(in_len, k, s, p): - return (in_len - k + 2 * p) // s + 1 - - out_h = out_len(h_in, kh, stride_h, pad_h) - out_w = out_len(w_in, kw, stride_w, pad_w) - ch_out = layer.out_channels - return ch_out, out_h, out_w - - conv_out_shape = get_shape_after_conv( - self.conv_layer, input_shape=self.input_shape - ) - return conv_out_shape + def get_output_shape(self) -> List[Tuple[int, int, int]]: + """Return the output shapes of the layer, including pooling. - def get_output_shape(self) -> Tuple[int, int, int]: + Returns + ------- + - output_shape (list of tuples): + One entry per destination, each formatted as (features, height, width). + """ neuron_shape = self.get_neuron_shape() # this is the actual output shape, including pooling - if self.pool_layer is not None: - pool = expand_to_pair(self.pool_layer.kernel_size) - return ( + output_shape = [] + for pool in self.pool: + output_shape.append( neuron_shape[0], - neuron_shape[1] // pool[0], - neuron_shape[2] // pool[1], + neuron_shape[1] // pool, + neuron_shape[2] // pool, ) - else: - return neuron_shape + return output_shape def summary(self) -> dict: + """Returns a summary of the convolution's/pooling's kernel sizes and the output shape of the spiking layer.""" + return { - "pool": ( - None if self.pool_layer is None else list(self.pool_layer.kernel_size) - ), + "pool": (self.pool), "kernel": list(self.conv_layer.weight.data.shape), - "neuron": self.get_neuron_shape(), + "neuron": self._get_conv_output_shape(), # neuron layer output has the same shape as the convolution layer ouput. } def memory_summary(self): @@ -183,7 +247,9 @@ def memory_summary(self): """ summary = self.summary() f, c, h, w = summary["kernel"] - f, neuron_height, neuron_width = self.get_neuron_shape() + f, neuron_height, neuron_width = ( + self._get_conv_output_shape() + ) # neuron layer output has the same shape as the convolution layer ouput. return { "kernel": c * pow(2, np.ceil(np.log2(h * w)) + np.ceil(np.log2(f))), @@ -192,13 +258,31 @@ def memory_summary(self): "bias": 0 if self.conv_layer.bias is None else len(self.conv_layer.bias), } - def forward(self, x): - """Torch forward pass.""" - x = self.conv_layer(x) - x = self.spk_layer(x) - if self.pool_layer is not None: - x = self.pool_layer(x) - return x + ####################################################### Private Methods ####################################################### - def zero_grad(self, set_to_none: bool = False) -> None: - return self.spk_layer.zero_grad(set_to_none) + def _get_conv_output_shape(self) -> Tuple[int, int, int]: + """Computes the output dimensions of `conv_layer`. + + Returns + ---------- + - output dimensions (tuple): a tuple describing `(output channels, height, width)`. + """ + # get the layer's parameters. + + out_channels = self.conv_layer.out_channels + kernel_size = self.conv_layer.kernel_size + stride = self.conv_layer.stride + padding = self.conv_layer.padding + dilation = self.conv_layer.dilation + + # compute the output height and width. + out_height = ( + (self.in_shape[1] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) + // stride[0] + ) + 1 + out_width = ( + (self.in_shape[2] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) + // stride[1] + ) + 1 + + return (out_channels, out_height, out_width) diff --git a/sinabs/backend/dynapcnn/dynapcnn_layer_utils.py b/sinabs/backend/dynapcnn/dynapcnn_layer_utils.py new file mode 100644 index 00000000..2e67acc8 --- /dev/null +++ b/sinabs/backend/dynapcnn/dynapcnn_layer_utils.py @@ -0,0 +1,341 @@ +from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union + +from torch import nn + +from sinabs import layers as sl +from sinabs.utils import expand_to_pair + +from .dynapcnn_layer import DynapcnnLayer + + +def construct_dynapcnnlayers_from_mapper( + dcnnl_map: Dict, + dvs_layer_info: Union[None, Dict], + discretize: bool, + rescale_fn: Optional[Callable] = None, +) -> Tuple[Dict[int, DynapcnnLayer], Dict[int, Set[int]], List[int]]: + """Construct DynapcnnLayer instances from `dcnnl_map` + + Paramters + --------- + + Returns + ------- + - Dict of new DynapcnnLayer instances, with keys corresponding to `dcnnl_map` + - Dict mapping to each layer index a set of destination indices + - List of layer indices that act as entry points to the network + """ + finalize_dcnnl_map(dcnnl_map, dvs_layer_info, rescale_fn) + + dynapcnn_layers = { + layer_idx: construct_single_dynapcnn_layer(layer_info, discretize) + for layer_idx, layer_info in dcnnl_map.items() + } + + destination_map = construct_destination_map(dcnnl_map, dvs_layer_info) + + entry_points = collect_entry_points(dcnnl_map, dvs_layer_info) + + return dynapcnn_layers, destination_map, entry_points + + +def finalize_dcnnl_map( + dcnnl_map: Dict, dvs_info: Union[Dict, None], rescale_fn: Optional[Callable] = None +) -> None: + """Finalize dcnnl map by consolidating information + + Update dcnnl_map in-place + - Consolidate chained pooling layers + - Determine rescaling of layer weights + - Fix input shapes + + Parameters + ---------- + - dcnnl_map: Dict holding info needed to instantiate DynapcnnLayer instances + - rescale_fn: Optional callable that is used to determine layer + rescaling in case of conflicting preceeding average pooling + """ + # Consolidate pooling information for DVS layer + consolidate_dvs_pooling(dvs_info, dcnnl_map) + + # Consolidate pooling information for each destination + for layer_info in dcnnl_map.values(): + consolidate_layer_pooling(layer_info, dcnnl_map) + + for layer_info in dcnnl_map.values(): + # Consolidate scale factors + consolidate_layer_scaling(layer_info, rescale_fn) + + +def consolidate_dvs_pooling(dvs_info: Union[Dict, None], dcnnl_map: Dict): + """Consolidate pooling information for dvs layer + + Update `dvs_info` and `dcnnl_map` in place. + - Extract pooling and scale factor of consecutive pooling operations + - Add entries "cumulative_pooling" and "cumulative_scaling" + - Update DVSLayer pooling if applicable + - For each destination, add cumulative rescale factor to "rescale_factors" + entry in corresponding entry of `dcnnl_map`. + + Parameters + ---------- + - dvs_info: Dict holding info of dvs layer. + - dcnnl_map: Dict holding info needed to instantiate DynapcnnLayer instances + """ + if dvs_info is None or dvs_info["pooling"] is None: + # Nothing to do + return + + # Check whether pooling can be incorporated into the DVSLayer. + dvs_layer = dvs_info["module"] + crop_layer = dvs_layer.crop_layer + if ( + crop_layer.top_crop != 0 + or crop_layer.left_crop != 0 + or crop_layer.bottom_crop != dvs_layer.input_shape[1] + or crop_layer.right_crop != dvs_layer.input_shape[2] + ): + raise ValueError( + "DVSLayer with cropping is followed by a pooling layer. " + "This is currently not supported. Please define pooling " + "directly within the DVSLayer (with the `pool` argument) " + "and remove the pooling layer that follows the DVSLayer" + ) + flip_layer = dvs_layer.flip_layer + if flip_layer.flip_x or flip_layer.flip_y or flip_layer.swap_xy: + raise ValueError( + "DVSLayer with flipping or dimension swapping is followed " + "by a pooling layer. This is currently not supported. " + "Please define pooling directly within the DVSLayer " + "(with the `pool` argument) and remove the pooling " + "layer that follows the DVSLayer" + ) + + # Incorporate pooling into DVSLayer + pool_layer = dvs_info["pooling"]["module"] + added_pooling, scale = extract_pooling_from_module(pool_layer) + dvs_pooling = expand_to_pair(dvs_layer.pool_layer.kernel_size) + cumulative_pooling = ( + dvs_pooling[0] * added_pooling[0], + dvs_pooling[1] * added_pooling[1], + ) + dvs_layer.pool_layer.kernel_size = cumulative_pooling + dvs_layer.pool_layer.stride = None + + # Update cropping layer to account for reduced size after pooling + dvs_layer.crop_layer.bottom_crop //= added_pooling[0] + dvs_layer.crop_layer.right_crop //= added_pooling[1] + + # Set rescale_factor for targeted dynapcnn layers + if dvs_info["destinations"] is not None: + for dest_lyr_idx in dvs_info["destinations"]: + dcnnl_map[dest_lyr_idx]["rescale_factors"].add(scale) + + +def consolidate_layer_pooling(layer_info: Dict, dcnnl_map: Dict): + """Consolidate pooling information for individual layer + + Update `layer_info` and `dcnnl_map` in place. + - Extract pooling and scale factor of consecutive pooling operations + - To each "destination" add entries "cumulative_pooling" and + "cumulative_scaling" + - Add "pooling_list" to `layer_info` with all poolings of a layer + in order of its "destination"s. + - For each destination, add cumulative rescale factor to "rescale_factors" + entry in corresponding entry of `dcnnl_map`. + + Parameters + ---------- + - layer_info: Dict holding info of single layer. Corresponds to + single entry in `dcnnl_map` + - dcnnl_map: Dict holding info needed to instantiate DynapcnnLayer instances + """ + layer_info["pooling_list"] = [] + for destination in layer_info["destinations"]: + pool, scale = consolidate_dest_pooling(destination["pooling_modules"]) + destination["cumulative_pooling"] = pool + layer_info["pooling_list"].append(pool) + destination["cumulative_scaling"] = scale + if (dest_lyr_idx := destination["destination_layer"]) is not None: + dcnnl_map[dest_lyr_idx]["rescale_factors"].add(scale) + + +def consolidate_dest_pooling( + modules: Iterable[nn.Module], +) -> Tuple[Tuple[int, int], float]: + """Consolidate pooling information for consecutive pooling modules + for single destination. + + Parameters + ---------- + modules: Iteravle of pooling modules + + Returns + ------- + cumulative_pooling: Tuple of two ints, indicating pooling along + vertical and horizontal dimensions for all modules together + cumulative_scaling: float, indicating by how much subsequent weights + need to be rescaled to account for average pooling being converted + to sum pooling, considering all provided modules. + """ + cumulative_pooling = [1, 1] + cumulative_scaling = 1.0 + + for pooling_layer in modules: + pooling, rescale_factor = extract_pooling_from_module(pooling_layer) + cumulative_pooling[0] *= pooling[0] + cumulative_pooling[1] *= pooling[1] + cumulative_scaling *= rescale_factor + + return cumulative_pooling, cumulative_scaling + + +def extract_pooling_from_module( + pooling_layer: Union[nn.AvgPool2d, sl.SumPool2d] +) -> Tuple[Tuple[int, int], float]: + """Extract pooling size and required rescaling factor from pooling module + + Parameters + ---------- + pooling_layer: pooling module + + Returns + ------- + pooling: Tuple of two ints, indicating pooling along vertical and horizontal dimensions + scale_factor: float, indicating by how much subsequent weights need to be rescaled to + account for average pooling being converted to sum pooling. + """ + pooling = expand_to_pair(pooling_layer.kernel_size) + + if pooling_layer.stride is not None: + stride = expand_to_pair(pooling_layer.stride) + if pooling != stride: + raise ValueError( + f"Stride length {pooling_layer.stride} should be the same as pooling kernel size {pooling_layer.kernel_size}" + ) + if isinstance(pooling_layer, nn.AvgPool2d): + scale_factor = 1.0 / (pooling[0] * pooling[1]) + elif isinstance(pooling_layer, sl.SumPool2d): + scale_factor = 1.0 + else: + raise ValueError(f"Unsupported type {type(pooling_layer)} for pooling layer") + + return pooling, scale_factor + + +def consolidate_layer_scaling(layer_info: Dict, rescale_fn: Optional[Callable] = None): + """Dertermine scale factor of single layer + + Add "rescale_factor" entry to `layer_info`. If more than one + different rescale factors have been determined due to conflicting + average pooling in preceding layers, requrie `rescale_fn` to + resolve. + + Parameters + ---------- + - layer_info: Dict holding info of single layer. + - rescale_fn: Optional callable that is used to determine layer + rescaling in case of conflicting preceeding average pooling + """ + if len(layer_info["rescale_factors"]) == 0: + rescale_factor = 1 + elif len(layer_info["rescale_factors"]) == 1: + rescale_factor = layer_info["rescale_factors"].pop() + else: + if rescale_fn is None: + raise ValueError( + "Average pooling layers of conflicting sizes pointing to " + "same destination. Either replace them by SumPool2d layers " + "or provide a `rescale_fn` to resolve this" + ) + else: + rescale_factor = rescale_fn(layer_info["rescale_factors"]) + layer_info["rescale_factor"] = rescale_factor + + +def construct_single_dynapcnn_layer( + layer_info: Dict, discretize: bool +) -> DynapcnnLayer: + """Instantiate a DynapcnnLayer instance from the information + in `layer_info' + + Parameters + ---------- + - layer_info: Dict holding info of single layer. + - discretize: bool indicating whether layer parameters should be + discretized (weights, biases, thresholds) + + Returns + ------- + """ + return DynapcnnLayer( + conv=layer_info["conv"]["module"], + spk=layer_info["neuron"]["module"], + in_shape=layer_info["input_shape"], + pool=layer_info["pooling_list"], + discretize=discretize, + rescale_weights=layer_info["rescale_factor"], + ) + + +def construct_destination_map( + dcnnl_map: Dict[int, Dict], dvs_layer_info: Union[None, Dict] +) -> Dict[int, List[int]]: + """Create a dict that holds destinations for each layer + + Parameters + ---------- + - dcnnl_map: Dict holding info needed to instantiate DynapcnnLayer instances + - dynapcnn_layer_info: Dict holding info about DVSLayer instance and its destinations + + Returns + ------- + Dict with layer indices (int) as keys and list of destination indices (int) as values. + Layer outputs that are not sent to other dynapcnn layers are considered + exit points of the network and represented by negative indices. + """ + destination_map = dict() + for layer_index, layer_info in dcnnl_map.items(): + destination_indices = [] + none_counter = 0 + for dest in layer_info["destinations"]: + if (dest_idx := dest["destination_layer"]) is None: + # For `None` destinations use unique negative index + none_counter += 1 + destination_indices.append(-none_counter) + else: + destination_indices.append(dest_idx) + destination_map[layer_index] = destination_indices + if dvs_layer_info is not None: + if (dest_info := dvs_layer_info["destinations"]) is None: + destination_map["dvs"] = [-1] + else: + # Copy destination list from dvs layer info + destination_map["dvs"] = [d for d in dest_info] + + return destination_map + + +def collect_entry_points( + dcnnl_map: Dict[int, Dict], dvs_layer_info: Union[None, Dict] +) -> Set[int]: + """Return set of layer indices that are entry points + + Parameters + ---------- + - dcnnl_map: Dict holding info needed to instantiate DynapcnnLayer instances + - dynapcnn_layer_info: Dict holding info about DVSLayer instance and its destinations + If it is not None, it will be the only entry point returned. + + Returns + ------- + Set of all layer indices which act as entry points to the network + """ + if dvs_layer_info is None: + return { + layer_index + for layer_index, layer_info in dcnnl_map.items() + if layer_info["is_entry_node"] + } + else: + return {"dvs"} diff --git a/sinabs/backend/dynapcnn/dynapcnn_network.py b/sinabs/backend/dynapcnn/dynapcnn_network.py index 4a1bb525..17376d8e 100644 --- a/sinabs/backend/dynapcnn/dynapcnn_network.py +++ b/sinabs/backend/dynapcnn/dynapcnn_network.py @@ -1,108 +1,312 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + import time -from subprocess import CalledProcessError -from typing import List, Optional, Sequence, Tuple, Union +from pprint import pformat +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from warnings import warn import samna import torch import torch.nn as nn +from samna.dynapcnn.configuration import DynapcnnConfiguration +from torch import Tensor import sinabs +import sinabs.layers as sl from .chip_factory import ChipFactory from .dvs_layer import DVSLayer from .dynapcnn_layer import DynapcnnLayer from .io import disable_timestamps, enable_timestamps, open_device, reset_timestamps +from .nir_graph_extractor import GraphExtractor from .utils import ( - DEFAULT_IGNORED_LAYER_TYPES, - build_from_list, - convert_model_to_layer_list, + COMPLETELY_IGNORED_LAYER_TYPES, + IGNORED_LAYER_TYPES, infer_input_shape, parse_device_id, ) +from .weight_rescaling_methods import rescale_method_1 class DynapcnnNetwork(nn.Module): - """Given a sinabs spiking network, prepare a dynapcnn-compatible network. This can be used to - test the network will be equivalent once on DYNAPCNN. This class also provides utilities to - make the dynapcnn configuration and upload it to DYNAPCNN. - - The following operations are done when converting to dynapcnn-compatible: - - * multiple avg pooling layers in a row are consolidated into one and \ - turned into sum pooling layers; - * checks are performed on layer hyperparameter compatibility with dynapcnn \ - (kernel sizes, strides, padding) - * checks are performed on network structure compatibility with dynapcnn \ - (certain layers can only be followed by other layers) - * linear layers are turned into convolutional layers - * dropout layers are ignored - * weights, biases and thresholds are discretized according to dynapcnn requirements - - Note that the model parameters are only ever transferred to the device - on the `to` call, so changing a threshold or weight of a model that - is deployed will have no effect on the model on chip until `to` is called again. - """ - def __init__( self, - snn: Union[nn.Sequential, sinabs.Network], + snn: nn.Module, input_shape: Optional[Tuple[int, int, int]] = None, - dvs_input: bool = False, + batch_size: Optional[int] = None, + dvs_input: Optional[bool] = None, discretize: bool = True, + weight_rescaling_fn: Callable = rescale_method_1, ): """ - DynapcnnNetwork: a class turning sinabs networks into dynapcnn - compatible networks, and making dynapcnn configurations. + Given a sinabs spiking network, prepare a dynapcnn-compatible network. This can be used to + test the network will be equivalent once on DYNAPCNN. This class also provides utilities to + make the dynapcnn configuration and upload it to DYNAPCNN. Parameters ---------- - snn: sinabs.Network - SNN that determines the structure of the `DynapcnnNetwork` - input_shape: None or tuple of ints - Shape of the input, convention: (features, height, width) - If None, `snn` needs an InputLayer - dvs_input: bool - Does dynapcnn receive input from its DVS camera? - discretize: bool - If True, discretize the parameters and thresholds. - This is needed for uploading weights to dynapcnn. Set to False only for - testing purposes. + - snn (nn.Module): a implementing a spiking network. + - input_shape (tuple or None): a description of the input dimensions + as `(features, height, width)`. If `None`, `snn` must contain a + `DVSLayer` instance, from which the input shape will be inferred. + - batch_size (optional int): If `None`, will try to infer the batch size from the model. + If int value is provided, it has to match the actual batch size of the model. + - dvs_input (bool): optional (default as `None`). Wether or not dynapcnn receive + input from its DVS camera. + If a `DVSLayer` is part of `snn`... + ... and `dvs_input` is `False`, its `disable_pixel_array` attribute + will be set `True`. This means the DVS sensor will be configured + upon deployment but its output will not be sent as input + ... and `dvs_input` is `None`, the `disable_pixel_array` attribute + of the layer will not be changed. + ... and `dvs_input` is `True`, `disable_pixel_array` will be set + `False`, so that the DVS sensor data is sent to the network. + If no `DVSLayer` is part of `snn`... + ... and `dvs_input` is `False` or `None`, no `DVSLayer` will be added + and the DVS sensor will not be configured upon deployment. + ... and `dvs_input` is `True`, a `DVSLayer` instance will be added + to the network, with `disable_pixel_array` set to `False`. + - discretize (bool): If `True`, discretize the parameters and thresholds. This is needed for uploading + weights to dynapcnn. Set to `False` only for testing purposes. + - weight_rescaling_fn (callable): a method that handles how the re-scaling factor for one or more `SumPool2d` projecting to + the same convolutional layer are combined/re-scaled before applying them. """ super().__init__() - # This attribute stores the location/core-id of each of the DynapcnnLayers upon placement on chip - self.chip_layers_ordering = [] + if isinstance(snn, sinabs.Network): + # Ignore `analog_model` of sinabs `Network` instances + snn = snn.spiking_model + + self.dvs_input = dvs_input + self.input_shape = infer_input_shape(snn, input_shape) + self._layer2core_map = None + + # Infer batch size for dummpy input to graph extractor + if batch_size is None: + batch_size = sinabs.utils.get_smallest_compatible_time_dimension(snn) + # computational graph from original PyTorch module. + self._graph_extractor = GraphExtractor( + snn, + torch.randn((batch_size, *self.input_shape)), + self.dvs_input, + ignore_node_types=COMPLETELY_IGNORED_LAYER_TYPES, + ) - self.input_shape = input_shape # Convert models to sequential - layers = convert_model_to_layer_list( - model=snn, ignore=DEFAULT_IGNORED_LAYER_TYPES + # Remove nodes of ignored classes (including merge nodes) + # Other than `COMPLETELY_IGNORED_LAYER_TYPES`, `IGNORED_LAYER_TYPES` are + # part of the graph initially and are needed to ensure proper handling of + # graph structure (e.g. Merge nodes) or meta-information (e.g. + # `nn.Flatten` for io-shapes) + self._graph_extractor.remove_nodes_by_class(IGNORED_LAYER_TYPES) + + # Module to execute forward pass through network + self._dynapcnn_module = self._graph_extractor.get_dynapcnn_network_module( + discretize=discretize, weight_rescaling_fn=weight_rescaling_fn ) - # Check if dvs input is expected - if dvs_input: - self.dvs_input = True - else: - self.dvs_input = False + self._dynapcnn_module.setup_dynapcnnlayer_graph(index_layers_topologically=True) + + ####################################################### Public Methods ####################################################### + + @property + def all_layers(self): + return self._dynapcnn_module.all_layers + + @property + def dvs_node_info(self): + return self._dynapcnn_module.dvs_node_info + + @property + def dvs_layer(self): + return self._dynapcnn_module.dvs_layer - input_shape = infer_input_shape(layers, input_shape=input_shape) - assert len(input_shape) == 3, "infer_input_shape did not return 3-tuple" + @property + def chip_layers_ordering(self): + warn( + "`chip_layers_ordering` is deprecated. Returning `layer2core_map` instead.", + DeprecationWarning, + ) + return self._layer2core_map + + @property + def dynapcnn_layers(self): + return self._dynapcnn_module.dynapcnn_layers + + @property + def dynapcnn_module(self): + return self._dynapcnn_module + + @property + def exit_layers(self): + return [self.all_layers[i] for i in self._dynapcnn_module.get_exit_layers()] - # Build model from layers - self.sequence = build_from_list( - layers, - in_shape=input_shape, - discretize=discretize, - dvs_input=self.dvs_input, + @property + def exit_layer_ids(self): + return self._dynapcnn_module.get_exit_layers() + + @property + def is_deployed_on_dynapcnn_device(self): + return ( + hasattr(self, "device") + and parse_device_id(self.device)[0] in ChipFactory.supported_devices ) + @property + def layer_destination_map(self): + return self._dynapcnn_module.destination_map + + @property + def layer2core_map(self): + return self._layer2core_map + + @property + def name_2_indx_map(self): + return self._graph_extractor.name_2_indx_map + + def hw_forward(self, x): + """Forwards data through the chip.""" + + # flush buffer. + _ = self.samna_output_buffer.get_events() + + # NOTE: The code to start and stop time stamping is device specific + reset_timestamps(self.device) + enable_timestamps(self.device) + + # send input. + self.samna_input_buffer.write(x) + received_evts = [] + + # record at least until the last event has been replayed. + min_duration = max(event.timestamp for event in x) * 1e-6 + time.sleep(min_duration) + + # keep recording if more events are being registered. + while True: + prev_length = len(received_evts) + time.sleep(0.1) + received_evts.extend(self.samna_output_buffer.get_events()) + if prev_length == len(received_evts): + break + + # disable timestamp + disable_timestamps(self.device) + + return received_evts + + def forward( + self, x, return_complete: bool = False + ) -> Union[List["event"], Tensor, Dict[int, Dict[int, Tensor]]]: + """Forwards data through the `DynapcnnNetwork` instance. + + If the network has been deployed on a Dynapcnn/Speck device the forward + pass happens on the devices. Otherwise the device will be simulated by + passing the data through the `DynapcnnLayer` instances. + + Parameters + ---------- + x: Tensor that serves as input to network. Is passed to all layers + that are marked as entry points + return_complete: bool that indicates whether all layer outputs should + be return or only those with no further destinations (default) + + Returns + ------- + The returned object depends on whether the network has been deployed + on chip. If this is the case, a flat list of samna events is returned, + in the order in which the events have been collected. + If the data is passed through the `DynapcnnLayer` instances, the output + depends on `return_complete` and on the network configuration: + * If `return_complete` is `True`, all layer outputs will be returned in a + dict, with layer indices as keys, and nested dicts as values, which + hold destination indices as keys and output tensors as values. + * If `return_complete` is `False` and there is only a single destination + in the whole network that is marked as final (i.e. destination + index in dynapcnn layer handler is negative), it will return the + output as a single tensor. + * If `return_complete` is `False` and no destination in the network + is marked as final, a warning will be raised and the function + returns an empty dict. + * In all other cases a dict will be returned that is of the same + structure as if `return_complete` is `True`, but only with entries + where the destination is marked as final. + """ + if self.is_deployed_on_dynapcnn_device: + return self.hw_forward(x) + else: + # Forward pass through software DynapcnnLayer instance + return self.dynapcnn_module(x, return_complete=return_complete) + + def parameters(self) -> list: + """Gathers all the parameters of the network in a list. This is done by accessing the convolutional layer in each `DynapcnnLayer`, + calling its `.parameters` method and saving it to a list. + + Note: the method assumes no biases are used. + + Returns + ---------- + - parameters (list): a list of parameters of all convolutional layers in the `DynapcnnNetwok`. + """ + parameters = [] + + for layer in self.dynapcnn_layers.values(): + if isinstance(layer, DynapcnnLayer): + parameters.extend(layer.conv_layer.parameters()) + + return parameters + + def memory_summary(self) -> Dict[str, Dict[int, int]]: + """Get a summary of the network's memory requirements. + + Returns + ------- + dict: + A dictionary with keys kernel, neuron, bias. The values are a dicts. + Each nested dict has as keys the indices of all dynapcnn_layers and + as values the corresonding memory values for each layer. + """ + # For each entry (kernel, neuron, bias) provide one nested dict with + # one entry for each layer + summary = {key: dict() for key in ("kernel", "neuron", "bias")} + + for layer_index, layer in self.dynapcnn_layers.items(): + for key, val in layer.memory_summary().items(): + summary[key][layer_index] = val + + return summary + + def init_weights(self, init_fn: nn.init = nn.init.xavier_normal_) -> None: + """Call the weight initialization method `init_fn` on each `DynapcnnLayer.conv_layer.weight.data` in the `DynapcnnNetwork` instance. + + Parameters + ---------- + - init_fn (torch.nn.init): the weight initialization method to be used. + """ + for layer in self.dynapcnn_layers.values(): + if isinstance(layer, DynapcnnLayer): + init_fn(layer.conv_layer.weight.data) + + def detach_neuron_states(self) -> None: + """Detach the neuron states and activations from current computation graph (necessary).""" + + for module in self.dynapcnn_layers.values(): + if isinstance(module, DynapcnnLayer): + if isinstance(module.spk_layer, sl.StatefulLayer): + for name, buffer in module.spk_layer.named_buffers(): + buffer.detach_() + def to( self, - device="cpu", - chip_layers_ordering="auto", + device: str = "cpu", monitor_layers: Optional[Union[List, str]] = None, - config_modifier=None, - slow_clk_frequency: int = None, + config_modifier: Optional[Callable] = None, + slow_clk_frequency: Optional[int] = None, + layer2core_map: Union[Dict[int, int], str] = "auto", + chip_layers_ordering: Optional[Union[Sequence[int], str]] = None, ): - """Note that the model parameters are only ever transferred to the device on the `to` call, + """Deploy model to cpu, gpu or a SynSense device. + + Note that the model parameters are only ever transferred to the device on the `to` call, so changing a threshold or weight of a model that is deployed will have no effect on the model on chip until `to` is called again. @@ -112,13 +316,6 @@ def to( device: String cpu:0, cuda:0, dynapcnndevkit, speck2devkit - chip_layers_ordering: sequence of integers or `auto` - The order in which the dynapcnn layers will be used. If `auto`, - an automated procedure will be used to find a valid ordering. - A list of layers on the device where you want each of the model's DynapcnnLayers to be placed. - The index of the core on chip to which the i-th layer in the model is mapped is the value of the i-th entry in the list. - Note: This list should be the same length as the number of dynapcnn layers in your model. - monitor_layers: None/List A list of all layers in the module that you want to monitor. Indexing starts with the first non-dvs layer. If you want to monitor the dvs layer for eg. @@ -127,48 +324,72 @@ def to( monitor_layers = ["dvs"] # If you want to monitor the output of the pre-processing layer monitor_layers = ["dvs", 8] # If you want to monitor preprocessing and layer 8 monitor_layers = "all" # If you want to monitor all the layers + monitor_layers = [-1] # If you want to only monitor exit points of the network (i.e. final layers) config_modifier: A user configuration modifier method. This function can be used to make any custom changes you want to make to the configuration object. + layer2core_map (dict or "auto"): Defines how cores on chip are + assigned to DynapcnnLayers. If `auto`, an automated procedure + will be used to find a valid ordering. Otherwise a dict needs + to be passed, with DynapcnnLayer indices as keys and assigned + core IDs as values. DynapcnnLayer indices have to match those of + `self.dynapcnn_layers`. + + chip_layers_ordering: sequence of integers or `auto` + The order in which the dynapcnn layers will be used. If `auto`, + an automated procedure will be used to find a valid ordering. + A list of layers on the device where you want each of the model's DynapcnnLayers to be placed. + The index of the core on chip to which the i-th layer in the model is mapped is the value of the i-th entry in the list. + Note: This list should be the same length as the number of dynapcnn layers in your model. + Note: This parameter is obsolete and should not be passed anymore. Use + `layer2core_map` instead. + Note ---- chip_layers_ordering and monitor_layers are used only when using synsense devices. For GPU or CPU usage these options are ignored. """ self.device = device + if isinstance(device, torch.device): - return super().to(device) + self._to_device(device) + elif isinstance(device, str): device_name, _ = parse_device_id(device) - if device_name in ChipFactory.supported_devices: # pragma: no cover - # Generate config + + if device_name in ChipFactory.supported_devices: + + # generate config. config = self.make_config( + layer2core_map=layer2core_map, chip_layers_ordering=chip_layers_ordering, device=device, monitor_layers=monitor_layers, config_modifier=config_modifier, ) - # Apply configuration to device + # apply configuration to device. self.samna_device = open_device(device) self.samna_device.get_model().apply_configuration(config) time.sleep(1) - # Set external slow-clock if need + # set external slow-clock if needed. if slow_clk_frequency is not None: dk_io = self.samna_device.get_io_module() dk_io.set_slow_clk(True) dk_io.set_slow_clk_rate(slow_clk_frequency) # Hz builder = ChipFactory(device).get_config_builder() - # Create input source node + + # create input source node. self.samna_input_buffer = builder.get_input_buffer() - # Create output sink node node + + # create output sink node node. self.samna_output_buffer = builder.get_output_buffer() - # Connect source node to device sink + # connect source node to device sink. self.device_input_graph = samna.graph.EventFilterGraph() self.device_input_graph.sequential( [ @@ -177,7 +398,7 @@ def to( ] ) - # Connect sink node to device + # connect sink node to device. self.device_output_graph = samna.graph.EventFilterGraph() self.device_output_graph.sequential( [ @@ -185,131 +406,60 @@ def to( self.samna_output_buffer, ] ) + self.device_input_graph.start() self.device_output_graph.start() self.samna_config = config + return self + else: - return super().to(device) + self._to_device(device) + else: raise Exception("Unknown device description.") - def _make_config( - self, - chip_layers_ordering: Union[Sequence[int], str] = "auto", - device="dynapcnndevkit:0", - monitor_layers: Optional[Union[List, str]] = None, - config_modifier=None, - ) -> Tuple["SamnaConfiguration", bool]: - """Prepare and output the `samna` configuration for this network. - - Parameters - ---------- - - chip_layers_ordering: sequence of integers or `auto` - The order in which the dynapcnn layers will be used. If `auto`, - an automated procedure will be used to find a valid ordering. - A list of layers on the device where you want each of the model's DynapcnnLayers to be placed. - The index of the core on chip to which the i-th layer in the model is mapped is the value of the i-th entry in the list. - Note: This list should be the same length as the number of dynapcnn layers in your model. - - device: String - dynapcnndevkit, speck2b or speck2devkit - - monitor_layers: None/List/Str - A list of all layers in the module that you want to monitor. Indexing starts with the first non-dvs layer. - If you want to monitor the dvs layer for eg. - :: - - monitor_layers = ["dvs"] # If you want to monitor the output of the pre-processing layer - monitor_layers = ["dvs", 8] # If you want to monitor preprocessing and layer 8 - monitor_layers = "all" # If you want to monitor all the layers - - If this value is left as None, by default the last layer of the model is monitored. - - config_modifier: - A user configuration modifier method. - This function can be used to make any custom changes you want to make to the configuration object. + def is_compatible_with(self, device_type: str) -> bool: + """Check if the current model is compatible with a given device. - Returns - ------- - Configuration object - Object defining the configuration for the device - Bool - True if the configuration is valid for the given device. + Args: + device_type (str): Device type ie speck2b, speck2fmodule - Raises - ------ - ImportError - If samna is not available. + Returns: + bool: True if compatible """ - config_builder = ChipFactory(device).get_config_builder() - - has_dvs_layer = isinstance(self.sequence[0], DVSLayer) - - # Figure out layer ordering - if chip_layers_ordering == "auto": - chip_layers_ordering = config_builder.get_valid_mapping(self) - else: - # Truncate chip_layers_ordering just in case a longer list is passed - if has_dvs_layer: - chip_layers_ordering = chip_layers_ordering[: len(self.sequence) - 1] - chip_layers_ordering = chip_layers_ordering[: len(self.sequence)] - - # Save the chip layers - self.chip_layers_ordering = chip_layers_ordering - # Update config - config = config_builder.build_config(self, chip_layers_ordering) - if self.input_shape and self.input_shape[0] == 1: - config.dvs_layer.merge = True - # Check if any monitoring is enabled and if not, enable monitoring for the last layer - if monitor_layers is None: - monitor_layers = [-1] - elif monitor_layers == "all": - num_cnn_layers = len(self.sequence) - int(has_dvs_layer) - monitor_layers = list(range(num_cnn_layers)) - - # Enable monitors on the specified layers - # Find layers corresponding to the chip - monitor_chip_layers = [ - self.find_chip_layer(lyr) for lyr in monitor_layers if lyr != "dvs" - ] - if "dvs" in monitor_layers: - monitor_chip_layers.append("dvs") - config_builder.monitor_layers(config, monitor_chip_layers) - - # Fix default factory setting to not return input events (UGLY!! Ideally this should happen in samna) - # config.factory_settings.monitor_input_enable = False - - # Apply user config modifier - if config_modifier is not None: - config = config_modifier(config) - - # Validate config - return config, config_builder.validate_configuration(config) + try: + _, is_compatible = self._make_config(device=device_type) + except ValueError as e: + # Catch "No valid mapping found" error + if e.args[0] == ( + "One or more of the DynapcnnLayers could not be mapped to any core." + ): + return False + else: + raise e + return is_compatible def make_config( self, - chip_layers_ordering: Union[Sequence[int], str] = "auto", - device="dynapcnndevkit:0", + layer2core_map: Union[Dict[int, int], str] = "auto", + device: str = "dynapcnndevkit:0", monitor_layers: Optional[Union[List, str]] = None, - config_modifier=None, - ): + config_modifier: Optional[Callable] = None, + chip_layers_ordering: Optional[Union[Sequence[int], str]] = None, + ) -> DynapcnnConfiguration: """Prepare and output the `samna` DYNAPCNN configuration for this network. Parameters ---------- - - chip_layers_ordering: sequence of integers or `auto` - The order in which the dynapcnn layers will be used. If `auto`, - an automated procedure will be used to find a valid ordering. - A list of layers on the device where you want each of the model's DynapcnnLayers to be placed. - Note: This list should be the same length as the number of dynapcnn layers in your model. - - device: String - dynapcnndevkit, speck2b or speck2devkit - - monitor_layers: None/List/Str + - layer2core_map (dict or "auto"): Defines how cores on chip are + assigned to DynapcnnLayers. If `auto`, an automated procedure + will be used to find a valid ordering. Otherwise a dict needs + to be passed, with DynapcnnLayer indices as keys and assigned + core IDs as values. DynapcnnLayer indices have to match those of + `self.dynapcnn_layers`. + - device: (string): dynapcnndevkit, speck2b or speck2devkit + - monitor_layers: None/List/Str A list of all layers in the module that you want to monitor. Indexing starts with the first non-dvs layer. If you want to monitor the dvs layer for eg. :: @@ -317,12 +467,20 @@ def make_config( monitor_layers = ["dvs"] # If you want to monitor the output of the pre-processing layer monitor_layers = ["dvs", 8] # If you want to monitor preprocessing and layer 8 monitor_layers = "all" # If you want to monitor all the layers + monitor_layers = [-1] # If you want to only monitor exit points of the network (i.e. final layers) If this value is left as None, by default the last layer of the model is monitored. - config_modifier: + - config_modifier (Callable or None): A user configuration modifier method. This function can be used to make any custom changes you want to make to the configuration object. + - chip_layers_ordering (None, sequence of integers or "auto", obsolete): + The order in which the dynapcnn layers will be used. If `auto`, + an automated procedure will be used to find a valid ordering. + A list of layers on the device where you want each of the model's DynapcnnLayers to be placed. + Note: This list should be the same length as the number of dynapcnn layers in your model. + Note: This parameter is obsolete and should not be passed anymore. Use + `layer2core_map` instead. Returns ------- @@ -337,41 +495,60 @@ def make_config( If the generated configuration is not valid for the specified device. """ config, is_compatible = self._make_config( - chip_layers_ordering=chip_layers_ordering, + layer2core_map=layer2core_map, device=device, monitor_layers=monitor_layers, config_modifier=config_modifier, + chip_layers_ordering=chip_layers_ordering, ) + # Validate config if is_compatible: print("Network is valid") return config else: - raise ValueError(f"Generated config is not valid for {device}") + raise ValueError( + f"Generated config is not valid for {device}. " + "Probably one or more layers are too large. Try " + "Reducing the number of neurons or the kernel sizes." + ) - def is_compatible_with(self, device_type: str) -> bool: - """Check if the current model is compatible with a given device. + def has_dvs_layer(self) -> bool: + """Return True if there is a DVSLayer in the network - Args: - device_type (str): Device type ie speck2b, speck2fmodule + Returns + ------- + bool: True if DVSLayer is found within the network. + """ + return self.dvs_layer is not None - Returns: - bool: True if compatible + def zero_grad(self, set_to_none: bool = False) -> None: + """Call `zero_grad` method of each DynapCNN layer + + Parameters + ---------- + - set_to_none (bool): This argument is passed directly to the + `zero_grad` method of each DynapCNN layer """ - try: - _, is_compatible = self._make_config(device=device_type) - except ValueError as e: - # Catch "No valid mapping found" error - if e.args[0] == ("No valid mapping found"): - return False - else: - raise e - return is_compatible + for lyr in self.dynapcnn_layers.values(): + lyr.zero_grad(set_to_none) def reset_states(self, randomize=False): - """Reset the states of the network.""" + """Reset the states of the network. + + Parameters + ---------- + - randomize (bool): If `False` (default), will set all states to 0. + Otherwise will set to random values. + + Notes + ----- + - Setting `randomize` to `True` is only supported for models that have + not yet been deployed on a SynSense device. + """ if hasattr(self, "device") and isinstance(self.device, str): # pragma: no cover device_name, _ = parse_device_id(self.device) + # Reset states on SynSense device if device_name in ChipFactory.supported_devices: config_builder = ChipFactory(self.device).get_config_builder() # Set all the vmem states in the samna config to zero @@ -391,117 +568,206 @@ def reset_states(self, randomize=False): time.sleep(0.1) self.samna_input_graph.start() return + + # Reset states of `DynapcnnLayer` instances for layer in self.sequence: if isinstance(layer, DynapcnnLayer): layer.spk_layer.reset_states(randomize=randomize) - def find_chip_layer(self, layer_idx): - """Given an index of a layer in the model, find the corresponding cnn core id where it is - placed. + ####################################################### Private Methods ####################################################### - > Note that the layer index does not include the DVSLayer. - > For instance your model comprises two layers [DVSLayer, DynapcnnLayer], - > then the index of DynapcnnLayer is 0 and not 1. + def _make_config( + self, + layer2core_map: Union[Dict[int, int], str] = "auto", + device: str = "dynapcnndevkit:0", + monitor_layers: Optional[Union[List, str]] = None, + config_modifier: Optional[Callable] = None, + chip_layers_ordering: Optional[Union[Sequence[int], str]] = None, + ) -> Tuple[DynapcnnConfiguration, bool]: + """Prepare and output the `samna` DYNAPCNN configuration for this network. Parameters ---------- - layer_idx: int - Index of a layer + - layer2core_map (dict or "auto"): Defines how cores on chip are + assigned to DynapcnnLayers. If `auto`, an automated procedure + will be used to find a valid ordering. Otherwise a dict needs + to be passed, with DynapcnnLayer indices as keys and assigned + core IDs as values. DynapcnnLayer indices have to match those of + `self.dynapcnn_layers`. + - device: (string): dynapcnndevkit, speck2b or speck2devkit + - monitor_layers: None/List/Str + A list of all layers in the module that you want to monitor. Indexing starts with the first non-dvs layer. + If you want to monitor the dvs layer for eg. + :: + + monitor_layers = ["dvs"] # If you want to monitor the output of the pre-processing layer + monitor_layers = ["dvs", 8] # If you want to monitor preprocessing and layer 8 + monitor_layers = "all" # If you want to monitor all the layers + monitor_layers = [-1] # If you want to only monitor exit points of the network (i.e. final layers) + + If this value is left as None, by default the last layer of the model is monitored. + + - config_modifier (Callable or None): + A user configuration modifier method. + This function can be used to make any custom changes you want to make to the configuration object. + - chip_layers_ordering (None, sequence of integers or "auto", obsolete): + The order in which the dynapcnn layers will be used. If `auto`, + an automated procedure will be used to find a valid ordering. + A list of layers on the device where you want each of the model's DynapcnnLayers to be placed. + Note: This list should be the same length as the number of dynapcnn layers in your model. + Note: This parameter is obsolete and should not be passed anymore. Use + `layer2core_map` instead. Returns ------- - chip_lyr_idx: int - Index of the layer on the chip where the model layer is placed. - """ - # Compute the expected number of cores - num_cores_required = len(self.sequence) - if isinstance(self.sequence[0], DVSLayer): - num_cores_required -= 1 - if len(self.chip_layers_ordering) != num_cores_required: - raise Exception( - f"Number of layers specified in chip_layers_ordering {self.chip_layers_ordering} does not correspond to the number of cores required for this model {num_cores_required}" - ) + Configuration object + Object defining the configuration for the device + Bool + True if the configuration is valid for the given device. - return self.chip_layers_ordering[layer_idx] - def forward(self, x): - if ( - hasattr(self, "device") - and parse_device_id(self.device)[0] in ChipFactory.supported_devices - ): # pragma: no cover - _ = self.samna_output_buffer.get_events() # Flush buffer - # NOTE: The code to start and stop time stamping is device specific - reset_timestamps(self.device) - enable_timestamps(self.device) - # Send input - self.samna_input_buffer.write(x) - received_evts = [] - # Record at least until the last event has been replayed - min_duration = max(event.timestamp for event in x) * 1e-6 - time.sleep(min_duration) - # Keep recording if more events are being registered - while True: - prev_length = len(received_evts) - time.sleep(0.1) - received_evts.extend(self.samna_output_buffer.get_events()) - if prev_length == len(received_evts): - break - # Disable timestamp - disable_timestamps(self.device) - return received_evts + Raises + ------ + ImportError + If samna is not available. + ValueError + If no valid mapping between the layers of this object and the cores of + the provided device can be found. + """ + config_builder = ChipFactory(device).get_config_builder() + + if chip_layers_ordering is not None: + if layer2core_map != "auto": + warn( + "Both `chip_layers_ordering` and `layer2core_map are provided. " + "The parameter `chip_layers_ordering` is deprecated and will " + "be ignored.", + DeprecationWarning, + ) + elif chip_layers_ordering == "auto": + warn( + "The parameter `chip_layers_ordering` is deprecated. Passing " + "'auto' is still accepted, but in the future please use " + "`layer2core_map` instead.", + DeprecationWarning, + ) + else: + layer2core_map = { + idx: core + for idx, core in zip(self.dynapcnn_layers, chip_layers_ordering) + } + warn( + "The parameter `chip_layers_ordering` is deprecated. " + "Because `layer2core_map` is 'auto', and `chip_layers_ordering` " + "is not, will convert `chip_layers_ordering` to a " + "dict matching `layer2core_map`. In the future please use " + "`layer2core_map` instead. Please make sure the inferred" + "mapping from DynapcnnLayer index to core index is correct:" + + pformat(layer2core_map), + DeprecationWarning, + ) + if layer2core_map == "auto": + # Assign chip core ID for each DynapcnnLayer. + layer2core_map = config_builder.map_layers_to_cores(self.dynapcnn_layers) else: - """Torch's forward pass.""" - return self.sequence(x) + if not layer2core_map.keys() == self.dynapcnn_layers.keys(): + raise ValueError( + "The keys provided in `layer2core_map` must exactly match " + "the keys in `self.dynapcnn_layers`" + ) - def memory_summary(self): - """Get a summary of the network's memory requirements. + self._layer2core_map = layer2core_map - Returns - ------- - dict: - A dictionary with keys kernel, neuron, bias. - The values are a list of the corresponding number per layer in the same order as the model - """ - summary = {} - - dynapcnn_layers = [ - lyr for lyr in self.sequence if isinstance(lyr, DynapcnnLayer) - ] - summary.update({k: list() for k in dynapcnn_layers[0].memory_summary().keys()}) - for lyr in dynapcnn_layers: - lyr_summary = lyr.memory_summary() - for k, v in lyr_summary.items(): - summary[k].append(v) - return summary + # update config (config. DynapcnnLayer instances into their assigned core). + config = config_builder.build_config( + layers=self.all_layers, + destination_map=self.layer_destination_map, + layer2core_map=layer2core_map, + ) - def zero_grad(self, set_to_none: bool = False) -> None: - for lyr in self.sequence: - lyr.zero_grad(set_to_none) + if monitor_layers is None: + # Monitor all layers with exit point destinations + monitor_layers = self._dynapcnn_module.get_exit_layers() + elif monitor_layers == "all": + monitor_layers = [ + lyr_idx + for lyr_idx, layer in self.dynapcnn_layers.items() + if not isinstance(layer, DVSLayer) + ] + elif -1 in monitor_layers: + # Replace `-1` with exit layer IDs + monitor_layers.remove(-1) + monitor_layers += self._dynapcnn_module.get_exit_layers() + + # Collect cores (chip layers) that are to be monitored + monitor_chip_layers = [] + for lyr_idx in monitor_layers: + if str(lyr_idx).lower() == "dvs": + monitor_chip_layers.append("dvs") + else: + monitor_chip_layers.append(layer2core_map[lyr_idx]) + + # enable monitors on the specified layers. + config_builder.monitor_layers(config, monitor_chip_layers) - def __del__(self): - # Stop the input graph - if hasattr(self, "device_input_graph") and self.device_input_graph: - self.device_input_graph.stop() + if config_modifier is not None: + # apply user config modifier. + config = config_modifier(config) + + # Validate config + return config, config_builder.validate_configuration(config) - # Stop the output graph. - if hasattr(self, "device_output_graph") and self.device_output_graph: - self.device_output_graph.stop() + def _to_device(self, device: torch.device) -> None: + """Access each sub-layer within all `DynapcnnLayer` instances and call `.to(device)` on them.""" + for layer in self.dynapcnn_layers.values(): + if isinstance(layer, sinabs.backend.dynapcnn.dynapcnn_layer.DynapcnnLayer): + layer.to(device) + + for _, data in self._merge_points.items(): + data["merge"].to(device) + + def __str__(self): + pretty_print = "" + if self.dvs_layer is not None: + pretty_print += ( + "-------------------------- [ DVSLayer ] --------------------------\n" + ) + pretty_print += f"{self.dvs_layer}\n\n" + for idx, layer_data in self.dynapcnn_layers.items(): + pretty_print += f"----------------------- [ DynapcnnLayer {idx} ] -----------------------\n" + if self.is_deployed_on_dynapcnn_device: + pretty_print += f"Core {self.layer2core_map[idx]}\n" + pretty_print += f"{layer_data}\n\n" + + return pretty_print + + def __repr__(self): + if self.is_deployed_on_dynapcnn_device: + layer_info = "\n\n".join( + f"{idx} - core: {self.layer2core_map[idx]}\n{pformat(layer)}" + for idx, layer in self.dynapcnn_layers.items() + ) + device_info = f" deployed on {self.device}," + else: + layer_info = "\n\n".join( + f"Index: {idx}\n{pformat(layer)}" + for idx, layer in self.dynapcnn_layers.items() + ) + device_info = f" on {self.device}," if hasattr(self, "device") else "" + return ( + f"DynapCNN Network{device_info} containing:\nDVS Layer: {pformat(self.dvs_layer)}" + "\n\nDynapCNN Layers:\n\n" + layer_info + ) class DynapcnnCompatibleNetwork(DynapcnnNetwork): """Deprecated class, use DynapcnnNetwork instead.""" - def __init__( - self, - snn: Union[nn.Sequential, sinabs.Network], - input_shape: Optional[Tuple[int, int, int]] = None, - dvs_input: bool = False, - discretize: bool = True, - ): + def __init__(self, *args, **kwargs): from warnings import warn warn( "DynapcnnCompatibleNetwork has been renamed to DynapcnnNetwork " + "and will be removed in a future release." ) - super().__init__(snn, input_shape, dvs_input, discretize) + super().__init__(*args, **kwargs) diff --git a/sinabs/backend/dynapcnn/dynapcnnnetwork_module.py b/sinabs/backend/dynapcnn/dynapcnnnetwork_module.py new file mode 100644 index 00000000..bd03d5d7 --- /dev/null +++ b/sinabs/backend/dynapcnn/dynapcnnnetwork_module.py @@ -0,0 +1,379 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + +from pprint import pformat +from typing import Dict, List, Optional, Set, Union +from warnings import warn + +import torch.nn as nn +from torch import Tensor + +import sinabs.layers as sl + +from .dvs_layer import DVSLayer +from .dynapcnn_layer import DynapcnnLayer +from .utils import Edge, topological_sorting + + +class DynapcnnNetworkModule(nn.Module): + """Allow forward (and backward) passing through a network of `DynapcnnLayer`s. + + Internally constructs a graph representation based on the provided arguments + and uses this to pass data through all layers in correct order. + + Parameters + ---------- + - dynapcnn_layers (dict): a mapper containing `DynapcnnLayer` instances. + - destination_map (dict): Maps layer indices to list of destination indices. + Exit destinations are marked by negative integers + - entry_points (set): Set of layer indices that act as network entry points. + - dvs_node_info (dict): contains information associated with the `DVSLayer` node. + `None` if no DVS node exists. + + Attributes + ---------- + This class internally builds a graph with `DynapcnnLayer` as nodes and their + connections as edges. Several data structures help efficient retrieval of + information required for the forward pass: + - _dynapcnnlayer_edges: Set of edges connecting dynapcnn layers. Tuples + of indices of source and target layers. + - _sorted_nodes: List of layer indices in topological order, to ensure forward + calls to layers only happen when required inputs are available. + - _node_source_map: Dict with layer indices as keys and list of input layer indices + as values. + """ + + def __init__( + self, + dynapcnn_layers: Dict[int, DynapcnnLayer], + destination_map: Dict[int, List[int]], + entry_points: Set[int], + dvs_node_info: Optional[Dict] = None, + ): + super().__init__() + + self._dvs_node_info = dvs_node_info + + # Unfortunately ModuleDict does not allow for integer keys + module_dict = {str(idx): lyr for idx, lyr in dynapcnn_layers.items()} + self._dynapcnn_layers = nn.ModuleDict(module_dict) + + if self._dvs_node_info is not None: + self._dvs_layer = dvs_node_info["module"] + else: + self._dvs_layer = None + + self._destination_map = destination_map + self._entry_points = entry_points + + # `Merge` layers are stateless. One instance can be used for all merge points during forward pass + self.merge_layer = sl.Merge() + + @property + def all_layers(self): + layers = self.dynapcnn_layers + if self.dvs_layer is not None: + # `self.dynapcnn_layers` is a (shallow) copy. Adding entries won't + # affect `self._dynapcnn_layers` + layers["dvs"] = self.dvs_layer + return layers + + @property + def dvs_node_info(self): + return self._dvs_node_info + + @property + def dvs_layer(self): + return self._dvs_layer + + @property + def destination_map(self): + return self._destination_map + + @property + def dynapcnn_layers(self): + # Convert string-indices to integer-indices and sort by index + return {int(idx): lyr for idx, lyr in sorted(self._dynapcnn_layers.items())} + + @property + def entry_points(self): + return self._entry_points + + @property + def sorted_nodes(self): + return self._sorted_nodes + + @property + def node_source_map(self): + return self._node_source_map + + def get_exit_layers(self) -> List[int]: + """Get layers that act as exit points of the network + + Returns + ------- + - List[int]: Layer indices with at least one exit destination. + """ + return [ + layer_idx + for layer_idx, destinations in self.destination_map.items() + if any(d < 0 for d in destinations) + ] + + def get_exit_points(self) -> Dict[int, Dict]: + """Get details of layers that act as exit points of the network + + Returns + ------- + - Dict[int, Dict]: Dict whose keys are layer indices of `dynapcnn_layers` + with at least one exit destination. Values are list of dicts, providing + for each exit destination the negative valued ID ('destination_id'), + the index of that destination within the list of destinations of the + corresponding `DynapcnnLayer` ('destination_index'), and the pooling + for this destination. + """ + exit_layers = dict() + for layer_idx, destinations in self.destination_map.items(): + exit_destinations = [] + for i, dest in enumerate(destinations): + if dest < 0: + exit_destinations.append( + { + "destination_id": dest, + "destination_index": i, + "pooling": self.dynapcnn_layers[layer_idx].pool[i], + } + ) + if exit_destinations: + exit_layers[layer_idx] = exit_destinations + + return exit_layers + + def setup_dynapcnnlayer_graph( + self, index_layers_topologically: bool = False + ) -> None: + """Set up data structures to run forward pass through dynapcnn layers + + Parameters + ---------- + - index_layers_topologically (bool): If True, will assign new indices to + dynapcnn layers such that they match their topological order within the + network graph. This is not necessary but can help understand the network + more easily when inspecting it. + """ + self._dynapcnnlayer_edges = self.get_dynapcnnlayers_edges() + self.add_entry_points_edges(self._dynapcnnlayer_edges) + self._sorted_nodes = topological_sorting(self._dynapcnnlayer_edges) + self._node_source_map = self.get_node_source_map(self._dynapcnnlayer_edges) + if index_layers_topologically: + self.reindex_layers(self._sorted_nodes) + + def get_dynapcnnlayers_edges(self) -> Set[Edge]: + """Create edges representing connections between `DynapcnnLayer` instances. + + Returns + ---------- + - dcnnl_edges: a set of edges using the IDs of `DynapcnnLayer` instances. These edges describe the computational + graph implemented by the layers of the model (i.e., how the `DynapcnnLayer` instances address each other). + """ + dcnnl_edges = set() + + for dcnnl_idx, destination_indices in self._destination_map.items(): + for dest in destination_indices: + if dest >= 0: # Ignore negative destinations (network exit points) + dcnnl_edges.add((dcnnl_idx, dest)) + + return dcnnl_edges + + def add_entry_points_edges(self, dcnnl_edges: Set[Edge]) -> None: + """Add extra edges `('input', X)` to `dcnnl_edges` for + layers which are entry points of the `DynapcnnNetwork`, i.e. + `handler.entry_node = True`. + + Parameters + ---------- + - dcnnl_edges (Set): tuples representing the output->input mapping between + `DynapcnnLayer` instances. Will be changed in place. + """ + for indx in self._entry_points: + dcnnl_edges.add(("input", indx)) + + def get_node_source_map(self, dcnnl_edges: Set[Edge]) -> Dict[int, List[int]]: + """From a set of edges, create a dict that maps to each node its sources + + Parameters + ---------- + - dcnnl_edges (Set): tuples representing the output->input mapping between + `DynapcnnLayer` instances. + + Returns + ------- + - Dict with layer indices (int) as keys and list of layer indices that + map to corresponding layer + """ + sources = dict() + + for src, trg in dcnnl_edges: + if trg in sources: + sources[trg].append(src) + else: + sources[trg] = [src] + + return sources + + def forward( + self, x, return_complete: bool = False + ) -> Union[Tensor, Dict[int, Dict[int, Tensor]]]: + """Perform a forward pass through all dynapcnn layers + The `setup_dynapcnnlayer_graph` method has to be executed beforehand. + + Parameters + ---------- + x: Tensor that serves as input to network. Is passed to all layers + that are marked as entry points + return_complete: bool that indicates whether all layer outputs should + be return or only those with no further destinations (default) + + Returns + ------- + The returned object depends on whether `return_complete` is set and on + the network configuration: + * If `return_complete` is `True`, all layer outputs will be returned in a + dict, with layer indices as keys, and nested dicts as values, which + hold destination indices as keys and output tensors as values. + * If `return_complete` is `False` and there is only a single destination + in the whole network that is marked as exit point (i.e. destination + index in dynapcnn layer handler is negative), it will return the + output as a single tensor. + * If `return_complete` is `False` and no destination in the network + is marked as exit point, a warning will be raised and the function + returns an empty dict. + * In all other cases a dict will be returned that is of the same + structure as if `return_complete` is `True`, but only with entries + where the destination is marked as exit point. + + """ + if not hasattr(self, "_sorted_nodes"): + raise RuntimeError( + "It looks like `setup_dynapcnnlayers_graph` has never been executed. " + "It needs to be called at least once before calling `forward`." + ) + + # For each layer store its outputs as dict with destination layers as keys. + # For input set `x` as input to entry points + layers_outputs = {"input": {ep: x for ep in self.entry_points}} + + for idx_curr in self._sorted_nodes: + # Get inputs to the layer + if len(sources := self._node_source_map[idx_curr]) > 1: + # Layer has multiple inputs + inputs = [layers_outputs[idx_src][idx_curr] for idx_src in sources] + current_input = self.merge_layer(*inputs) + else: + idx_src = sources[0] + current_input = layers_outputs[idx_src][idx_curr] + + # Get current layer instance and destinations + layer = self.all_layers[idx_curr] + destinations = self._destination_map[idx_curr] + + # Forward pass through layer + output = layer(current_input) + + # Store layer output for all destinations + if len(destinations) == 1: + # Output is single tensor + layers_outputs[idx_curr] = {destinations[0]: output} + else: + if isinstance(layer, DVSLayer): + # DVSLayer returns a single tensor (same for all its destinations). + layers_outputs[idx_curr] = { + idx_dest: output for idx_dest in destinations + } + else: + # Output is list of tensors for different destinations + layers_outputs[idx_curr] = { + idx_dest: out for idx_dest, out in zip(destinations, output) + } + + if return_complete: + return layers_outputs + + # Take outputs with exit point destinations as network output + network_outputs = {} + for layer_idx, layer_out in layers_outputs.items(): + outputs = { + idx_dest: out + for idx_dest, out in layer_out.items() + if isinstance(idx_dest, int) and idx_dest < 0 + } + if outputs: + network_outputs[layer_idx] = outputs + + # If no outputs have been found return None and warn + if not network_outputs: + warn( + "No exit points have been found. Try setting `return_complete` " + "`True` to get all outputs, or mark exit points by setting " + "corresponding destination layer indices in destination_map " + " to negative integer values" + ) + return dict() + + # Special case with single output: return single tensor + if ( + len(network_outputs) == 1 + and len(out := (next(iter(network_outputs.values())))) == 1 + ): + return next(iter(out.values())) + + # If there is output from multiple layers return all of them in a dict + return network_outputs + + def reindex_layers(self, index_order: List[int]) -> None: + """Reindex layers based on provided order + + Will assign new index to dynapcnn layers and update all internal + attributes accordingly. + + Parameters + ---------- + index_order: List of integers indicating new order of layers: + Position of layer index within this list indicates new index + """ + mapping = {old: new for new, old in enumerate(index_order)} + + def remap(key): + if key in ["dvs", "input"] or (isinstance(key, int) and key < 0): + # Entries 'dvs', 'input' and negative indices are not changed + return key + else: + return mapping[key] + + # Remap all internal objects + self._dynapcnn_layers = nn.ModuleDict( + {str(remap(int(idx))): lyr for idx, lyr in self._dynapcnn_layers.items()} + ) + + self._entry_points = {remap(idx) for idx in self._entry_points} + + self._destination_map = { + remap(idx): [remap(dest) for dest in destinations] + for idx, destinations in self._destination_map.items() + } + + self._dynapcnnlayer_edges = { + (remap(src), remap(trg)) for (src, trg) in self._dynapcnnlayer_edges + } + + self._sorted_nodes = [remap(idx) for idx in self._sorted_nodes] + + self._node_source_map = { + remap(node): [remap(src) for src in sources] + for node, sources in self._node_source_map.items() + } + + def __repr__(self): + return f"DVS Layer: {pformat(self.dvs_layer)}\n\nDynapCNN Layers:\n" + pformat( + self.dynapcnn_layers + ) diff --git a/sinabs/backend/dynapcnn/exceptions.py b/sinabs/backend/dynapcnn/exceptions.py index cd5c63aa..b1d64491 100644 --- a/sinabs/backend/dynapcnn/exceptions.py +++ b/sinabs/backend/dynapcnn/exceptions.py @@ -1,3 +1,13 @@ +from typing import Set, Tuple, Type + +default_invalid_structure_string = ( + "This should never happen, but is most likely due to an unsupported SNN " + "architecture. In general, a dynapcnn network should consist of groups of " + "a weight layer (conv or linear), a spiking layer (IAFSqueeze), and " + "optionally a pooling layer." +) + + class MissingLayer(Exception): index: int @@ -6,8 +16,8 @@ def __init__(self, index: int): class UnexpectedLayer(Exception): - layer_type_found: type - layer_type_expected: type + layer_type_found: Type + layer_type_expected: Type def __init__(self, expected, found): super().__init__(f"Expected {expected} but found {found}") @@ -17,3 +27,112 @@ class InputConfigurationError(Exception): """Is raised when input to dynapcnn is not configured correctly.""" pass + + +class WrongModuleCount(Exception): + dynapcnnlayer_indx: Type + modules_count: Type + + def __init__(self, dynapcnnlayer_indx, modules_count): + super().__init__( + f"A DynapcnnLayer {dynapcnnlayer_indx} should have 2 or 3 modules but found {modules_count}." + ) + + +class WrongPoolingModule(Exception): + pooling_module: Type + + def __init__( + self, + pooling_module, + ): + super().__init__( + f"The function 'utils.build_SumPool2d(mod)' expects 'mod = nn.AvgPool2d' but got 'mod = {pooling_module}'." + ) + + +class UnsupportedLayerType(Exception): + pass + + +class InvalidModel(Exception): + model: Type + + def __init__( + self, + model, + ): + super().__init__( + f"'model' accepts either a DynapcnnNetwork or a DynapcnnNetworkGraph but {model} was given." + ) + + +class InvalidTorchModel(Exception): + network_type: str + + def __init__(self, network_type): + super().__init__(f"A {network_type} needs to be of type nn.Module.") + + +class InvalidGraphStructure(Exception): + pass + + +class InvalidModelWithDVSSetup(Exception): + def __init__(self): + super().__init__( + "The network provided has a DVSLayer instance but argument 'dvs_input' is set to False." + ) + + +# Edge exceptions. + + +class InvalidEdge(Exception): + edge: Tuple[int, int] + source: Type + target: Type + + def __init__(self, edge, source, target): + super().__init__( + f"Invalid edge {edge}: {source} can not target {target}. " + + default_invalid_structure_string + ) + + +class UnknownNode(Exception): + node: int + + def __init__(self, node): + super().__init__( + f"Node {node} can not be found within any DynapcnnLayer mapper." + ) + + +class MaxDestinationsReached(Exception): + dynapcnnlayer_index: int + + def __init__(self, dynapcnnlayer_index): + super().__init__( + f"DynapcnnLayer with index {dynapcnnlayer_index} has more than 2 destinations." + ) + + +class InvalidLayerLoop(Exception): + dynapcnnlayerA_index: int + dynapcnnlayerB_index: int + + def __init__(self, dynapcnnlayerA_index, dynapcnnlayerB_index): + super().__init__( + f"DynapcnnLayer {dynapcnnlayerA_index} can not connect to {dynapcnnlayerB_index} since reverse edge already exists." + ) + + +class InvalidLayerDestination(Exception): + dynapcnnlayerA: Type + dynapcnnlayerB: Type + + def __init__(self, dynapcnnlayerA, dynapcnnlayerB): + super().__init__( + f"DynapcnnLayer {dynapcnnlayerA} in one core can not connect to {dynapcnnlayerB} in another core." + ) diff --git a/sinabs/backend/dynapcnn/io.py b/sinabs/backend/dynapcnn/io.py index 7474b67d..eaeed8c2 100644 --- a/sinabs/backend/dynapcnn/io.py +++ b/sinabs/backend/dynapcnn/io.py @@ -252,7 +252,13 @@ def open_device(device_id: str): """ device_id = standardize_device_id(device_id=device_id) device_map = get_device_map() - device_info = device_map[device_id] + try: + device_info = device_map[device_id] + except KeyError: + msg = f"Device {device_id} has not been found. Make sure it is connected." + if device_map: + msg += "The following devices are available:\n" + "\n".join(device_map) + raise IOError(msg) device_handle = samna.device.open_device(device_info) if device_handle is not None: diff --git a/sinabs/backend/dynapcnn/mapping.py b/sinabs/backend/dynapcnn/mapping.py index 5fefaa0b..d80475d9 100644 --- a/sinabs/backend/dynapcnn/mapping.py +++ b/sinabs/backend/dynapcnn/mapping.py @@ -1,10 +1,13 @@ from collections import deque from copy import deepcopy from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union + +import sinabs from .dvs_layer import DVSLayer from .dynapcnn_layer import DynapcnnLayer +from .exceptions import InvalidModel @dataclass @@ -44,50 +47,75 @@ def find_chip_layers( def get_valid_mapping( - model: "DynapcnnNetwork", constraints: List[LayerConstraints] -) -> List[Tuple[int, int]]: + layers: Dict[int, DynapcnnLayer], constraints: List[LayerConstraints] +) -> Dict[int, int]: """Given a model, find a valid layer ordering for its placement within the constraints provided. Parameters ---------- - model: - DynapcnnNetwork - constraints: - A list of all the layer's constraints + - model: an instance of a DynapcnnNetwork or a DynapcnnNetworkGraph. + - constraints: a list of all the layer's constraints. Returns + - Dict mapping from layer index (key) to assigned core ID (value) ------- """ + # Store layer indices and lists of possible target chips in separate lists + layer_indices = [] layer_mapping = [] - - for layer in model.sequence: - if isinstance(layer, DynapcnnLayer): - layer_mapping.append(find_chip_layers(layer, constraints)) + for layer_index, this_layer in layers.items(): + # Skip DVSLayers + if isinstance(this_layer, DynapcnnLayer): + chip_layers = find_chip_layers(this_layer, constraints) + layer_mapping.append(chip_layers) + layer_indices.append(layer_index) + # Make sure only DynapcnnLayers and DVSLayers are passed + elif not isinstance(this_layer, DVSLayer): + raise ValueError(f"Found unexpected layer type: `{type(this_layer)}") graph = make_flow_graph(layer_mapping, len(constraints)) - # Call mapping + # use Edmonds' Algorithm to find suitable cores for each DynapcnnLayer. new_graph = edmonds(graph, 0, len(graph) - 1) + netmap = recover_mapping(new_graph, len(layer_mapping)) - netmap = recover_mapping(new_graph, layer_mapping) - return netmap + # Convert `netmap` to dict mapping from layer index to core ID + return {layer_idx: core_id for layer_idx, core_id in zip(layer_indices, netmap)} @dataclass -class Edge: +class FlowGraphEdge: s: int t: int cap: int flow: int = 0 - rev: Optional["Edge"] = None + rev: Optional["FlowGraphEdge"] = None def __repr__(self): - return f"Edge from {self.s} to {self.t} with capacity {self.cap} and flow {self.flow}" + return f"FlowGraphEdge from {self.s} to {self.t} with capacity {self.cap} and flow {self.flow}" + + +def edmonds( + graph: List[List[FlowGraphEdge]], source: int, sink: int, verbose: bool = False +) -> List[List[FlowGraphEdge]]: + """Use Edmonds' Algorithm to compute flow of flow graph + Makes a copy of the graph. The original graph is not changed in place. -# graph is list of list of edges. Each edge is -def edmonds(graph, source, sink, verbose: bool = False): + Parameters + ---------- + - graph List[List[FlowGraphEdge]]): Flow graph representation. Each list entry + corresponds to a node and consists of a list holding the outgoing edges + from this node. + - source (int): Index of source node within graph + - sind (int): Index of sink node within graph + - verbose (bool): Print detailed flow information if `True` + + Returns + ------- + List[List[FlowGraphEdge]]: New flow graph with calculated flow + """ graph = deepcopy(graph) flow = 0 while True: @@ -122,31 +150,32 @@ def edmonds(graph, source, sink, verbose: bool = False): def make_flow_graph( layer_mapping: List[List[int]], num_layers: int = 9 -) -> List[List[Edge]]: - """Make a flow graph given all possible chip layers for each DynapcnnCompatibleLayer layer. - Note that the flows are not computed yet. The flow for the graph generated here needs to be - populated by calling the method `edmonds` +) -> List[List[FlowGraphEdge]]: + """Make a flow graph given all possible chip cores for each software layer. + + Note that the flows are not computed yet. The flow for the graph generated here + needs to be populated by calling the method `edmonds` Parameters ---------- - layer_mapping: - List of a list of all layer indices. Eg. [[1,3], [4, 6, 1]] for a two layer model - num_layers: - Number of layers on the chip + - layer_mapping: List of a list of matching chip core indices for each software layer. + Eg. [[1,3], [4, 6, 1]] for a two layer model + - num_layers (int): Number of layers on the chip Returns ------- - graph: List[List[Edge]] + List[List[FlowGraphEdge]]: Flow graph representation. Each list entry corresponds + to a node and consists of a list holding the outgoing edges from this node. """ graph = [] # add all our nodes # one source node graph.append([]) # one node for every layer that will be mapped - for x in range(len(layer_mapping)): + for __ in range(len(layer_mapping)): graph.append([]) # one node for every chip layer - for x in range(num_layers): + for __ in range(num_layers): graph.append([]) # one sink node graph.append([]) @@ -154,41 +183,60 @@ def make_flow_graph( target_offset = len(layer_mapping) + 1 # first from source to all layers for i in range(len(layer_mapping)): - graph[0].append(Edge(s=0, t=i + 1, cap=1, flow=0)) - # add the reverse edge - graph[i + 1].append(Edge(s=i + 1, t=0, cap=0, flow=0)) + source_to_layer = FlowGraphEdge(s=0, t=i + 1, cap=1, flow=0) + layer_to_source = FlowGraphEdge(s=i + 1, t=0, cap=0, flow=0) # fill in reverse pointers - graph[0][-1].rev = graph[i + 1][-1] - graph[i + 1][-1].rev = graph[0][-1] + source_to_layer.rev = layer_to_source + layer_to_source.rev = source_to_layer + # append new edges + graph[0].append(source_to_layer) + graph[i + 1].append(layer_to_source) # then from layers to chip layers for i, layer_targets in enumerate(layer_mapping): for target in layer_targets: - graph[i + 1].append(Edge(s=i + 1, t=target + target_offset, cap=1, flow=0)) - graph[target + target_offset].append( - Edge(s=target + target_offset, t=i + 1, cap=0, flow=0) + layer_to_chip = FlowGraphEdge( + s=i + 1, t=target + target_offset, cap=1, flow=0 + ) + chip_to_layer = FlowGraphEdge( + s=target + target_offset, t=i + 1, cap=0, flow=0 ) - graph[i + 1][-1].rev = graph[target + target_offset][-1] - graph[target + target_offset][-1].rev = graph[i + 1][-1] - # print(graph) + layer_to_chip.rev = chip_to_layer + chip_to_layer.rev = layer_to_chip + graph[i + 1].append(layer_to_chip) + graph[target + target_offset].append(chip_to_layer) # then from chip layers to sink - for i, layer in enumerate(graph[target_offset:-1]): - sink = len(graph) - 1 - source = i + target_offset - graph[source].append(Edge(s=source, t=sink, cap=1, flow=0)) - graph[sink].append(Edge(s=sink, t=source, cap=0, flow=0)) - graph[source][-1].rev = graph[sink][-1] + sink = len(graph) - 1 + for chip_node in range(target_offset, sink): + graph[chip_node].append(FlowGraphEdge(s=chip_node, t=sink, cap=1, flow=0)) + graph[sink].append(FlowGraphEdge(s=sink, t=chip_node, cap=0, flow=0)) + graph[chip_node][-1].rev = graph[sink][-1] graph[sink][-1].rev = graph[sink][-1] return graph -def recover_mapping(graph, layer_mapping) -> List[Tuple[int, int]]: +def recover_mapping(graph: List[List[FlowGraphEdge]], num_layers: int) -> List[int]: + """Based on the flow graph retrieve a layer-to-core mapping + + Parameters + ---------- + - graph List[List[FlowGraphEdge]]): Flow graph representation with flow calculated. + Each list entry corresponds to a node and consists of a list holding the + outgoing edges from this node. + - num_layers (int): Number of software layers + + Returns + ------- + List[int]: Assigned core IDs for each layer in order. + """ mapping = [] - for i, layer in enumerate(layer_mapping): - for edge in graph[i + 1]: + for i in range(1, num_layers + 1): # `+1` to skip source node + for edge in graph[i]: if edge.flow == 1: - mapping.append((i, edge.t - len(layer_mapping) - 1)) - if len(mapping) != len(layer_mapping): - raise ValueError("No valid mapping found") + mapping.append(edge.t - num_layers - 1) + if len(mapping) != num_layers: + raise ValueError( + "One or more of the DynapcnnLayers could not be mapped to any core." + ) return mapping diff --git a/sinabs/backend/dynapcnn/nir_graph_extractor.py b/sinabs/backend/dynapcnn/nir_graph_extractor.py new file mode 100644 index 00000000..3f1cf970 --- /dev/null +++ b/sinabs/backend/dynapcnn/nir_graph_extractor.py @@ -0,0 +1,861 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + +from copy import deepcopy +from pprint import pformat +from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union + +import nirtorch +import torch +import torch.nn as nn + +from sinabs import layers as sl +from sinabs.utils import get_new_index + +from .connectivity_specs import ( + LAYER_TYPES_WITH_MULTIPLE_INPUTS, + LAYER_TYPES_WITH_MULTIPLE_OUTPUTS, + SupportedNodeTypes, +) +from .dvs_layer import DVSLayer +from .dynapcnn_layer_utils import construct_dynapcnnlayers_from_mapper +from .dynapcnnnetwork_module import DynapcnnNetworkModule +from .exceptions import InvalidGraphStructure, UnsupportedLayerType +from .sinabs_edges_handler import ( + collect_dynapcnn_layer_info, + fix_dvs_module_edges, + handle_batchnorm_nodes, +) +from .utils import Edge, topological_sorting + +try: + from nirtorch.graph import TorchGraph +except ImportError: + # In older nirtorch versions TorchGraph is called Graph + from nirtorch.graph import Graph as TorchGraph + + +class GraphExtractor: + def __init__( + self, + spiking_model: nn.Module, + dummy_input: torch.tensor, + dvs_input: Optional[bool] = None, + ignore_node_types: Optional[Iterable[Type]] = None, + ): + """Class implementing the extraction of the computational graph from `spiking_model`, where + each node represents a layer in the model and the list of edges represents how the data flow between + the layers. + + Parameters + ---------- + - spiking_model (nn.Module): a sinabs-compatible spiking network. + - dummy_input (torch.tensor): a random input sample to be fed through the model to acquire both + the computational graph (via `nirtorch`) and the I/O shapes of each node. Its a 4-D shape + with `(batch, channels, heigh, width)`. + + Attributes + ---------- + - edges (set of 2-tuples of integers): + Tuples describing the connections between layers in `spiking_model`. + Each layer (node) is identified by a unique integer ID. + - name_2_index_map (dict): + Keys are original variable names of layers in `spiking_model`. + Values are unique integer IDs. + - entry_nodes (set of ints): + IDs of nodes acting as entry points for the network, i.e. receiving external input. + - indx_2_module_map (dict): + Map from layer ID to the corresponding nn.Module instance. + - nodes_io_shapes (dict): + Map from node ID to dict containing node's in- and output shapes + - dvs_input (bool): optional (default as `None`). Whether or not the model + should start with a `DVSLayer`. + - ignore_node_types (iterable of types): Node types that should be + ignored completely from the graph. This can include, for instance, + `nn.Dropout2d`, which otherwise can result in wrongly inferred + graph structures by NIRTorch. Types such as `nn.Flatten`, or sinabs + `Merge` should not be included here, as they are needed to properly + handle graph structure and metadata. They can be removed after + instantiation with `remove_nodes_by_class`. + """ + + # Store state before it is changed due to NIRTorch and + # `self._get_nodes_io_shapes` passing dummy input + original_state = { + n: b.detach().clone() for n, b in spiking_model.named_buffers() + } + + # Empty sequentials will cause nirtorch to fail. Treat this case separately + if isinstance(spiking_model, nn.Sequential) and len(spiking_model) == 0: + self._name_2_indx_map = dict() + self._edges = set() + original_state = {} + else: + + # extract computational graph. + nir_graph = nirtorch.extract_torch_graph( + spiking_model, dummy_input, model_name=None + ).ignore_tensors() + if ignore_node_types is not None: + for node_type in ignore_node_types: + nir_graph = nir_graph.ignore_nodes(node_type) + + # Map node names to indices + self._name_2_indx_map = self._get_name_2_indx_map(nir_graph) + + # Extract edges list from graph + self._edges = self._get_edges_from_nir(nir_graph, self._name_2_indx_map) + + # Store the associated `nn.Module` (layer) of each node. + self._indx_2_module_map = self._get_named_modules(spiking_model) + + # Merges BatchNorm2d/BatchNorm1d nodes with Conv2d/Linear ones. + handle_batchnorm_nodes( + self._edges, self._indx_2_module_map, self._name_2_indx_map + ) + + # Determine entry points to graph + self._entry_nodes = self._get_entry_nodes(self._edges) + + # Make sure DVS input is properly integrated into graph + self._handle_dvs_input(input_shape=dummy_input.shape[1:], dvs_input=dvs_input) + + # retrieves what the I/O shape for each node's module is. + self._nodes_io_shapes = self._get_nodes_io_shapes(dummy_input) + + # Restore original state - after forward passes from nirtorch and `_get_nodes_io_shapes` + for n, b in spiking_model.named_buffers(): + b.set_(original_state[n].clone()) + + # Verify that graph is compatible + self.verify_graph_integrity() + + ####################################################### Publich Methods ####################################################### + + @property + def dvs_layer(self) -> Union[DVSLayer, None]: + idx = self.dvs_node_id + if idx is None: + return None + else: + return self.indx_2_module_map[self.dvs_node_id] + + @property + def dvs_node_id(self) -> Union[int, None]: + return self._get_dvs_node_id() + + @property + def entry_nodes(self) -> Set[int]: + return {n for n in self._entry_nodes} + + @property + def edges(self) -> Set[Edge]: + return {(src, tgt) for src, tgt in self._edges} + + @property + def has_dvs_layer(self) -> bool: + return self.dvs_layer is not None + + @property + def name_2_indx_map(self) -> Dict[str, int]: + return {name: idx for name, idx in self._name_2_indx_map.items()} + + @property + def nodes_io_shapes(self) -> Dict[int, Tuple[torch.Size]]: + return {n: size for n, size in self._nodes_io_shapes.items()} + + @property + def sorted_nodes(self) -> List[int]: + return [n for n in self._sort_graph_nodes()] + + @property + def indx_2_module_map(self) -> Dict[int, nn.Module]: + return {n: module for n, module in self._indx_2_module_map.items()} + + def get_dynapcnn_network_module( + self, discretize: bool = False, weight_rescaling_fn: Optional[Callable] = None + ) -> DynapcnnNetworkModule: + """Create DynapcnnNetworkModule based on stored graph representation + + This includes construction of the DynapcnnLayer instances + + Parameters: + ----------- + - discretize (bool): If `True`, discretize the parameters and thresholds. This is needed for uploading + weights to dynapcnn. Set to `False` only for testing purposes. + - weight_rescaling_fn (callable): a method that handles how the re-scaling factor for one or more `SumPool2d` projecting to + the same convolutional layer are combined/re-scaled before applying them. + + Returns + ------- + - The DynapcnnNetworkModule based on graph representation of this `GraphExtractor` + + """ + # Make sure all nodes are supported and there are no isolated nodes. + self.verify_node_types() + self.verify_no_isolated_nodes() + + # create a dict holding the data necessary to instantiate a `DynapcnnLayer`. + self.dcnnl_map, self.dvs_layer_info = collect_dynapcnn_layer_info( + indx_2_module_map=self.indx_2_module_map, + edges=self.edges, + nodes_io_shapes=self.nodes_io_shapes, + entry_nodes=self.entry_nodes, + ) + + # Special case where there is a disconnected `DVSLayer`: There are no + # Edges for the edges handler to process. Instantiate layer info manually. + if self.dvs_layer_info is None and self.dvs_layer is not None: + self.dvs_layer_info = { + "node_id": self.dvs_node_id, + "input_shape": self.nodes_io_shapes[self.dvs_node_id]["input"], + "module": self.dvs_layer, + "pooling": None, + "destinations": None, + } + + # build `DynapcnnLayer` instances from mapper. + dynapcnn_layers, destination_map, entry_points = ( + construct_dynapcnnlayers_from_mapper( + dcnnl_map=self.dcnnl_map, + dvs_layer_info=self.dvs_layer_info, + discretize=discretize, + rescale_fn=weight_rescaling_fn, + ) + ) + + # Instantiate the DynapcnnNetworkModule + return DynapcnnNetworkModule( + dynapcnn_layers, destination_map, entry_points, self.dvs_layer_info + ) + + def remove_nodes_by_class(self, node_classes: Tuple[Type]): + """Remove nodes of given classes from graph in place. + + Create a new set of edges, considering layers that `DynapcnnNetwork` will ignore. This + is done by setting the source (target) node of an edge where the source (target) node + will be dropped as the node that originally targeted this node to be dropped. + + Will change internal attributes `self._edges`, `self._entry_nodes`, + `self._name_2_indx_map`, and `self._nodes_io_shapes` to reflect the changes. + + Parameters + ---------- + - node_classes (tuple of types): + Layer classes that should be removed from the graph. + + """ + # Compose new graph by creating a dict with all remaining node IDs as keys and set of target node IDs as values + source2target: Dict[int, Set[int]] = {} + for node in self.sorted_nodes: + if isinstance((mod := self.indx_2_module_map[node]), node_classes): + # If an entry node is removed, its targets become entry nodes + if node in self.entry_nodes: + targets = self._find_valid_targets(node, node_classes) + self._entry_nodes.update(targets) + + # Update input shapes of nodes after `Flatten` to the shape before flattening + # Note: This is likely to produce incorrect results if multiple Flatten layers + # come in sequence. + if isinstance(mod, nn.Flatten): + shape_before_flatten = self.nodes_io_shapes[node]["input"] + for target_node in self._find_valid_targets(node, node_classes): + self._nodes_io_shapes[target_node][ + "input" + ] = shape_before_flatten + + else: + source2target[node] = self._find_valid_targets(node, node_classes) + + # remapping nodes indices contiguously starting from 0 + remapped_nodes = { + old_idx: new_idx + for new_idx, old_idx in enumerate(sorted(source2target.keys())) + } + + # Parse new set of edges based on remapped node IDs + self._edges = { + (remapped_nodes[src], remapped_nodes[tgt]) + for src, targets in source2target.items() + for tgt in targets + } + + # Update internal graph representation according to changes + self._update_internal_representation(remapped_nodes) + + def get_node_io_shapes(self, node: int) -> Tuple[torch.Size, torch.Size]: + """Returns the I/O tensors' shapes of `node`. + + Returns + ---------- + - input shape (torch.Size): shape of the input tensor to `node`. + - output shape (torch.Size): shape of the output tensor from `node`. + """ + return ( + self._nodes_io_shapes[node]["input"], + self._nodes_io_shapes[node]["output"], + ) + + def verify_graph_integrity(self): + """Apply checks to verify that graph is supported + + Check that: + - Only nodes of specific classes have multiple sources or targets. + + Raises + ------ + - InvalidGraphStructure: If any verification fails + """ + + for node, module in self.indx_2_module_map.items(): + # Make sure there are no individual, unconnected nodes + edges_with_node = {e for e in self.edges if node in e} + if not edges_with_node and not isinstance(module, DVSLayer): + raise InvalidGraphStructure( + f"There is an isolated module of type {type(module)}. Only " + "`DVSLayer` instances can be completely disconnected from " + "any other module. Other than that, layers for DynapCNN " + "consist of groups of weight layers (`Linear` or `Conv2d`), " + "spiking layers (`IAF` or `IAFSqueeze`), and optioanlly " + "pooling layers (`SumPool2d`, `AvgPool2d`)." + ) + # Ensure only certain module types have multiple inputs + if not isinstance(module, LAYER_TYPES_WITH_MULTIPLE_INPUTS): + sources = self._find_all_sources_of_input_to(node) + if len(sources) > 1: + raise InvalidGraphStructure( + f"Only nodes of type {LAYER_TYPES_WITH_MULTIPLE_INPUTS} " + f"can have more than one input. Node {node} is of type " + f"{type(module)} and has {len(sources)} inputs." + ) + # Ensure only certain module types have multiple targets + if not isinstance(module, LAYER_TYPES_WITH_MULTIPLE_OUTPUTS): + targets = self._find_valid_targets(node) + if len(targets) > 1: + raise InvalidGraphStructure( + f"Only nodes of type {LAYER_TYPES_WITH_MULTIPLE_OUTPUTS} " + f"can have more than one output. Node {node} is of type " + f"{type(module)} and has {len(targets)} outputs." + ) + + def verify_node_types(self): + """Verify that all nodes are of a supported type. + + Raises + ------ + - UnsupportedLayerType: If any verification fails + """ + unsupported_nodes = dict() + for index, module in self.indx_2_module_map.items(): + if not isinstance(module, SupportedNodeTypes): + node_type = type(module) + if node_type in unsupported_nodes: + unsupported_nodes[node_type].add(index) + else: + unsupported_nodes[node_type] = {index} + # Specific error message for non-squeezing IAF layer + iaf_layers = [] + for idx in unsupported_nodes.pop(sl.IAF, []): + iaf_layers.append(self.indx_2_module_map[idx]) + if iaf_layers: + layer_str = ", ".join(str(lyr) for lyr in (iaf_layers)) + raise UnsupportedLayerType( + f"The provided SNN contains IAF layers:\n{layer_str}.\n" + "For compatibility with torch's `nn.Conv2d` modules, please " + "use `IAFSqueeze` layers instead." + ) + # Specific error message for leaky neuron types + lif_layers = [] + for lif_type in (sl.LIF, sl.LIFSqueeze): + for idx in unsupported_nodes.pop(lif_type, []): + lif_layers.append(self.indx_2_module_map[idx]) + if lif_layers: + layer_str = ", ".join(str(lyr) for lyr in (lif_layers)) + raise UnsupportedLayerType( + f"The provided SNN contains LIF layers:\n{layer_str}.\n" + "Leaky integrate-and-fire dynamics are not supported by " + "DynapCNN. Use non-leaky `IAF` or `IAFSqueeze` layers " + "instead." + ) + # Specific error message for most common non-spiking activation layers + activation_layers = [] + for activation_type in (nn.ReLU, nn.Sigmoid, nn.Tanh, sl.NeuromorphicReLU): + for idx in unsupported_nodes.pop(activation_type, []): + activation_layers.append(self.indx_2_module_map[idx]) + if activation_layers: + layer_str = ", ".join(str(lyr) for lyr in (activation_layers)) + raise UnsupportedLayerType( + "The provided SNN contains non-spiking activation layers:\n" + f"{layer_str}.\nPlease convert them to `IAF` or `IAFSqueeze` " + "layers before instantiating a `DynapcnnNetwork`. You can " + "use the function `sinabs.from_model.from_torch` for this." + ) + if unsupported_nodes: + # More generic error message for all remaining types + raise UnsupportedLayerType( + "One or more layers in the provided SNN are not supported: " + f"{pformat(unsupported_nodes)}. Supported layer types are: " + f"{pformat(SupportedNodeTypes)}." + ) + + def verify_no_isolated_nodes(self): + """Verify that there are no disconnected nodes except for `DVSLayer` instances. + + Raises + ------ + - InvalidGraphStructure when disconnected nodes are detected + """ + for node, module in self.indx_2_module_map.items(): + # Make sure there are no individual, unconnected nodes + edges_with_node = {e for e in self.edges if node in e} + if not edges_with_node and not isinstance(module, DVSLayer): + raise InvalidGraphStructure( + f"There is an isolated module of type {type(module)}. Only " + "`DVSLayer` instances can be completely disconnected from " + "any other module. Other than that, layers for DynapCNN " + "consist of groups of weight layers (`Linear` or `Conv2d`), " + "spiking layers (`IAF` or `IAFSqueeze`), and optioanlly " + "pooling layers (`SumPool2d`, `AvgPool2d`)." + ) + + ####################################################### Pivate Methods ####################################################### + + def _handle_dvs_input( + self, input_shape: Tuple[int, int, int], dvs_input: Optional[bool] = None + ): + """Make sure DVS input is properly integrated into graph + + - Decide whether `DVSLayer` instance needs to be added to the graph + This is the case when `dvs_input==True` and there is no `DVSLayer` yet. + - Make sure edges between DVS related nodes are set properly + - Absorb pooling layers in DVS node if applicable + + Parameters + ---------- + - input_shape (tuple of three integers): Input shape (features, height, width) + - dvs_input (bool or `None` (default)): If `False`, will raise + `InvalidModelWithDvsSetup` if a `DVSLayer` is part of the graph. If `True`, + a `DVSLayer` will be added to the graph if there is none already. If `None`, + the model is considered to be using DVS input only if the graph contains + a `DVSLayer`. + """ + if self.has_dvs_layer: + # Make a copy of the layer so that the original version is not + # changed in place + new_dvs_layer = deepcopy(self.dvs_layer) + self._indx_2_module_map[self.dvs_node_id] = new_dvs_layer + elif dvs_input: + # Insert a DVSLayer node in the graph. + new_dvs_layer = self._add_dvs_node(dvs_input_shape=input_shape) + else: + dvs_input = None + if dvs_input is not None: + # Disable pixel array if `dvs_input` is False + new_dvs_layer.disable_pixel_array = not dvs_input + + # Check for the need of fixing NIR edges extraction when DVS is a node in the graph. If DVS + # is used its node becomes the only entry node in the graph. + fix_dvs_module_edges( + self._edges, + self._indx_2_module_map, + self._name_2_indx_map, + self._entry_nodes, + ) + + # Check if graph structure and DVSLayer.merge_polarities are correctly set (if DVS node exists). + self._validate_dvs_setup(dvs_input_shape=input_shape) + + def _add_dvs_node(self, dvs_input_shape: Tuple[int, int, int]) -> DVSLayer: + """In-place modification of `self._name_2_indx_map`, `self._indx_2_module_map`, and `self._edges` to accomodate the + creation of an extra node in the graph representing the DVS camera of the chip. The DVSLayer node will point to every + other node that is up to this point an entry node of the original graph, so `self._entry_nodes` is modified in-place + to have only one entry: the index of the DVS node. + + Parameters + ---------- + - dvs_input_shape (tuple): shape of the DVSLayer input in format `(features, height, width)` + + Returns + - DVSLayer: A handler to the newly added `DVSLayer` instance + """ + + (features, height, width) = dvs_input_shape + if features > 2: + raise ValueError( + f"A DVSLayer istance can have the feature dimension of its inputs with values 1 or 2 but {features} was given." + ) + + # Find new index to be assigned to DVS node + self._name_2_indx_map["dvs"] = get_new_index(self._name_2_indx_map.values()) + # add module entry for node 'dvs'. + dvs_layer = DVSLayer( + input_shape=(height, width), + merge_polarities=(features == 1), + ) + self._indx_2_module_map[self._name_2_indx_map["dvs"]] = dvs_layer + + # set DVS node as input to each entry node of the graph + self._edges.update( + { + (self._name_2_indx_map["dvs"], entry_node) + for entry_node in self._entry_nodes + } + ) + # DVSLayer node becomes the only entrypoint of the graph + self._entry_nodes = {self._name_2_indx_map["dvs"]} + + return dvs_layer + + def _get_dvs_node_id(self) -> Union[int, None]: + """Loop though all modules and return index of `DVSLayer` + instance if it exists. + + Returns + ------- + - DVSLayer if exactly one is found, otherwise None + + Raises + ------ + - InvalidGraphStructure if more than one DVSLayer is found + + """ + + dvs_layer_indices = { + index + for index, module in self._indx_2_module_map.items() + if isinstance(module, DVSLayer) + } + + if (num_dvs := len(dvs_layer_indices)) == 0: + return + elif num_dvs == 1: + return dvs_layer_indices.pop() + else: + raise InvalidGraphStructure( + f"The provided model has {num_dvs} `DVSLayer`s. At most one is allowed." + ) + + def _validate_dvs_setup(self, dvs_input_shape: Tuple[int, int, int]) -> None: + """If a DVSLayer node exists, makes sure it is the only entry node of the graph. Checks if its `merge_polarities` + attribute matches `dummy_input.shape[0]` (the number of features) and, if not, it will be set based on the numeber of + features of the input. + + Parameters + ---------- + - dvs_input_shape (tuple): shape of the DVSLayer input in format `(features, height, width)`. + """ + + if self.dvs_layer is None: + # No DVSLayer found - nothing to do here. + return + + if (nb_entries := len(self._entry_nodes)) > 1: + raise ValueError( + f"A DVSLayer node exists and there are {nb_entries} entry nodes in the graph: the DVSLayer should be the only entry node." + ) + + (features, _, _) = dvs_input_shape + + if features > 2: + raise ValueError( + f"A DVSLayer istance can have the feature dimension of its inputs with values 1 or 2 but {features} was given." + ) + + if self.dvs_layer.merge_polarities and features != 1: + raise ValueError( + f"The 'DVSLayer.merge_polarities' is set to 'True' which means the number of input features should be 1 (current input shape is {dvs_input_shape})." + ) + + if features == 1: + self.dvs_layer.merge_polarities = True + + def _get_name_2_indx_map(self, nir_graph: TorchGraph) -> Dict[str, int]: + """Assign unique index to each node and return mapper from name to index. + + Parameters + ---------- + - nir_graph (TorchGraph): a NIR graph representation of `spiking_model`. + + Returns + ---------- + - name_2_indx_map (dict): `key` is the original variable name for a layer in + `spiking_model` and `value is an integer representing the layer in a standard format. + """ + + return { + node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list) + } + + def _get_edges_from_nir( + self, nir_graph: TorchGraph, name_2_indx_map: Dict[str, int] + ) -> Set[Edge]: + """Standardize the representation of TorchGraph` into a list of edges, + representing nodes by their indices. + + Parameters + ---------- + - nir_graph (TorchGraph): a NIR graph representation of `spiking_model`. + - name_2_indx_map (dict): Map from node names to unique indices. + + Returns + ---------- + - edges (set): tuples describing the connections between layers in `spiking_model`. + - name_2_indx_map (dict): `key` is the original variable name for a layer in `spiking_model` and `value is an integer representing the layer in a standard format. + - entry_nodes (set): IDs of nodes acting as entry points for the network (i.e., receiving external input). + """ + return { + (name_2_indx_map[src.name], name_2_indx_map[tgt.name]) + for src in nir_graph.node_list + for tgt in src.outgoing_nodes + } + + def _get_entry_nodes(self, edges: Set[Edge]) -> Set[Edge]: + """Find nodes that act as entry points to the graph + + Parameters + ---------- + - edges (set): tuples describing the connections between layers in `spiking_model`. + + Returns + ---------- + - entry_nodes (set): IDs of nodes acting as entry points for the network + (i.e., receiving external input). + """ + if not edges: + return set() + + all_sources, all_targets = zip(*edges) + return set(all_sources) - set(all_targets) + + def _get_named_modules(self, model: nn.Module) -> Dict[int, nn.Module]: + """Find for each node in the graph what its associated layer in `model` is. + + Parameters + ---------- + - model (nn.Module): the `spiking_model` used as argument to the class instance. + + Returns + ---------- + - indx_2_module_map (dict): the mapping between a node (`key` as an `int`) and its module (`value` as a `nn.Module`). + """ + + indx_2_module_map = dict() + + for name, module in model.named_modules(): + # Make sure names match those provided by nirtorch nodes + if name in self._name_2_indx_map: + indx_2_module_map[self._name_2_indx_map[name]] = module + else: + # In older nirtorch versions, node names are "sanitized" + # Try with sanitized version of the name + name = nirtorch.utils.sanitize_name(name) + if name in self._name_2_indx_map: + indx_2_module_map[self._name_2_indx_map[name]] = module + + return indx_2_module_map + + def _update_internal_representation(self, remapped_nodes: Dict[int, int]): + """Update internal attributes after remapping of nodes + + Parameters + ---------- + remapped_nodes (dict): Maps previous (key) to new (value) node + indices. Nodes that were removed are not included. + """ + + # Update name-to-index map based on new node indices + self._name_2_indx_map = { + name: remapped_nodes[old_idx] + for name, old_idx in self._name_2_indx_map.items() + if old_idx in remapped_nodes + } + + # Update entry nodes based on new node indices + self._entry_nodes = { + remapped_nodes[old_idx] + for old_idx in self._entry_nodes + if old_idx in remapped_nodes + } + + # Update io-shapes based on new node indices + self._nodes_io_shapes = { + remapped_nodes[old_idx]: shape + for old_idx, shape in self._nodes_io_shapes.items() + if old_idx in remapped_nodes + } + + # Update sinabs module map based on new node indices + self._indx_2_module_map = { + remapped_nodes[old_idx]: module + for old_idx, module in self._indx_2_module_map.items() + if old_idx in remapped_nodes + } + + def _sort_graph_nodes(self) -> List[int]: + """Sort graph nodes topologically. + + Returns + ------- + - sorted_nodes (list of integers): IDs of nodes, sorted. + """ + # Make a temporary copy of edges and include inputs + temp_edges = self.edges + for node in self._entry_nodes: + temp_edges.add(("input", node)) + return topological_sorting(temp_edges) + + def _get_nodes_io_shapes( + self, input_dummy: torch.tensor + ) -> Dict[int, Dict[str, torch.Size]]: + """Iteratively calls the forward method of each `nn.Module` (i.e., a layer/node in the graph) using the topologically + sorted nodes extracted from the computational graph of the model being parsed. + + Parameters + ---------- + - input_dummy (torch.tensor): a sample (random) tensor of the sort of input being fed to the network. + + Returns + ---------- + - nodes_io_map (dict): a dictionary mapping nodes to their I/O shapes. + """ + nodes_io_map = {} + + # propagate inputs through the nodes. + for node in self.sorted_nodes: + + if isinstance(self.indx_2_module_map[node], sl.merge.Merge): + # find `Merge` arguments (at this point the inputs to Merge should have been calculated). + input_nodes = self._find_merge_arguments(node) + + # retrieve arguments output tensors. + inputs = [nodes_io_map[n]["output"] for n in input_nodes] + + # TODO - this is currently a limitation imposed by the validation checks done by Speck once a configuration: it wants + # different input sources to a core to have the same output shapes. + if any(inp.shape != inputs[0].shape for inp in inputs): + raise ValueError( + f"Layer `sinabs.layers.merge.Merge` (node {node}) requires input tensors with the same shape" + ) + + # forward input through the node. + _output = self.indx_2_module_map[node](*inputs) + + # save node's I/O tensors. + nodes_io_map[node] = {"input": inputs[0], "output": _output} + + else: + + if node in self._entry_nodes: + # forward input dummy through node. + _output = self.indx_2_module_map[node](input_dummy) + + # save node's I/O tensors. + nodes_io_map[node] = {"input": input_dummy, "output": _output} + + else: + # find node generating the input to be used. + input_node = self._find_source_of_input_to(node) + _input = nodes_io_map[input_node]["output"] + + # forward input through the node. + _output = self.indx_2_module_map[node](_input) + + # save node's I/O tensors. + nodes_io_map[node] = {"input": _input, "output": _output} + + # replace the I/O tensor information by its shape information, ignoring the batch/time axis + for node, io in nodes_io_map.items(): + input_shape = io["input"].shape[1:] + output_shape = io["output"].shape[1:] + # Linear layers have fewer in/out dimensions. Extend by appending 1's + if (length := len(input_shape)) < 3: + input_shape = (*input_shape, *(1 for __ in range(3 - length))) + assert len(input_shape) == 3 + if (length := len(output_shape)) < 3: + output_shape = (*output_shape, *(1 for __ in range(3 - length))) + assert len(output_shape) == 3 + nodes_io_map[node]["input"] = input_shape + nodes_io_map[node]["output"] = output_shape + + return nodes_io_map + + def _find_all_sources_of_input_to(self, node: int) -> Set[int]: + """Finds all source nodes to `node`. + + Parameters + ---------- + - node (int): the node in the computational graph for which we whish to find the input source (either another node in the + graph or the original input itself to the network). + + Returns + ---------- + - input sources (set of int): IDs of the nodes in the computational graph providing the input to `node`. + """ + return set(src for (src, tgt) in self._edges if tgt == node) + + def _find_source_of_input_to(self, node: int) -> int: + """Finds the first edge `(X, node)` returns `X`. + + Parameters + ---------- + - node (int): the node in the computational graph for which we whish to find the input source (either another node in the + graph or the original input itself to the network). + + Returns + ---------- + - input source (int): ID of the node in the computational graph providing the input to `node`. If `node` is + receiving outside input (i.e., it is a starting node) the return will be -1. For example, this will be the case + when a network with two independent branches (each starts from a different "input node") merge along the computational graph. + """ + sources = self._find_all_sources_of_input_to(node) + if len(sources) == 0: + return -1 + if len(sources) > 1: + raise RuntimeError(f"Node {node} has more than 1 input") + return sources.pop() + + def _find_merge_arguments(self, node: int) -> Edge: + """A `Merge` layer receives two inputs. Return the two inputs to `merge_node` representing a `Merge` layer. + + Returns + ---------- + - args (tuple): the IDs of the nodes that provice the input arguments to a `Merge` layer. + """ + sources = self._find_all_sources_of_input_to(node) + + if len(sources) != 2: + raise ValueError( + f"Number of arguments found for `Merge` node {node} is {len(sources)} (should be 2)." + ) + + return tuple(sources) + + def _find_valid_targets( + self, node: int, ignored_node_classes: Tuple[Type] = () + ) -> Set[int]: + """Find all targets of a node that are not ignored classes + + Return a set of all target nodes that are not of an ignored class. + For target nodes of ignored classes, recursively return their valid + targets. + + Parameters + ---------- + - node (int): ID of node whose targets should be found + - ignored_node_classes (tuple of types): Classes of which nodes should be skiped + + Returns + ------- + - valid_targets (set of int): Set of all recursively found target IDs + """ + targets = set() + for src, tgt in self.edges: + # Search for all edges with node as source + if src == node: + if isinstance(self.indx_2_module_map[tgt], ignored_node_classes): + # Find valid targets of target + targets.update(self._find_valid_targets(tgt, ignored_node_classes)) + else: + # Target is valid, add it to `targets` + targets.add(tgt) + return targets diff --git a/sinabs/backend/dynapcnn/sinabs_edges_handler.py b/sinabs/backend/dynapcnn/sinabs_edges_handler.py new file mode 100644 index 00000000..6e5a3346 --- /dev/null +++ b/sinabs/backend/dynapcnn/sinabs_edges_handler.py @@ -0,0 +1,1008 @@ +""" +functionality : functions implementing the pre-processing of edges into blocks of nodes (modules) for future + creation of DynapcnnLayer objects. +author : Willian Soares Girao +contact : williansoaresgirao@gmail.com +""" + +from typing import Deque, Dict, List, Optional, Set, Tuple, Type, Union + +from torch import Size, nn + +from sinabs.layers import SumPool2d + +from .connectivity_specs import VALID_SINABS_EDGE_TYPES +from .crop2d import Crop2d +from .dvs_layer import DVSLayer +from .exceptions import ( + InvalidEdge, + InvalidGraphStructure, + default_invalid_structure_string, +) +from .flipdims import FlipDims +from .utils import Edge, merge_bn + + +def remap_edges_after_drop( + dropped_node: int, source_of_dropped_node: int, edges: Set[Edge] +) -> Set[Edge]: + """Creates a new set of edges from `edges`. All edges where `dropped_node` is the source node will be used to generate + a new edge where `source_of_dropped_node` becomes the source node (the target is kept). + + Parameters + ---------- + - dropped_node (int): + - source_of_dropped_node (int): + - edges (set): tuples describing the connections between layers in `spiking_model`. + + Returns + ------- + - remapped_edges (set): new set of edges with `source_of_dropped_node` as the source node where `dropped_node` used to be. + """ + remapped_edges = set() + + for src, tgt in edges: + if src == dropped_node: + remapped_edges.add((source_of_dropped_node, tgt)) + + return remapped_edges + + +def handle_batchnorm_nodes( + edges: Set[Edge], + indx_2_module_map: Dict[int, nn.Module], + name_2_indx_map: Dict[str, int], +) -> None: + """Merges `BatchNorm2d`/`BatchNorm1d` layers into `Conv2d`/`Linear` ones. The batch norm nodes will be removed from the graph (by updating all variables + passed as arguments in-place) after their properties are used to re-scale the weights of the convolutional/linear layers associated with batch + normalization via the `weight-batchnorm` edges found in the original graph. + + Parameters + ---------- + - edges (set): tuples describing the connections between layers in `spiking_model`. + - indx_2_module_map (dict): the mapping between a node (`key` as an `int`) and its module (`value` as a `nn.Module`). + - name_2_indx_map (dict): Map from node names to unique indices. + """ + + # Gather indexes of the BatchNorm2d/BatchNorm1d nodes. + bnorm_nodes = { + index + for index, module in indx_2_module_map.items() + if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)) + } + + if len(bnorm_nodes) == 0: + # There are no edges with batch norm - nothing to do here. + return + + # Find weight-bnorm edges. + weight_bnorm_edges = { + (src, tgt) + for (src, tgt) in edges + if ( + isinstance(indx_2_module_map[src], nn.Conv2d) + and isinstance(indx_2_module_map[tgt], nn.BatchNorm2d) + ) + or ( + isinstance(indx_2_module_map[src], nn.Linear) + and isinstance(indx_2_module_map[tgt], nn.BatchNorm1d) + ) + } + + # Merge conv/linear and bnorm layers using 'weight-bnorm' edges. + for edge in weight_bnorm_edges: + bnorm = indx_2_module_map[edge[1]] + weight = indx_2_module_map[edge[0]] + + # merge and update weight node. + indx_2_module_map[edge[0]] = merge_bn(weight, bnorm) + + # Point weight nodes to the targets of their respective batch norm nodes. + new_edges = set() + for weight_id, bnorm_id in weight_bnorm_edges: + new_edges.update( + remap_edges_after_drop( + dropped_node=bnorm_id, source_of_dropped_node=weight_id, edges=edges + ) + ) + # Remove all edges to and from a batch norm node and replace with new edges + bnorm_edges = {e for e in edges if bnorm_nodes.intersection(e)} + edges.difference_update(bnorm_edges) + edges.update(new_edges) + + # Remove references to the bnorm node. + for idx in bnorm_nodes: + indx_2_module_map.pop(idx) + + for name in [name for name, indx in name_2_indx_map.items() if indx in bnorm_nodes]: + name_2_indx_map.pop(name) + + +def fix_dvs_module_edges( + edges: Set[Edge], + indx_2_module_map: Dict[int, nn.Module], + name_2_indx_map: Dict[str, int], + entry_nodes: Set[Edge], +) -> None: + """All arguments are modified in-place to fix wrong node extractions from NIRtorch when a DVSLayer istance is the first layer in the network. + + Modifies `edges` to re-structure the edges related witht the DVSLayer instance. The DVSLayer's forward method feeds data in the + sequence 'DVS -> DVS.pool -> DVS.crop -> DVS.flip', so we remove edges involving these nodes (that are internaly implementend in + the DVSLayer) from the graph and point the node of DVSLayer to the node where it should send its output to. This is also removes + a self-recurrent node with edge '(FlipDims, FlipDims)' that is wrongly extracted. + + Modifies `indx_2_module_map` and `name_2_indx_map` to remove the internal DVSLayer nodes (Crop2d, FlipDims and DVSLayer's pooling) since + these should not be independent nodes in the graph. + + Modifies `entry_nodes` such that the DVSLayer becomes the only entry node of the graph. + + Parameters + ---------- + - edges (set): tuples describing the connections between layers in `spiking_model`. + - indx_2_module_map (dict): the mapping between a node (`key` as an `int`) and its module (`value` as a `nn.Module`). + - name_2_indx_map (dict): Map from node names to unique indices. + - entry_nodes (set): IDs of nodes acting as entry points for the network (i.e., receiving external input). + """ + # TODO - the 'fix_' is to imply there's something odd with the extracted adges for the forward pass implemented by + # the DVSLayer. For now this function is fixing these edges to have them representing the information flow through + # this layer as **it should be** but the graph tracing of NIR should be looked into to solve the root problem. + + # spot nodes (ie, modules) used in a DVSLayer instance's forward pass (including the DVSLayer node itself). + dvslayer_nodes = { + index: module + for index, module in indx_2_module_map.items() + if any( + isinstance(module, dvs_node) for dvs_node in (DVSLayer, Crop2d, FlipDims) + ) + } + + if len(dvslayer_nodes) <= 1: + # No module within the DVSLayer instance appears as an independent node - nothing to do here. + return + + # TODO - a `SumPool2d` is also a node that's used inside a DVSLayer instance. In what follows we try to find it + # by looking for pooling nodes that appear in a (pool, crop) edge - the assumption being that if the pooling is + # inputing into a crop layer than the pool is inside the DVSLayer instance. It feels like a hacky way to do it + # so we should revise this. + dvslayer_nodes.update( + { + edge[0]: indx_2_module_map[edge[0]] + for edge in edges + if isinstance(indx_2_module_map[edge[0]], SumPool2d) + and isinstance(indx_2_module_map[edge[1]], Crop2d) + } + ) + + # NIR is extracting an edge (FlipDims, FlipDims) from the DVSLayer: remove self-recurrent nodes from the graph. + for edge in [ + (src, tgt) + for (src, tgt) in edges + if (src == tgt and isinstance(indx_2_module_map[src], FlipDims)) + ]: + edges.remove(edge) + + # Since NIR is not extracting the edges for the DVSLayer correctly, remove all edges involving the DVS. + for edge in [ + (src, tgt) + for (src, tgt) in edges + if (src in dvslayer_nodes or tgt in dvslayer_nodes) + ]: + edges.remove(edge) + + # Get node's indexes based on the module type - just for validation. + dvs_node = [ + key for key, value in dvslayer_nodes.items() if isinstance(value, DVSLayer) + ] + dvs_pool_node = [ + key for key, value in dvslayer_nodes.items() if isinstance(value, SumPool2d) + ] + dvs_crop_node = [ + key for key, value in dvslayer_nodes.items() if isinstance(value, Crop2d) + ] + dvs_flip_node = [ + key for key, value in dvslayer_nodes.items() if isinstance(value, FlipDims) + ] + + if any( + len(node) > 1 + for node in [dvs_node, dvs_pool_node, dvs_crop_node, dvs_flip_node] + ): + raise ValueError( + f"Internal DVS nodes should be single instances but multiple have been found: dvs_node: {len(dvs_node)} dvs_pool_node: {len(dvs_pool_node)} dvs_crop_node: {len(dvs_crop_node)} dvs_flip_node: {len(dvs_flip_node)}" + ) + + # Remove dvs_pool, dvs_crop and dvs_flip nodes from `indx_2_module_map` (these operate within the DVS, not as independent nodes of the final graph). + indx_2_module_map.pop(dvs_pool_node[-1]) + indx_2_module_map.pop(dvs_crop_node[-1]) + indx_2_module_map.pop(dvs_flip_node[-1]) + + # Remove internal DVS modules from name/index map. + # Iterate over copy to prevent iterable from changing size. + n2i_map_copy = {k: v for k, v in name_2_indx_map.items()} + for name, index in n2i_map_copy.items(): + if index in [dvs_pool_node[-1], dvs_crop_node[-1], dvs_flip_node[-1]]: + name_2_indx_map.pop(name) + + dvs_node = dvs_node[0] + if edges: + # Add edges from 'dvs' node to the entry point of the graph. + all_sources, all_targets = zip(*edges) + local_entry_nodes = set(all_sources) - set(all_targets) + edges.update({(dvs_node, node) for node in local_entry_nodes}) + + # DVS becomes the only entry node of the graph. + entry_nodes.clear() + entry_nodes.add(dvs_node) + + +def collect_dynapcnn_layer_info( + indx_2_module_map: Dict[int, nn.Module], + edges: Set[Edge], + nodes_io_shapes: Dict[int, Dict[str, Tuple[Size, Size]]], + entry_nodes: Set[int], +) -> Tuple[Dict[int, Dict], Union[Dict, None]]: + """Collect information to construct DynapcnnLayer instances. + + Validate and sort edges based on the type of nodes they connect. + Iterate over edges in order of their type. For each neuron->weight edge + generate a new dict to collect information for the corresponding dynapcnn layer. + Then add pooling based on neuron->pooling type edges. Collect additional pooling + from pooling->pooling type edges. Finally set layer destinations based on + neuron/pooling->weight type of edges. + + Parameters + ---------- + - indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value` + - edges (set of tuples): Represent connections between two nodes in computational graph + - nodes_io_shapes (dict): Map from node ID to dict containing node's in- and output shapes + - entry_nodes (set of int): IDs of nodes that receive external input + + Returns + ------- + dynapcnn_layer_info (dict): Each 'key' is the index of a future 'DynapcnnLayer' and + 'value' is a dictionary, with keys 'conv', 'neuron', and 'destinations', + containing corresponding node ids and modules required to build the layer + dvs_layer_info (dict or None): If a DVSLayer is part of the network, this will + be a dict containing the layer itself and its destination indices. + """ + + # Sort edges by edge type (type of layers they connect) + edges_by_type: Dict[str, Set[Edge]] = sort_edges_by_type( + edges=edges, indx_2_module_map=indx_2_module_map + ) + edge_counts_by_type = {t: len(e) for t, e in edges_by_type.items()} + + # Dict to collect information for each future dynapcnn layer + dynapcnn_layer_info = dict() + # Map node IDs to dynapcnn layer ID + node_2_layer_map = dict() + + # Each weight->neuron connection instantiates a new, unique dynapcnn layer + weight_neuron_edges = edges_by_type.get("weight-neuron", set()) + while weight_neuron_edges: + edge = weight_neuron_edges.pop() + init_new_dynapcnnlayer_entry( + dynapcnn_layer_info, + edge, + indx_2_module_map, + nodes_io_shapes, + node_2_layer_map, + entry_nodes, + ) + + # Process all edges related to DVS layer + dvs_layer_info = dvs_setup( + edges_by_type, indx_2_module_map, node_2_layer_map, nodes_io_shapes + ) + + # Process all edges connecting two dynapcnn layers that do not include pooling + neuron_weight_edges = edges_by_type.get("neuron-weight", set()) + while neuron_weight_edges: + edge = neuron_weight_edges.pop() + set_neuron_layer_destination( + dynapcnn_layer_info, + edge, + node_2_layer_map, + nodes_io_shapes, + indx_2_module_map, + ) + + # Add pooling based on neuron->pooling connections + pooling_pooling_edges = edges_by_type.get("pooling-pooling", set()) + neuron_pooling_edges = edges_by_type.get("neuron-pooling", set()) + while neuron_pooling_edges: + edge = neuron_pooling_edges.pop() + # Search pooling-pooling edges for chains of pooling and add to existing entry + pooling_chains, edges_used = trace_paths(edge[1], pooling_pooling_edges) + add_pooling_to_entry( + dynapcnn_layer_info, + edge, + pooling_chains, + indx_2_module_map, + node_2_layer_map, + ) + # Remove handled pooling-pooling edges + pooling_pooling_edges.difference_update(edges_used) + + # After adding pooling make sure all pooling-pooling edges have been handled + if len(pooling_pooling_edges) > 0: + unmatched_layers = {edge[0] for edge in pooling_pooling_edges} + raise InvalidGraphStructure( + f"Pooling layers {unmatched_layers} could not be assigned to a " + "dynapcnn layer. This is likely due to an unsupported SNN " + "architecture. Pooling layers must always be preceded by a " + "spiking layer (`IAFSqueeze`), another pooling layer, or" + "DVS input" + ) + + # Add all edges connecting pooling to a new dynapcnn layer + pooling_weight_edges = edges_by_type.get("pooling-weight", set()) + while pooling_weight_edges: + edge = pooling_weight_edges.pop() + set_pooling_layer_destination( + dynapcnn_layer_info, + edge, + node_2_layer_map, + nodes_io_shapes, + indx_2_module_map, + ) + + # Make sure we have taken care of all edges + assert all(len(edges) == 0 for edges in edges_by_type.values()) + + # Set minimal destination entries for layers without child nodes, to act as network outputs + set_exit_destinations(dynapcnn_layer_info) + + # Assert formal correctness of layer info + verify_layer_info(dynapcnn_layer_info, edge_counts_by_type) + + return dynapcnn_layer_info, dvs_layer_info + + +def get_valid_edge_type( + edge: Edge, + layers: Dict[int, nn.Module], + valid_edge_ids: Dict[Tuple[Type, Type], int], +) -> int: + """Checks if the modules each node in 'edge' represent are a valid connection between a sinabs network to be + loaded on Speck and return the edge type + + Parameters + ---------- + edge (tuple of two int): The edge whose type is to be inferred + layers (Dict): Dict with node IDs as keys and layer instances as values + valid_edge_ids: Dict with valid edge-types (tuples of Types) as keys and edge-type-ID as value + + Returns + ---------- + edge_type: the edge type specified in 'valid_edges_map' ('None' if edge is not valid). + """ + source_type = type(layers[edge[0]]) + target_type = type(layers[edge[1]]) + + return valid_edge_ids.get((source_type, target_type), None) + + +def sort_edges_by_type( + edges: Set[Edge], indx_2_module_map: Dict[int, Type] +) -> Dict[str, Set[Edge]]: + """Sort edges by the type of nodes they connect + + Parameters + ---------- + edges (set of tuples): Represent connections between two nodes in computational graph + indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value` + + Returns + ------- + Dict with possible keys "weight-neuron", "neuron-weight", "neuron-pooling", "pooling-pooling", + and "pooling-weight". Values are sets of edges corresponding to these types. + """ + edges_by_type: Dict[str, Set[Edge]] = dict() + + for edge in edges: + edge_type = get_valid_edge_type( + edge, indx_2_module_map, VALID_SINABS_EDGE_TYPES + ) + + # Validate edge type + if edge_type is None: + raise InvalidEdge( + edge, type(indx_2_module_map[edge[0]]), type(indx_2_module_map[edge[1]]) + ) + + if edge_type in edges_by_type: + edges_by_type[edge_type].add(edge) + else: + edges_by_type[edge_type] = {edge} + + return edges_by_type + + +def init_new_dynapcnnlayer_entry( + dynapcnn_layer_info: Dict[int, Dict[int, Dict]], + edge: Edge, + indx_2_module_map: Dict[int, nn.Module], + nodes_io_shapes: Dict[int, Dict[str, Tuple[Size, Size]]], + node_2_layer_map: Dict[int, int], + entry_nodes: Set[int], +) -> None: + """Initiate dict to hold information for new dynapcnn layer based on a "weight->neuron" edge. + Change `dynapcnn_layer_info` in-place. + + Parameters + ---------- + dynapcnn_layer_info: Dict with one entry for each future dynapcnn layer. + key is unique dynapcnn layer ID, value is dict with nodes of the layer + Will be updated in-place. + edge: Tuple of 2 integers, indicating edge between two nodes in graph. + Edge source has to be within an existing entry of `dynapcnn_layer_info`. + indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value` + nodes_io_shapes (dict): Map from node ID to dict containing node's in- and output shapes + node_2_layer_map (dict): Maps each node ID to the ID of the layer it is assigned to. + Will be updated in-place. + entry_nodes (set of int): IDs of nodes that receive external input + """ + # Make sure there are no existing entries holding any of the modules connected by `edge` + assert edge[0] not in node_2_layer_map + assert edge[1] not in node_2_layer_map + + # Take current length of the dict as new, unique ID + layer_id = len(dynapcnn_layer_info) + assert layer_id not in dynapcnn_layer_info + + dynapcnn_layer_info[layer_id] = { + "input_shape": nodes_io_shapes[edge[0]]["input"], + "conv": { + "module": indx_2_module_map[edge[0]], + "node_id": edge[0], + }, + "neuron": { + "module": indx_2_module_map[edge[1]], + "node_id": edge[1], + }, + # This will be used later to account for average pooling in preceding layers + "rescale_factors": set(), + "is_entry_node": edge[0] in entry_nodes, + # Will be populated by `set_[pooling/neuron]_layer_destination` + "destinations": [], + } + node_2_layer_map[edge[0]] = layer_id + node_2_layer_map[edge[1]] = layer_id + + +def add_pooling_to_entry( + dynapcnn_layer_info: Dict[int, Dict], + edge: Edge, + pooling_chains: List[Deque[int]], + indx_2_module_map: Dict[int, nn.Module], + node_2_layer_map: Dict[int, int], +) -> None: + """Add or extend destination information with pooling for existing + entry in `dynapcnn_layer_info`. + + Correct entry is identified by existing neuron node. Destination information is a + dict containing list of IDs and list of modules for each chains of pooling nodes. + + Parameters + ---------- + dynapcnn_layer_info: Dict with one entry for each future dynapcnn layer. + key is unique dynapcnn layer ID, value is dict with nodes of the layer + Will be updated in-place. + edge: Tuple of 2 integers, indicating edge between a neuron node and the pooling + node that starts all provided `pooling_chains`. + Edge source has to be a neuron node within an existing entry of + `dynapcnn_layer_info`, i.e. it has to have been processed already. + pooling_chains: List of deque of int. All sequences ("chains") of connected pooling nodes, + starting from edge[1] + indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value` + node_2_layer_map (dict): Maps each node ID to the ID of the layer it is assigned to. + Will be updated in-place. + """ + # Find layer containing edge[0] + try: + layer_idx = node_2_layer_map[edge[0]] + except KeyError: + neuron_layer = indx_2_module_map[edge[0]] + raise InvalidGraphStructure( + f"Spiking layer {neuron_layer} cannot be assigned to a dynapcnn layer. " + "This is likely due to an unsupported SNN architecture. Spiking " + "layers have to be preceded by a weight layer (`nn.Conv2d` or " + "`nn.Linear`)." + ) + # Make sure all pooling chains start with expected node + assert all(chain[0] == edge[1] for chain in pooling_chains) + + # Keep track of all nodes that have been added + new_nodes = set() + + # For each pooling chain initialize new destination + layer_info = dynapcnn_layer_info[layer_idx] + for chain in pooling_chains: + layer_info["destinations"].append( + { + "pooling_ids": chain, + "pooling_modules": [indx_2_module_map[idx] for idx in chain], + # Setting `destination_layer` to `None` allows for this layer + # to act as network exit point if not destination is added later + "destination_layer": None, + } + ) + new_nodes.update(set(chain)) + + for node in new_nodes: + # Make sure new pooling nodes have not been used elsewhere + assert node not in node_2_layer_map + node_2_layer_map[node] = layer_idx + + +def dvs_setup( + edges_by_type: Dict[str, Set[Edge]], + indx_2_module_map: Dict[int, nn.Module], + node_2_layer_map: Dict[int, int], + nodes_io_shapes: Dict[int, Dict[str, Tuple[Size, Size]]], +) -> Union[None, Dict]: + """Generate dict containing information to set up DVS layer + + Parameters + ---------- + edges_by_type (dict of sets of edges): Keys are edge types (str), values are sets of edges. + indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value` + node_2_layer_map (dict): Maps each node ID to the ID of the layer it is assigned to. + Will be updated in-place. + nodes_io_shapes (dict): Map from node ID to dict containing node's in- and output shapes + + Returns + ------- + dvs_layer_info: Dict containing information about the DVSLayer. + """ + # Process all outgoing edges of a DVSLayer + dvs_weight_edges = edges_by_type.get("dvs-weight", set()) + dvs_pooling_edges = edges_by_type.get("dvs-pooling", set()) + + # Process all dvs->weight edges connecting the DVS camera to a dynapcnn layer. + if dvs_weight_edges: + if dvs_pooling_edges: + raise InvalidGraphStructure( + "DVS layer has destinations with and without pooling. Unlike " + "with CNN layers, pooling of the DVS has to be the same for " + "all destinations." + ) + return init_dvs_entry( + dvs_weight_edges, + indx_2_module_map, + node_2_layer_map, + nodes_io_shapes, + ) + + # Process dvs->pooling edges adding pooling to a DVS Layer + elif dvs_pooling_edges: + # Make sure there is exactly one dvs->pooling edge + if len(dvs_pooling_edges) > 1: + raise InvalidGraphStructure( + "DVSLayer has connects to multiple pooling layers. Unlike " + "with CNN layers, pooling of the DVS has to be the same for " + "all destinations, therefore the DVSLayer can connect to at " + "most one pooling layer." + ) + dvs_pooling_edge = dvs_pooling_edges.pop() + # Find pooling-weight edges that connect DVS layer to dynapcnn layers. + pooling_weight_edges = edges_by_type.get("pooling-weight", set()) + dvs_pooling_weight_edges = find_edges_by_source( + pooling_weight_edges, dvs_pooling_edge[1] + ) + # Remove handled pooling-weight edges + pooling_weight_edges.difference_update(dvs_pooling_weight_edges) + + return init_dvs_entry_with_pooling( + dvs_pooling_edge, + dvs_pooling_weight_edges, + indx_2_module_map, + node_2_layer_map, + nodes_io_shapes, + ) + else: + # If no edges related to DVS have been found return None + return + + +def init_dvs_entry( + dvs_weight_edges: Set[Edge], + indx_2_module_map: Dict[int, nn.Module], + node_2_layer_map: Dict[int, int], + nodes_io_shapes: Dict[int, Dict[str, Tuple[Size, Size]]], +) -> Dict: + """Initiate dict to hold information for a DVS Layer configuration + based on "dvs-weight" edges. + + Parameters + ---------- + dvs_weight_edges: Set of edges between two nodes in graph. + Edge source has to be a DVSLayer and the same for all edges. + Edge target has to be within an existing entry of `dynapcnn_layer_info`. + indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value` + node_2_layer_map (dict): Maps each node ID to the ID of the layer it is assigned to. + Will be updated in-place. + nodes_io_shapes (dict): Map from node ID to dict containing node's in- and output shapes + + Returns + ------- + dvs_layer_info: Dict containing information about the DVSLayer. + """ + + # Pick any of the edges in set to get the DVS node ID. Should be same for all. + dvs_node_id = next(dvs_weight_edges.__iter__())[0] + + # This should never fail + if not all(edge[0] == dvs_node_id for edge in dvs_weight_edges): + raise InvalidGraphStructure( + "The provided network seems to consist of multiple DVS layers. " + "This is not supported." + ) + assert isinstance( + (dvs_layer := indx_2_module_map[dvs_node_id]), DVSLayer + ), f"Source node in edges {dvs_weight_edges} is of type {type(dvs_layer)} (it should be a DVSLayer instance)." + + # Initialize dvs config dict + dvs_layer_info = { + "node_id": dvs_node_id, + "input_shape": nodes_io_shapes[dvs_node_id]["input"], + "module": dvs_layer, + "pooling": None, + } + node_2_layer_map[dvs_node_id] = "dvs" + + # Find destination layer indices + destinations = [] + while dvs_weight_edges: + edge = dvs_weight_edges.pop() + try: + destination_layer_idx = node_2_layer_map[edge[1]] + except KeyError: + weight_layer = indx_2_module_map[edge[1]] + raise InvalidGraphStructure( + f"Weight layer {weight_layer} cannot be assigned to a dynapcnn layer. " + "This is likely due to an unsupported SNN architecture. Weight " + "layers have to be followed by a spiking layer (`sl.IAFSqueeze`)." + ) + + # Update entry for DVS with new destination. + assert destination_layer_idx not in destinations + destinations.append(destination_layer_idx) + + if destinations: + dvs_layer_info["destinations"] = destinations + else: + dvs_layer_info["destinations"] = None + + return dvs_layer_info + + +def init_dvs_entry_with_pooling( + dvs_pooling_edge: Edge, + pooling_weight_edges: Set[Edge], + indx_2_module_map: Dict[int, nn.Module], + node_2_layer_map: Dict[int, int], + nodes_io_shapes: Dict[int, Dict[str, Tuple[Size, Size]]], +) -> Dict: + """Initiate dict to hold information for a DVS Layer configuration with additional pooling + + Parameters + ---------- + dvs_pooling_edge: Edge from DVSLayer to pooling layer. + pooling_weight_edges: Set of edges between pooling layer and weight layer + Edge source has to be the target of `dvs_pooling_edge`. + Edge targets have to be within an existing entry of `dynapcnn_layer_info`. + indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value` + node_2_layer_map (dict): Maps each node ID to the ID of the layer it is assigned to. + Will be updated in-place. + nodes_io_shapes (dict): Map from node ID to dict containing node's in- and output shapes + + Returns + ------- + dvs_layer_info: Dict containing information about the DVSLayer. + """ + + dvs_node_id, pooling_id = dvs_pooling_edge + + # This should never fail + assert all(edge[0] == pooling_id for edge in pooling_weight_edges) + assert isinstance( + (dvs_layer := indx_2_module_map[dvs_node_id]), DVSLayer + ), f"Source node in edge {dvs_pooling_edge} is of type {type(dvs_layer)} (it should be a DVSLayer instance)." + + # Initialize dvs config dict + dvs_layer_info = { + "node_id": dvs_node_id, + "input_shape": nodes_io_shapes[dvs_node_id]["input"], + "module": dvs_layer, + "pooling": {"module": indx_2_module_map[pooling_id], "node_id": pooling_id}, + } + node_2_layer_map[dvs_node_id] = "dvs" + + # Find destination layer indices + destinations = [] + for edge in pooling_weight_edges: + try: + destination_layer_idx = node_2_layer_map[edge[1]] + except KeyError: + weight_layer = indx_2_module_map[edge[1]] + raise InvalidGraphStructure( + f"Weight layer {weight_layer} cannot be assigned to a dynapcnn layer. " + "This is likely due to an unsupported SNN architecture. Weight " + "layers have to be followed by a spiking layer (`sl.IAFSqueeze`)." + ) + + # Update entry for DVS with new destination. + assert destination_layer_idx not in destinations + destinations.append(destination_layer_idx) + + if destinations: + dvs_layer_info["destinations"] = destinations + else: + dvs_layer_info["destinations"] = None + + return dvs_layer_info + + +def set_exit_destinations(dynapcnn_layer: Dict) -> None: + """Set minimal destination entries for layers that don't have any. + + This ensures that the forward methods of the resulting DynapcnnLayer + instances return an output, letting these layers act as exit points + of the network. + The destination layer will be `None`, and no pooling applied. + + Parameters + ---------- + dynapcnn_layer_info: Dict with one entry for each future dynapcnn layer. + key is unique dynapcnn layer ID, value is dict with nodes of the layer + Will be updated in-place. + """ + for layer_info in dynapcnn_layer.values(): + if not (destinations := layer_info["destinations"]): + # Add `None` destination to empty destination lists + destinations.append( + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": None, + } + ) + + +def set_neuron_layer_destination( + dynapcnn_layer_info: Dict[int, Dict], + edge: Edge, + node_2_layer_map: Dict[int, int], + nodes_io_shapes: Dict[int, Dict[str, Tuple[Size, Size]]], + indx_2_module_map: Dict[int, nn.Module], +) -> None: + """Set destination layer without pooling for existing entry in `dynapcnn_layer_info`. + + Parameters + ---------- + dynapcnn_layer_info: Dict with one entry for each future dynapcnn layer. + key is unique dynapcnn layer ID, value is dict with nodes of the layer + Will be updated in-place. + edge: Tuple of 2 integers, indicating edge between two nodes in graph. + Edge source has to be a neuron layer within an existing entry of + `dynapcnn_layer_info`. Edge target has to be the weight layer of + another dynapcnn layer. + node_2_layer_map (dict): Maps each node ID to the ID of the layer it is assigned to. + Will be updated in-place. + nodes_io_shapes (dict): Map from node ID to dict containing node's in- and output shapes + indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value` + """ + # Make sure both source (neuron layer) and target (weight layer) have been previously processed + try: + source_layer_idx = node_2_layer_map[edge[0]] + except KeyError: + neuron_layer = indx_2_module_map[edge[0]] + raise InvalidGraphStructure( + f"Spiking layer {neuron_layer} cannot be assigned to a dynapcnn layer. " + "This is likely due to an unsupported SNN architecture. Spiking " + "layers have to be preceded by a weight layer (`nn.Conv2d` or " + "`nn.Linear`)." + ) + try: + destination_layer_idx = node_2_layer_map[edge[1]] + except KeyError: + weight_layer = indx_2_module_map[edge[1]] + raise InvalidGraphStructure( + f"Weight layer {weight_layer} cannot be assigned to a dynapcnn layer. " + "This is likely due to an unsupported SNN architecture. Weight " + "layers have to be followed by a spiking layer (`IAFSqueeze`)." + ) + + # Add new destination + output_shape = nodes_io_shapes[edge[0]]["output"] + layer_info = dynapcnn_layer_info[source_layer_idx] + layer_info["destinations"].append( + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": destination_layer_idx, + "output_shape": output_shape, + } + ) + + +def set_pooling_layer_destination( + dynapcnn_layer_info: Dict[int, Dict], + edge: Edge, + node_2_layer_map: Dict[int, int], + nodes_io_shapes: Dict[int, Dict[str, Tuple[Size, Size]]], + indx_2_module_map: Dict[int, nn.Module], +) -> None: + """Set destination layer with pooling for existing entry in `dynapcnn_layer_info`. + + Parameters + ---------- + dynapcnn_layer_info: Dict with one entry for each future dynapcnn layer. + key is unique dynapcnn layer ID, value is dict with nodes of the layer + Will be updated in-place. + edge: Tuple of 2 integers, indicating edge between two nodes in graph. + Edge source has to be a pooling layer that is at the end of at least + one pooling chain within an existing entry of `dynapcnn_layer_info`. + Edge target has to be a weight layer within an existing entry of + `dynapcnn_layer_info`. + node_2_layer_map (dict): Maps each node ID to the ID of the layer it is assigned to. + Will be updated in-place. + nodes_io_shapes (dict): Map from node ID to dict containing node's in- and output shapes + indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value` + """ + # Make sure both source (pooling layer) and target (weight layer) have been previously processed + try: + source_layer_idx = node_2_layer_map[edge[0]] + except KeyError: + poolin_layer = indx_2_module_map[edge[0]] + raise InvalidGraphStructure( + f"Layer {poolin_layer} cannot be assigned to a dynapcnn layer. " + "This is likely due to an unsupported SNN architecture. Pooling " + "layers have to be preceded by a spiking layer (`IAFSqueeze`), " + "another pooling layer, or DVS input" + ) + try: + destination_layer_idx = node_2_layer_map[edge[1]] + except KeyError: + weight_layer = indx_2_module_map[edge[1]] + raise InvalidGraphStructure( + f"Weight layer {weight_layer} cannot be assigned to a dynapcnn layer. " + "This is likely due to an unsupported SNN architecture. Weight " + "layers have to be preceded by a spiking layer (`IAFSqueeze`), " + "another pooling layer, or DVS input" + ) + + # Find current source node within destinations + layer_info = dynapcnn_layer_info[source_layer_idx] + matched = False + for destination in layer_info["destinations"]: + if destination["pooling_ids"][-1] == edge[0]: + if destination["destination_layer"] is not None: + # Destination is already linked to a postsynaptic layer. This happens when + # pooling nodes have outgoing edges to different weight layer. + # Copy the destination + # TODO: Add unit test for this case + destination = {k: v for k, v in destination.items()} + layer_info["destinations"].append(destination) + matched = True + break + if not matched: + pooling_layer = indx_2_module_map[edge[0]] + raise InvalidGraphStructure( + f"Layer {pooling_layer} cannot be assigned to a dynapcnn layer. " + "This is likely due to an unsupported SNN architecture. Pooling " + "layers have to be preceded by a spiking layer (`IAFSqueeze`), " + "another pooling layer, or DVS input" + ) + + # Set destination layer within destination dict that holds current source node + destination["destination_layer"] = destination_layer_idx + output_shape = nodes_io_shapes[edge[0]]["output"] + destination["output_shape"] = output_shape + + +def trace_paths(node: int, remaining_edges: Set[Edge]) -> List[Deque[int]]: + """Trace any path of collected edges through the graph. + + Start with `node`, and recursively look for paths of connected nodes + within `remaining edges.` + + Parameters + ---------- + node (int): ID of current node + remaining_edges: Set of remaining edges still to be searched + + Returns + ------- + paths: List of deque of int, all paths of connected edges starting from `node`. + processed_edges: Set of edges that are part of the returned paths + """ + paths = [] + processed_edges = set() + for src, tgt in remaining_edges: + if src == node: + processed_edges.add((src, tgt)) + # For each edge with `node` as source, find subsequent pooling nodes recursively + new_remaining = remaining_edges.difference({(src, tgt)}) + branches, new_processed = trace_paths(tgt, new_remaining) + # Make sure no edge was processed twice + assert len(processed_edges.intersection(new_processed)) == 0 + + # Keep track of newly processed edges + processed_edges.update(new_processed) + + # Collect all branching paths of pooling, inserting src at beginning + for branch in branches: + branch.appendleft(src) + paths.append(branch) + + if not paths: + # End of recursion: instantiate a deque only with node + paths = [Deque([node])] + + return paths, processed_edges + + +def find_edges_by_source(edges: Set[Edge], source: int) -> Set[Edge]: + """Utility function to find all edges with a given source node. + + Parameters + ---------- + - edges: Set of `Edge` instances to be searched + - source (int): Node ID that returned edges should have as source + + Returns + ------- + - Set[Edge]: All sets from `edges` that have `source` as source + """ + return {(src, tgt) for (src, tgt) in edges if src == source} + + +def verify_layer_info( + dynapcnn_layer_info: Dict[int, Dict], edge_counts: Optional[Dict[str, int]] = None +): + """Verify that `dynapcnn_layer_info` matches formal requirements. + + - Every layer needs to have at least a `conv`, `neuron`, and `destinations` + entry. + - If `edge_counts` is provided, also make sure that number of layer matches + numbers of edges. + + Parameters + ---------- + - dynapcnn_layer_info: Dict with information to construct and connect + DynapcnnLayer instances + - edge_counts: Optional Dict with edge counts for each edge type. If not + `None`, will be used to do further verifications on `dynapcnn_layer_info` + + Raises + ------ + - InvalidGraphStructure: if any verification fails. + """ + + # Make sure that each dynapcnn layer has at least a weight layer and a neuron layer + for idx, info in dynapcnn_layer_info.items(): + if not "conv" in info: + raise InvalidGraphStructure( + f"DynapCNN layer {idx} has no weight assigned, which should " + "never happen. " + default_invalid_structure_string + ) + if not "neuron" in info: + raise InvalidGraphStructure( + f"DynapCNN layer {idx} has no spiking layer assigned, which " + "should never happen. " + default_invalid_structure_string + ) + if not "destinations" in info: + raise InvalidGraphStructure( + f"DynapCNN layer {idx} has no destination info assigned, which " + "should never happen. " + default_invalid_structure_string + ) + if edge_counts is not None: + # Make sure there are as many layers as edges from weight to neuron + if edge_counts.get("weight-neuron", 0) - len(dynapcnn_layer_info) > 0: + raise InvalidGraphStructure( + "Not all weight-to-neuron edges have been processed, which " + "should never happen. " + default_invalid_structure_string + ) diff --git a/sinabs/backend/dynapcnn/utils.py b/sinabs/backend/dynapcnn/utils.py index 2719fc65..e9e919d1 100644 --- a/sinabs/backend/dynapcnn/utils.py +++ b/sinabs/backend/dynapcnn/utils.py @@ -1,67 +1,134 @@ +from collections import defaultdict, deque from copy import deepcopy -from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, TypeVar, Union import torch import torch.nn as nn -import sinabs import sinabs.layers as sl from .crop2d import Crop2d -from .dvs_layer import DVSLayer, expand_to_pair -from .dynapcnn_layer import DynapcnnLayer -from .exceptions import InputConfigurationError, MissingLayer, UnexpectedLayer -from .flipdims import FlipDims +from .dvs_layer import DVSLayer +from .exceptions import InputConfigurationError if TYPE_CHECKING: from sinabs.backend.dynapcnn.dynapcnn_network import DynapcnnNetwork -DEFAULT_IGNORED_LAYER_TYPES = (nn.Identity, nn.Dropout, nn.Dropout2d, nn.Flatten) +# Other than `COMPLETELY_IGNORED_LAYER_TYPES`, `IGNORED_LAYER_TYPES` are +# part of the graph initially and are needed to ensure proper handling of +# graph structure (e.g. Merge nodes) or meta-information (e.g. +# `nn.Flatten` for io-shapes) +COMPLETELY_IGNORED_LAYER_TYPES = (nn.Identity, nn.Dropout, nn.Dropout2d) +IGNORED_LAYER_TYPES = (nn.Flatten, sl.Merge) +Edge = Tuple[int, int] # Define edge-type alias + + +####################################################### Device Related ####################################################### + + +def parse_device_id(device_id: str) -> Tuple[str, int]: + """Parse device id into device type and device index. + + Args: + device_id (str): Device id typically of the form `device_type:index`. + In case no index is specified, the default index of zero is returned. + + Returns: + Tuple[str, int]: (device_type, index) Returns a tuple with the index and device type. + """ + parts = device_id.split(sep=":") + if len(parts) == 1: + device_type = parts[0] + index = 0 + elif len(parts) == 2: + device_type, index = parts + else: + raise Exception( + "Device id not understood. A string of form `device_type:index` expected." + ) + + return device_type, int(index) + + +def get_device_id(device_type: str, index: int) -> str: + """Generate a device id string given a device type and its index. + + Args: + device_type (str): Device type + index (int): Device index + + Returns: + str: A string of the form `device_type:index` + """ + return f"{device_type}:{index}" + + +def standardize_device_id(device_id: str) -> str: + """Standardize device id string. + + Args: + device_id (str): Device id string. Could be of the form `device_type` or `device_type:index` + + Returns: + str: Returns a sanitized device id of the form `device_type:index` + """ + device_type, index = parse_device_id(device_id=device_id) + return get_device_id(device_type=device_type, index=index) + + +####################################################### DynapcnnNetwork Related ####################################################### -def infer_input_shape( - layers: List[nn.Module], input_shape: Optional[Tuple[int, int, int]] = None -) -> Tuple[int, int, int]: - """Checks if the input_shape is specified. If either of them are specified, then it checks if - the information is consistent and returns the input shape. + +def topological_sorting(edges: Set[Tuple[int, int]]) -> List[int]: + """Performs a topological sorting (using Kahn's algorithm) of a graph descrobed by a list edges. An entry node `X` + of the graph have to be flagged inside `edges` by a tuple `('input', X)`. Parameters ---------- - layers: - List of modules - input_shape : - (channels, height, width) + - edges (set): the edges describing the *acyclic* graph. Returns - ------- - Output shape: - (channels, height, width) + ---------- + - topological_order (list): the nodes sorted by the graph's topology. """ - if input_shape is not None and len(input_shape) != 3: - raise InputConfigurationError( - f"input_shape expected to have length 3 or None but input_shape={input_shape} given." - ) - input_shape_from_layer = None - if layers and isinstance(layers[0], DVSLayer): - input_shape_from_layer = layers[0].input_shape - if len(input_shape_from_layer) != 3: - raise InputConfigurationError( - f"input_shape of layer {layers[0]} expected to have length 3 or None but input_shape={input_shape_from_layer} found." - ) - if (input_shape is not None) and (input_shape_from_layer is not None): - if input_shape == input_shape_from_layer: - return input_shape + graph = defaultdict(list) + in_degree = defaultdict(int) + + # initialize the graph and in-degrees. + for u, v in edges: + if u != "input": + graph[u].append(v) + in_degree[v] += 1 else: - raise InputConfigurationError( - f"Input shape from the layer {input_shape_from_layer} does not match the specified input_shape {input_shape}" - ) - elif input_shape_from_layer is not None: - return input_shape_from_layer - elif input_shape is not None: - return input_shape - else: - raise InputConfigurationError("No input shape could be inferred") + if v not in in_degree: + in_degree[v] = 0 + if v not in in_degree: + in_degree[v] = 0 + + # find all nodes with zero in-degrees. + zero_in_degree_nodes = deque( + [node for node, degree in in_degree.items() if degree == 0] + ) + + # process nodes and create the topological order. + topological_order = [] + + while zero_in_degree_nodes: + node = zero_in_degree_nodes.popleft() + topological_order.append(node) + + for neighbor in graph[node]: + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + zero_in_degree_nodes.append(neighbor) + + # check if all nodes are processed (to handle cycles). + if len(topological_order) == len(in_degree): + return topological_order + + raise ValueError("The graph has a cycle and cannot be topologically sorted.") def convert_cropping2dlayer_to_crop2d( @@ -89,110 +156,24 @@ def convert_cropping2dlayer_to_crop2d( return Crop2d(((top, bottom), (left, right))) -def construct_dvs_layer( - layers: List[nn.Module], - input_shape: Tuple[int, int, int], - idx_start: int = 0, - dvs_input: bool = False, -) -> Tuple[Optional[DVSLayer], int, float]: - """ - Generate a DVSLayer given a list of layers. If `layers` does not start - with a pooling, cropping or flipping layer and `dvs_input` is False, - will return `None` instead of a DVSLayer. - NOTE: The number of channels is implicitly assumed to be 2 because of DVS - - Parameters - ---------- - layers: - List of layers - input_shape: - Shape of input (channels, height, width) - idx_start: - Starting index to scan the list. Default 0 - - Returns - ------- - dvs_layer: - None or DVSLayer - idx_next: int or None - Index of first layer after this layer is constructed - rescale_factor: float - Rescaling factor needed when turning AvgPool to SumPool. May - differ from the pooling kernel in certain cases. - dvs_input: bool - Whether DVSLayer should have pixel array activated. - """ - # Start with defaults - layer_idx_next = idx_start - crop_lyr = None - flip_lyr = None - - if len(input_shape) != 3: - raise ValueError( - f"Input shape should be 3 dimensional but input_shape={input_shape} was given." - ) - - # Return existing DVS layer as is - if len(layers) and isinstance(layers[0], DVSLayer): - return deepcopy(layers[0]), 1, 1 - - # Construct pooling layer - pool_lyr, layer_idx_next, rescale_factor = construct_next_pooling_layer( - layers, layer_idx_next - ) - - # Find next layer (check twice for two layers) - for __ in range(2): - # Go to the next layer - if layer_idx_next < len(layers): - layer = layers[layer_idx_next] - else: - break - # Check layer type - if isinstance(layer, sl.Cropping2dLayer): - # The shape after pooling is - pool = expand_to_pair(pool_lyr.kernel_size) - h = input_shape[1] // pool[0] - w = input_shape[2] // pool[1] - print(f"Input shape to the cropping layer is {h}, {w}") - crop_lyr = convert_cropping2dlayer_to_crop2d(layer, (h, w)) - elif isinstance(layer, Crop2d): - crop_lyr = layer - elif isinstance(layer, FlipDims): - flip_lyr = layer - else: - break - - layer_idx_next += 1 - - # If any parameters have been found or dvs_input is True - if (layer_idx_next > 0) or dvs_input: - dvs_layer = DVSLayer.from_layers( - pool_layer=pool_lyr, - crop_layer=crop_lyr, - flip_layer=flip_lyr, - input_shape=input_shape, - disable_pixel_array=not dvs_input, - ) - return dvs_layer, layer_idx_next, rescale_factor - else: - # No parameters/layers pertaining to DVS preprocessing found - return None, 0, 1 +WeightLayer = TypeVar("WeightLayer", nn.Linear, nn.Conv2d) -def merge_conv_bn(conv, bn): - """Merge a convolutional layer with subsequent batch normalization. +def merge_bn( + weight_layer: WeightLayer, bn: Union[nn.BatchNorm1d, nn.BatchNorm2d] +) -> WeightLayer: + """Merge a convolutional or linear layer with subsequent batch normalization. Parameters ---------- - conv: torch.nn.Conv2d - Convolutional layer - bn: torch.nn.Batchnorm2d + weight_layer: torch.nn.Conv2d or nn.Linear + Convolutional or linear layer + bn: torch.nn.Batchnorm2d or nn.Batchnorm1d Batch normalization Returns ------- - torch.nn.Conv2d: Convolutional layer including batch normalization + Weight layer including batch normalization """ mu = bn.running_mean sigmasq = bn.running_var @@ -204,297 +185,56 @@ def merge_conv_bn(conv, bn): factor = gamma / sigmasq.sqrt() - c_weight = conv.weight.data.clone().detach() - c_bias = 0.0 if conv.bias is None else conv.bias.data.clone().detach() - - conv = deepcopy(conv) # TODO: this will cause copying twice - - conv.weight.data = c_weight * factor[:, None, None, None] - conv.bias.data = beta + (c_bias - mu) * factor - - return conv - - -def construct_next_pooling_layer( - layers: List[nn.Module], idx_start: int -) -> Tuple[Optional[sl.SumPool2d], int, float]: - """Consolidate the first `AvgPool2d` objects in `layers` until the first object of different - type. - - Parameters - ---------- - layers: Sequence of layer objects - Contains `AvgPool2d` and other objects. - idx_start: int - Layer index to start construction from - Returns - ------- - lyr_pool: int or tuple of ints - Consolidated pooling size. - idx_next: int - Index of first object in `layers` that is not a `AvgPool2d`, - rescale_factor: float - Rescaling factor needed when turning AvgPool to SumPool. May - differ from the pooling kernel in certain cases. - """ + weight = weight_layer.weight.data.clone().detach() + bias = 0.0 if weight_layer.bias is None else weight_layer.bias.data.clone().detach() - rescale_factor = 1 - cumulative_pooling = expand_to_pair(1) - - idx_next = idx_start - # Figure out pooling dims - while idx_next < len(layers): - lyr = layers[idx_next] - if isinstance(lyr, nn.AvgPool2d): - if lyr.padding != 0: - raise ValueError("Padding is not supported for the pooling layers") - elif isinstance(lyr, sl.SumPool2d): - ... - else: - # Reached a non pooling layer - break - # Increment if it is a pooling layer - idx_next += 1 - - pooling = expand_to_pair(lyr.kernel_size) - if lyr.stride is not None: - stride = expand_to_pair(lyr.stride) - if pooling != stride: - raise ValueError( - f"Stride length {lyr.stride} should be the same as pooling kernel size {lyr.kernel_size}" - ) - # Compute cumulative pooling - cumulative_pooling = ( - cumulative_pooling[0] * pooling[0], - cumulative_pooling[1] * pooling[1], - ) - # Update rescaling factor - if isinstance(lyr, nn.AvgPool2d): - rescale_factor *= pooling[0] * pooling[1] + weight_layer = deepcopy(weight_layer) - # If there are no layers - if cumulative_pooling == (1, 1): - return None, idx_next, 1 + new_bias = beta + (bias - mu) * factor + if weight_layer.bias is None: + weight_layer.bias = nn.Parameter(new_bias) else: - lyr_pool = sl.SumPool2d(cumulative_pooling) - return lyr_pool, idx_next, rescale_factor - - -def construct_next_dynapcnn_layer( - layers: List[nn.Module], - idx_start: int, - in_shape: Tuple[int, int, int], - discretize: bool, - rescale_factor: float = 1, -) -> Tuple[DynapcnnLayer, int, float]: - """Generate a DynapcnnLayer from a Conv2d layer and its subsequent spiking and pooling layers. - - Parameters - ---------- - - layers: sequence of layer objects - First object must be Conv2d, next must be an IAF layer. All pooling - layers that follow immediately are consolidated. Layers after this - will be ignored. - idx_start: - Layer index to start construction from - in_shape: tuple of integers - Shape of the input to the first layer in `layers`. Convention: - (input features, height, width) - discretize: bool - Discretize weights and thresholds if True - rescale_factor: float - Weights of Conv2d layer are scaled down by this factor. Can be - used to account for preceding average pooling that gets converted - to sum pooling. - - Returns - ------- - dynapcnn_layer: DynapcnnLayer - DynapcnnLayer - layer_idx_next: int - Index of the next layer after this layer is constructed - rescale_factor: float - rescaling factor to account for average pooling - """ - layer_idx_next = idx_start # Keep track of layer indices - - # Check that the first layer is Conv2d, or Linear - if not isinstance(layers[layer_idx_next], (nn.Conv2d, nn.Linear)): - raise UnexpectedLayer(nn.Conv2d, layers[layer_idx_next]) - - # Identify and consolidate conv layer - lyr_conv = layers[layer_idx_next] - layer_idx_next += 1 - if layer_idx_next >= len(layers): - raise MissingLayer(layer_idx_next) - # Check and consolidate batch norm - if isinstance(layers[layer_idx_next], nn.BatchNorm2d): - lyr_conv = merge_conv_bn(lyr_conv, layers[layer_idx_next]) - layer_idx_next += 1 - - # Check next layer exists - try: - lyr_spk = layers[layer_idx_next] - layer_idx_next += 1 - except IndexError: - raise MissingLayer(layer_idx_next) - - # Check that the next layer is spiking - # TODO: Check that the next layer is an IAF layer - if not isinstance(lyr_spk, sl.IAF): - raise TypeError( - f"Convolution must be followed by IAF spiking layer, found {type(lyr_spk)}" - ) + weight_layer.bias.data = new_bias - # Check for next pooling layer - lyr_pool, i_next, rescale_factor_after_pooling = construct_next_pooling_layer( - layers, layer_idx_next - ) - # Increment layer index to after the pooling layers - layer_idx_next = i_next - - # Compose DynapcnnLayer - dynapcnn_layer = DynapcnnLayer( - conv=lyr_conv, - spk=lyr_spk, - pool=lyr_pool, - in_shape=in_shape, - discretize=discretize, - rescale_weights=rescale_factor, - ) + for __ in range(weight_layer.weight.ndim - factor.ndim): + factor.unsqueeze_(-1) + weight_layer.weight.data = weight * factor - return dynapcnn_layer, layer_idx_next, rescale_factor_after_pooling + return weight_layer -def build_from_list( - layers: List[nn.Module], - in_shape, - discretize=True, - dvs_input=False, -) -> nn.Sequential: - """Build a sequential model of DVSLayer and DynapcnnLayer(s) given a list of layers comprising - a spiking CNN. +def merge_conv_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d: + """Merge a convolutional layer with subsequent batch normalization. Parameters ---------- - - layers: sequence of layer objects - in_shape: tuple of integers - Shape of the input to the first layer in `layers`. Convention: - (channels, height, width) - discretize: bool - Discretize weights and thresholds if True - dvs_input: bool - Whether model should receive DVS input. If `True`, the returned model - will begin with a DVSLayer with `disable_pixel_array` set to False. - Otherwise, the model starts with a DVSLayer only if the first element - in `layers` is a pooling, cropping or flipping layer. + conv: torch.nn.Conv2d + Convolutional layer + bn: torch.nn.Batchnorm2d + Batch normalization Returns ------- - nn.Sequential + torch.nn.Conv2d: Convolutional layer including batch normalization """ - compatible_layers = [] - lyr_indx_next = 0 - # Find and populate dvs layer (NOTE: We are ignoring the channel information here and could lead to problems) - dvs_layer, lyr_indx_next, rescale_factor = construct_dvs_layer( - layers, input_shape=in_shape, idx_start=lyr_indx_next, dvs_input=dvs_input - ) - if dvs_layer is not None: - compatible_layers.append(dvs_layer) - in_shape = dvs_layer.get_output_shape() - # Find and populate dynapcnn layers - while lyr_indx_next < len(layers): - if isinstance(layers[lyr_indx_next], DEFAULT_IGNORED_LAYER_TYPES): - # - Ignore identity, dropout and flatten layers - lyr_indx_next += 1 - continue - dynapcnn_layer, lyr_indx_next, rescale_factor = construct_next_dynapcnn_layer( - layers, - lyr_indx_next, - in_shape=in_shape, - discretize=discretize, - rescale_factor=rescale_factor, - ) - in_shape = dynapcnn_layer.get_output_shape() - compatible_layers.append(dynapcnn_layer) + return merge_bn(conv, bn) - return nn.Sequential(*compatible_layers) - -def convert_model_to_layer_list( - model: Union[nn.Sequential, sinabs.Network], - ignore: Union[Type, Tuple[Type, ...]] = (), -) -> List[nn.Module]: - """Convert a model to a list of layers. +def merge_linear_bn(linear: nn.Linear, bn: nn.BatchNorm1d) -> nn.Linear: + """Merge a linear (fully connected) layer with subsequent batch normalization. Parameters ---------- - model: nn.Sequential or sinabs.Network - ignore: type or tuple of types of modules to be ignored + linear: torch.nn.Linear + Linear layer + bn: torch.nn.BatchNorm1d + Batch normalization layer Returns ------- - List[nn.Module] - """ - if isinstance(model, sinabs.Network): - return convert_model_to_layer_list(model.spiking_model) - elif isinstance(model, nn.Sequential): - layers = [layer for layer in model if not isinstance(layer, ignore)] - else: - raise TypeError("Expected torch.nn.Sequential or sinabs.Network") - return layers - - -def parse_device_id(device_id: str) -> Tuple[str, int]: - """Parse device id into device type and device index. - - Args: - device_id (str): Device id typically of the form `device_type:index`. - In case no index is specified, the default index of zero is returned. - - Returns: - Tuple[str, int]: (device_type, index) Returns a tuple with the index and device type. - """ - parts = device_id.split(sep=":") - if len(parts) == 1: - device_type = parts[0] - index = 0 - elif len(parts) == 2: - device_type, index = parts - else: - raise Exception( - "Device id not understood. A string of form `device_type:index` expected." - ) - - return device_type, int(index) - - -def get_device_id(device_type: str, index: int) -> str: - """Generate a device id string given a device type and its index. - - Args: - device_type (str): Device type - index (int): Device index - - Returns: - str: A string of the form `device_type:index` - """ - return f"{device_type}:{index}" - - -def standardize_device_id(device_id: str) -> str: - """Standardize device id string. - - Args: - device_id (str): Device id string. Could be of the form `device_type` or `device_type:index` - - Returns: - str: Returns a sanitized device id of the form `device_type:index` + torch.nn.Linear: Linear layer including batch normalization """ - device_type, index = parse_device_id(device_id=device_id) - return get_device_id(device_type=device_type, index=index) + return merge_bn(linear, bn) def extend_readout_layer(model: "DynapcnnNetwork") -> "DynapcnnNetwork": @@ -510,30 +250,77 @@ def extend_readout_layer(model: "DynapcnnNetwork") -> "DynapcnnNetwork": """ model = deepcopy(model) input_shape = model.input_shape - og_readout_conv_layer = model.sequence[ - -1 - ].conv_layer # extract the conv layer from dynapcnn network - og_weight_data = og_readout_conv_layer.weight.data - og_bias_data = og_readout_conv_layer.bias - og_bias = og_bias_data is not None - # modify the out channels - og_out_channels = og_readout_conv_layer.out_channels - new_out_channels = (og_out_channels - 1) * 4 + 1 - og_readout_conv_layer.out_channels = new_out_channels - # build extended weight and replace the old one - ext_weight_shape = (new_out_channels, *og_weight_data.shape[1:]) - ext_weight_data = torch.zeros(ext_weight_shape, dtype=og_weight_data.dtype) - for i in range(og_out_channels): - ext_weight_data[i * 4] = og_weight_data[i] - og_readout_conv_layer.weight.data = ext_weight_data - # build extended bias and replace if necessary - if og_bias: - ext_bias_shape = (new_out_channels,) - ext_bias_data = torch.zeros(ext_bias_shape, dtype=og_bias_data.dtype) + for exit_layer in model.exit_layers: + # extract the conv layer from dynapcnn network + og_readout_conv_layer = exit_layer.conv_layer + og_weight_data = og_readout_conv_layer.weight.data + og_bias_data = og_readout_conv_layer.bias + og_bias = og_bias_data is not None + # modify the out channels + og_out_channels = og_readout_conv_layer.out_channels + new_out_channels = (og_out_channels - 1) * 4 + 1 + og_readout_conv_layer.out_channels = new_out_channels + # build extended weight and replace the old one + ext_weight_shape = (new_out_channels, *og_weight_data.shape[1:]) + ext_weight_data = torch.zeros(ext_weight_shape, dtype=og_weight_data.dtype) for i in range(og_out_channels): - ext_bias_data[i * 4] = og_bias_data[i] - og_readout_conv_layer.bias.data = ext_bias_data - _ = model( - torch.zeros(size=(1, *input_shape)) - ) # run a forward pass to initialize the new weights and last IAF + ext_weight_data[i * 4] = og_weight_data[i] + og_readout_conv_layer.weight.data = ext_weight_data + # build extended bias and replace if necessary + if og_bias: + ext_bias_shape = (new_out_channels,) + ext_bias_data = torch.zeros(ext_bias_shape, dtype=og_bias_data.dtype) + for i in range(og_out_channels): + ext_bias_data[i * 4] = og_bias_data[i] + og_readout_conv_layer.bias.data = ext_bias_data + # run a forward pass to initialize the new weights and last IAF + model(torch.zeros(size=(1, *input_shape))) return model + + +def infer_input_shape( + snn: nn.Module, input_shape: Optional[Tuple[int, int, int]] = None +) -> Tuple[int, int, int]: + """Infer expected shape of input for `snn` either from `input_shape` + or from `DVSLayer` instance within `snn` which provides it. + + If neither are available, raise an InputConfigurationError. + If both are the case, verify that the information is consistent. + + Parameters + ---------- + - snn (nn.Module): The SNN whose input shape is to be inferred + - input_shape (tuple or None): Explicitly provide input shape. + If not None, must be of the format `(channels, height, width)`. + + Returns + ------- + - tuple: The input shape to `snn`, in the format `(channels, height, width)` + """ + if input_shape is not None and len(input_shape) != 3: + raise InputConfigurationError( + f"input_shape expected to have length 3 or None but input_shape={input_shape} given." + ) + + # Find `DVSLayer` instance and infer input shape from it + input_shape_from_layer = None + for module in snn.modules(): + if isinstance(module, DVSLayer): + input_shape_from_layer = module.input_shape + # Make sure `input_shape_from_layer` is identical to provided `input_shape` + if input_shape is not None and input_shape != input_shape_from_layer: + raise InputConfigurationError( + f"Input shape from `DVSLayer` {input_shape_from_layer} does " + f"not match the specified input_shape {input_shape}" + ) + return input_shape_from_layer + + # If no `DVSLayer` is found, `input_shape` must not be provided + if input_shape is None: + raise InputConfigurationError( + "No input shape could be inferred. Either provide it explicitly " + "with the `input_shape` argument, or provide a model with " + "`DVSLayer` instance." + ) + else: + return input_shape diff --git a/sinabs/backend/dynapcnn/weight_rescaling_methods.py b/sinabs/backend/dynapcnn/weight_rescaling_methods.py new file mode 100644 index 00000000..e08e915f --- /dev/null +++ b/sinabs/backend/dynapcnn/weight_rescaling_methods.py @@ -0,0 +1,56 @@ +# author : Willian Soares Girao +# contact : williansoaresgirao@gmail.com + +import statistics +from typing import Iterable + +import numpy as np + + +def rescale_method_1(scaling_factors: Iterable[int], lambda_: float = 0.5) -> float: + """ + This method will use the average (scaled by `lambda_`) of the computed re-scaling factor + for the pooling layer(s) feeding into a convolutional layer. + + Arguments + --------- + - scaling_factors (list): the list of re-scaling factors computed by each `SumPool2d` layer targeting a + single `Conv2d` layer within a `DynapcnnLayer` instance. + - lambda_ (float): a scaling variable that multiplies the computed average re-scaling factor of the pooling layers. + + Returns + --------- + - the averaged re-scaling factor multiplied by `lambda_` if `len(scaling_factors) > 0`, else `1` is returned. + """ + + if len(scaling_factors) > 0: + return np.round(np.mean(list(scaling_factors)) * lambda_, 2) + else: + return 1.0 + + +def rescale_method_2(scaling_factors: Iterable[int], lambda_: float = 0.5) -> float: + """ + This method will use the harmonic mean (scaled by `lambda_`) of the computed re-scaling factor + for the pooling layer(s) feeding into a convolutional layer. + + Arguments + --------- + - scaling_factors (list): the list of re-scaling factors computed by each `SumPool2d` layer targeting a + single `Conv2d` layer within a `DynapcnnLayer` instance. + - lambda_ (float): a scaling variable that multiplies the computed average re-scaling factor of the pooling layers. + + Returns + --------- + - the averaged re-scaling factor multiplied by `lambda_` if `len(scaling_factors) > 0`, else `1` is returned. + + Note + --------- + - since the harmonic mean is less sensitive to outliers it **could be** that this is a better method + for weight re-scaling when multiple poolings with big differentces in kernel sizes are being considered. + """ + + if len(scaling_factors) > 0: + return np.round(statistics.harmonic_mean(list(scaling_factors)) * lambda_, 2) + else: + return 1.0 diff --git a/sinabs/from_torch.py b/sinabs/from_torch.py index beebf96b..c901e77d 100644 --- a/sinabs/from_torch.py +++ b/sinabs/from_torch.py @@ -101,6 +101,28 @@ def mapper_fn(module): **kwargs_backend, ).to(device), ) + + elif isinstance(model, nn.Module): + layers = [layer for _, layer in model.named_children()] + + if not isinstance(layers[-1], (nn.ReLU, sl.NeuromorphicReLU)): + snn.add_module( + "spike_output", + spike_layer_class( + spike_threshold=spike_threshold, + spike_fn=spike_fn, + reset_fn=reset_fn, + surrogate_grad_fn=surrogate_grad_fn, + min_v_mem=min_v_mem, + **kwargs_backend, + ).to(device), + ) + + else: + warn( + "Spiking output can only be added to sequential models that do not end in a ReLU. No layer has been added." + ) + else: warn( "Spiking output can only be added to sequential models that do not end in a ReLU. No layer has been added." diff --git a/sinabs/utils.py b/sinabs/utils.py index 7b090963..70998a18 100644 --- a/sinabs/utils.py +++ b/sinabs/utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Iterable, List, Sequence, Tuple, TypeVar, Union import numpy as np import torch @@ -7,6 +7,30 @@ import sinabs +def get_new_index(existing_indices: Sequence) -> int: + """Get a new index that is not yet part of a Sequence of existing indices + + Example: + `get_new_index([0,1,2,3])`: `4` + `get_new_index([0,1,3])`: `2` + + Parameters + ---------- + - existing_indices: Sequence of indices + + Returns + ------- + - int: Smallest number (starting from 0) that is not yet in `existing_indices`. + """ + existing_indices = set(existing_indices) + # Largest possible index is the length of `existing_indices`, if they are + # consecutively numbered. Otherwise, if there is a "gap", this would be + # filled by a smaller number. + possible_indices = range(len(existing_indices) + 1) + unused_indices = existing_indices.symmetric_difference(possible_indices) + return min(unused_indices) + + def reset_states(model: nn.Module) -> None: """Helper function to recursively reset all states of spiking layers within the model. @@ -179,3 +203,135 @@ def set_batch_size(model: nn.Module, batch_size: int): if isinstance(mod, sinabs.layers.SqueezeMixin): mod.batch_size = batch_size # reset_states(mod) + + +def get_batch_size(model: nn.Module) -> int: + """Get batch size from any model with sinabs squeeze layers + + Will raise a ValueError if different squeeze layers within the model + have different batch sizes. Ignores layers with batch size `-1`, if + others provide it. + + Args: + model (nn.Module): pytorch model with sinabs Squeeze layers + + Returns: + batch_size (int): The batch size, `-1` if none is found. + """ + + batch_sizes = { + mod.batch_size + for mod in model.modules() + if isinstance(mod, sinabs.layers.SqueezeMixin) + } + # Ignore values `-1` and `None` + batch_sizes.discard(-1) + batch_sizes.discard(None) + + if len(batch_sizes) == 0: + return -1 + elif len(batch_sizes) == 1: + return batch_sizes.pop() + else: + raise ValueError( + "The model contains layers with different batch sizes: " + ", ".join((str(s) for s in batch_sizes)) + ) + + +def get_num_timesteps(model: nn.Module) -> int: + """Get number of timesteps from any model with sinabs squeeze layers + + Will raise a ValueError if different squeeze layers within the model + have different `num_timesteps` attributes. Ignores layers with value + `-1`, if others provide it. + + Args: + model (nn.Module): pytorch model with sinabs Squeeze layers + + Returns: + num_timesteps (int): The number of time steps, `-1` if none is found. + """ + + numbers = { + mod.num_timesteps + for mod in model.modules() + if isinstance(mod, sinabs.layers.SqueezeMixin) + } + # Ignore values `-1` and `None` + numbers.discard(-1) + numbers.discard(None) + + if len(numbers) == 0: + return -1 + elif len(numbers) == 1: + return numbers.pop() + else: + raise ValueError( + "The model contains layers with different numbers of time steps: " + ", ".join((str(s) for s in numbers)) + ) + + +def get_smallest_compatible_time_dimension(model: nn.Module) -> int: + """Find the smallest size for input to a model with sinabs squeeze layers + along the batch/time (first) dimension. + + Will raise a ValueError if different squeeze layers within the model + have different `num_timesteps` or `batch_size` attributes (except for + `-1`) + + Args: + model (nn.Module): pytorch model with sinabs Squeeze layers + + Returns: + int: The smallest compatible size for the first dimension of + an input to the `model`. + """ + batch_size = abs(get_batch_size(model)) # Use `abs` to turn -1 to 1 + num_timesteps = abs(get_num_timesteps(model)) + # Use `abs` to turn `-1` to `1` + return abs(batch_size * num_timesteps) + + +def expand_to_pair(value) -> Tuple[int, int]: + """Expand a given value to a pair (tuple) if an int is passed. + + Parameters + ---------- + value: + int + + Returns + ------- + pair: + (int, int) + """ + return (value, value) if isinstance(value, int) else value + + +T = TypeVar("T") + + +def collapse_pair(pair: Union[Iterable[T], T]) -> T: + """Collapse an iterable of equal elements by returning only the first + + Parameters + ---------- + pair: Iterable. All elements should be the same. + + Returns + ------- + First item of `pair`. If `pair` is not iterable it will return `pair` itself. + + Raises + ------ + ValueError if not all elements in `pair` are equal. + """ + if isinstance(pair, Iterable): + items = [x for x in pair] + if any(x != items[0] for x in items): + raise ValueError("All elements of `pair` must be the same") + return items[0] + else: + return pair diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_dynapcnn/test_auto_mapping.py b/tests/test_dynapcnn/test_auto_mapping.py index 37de88d9..cff40a75 100644 --- a/tests/test_dynapcnn/test_auto_mapping.py +++ b/tests/test_dynapcnn/test_auto_mapping.py @@ -48,4 +48,4 @@ def test_auto_mapping_should_not_work(): graph = make_flow_graph(layer_mapping) new_graph = edmonds(graph, 0, len(graph) - 1) with pytest.raises(ValueError): - mapping = recover_mapping(new_graph, layer_mapping) + mapping = recover_mapping(new_graph, len(layer_mapping)) diff --git a/tests/test_dynapcnn/test_compatible_layer_build.py b/tests/test_dynapcnn/test_compatible_layer_build.py index 54215a52..82ccf5fe 100644 --- a/tests/test_dynapcnn/test_compatible_layer_build.py +++ b/tests/test_dynapcnn/test_compatible_layer_build.py @@ -4,194 +4,38 @@ import sinabs.layers as sl -def test_construct_pooling_from_1_layer(): - layers = [sl.SumPool2d(2)] - - from sinabs.backend.dynapcnn.utils import construct_next_pooling_layer - - pool_lyr, layer_idx_next, rescale_factor = construct_next_pooling_layer(layers, 0) - - assert pool_lyr.kernel_size == (2, 2) - assert layer_idx_next == 1 - assert rescale_factor == 1 +@pytest.mark.parametrize( + ("pooling", "layer_type", "expected_pooling", "expected_scaling"), + [ + (2, sl.SumPool2d, [2, 2], 1), + ((2, 2), sl.SumPool2d, [2, 2], 1), + (3, sl.SumPool2d, [3, 3], 1), + ((4, 4), sl.SumPool2d, [4, 4], 1), + (2, nn.AvgPool2d, [2, 2], 1.0 / 4), + ((2, 2), nn.AvgPool2d, [2, 2], 1.0 / 4), + (3, nn.AvgPool2d, [3, 3], 1.0 / 9), + ((4, 4), nn.AvgPool2d, [4, 4], 1.0 / 16), + ], +) +def test_construct_pooling_from_1_layer( + pooling, layer_type, expected_pooling, expected_scaling +): + layers = [layer_type(pooling)] + + from sinabs.backend.dynapcnn.dynapcnn_layer_utils import consolidate_dest_pooling + + cumulative_pooling, scaling = consolidate_dest_pooling(layers) + + assert cumulative_pooling == expected_pooling + assert scaling == expected_scaling def test_construct_pooling_from_2_layers(): - layers = [sl.SumPool2d(2), nn.AvgPool2d(3), sl.IAF()] - - from sinabs.backend.dynapcnn.utils import construct_next_pooling_layer - - pool_lyr, layer_idx_next, rescale_factor = construct_next_pooling_layer(layers, 0) - - assert pool_lyr.kernel_size == (6, 6) - assert layer_idx_next == 2 - assert rescale_factor == 9 - - -def test_non_square_pooling_kernel(): - layers = [ - nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), - sl.IAF(), - sl.SumPool2d((2, 3)), - ] - - from sinabs.backend.dynapcnn.utils import construct_next_dynapcnn_layer - - with pytest.raises(ValueError): - _ = construct_next_dynapcnn_layer( - layers, 0, in_shape=(2, 28, 28), discretize=True, rescale_factor=1 - ) - - -def test_construct_dynapcnn_layer_from_3_layers(): - layers = [ - nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), - sl.IAF(), - sl.SumPool2d(2), - ] - - from sinabs.backend.dynapcnn.utils import construct_next_dynapcnn_layer - - dynapcnn_lyr, layer_idx_next, rescale_factor = construct_next_dynapcnn_layer( - layers, 0, in_shape=(2, 28, 28), discretize=True, rescale_factor=1 - ) - - print(dynapcnn_lyr) - assert layer_idx_next == 3 - assert rescale_factor == 1 - - -def test_construct_dynapcnn_layer_no_pool_layers(): - layers = [ - nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), - sl.IAF(), - nn.Conv2d(8, 2, kernel_size=3, stride=1, bias=False), - sl.IAF(), - ] - - from sinabs.backend.dynapcnn.utils import construct_next_dynapcnn_layer - - dynapcnn_lyr, layer_idx_next, rescale_factor = construct_next_dynapcnn_layer( - layers, 0, in_shape=(2, 28, 28), discretize=True, rescale_factor=1 - ) - - print(dynapcnn_lyr) - assert layer_idx_next == 2 - assert rescale_factor == 1 - - -def test_construct_dynapcnn_layer_from_8_layers(): - layers = [ - nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), - sl.IAF(), - sl.SumPool2d(2), - nn.AvgPool2d(2), - nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), - sl.IAF(), - nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), - sl.IAF(), - ] - - from sinabs.backend.dynapcnn.utils import construct_next_dynapcnn_layer - - dynapcnn_lyr, layer_idx_next, rescale_factor = construct_next_dynapcnn_layer( - layers, 0, in_shape=(2, 28, 28), discretize=True, rescale_factor=1 - ) - - print(dynapcnn_lyr) - assert dynapcnn_lyr.pool_layer.kernel_size == (4, 4) - assert layer_idx_next == 4 - assert rescale_factor == 4 - - -def test_build_from_list_dynapcnn_layers_only(): - in_shape = (2, 28, 28) - layers = [ - nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), - sl.IAF(), - sl.SumPool2d(2), - nn.AvgPool2d(2), - nn.Conv2d(8, 16, kernel_size=3, stride=1, bias=False), - sl.IAF(), - nn.Dropout2d(), - nn.Conv2d(16, 2, kernel_size=3, stride=1, bias=False), - sl.IAF(), - nn.Flatten(), - nn.Linear(8, 5), - sl.IAF(), - ] - - from sinabs.backend.dynapcnn.utils import build_from_list - - chip_model = build_from_list(layers, in_shape=in_shape, discretize=True) - - assert len(chip_model) == 4 - assert chip_model[0].get_output_shape() == (8, 6, 6) - assert chip_model[1].get_output_shape() == (16, 4, 4) - assert chip_model[2].get_output_shape() == (2, 2, 2) - assert chip_model[3].get_output_shape() == (5, 1, 1) - - -def test_missing_spiking_layer(): - in_shape = (2, 28, 28) - layers = [ - nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), - sl.IAF(), - sl.SumPool2d(2), - nn.AvgPool2d(2), - nn.Conv2d(8, 16, kernel_size=3, stride=1, bias=False), - sl.IAF(), - nn.Dropout2d(), - nn.Conv2d(16, 2, kernel_size=3, stride=1, bias=False), - sl.IAF(), - nn.Flatten(), - nn.Linear(8, 5), - ] - from sinabs.backend.dynapcnn.exceptions import MissingLayer - from sinabs.backend.dynapcnn.utils import build_from_list - - with pytest.raises(MissingLayer): - build_from_list(layers, in_shape=in_shape, discretize=True) - - -def test_incorrect_model_start(): - in_shape = (2, 28, 28) - layers = [ - sl.IAF(), - sl.SumPool2d(2), - nn.AvgPool2d(2), - ] - from sinabs.backend.dynapcnn.exceptions import UnexpectedLayer - from sinabs.backend.dynapcnn.utils import construct_next_dynapcnn_layer - - with pytest.raises(UnexpectedLayer): - construct_next_dynapcnn_layer( - layers, 0, in_shape=in_shape, discretize=True, rescale_factor=1 - ) - - -def test_conversion_to_layer_list(): - from sinabs.backend.dynapcnn.utils import DEFAULT_IGNORED_LAYER_TYPES as DEF_IGNORE - from sinabs.backend.dynapcnn.utils import convert_model_to_layer_list + layers = [sl.SumPool2d(2), nn.AvgPool2d(3)] - model = nn.Sequential( - nn.Conv2d(2, 8, 3), - sl.IAF(), - nn.Conv2d(8, 16, 3), - nn.Identity(), - nn.AvgPool2d(2), - nn.Dropout(0.5), - nn.Conv2d(16, 16, 3), - sl.IAF(), - nn.Flatten(), - nn.Linear(64, 4), - sl.IAF(), - ) + from sinabs.backend.dynapcnn.dynapcnn_layer_utils import consolidate_dest_pooling - layer_list = convert_model_to_layer_list(model, ignore=DEF_IGNORE) + cumulative_pooling, scaling = consolidate_dest_pooling(layers) - # Should contain all layers except identity, dropbout, and flatten - assert len(layer_list) == len(model) - 3 - model_indices = (0, 1, 2, 4, 6, 7, 9, 10) - for layer, idx_model in zip(layer_list, model_indices): - assert layer is model[idx_model] + assert cumulative_pooling == [6, 6] + assert scaling == 1.0 / 9 diff --git a/tests/test_dynapcnn/test_config_making.py b/tests/test_dynapcnn/test_config_making.py index 77ae93ff..6e2a8b8b 100644 --- a/tests/test_dynapcnn/test_config_making.py +++ b/tests/test_dynapcnn/test_config_making.py @@ -25,6 +25,7 @@ ) sinabs_model = from_model(ann, add_spiking_output=True, batch_size=1) +# Make sure all states are zero input_shape = (1, 28, 28) hardware_compatible_model = DynapcnnNetwork( @@ -33,25 +34,29 @@ input_shape=input_shape, ) +devices = tuple(ChipFactory.supported_devices.keys()) +devices = [ + "dynapcnndevkit", + "speck2btiny", + "speck2e", + "speck2edevkit", + "speck2fmodule", +] -def test_zero_initial_states(): - for devkit in [ - "dynapcnndevkit", - "speck2btiny", - "speck2e", - "speck2edevkit", - "speck2fmodule", - ]: - config = hardware_compatible_model.make_config("auto", device=devkit) - for idx, lyr in enumerate(config.cnn_layers): - initial_value = torch.tensor(lyr.neurons_initial_value) - shape = initial_value.shape - zeros = torch.zeros(shape, dtype=torch.int) +@pytest.mark.parametrize("device", devices) +def test_zero_initial_states(device): + devkit = device + config = hardware_compatible_model.make_config("auto", device=devkit) + for idx, lyr in enumerate(config.cnn_layers): + initial_value = torch.tensor(lyr.neurons_initial_value) - assert ( - initial_value.all() == zeros.all() - ), f"Initial values of layer{idx} neuron states is not zeros!" + shape = initial_value.shape + zeros = torch.zeros(shape, dtype=torch.int) + + assert ( + initial_value.all() == zeros.all() + ), f"Initial values of layer{idx} neuron states is not zeros!" small_ann = nn.Sequential( @@ -72,23 +77,6 @@ def test_zero_initial_states(): ) -@pytest.mark.parametrize("device", tuple(ChipFactory.supported_devices.keys())) +@pytest.mark.parametrize("device", devices) def test_verify_working_config(device): assert small_hardware_compatible_model.is_compatible_with(device) - - -# Model that is too big to fit on any of our architectures -big_ann = deepcopy(ann) -big_ann.append(nn.ReLU()) -big_ann.append(nn.Linear(10, 999999, bias=False)) - -hardware_incompatible_model = DynapcnnNetwork( - from_model(big_ann, add_spiking_output=True, batch_size=1).cpu(), - discretize=True, - input_shape=input_shape, -) - - -@pytest.mark.parametrize("device", tuple(ChipFactory.supported_devices.keys())) -def test_verify_non_working_config(device): - assert not hardware_incompatible_model.is_compatible_with(device) diff --git a/tests/test_dynapcnn/test_device_movement.py b/tests/test_dynapcnn/test_device_movement.py index a6b9d8a1..99508ba6 100644 --- a/tests/test_dynapcnn/test_device_movement.py +++ b/tests/test_dynapcnn/test_device_movement.py @@ -2,7 +2,6 @@ import torch.nn as nn from sinabs.backend.dynapcnn import DynapcnnNetwork -from sinabs.backend.dynapcnn.mapping import edmonds, make_flow_graph, recover_mapping from sinabs.from_torch import from_model ann = nn.Sequential( @@ -33,7 +32,7 @@ def test_multi_device_movement(): input_shape=input_shape, ) - hardware_compatible_model.to("speck2b:0") + hardware_compatible_model.to("speck2edevkit") print("Second attempt") - hardware_compatible_model.to("speck2b:0") + hardware_compatible_model.to("speck2edevkit") diff --git a/tests/test_dynapcnn/test_discover_device.py b/tests/test_dynapcnn/test_discover_device.py index cb828111..10b33d85 100644 --- a/tests/test_dynapcnn/test_discover_device.py +++ b/tests/test_dynapcnn/test_discover_device.py @@ -3,15 +3,9 @@ from sinabs.backend.dynapcnn import io +pytest.mark.skip("Not suitable for automated testing. Depends on available devices") -@pytest.mark.skip("Not suitable for automated testing. Depends on available devices") -def test_list_all_devices(): - device_map = io.get_device_map() - # Ideally the device map needs to be tested against something expected. - raise NotImplementedError() - -@pytest.mark.skip("Not suitable for automated testing. Depends on available devices") def test_is_device_type(): devices = samna.device.get_all_devices() print([io.is_device_type(d, "dynapcnndevkit") for d in devices]) diff --git a/tests/test_dynapcnn/test_doorbell.py b/tests/test_dynapcnn/test_doorbell.py index fc8040df..4c78ac80 100644 --- a/tests/test_dynapcnn/test_doorbell.py +++ b/tests/test_dynapcnn/test_doorbell.py @@ -3,8 +3,10 @@ It will include testing of the network equivalence, and of the correct output configuration. """ +import pytest import samna import torch +from nirtorch.utils import sanitize_name from torch import nn from sinabs.backend.dynapcnn.dynapcnn_network import DynapcnnNetwork @@ -73,12 +75,29 @@ def test_same_result(): def test_auto_config(): - # - Should give an error with the normal layer ordering dynapcnn_net = DynapcnnNetwork(snn, input_shape=input_shape, discretize=True) dynapcnn_net.make_config(chip_layers_ordering=[0, 1, 2, 3, 4]) + dynapcnn_net.make_config(layer2core_map="auto") def test_was_copied(): # - Make sure that layers of different models are distinct objects - for lyr_snn, lyr_dynapcnn in zip(snn.spiking_model, dynapcnn_net.sequence): - assert lyr_snn is not lyr_dynapcnn + # "Sanitize" all layer names, for compatibility with older nirtorch versions + snn_layers = { + sanitize_name(name): lyr for name, lyr in snn.spiking_model.named_modules() + } + idx_2_name_map = { + idx: sanitize_name(name) for name, idx in dynapcnn_net.name_2_indx_map.items() + } + for idx, lyr_info in dynapcnn_net._graph_extractor.dcnnl_map.items(): + conv_lyr_dynapcnn = dynapcnn_net.dynapcnn_layers[idx].conv_layer + conv_node_idx = lyr_info["conv"]["node_id"] + conv_name = idx_2_name_map[conv_node_idx] + conv_lyr_snn = snn_layers[conv_name] + assert conv_lyr_dynapcnn is not conv_lyr_snn + + spk_lyr_dynapcnn = dynapcnn_net.dynapcnn_layers[idx].spk_layer + spk_node_idx = lyr_info["neuron"]["node_id"] + spk_name = idx_2_name_map[spk_node_idx] + spk_lyr_snn = snn_layers[spk_name] + assert spk_lyr_dynapcnn is not spk_lyr_snn diff --git a/tests/test_dynapcnn/test_dvs_input.py b/tests/test_dynapcnn/test_dvs_input.py index 014ccb1c..44d7f275 100644 --- a/tests/test_dynapcnn/test_dvs_input.py +++ b/tests/test_dynapcnn/test_dvs_input.py @@ -1,11 +1,9 @@ """This should test cases of dynapcnn compatible networks with dvs input.""" from itertools import product -from typing import Optional, Tuple +from typing import Optional, Tuple, Union -import numpy as np import pytest -import samna import torch from torch import nn @@ -13,7 +11,7 @@ from sinabs.backend.dynapcnn.dvs_layer import DVSLayer from sinabs.backend.dynapcnn.exceptions import * from sinabs.from_torch import from_model -from sinabs.layers import IAF +from sinabs.layers import IAFSqueeze INPUT_SHAPE = (2, 16, 16) input_data = torch.rand(1, *INPUT_SHAPE, requires_grad=False) * 100.0 @@ -30,7 +28,7 @@ def verify_dvs_config( origin: Tuple[int, int] = (0, 0), cut: Optional[Tuple[int, int]] = None, destination: Optional[int] = None, - dvs_input: bool = True, + dvs_input: Union[bool, None] = True, flip: Optional[dict] = None, merge_polarities: bool = False, ): @@ -42,9 +40,9 @@ def verify_dvs_config( return if destination is None: - assert dvs.destinations[0].enable == False + assert not dvs.destinations[0].enable else: - assert dvs.destinations[0].enable == True + assert dvs.destinations[0].enable assert dvs.destinations[0].layer == destination if cut is None: assert dvs.cut.y == origin[0] + INPUT_SHAPE[1] // pooling[0] - 1 @@ -80,12 +78,14 @@ def forward(self, x): class NetPool2D(nn.Module): - def __init__(self, input_layer: bool = False): + def __init__(self, add_input_layer: bool = False): super().__init__() - layers = [] + if add_input_layer: + layers = [DVSLayer(input_shape=INPUT_SHAPE[1:])] + else: + layers = [] layers += [ - nn.AvgPool2d(kernel_size=(2, 2)), - nn.AvgPool2d(kernel_size=(1, 2)), + nn.AvgPool2d(kernel_size=(2, 4)), nn.Conv2d(2, 4, kernel_size=2, stride=2), nn.ReLU(), ] @@ -106,7 +106,7 @@ def test_dvs_no_pooling(dvs_input): spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE) # If there is no pooling, a DVSLayer should only be added if `dvs_input` is True - assert isinstance(spn.sequence[0], DVSLayer) == dvs_input + assert spn.has_dvs_layer() == dvs_input # - Make sure missing input shapes cause exception with pytest.raises(InputConfigurationError): @@ -129,36 +129,52 @@ def test_dvs_no_pooling(dvs_input): ) -@pytest.mark.parametrize("dvs_input", (False, True)) -def test_dvs_pooling_2d(dvs_input): +args = product((True, False, None), (True, False)) + + +@pytest.mark.parametrize("dvs_input,add_input_layer", args) +def test_dvs_pooling_2d(dvs_input, add_input_layer): # - ANN and SNN generation - ann = NetPool2D(input_layer=True) + ann = NetPool2D(add_input_layer=add_input_layer) snn = from_model(ann.seq, batch_size=1) snn.eval() # - SPN generation - spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE) + if not dvs_input and not add_input_layer: + # No DVS layer is part of the SNN nor being added to it. The pooling layer should cause an exception + with pytest.raises(InvalidGraphStructure): + spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE) + return - # When there is pooling, a DVSLayer should also be added if `dvs_input` is True - assert isinstance(spn.sequence[0], DVSLayer) + # If `add_input_layer` is False but `dvs_input` is `True`, a DVS layer will + # be added to the DynapcnnNetwork upon instantiation + spn = DynapcnnNetwork(snn, dvs_input=dvs_input, input_shape=INPUT_SHAPE) + assert spn.has_dvs_layer() - # - Make sure missing input shapes cause exception - with pytest.raises(InputConfigurationError): - spn = DynapcnnNetwork(snn, dvs_input=dvs_input) + if not add_input_layer: + # - Make sure missing input shapes cause exception + with pytest.raises(InputConfigurationError): + spn = DynapcnnNetwork(snn, dvs_input=dvs_input) - # - Compare snn and spn outputs - spn_float = DynapcnnNetwork(snn, discretize=False, input_shape=INPUT_SHAPE) + # - Compare snn and spn outputs. - Always add DVS so that pooling layer is properly handled + spn_float = DynapcnnNetwork( + snn, dvs_input=True, discretize=False, input_shape=INPUT_SHAPE + ) snn_out = snn(input_data).squeeze() spn_out = spn_float(input_data).squeeze() assert torch.equal(snn_out.detach(), spn_out) # - Verify DYNAP-CNN config - target_layers = [5] - config = spn.make_config(chip_layers_ordering=target_layers) + # Get index of only DynapcnnLayer to map it to core 5 + cnn_layer_idx = next(spn.dynapcnn_layers.__iter__()) + target_dest = 5 + config = spn.make_config(layer2core_map={cnn_layer_idx: target_dest}) + if dvs_input is None: + dvs_input = not snn.spiking_model[0].disable_pixel_array verify_dvs_config( config, input_shape=INPUT_SHAPE, - destination=target_layers[0], + destination=target_dest, dvs_input=dvs_input, pooling=(2, 4), ) @@ -186,7 +202,7 @@ def __init__( **kwargs_flip, ), nn.Conv2d(n_channels_in, 4, kernel_size=2, stride=2), - IAF(), + IAFSqueeze(batch_size=1), ] self.seq = nn.Sequential(*layers) @@ -291,6 +307,13 @@ def test_whether_dvs_mirror_cfg_is_all_switched_off(dvs_input, pool): snn = nn.Sequential(*layer_list) + if pool and not dvs_input: + with pytest.raises(InvalidGraphStructure): + dynapcnn = DynapcnnNetwork( + snn=snn, input_shape=(1, 128, 128), dvs_input=dvs_input, discretize=True + ) + return + dynapcnn = DynapcnnNetwork( snn=snn, input_shape=(1, 128, 128), dvs_input=dvs_input, discretize=True ) diff --git a/tests/test_dynapcnn/test_dvs_layer.py b/tests/test_dynapcnn/test_dvs_layer.py index a6ada999..2ab12550 100644 --- a/tests/test_dynapcnn/test_dvs_layer.py +++ b/tests/test_dynapcnn/test_dvs_layer.py @@ -77,37 +77,6 @@ def test_from_layers(disable_pixel_array, num_channels): assert dvs_layer.get_roi() == ((0, 59), (0, 54)) -def test_construct_empty(): - from sinabs.backend.dynapcnn.utils import construct_dvs_layer - - layers = [] - - dvs_layer, layer_idx_next, rescale_factor = construct_dvs_layer( - layers, input_shape=(2, 128, 128) - ) - - assert rescale_factor == 1 - assert layer_idx_next == 0 - assert dvs_layer is None - - -def test_construct_from_sumpool(): - import sinabs.layers as sl - from sinabs.backend.dynapcnn.utils import construct_dvs_layer - - layers = [sl.SumPool2d(2), sl.Cropping2dLayer(((1, 1), (1, 1)))] - - dvs_layer, layer_idx_next, rescale_factor = construct_dvs_layer( - layers, input_shape=(2, 128, 128) - ) - - print(dvs_layer) - - assert rescale_factor == 1 - assert layer_idx_next == 2 - assert dvs_layer.get_roi() == ((1, 63), (1, 63)) - - def test_convert_cropping2dlayer_to_crop2d(): import sinabs.layers as sl from sinabs.backend.dynapcnn.utils import convert_cropping2dlayer_to_crop2d diff --git a/tests/test_dynapcnn/test_individual_cases.py b/tests/test_dynapcnn/test_individual_cases.py index 5256daba..d394da7d 100644 --- a/tests/test_dynapcnn/test_individual_cases.py +++ b/tests/test_dynapcnn/test_individual_cases.py @@ -1,5 +1,4 @@ import pytest -import samna import torch from torch import nn @@ -23,8 +22,6 @@ def reset_states(seq): def networks_equal_output(input_data, snn): snn.eval() - snn_out = snn(input_data).squeeze() # forward pass - reset_states(snn) spn = DynapcnnNetwork( snn, @@ -32,15 +29,14 @@ def networks_equal_output(input_data, snn): discretize=False, dvs_input=True, ) - print(spn) + + snn_out = snn(input_data).squeeze() # forward pass spn_out = spn(input_data).squeeze() - print(snn_out.sum(), spn_out.sum()) assert torch.equal(snn_out, spn_out) # this will give an error if the config is not compatible config = spn.make_config() - print(spn.chip_layers_ordering) return config @@ -61,10 +57,9 @@ def forward(self, x): snn = from_model(Net().seq, batch_size=1) snn.eval() - snn_out = snn(input_data).squeeze() # forward pass - - snn.reset_states() spn = DynapcnnNetwork(snn, input_shape=input_data.shape[1:], discretize=False) + + snn_out = snn(input_data).squeeze() # forward pass spn_out = spn(input_data).squeeze() assert torch.equal(snn_out, spn_out) @@ -201,27 +196,26 @@ def test_no_spk_ending(): nn.Linear(512, 2), ) - from sinabs.backend.dynapcnn.exceptions import MissingLayer + from sinabs.backend.dynapcnn.exceptions import InvalidGraphStructure - with pytest.raises(MissingLayer): + with pytest.raises(InvalidGraphStructure): DynapcnnNetwork(seq, input_shape=input_data.shape[1:], discretize=False) def test_no_spk_middle(): + from sinabs.backend.dynapcnn.exceptions import InvalidEdge + seq = nn.Sequential( nn.Flatten(), nn.Linear(512, 10), nn.Linear(10, 2), IAFSqueeze(batch_size=1) ) - with pytest.raises(TypeError): + with pytest.raises(InvalidEdge): DynapcnnNetwork(seq, input_shape=input_data.shape[1:], discretize=False) def test_no_conv_layers(): - seq = nn.Sequential() - from sinabs.backend.dynapcnn.dvs_layer import DVSLayer - from sinabs.backend.dynapcnn.utils import infer_input_shape - net = DynapcnnNetwork(snn=seq, input_shape=(2, 10, 10), dvs_input=True) - - assert isinstance(net.sequence[0], DVSLayer) + net = DynapcnnNetwork( + nn.Sequential(DVSLayer(input_shape=(10, 10))), input_shape=(2, 10, 10) + ) diff --git a/tests/test_dynapcnn/test_large_net.py b/tests/test_dynapcnn/test_large_net.py index 3bc8681b..c464a975 100644 --- a/tests/test_dynapcnn/test_large_net.py +++ b/tests/test_dynapcnn/test_large_net.py @@ -88,21 +88,33 @@ def test_same_result(): assert torch.equal(dynapcnn_out.squeeze(), snn_out.squeeze()) -def test_too_large(): - with pytest.raises(ValueError): - # - Should give an error with the normal layer ordering - dynapcnn_net.make_config(chip_layers_ordering=range(9)) - - def test_auto_config(): # - Should give an error with the normal layer ordering dynapcnn_net.make_config(chip_layers_ordering="auto") def test_was_copied(): + from nirtorch.utils import sanitize_name + # - Make sure that layers of different models are distinct objects - for lyr_snn, lyr_dynapcnn in zip(snn.spiking_model, dynapcnn_net.sequence): - assert lyr_snn is not lyr_dynapcnn + snn_layers = { + sanitize_name(name): lyr for name, lyr in snn.spiking_model.named_modules() + } + idx_2_name_map = { + idx: sanitize_name(name) for name, idx in dynapcnn_net.name_2_indx_map.items() + } + for idx, lyr_info in dynapcnn_net._graph_extractor.dcnnl_map.items(): + conv_lyr_dynapcnn = dynapcnn_net.dynapcnn_layers[idx].conv_layer + conv_node_idx = lyr_info["conv"]["node_id"] + conv_name = idx_2_name_map[conv_node_idx] + conv_lyr_snn = snn_layers[conv_name] + assert conv_lyr_dynapcnn is not conv_lyr_snn + + spk_lyr_dynapcnn = dynapcnn_net.dynapcnn_layers[idx].spk_layer + spk_node_idx = lyr_info["neuron"]["node_id"] + spk_name = idx_2_name_map[spk_node_idx] + spk_lyr_snn = snn_layers[spk_name] + assert spk_lyr_dynapcnn is not spk_lyr_snn def test_make_config(): @@ -162,6 +174,7 @@ def test_extended_readout_layer(out_channels: int): ) extended_net = extend_readout_layer(dynapcnn_net) - converted_channels = extended_net.sequence[-1].conv_layer.out_channels + assert len(exit_layers := extended_net.exit_layers) == 1 + converted_channels = exit_layers[0].conv_layer.out_channels assert (out_channels - 1) * 4 + 1 == converted_channels diff --git a/tests/test_dynapcnn/test_monitoring.py b/tests/test_dynapcnn/test_monitoring.py index aaad2031..cb589aca 100644 --- a/tests/test_dynapcnn/test_monitoring.py +++ b/tests/test_dynapcnn/test_monitoring.py @@ -72,15 +72,14 @@ def test_default_monitoring(): # As a default the last layer should be monitored config = dynapcnn_net.make_config(device="speck2b:0") - clo = dynapcnn_net.chip_layers_ordering - assert len(clo) > 0 + l2c = dynapcnn_net.layer2core_map + assert len(l2c) > 0 # Check that monitoring is off for all layers except last - for layer in clo[:-1]: - if layer == "dvs": - assert config.dvs_layer.monitor_enable == False + for layer, core in l2c.items(): + if layer in dynapcnn_net.exit_layer_ids: + assert config.cnn_layers[core].monitor_enable == True else: - assert config.cnn_layers[layer].monitor_enable == False - assert config.cnn_layers[clo[-1]].monitor_enable == True + assert config.cnn_layers[core].monitor_enable == False def test_model_level_monitoring_enable(): @@ -98,12 +97,13 @@ def test_model_level_monitoring_enable(): config = dynapcnn_net.make_config( device="speck2b:0", monitor_layers=["dvs", 5, -1] ) - clo = dynapcnn_net.chip_layers_ordering - assert len(clo) > 0 + l2c = dynapcnn_net.layer2core_map + assert len(l2c) > 0 assert config.dvs_layer.monitor_enable == True - assert config.cnn_layers[clo[5]].monitor_enable == True - assert config.cnn_layers[clo[-1]].monitor_enable == True + assert config.cnn_layers[l2c[5]].monitor_enable == True + for idx in dynapcnn_net.exit_layer_ids: + assert config.cnn_layers[l2c[idx]].monitor_enable == True # Specify layers to monitor - should not warn becuase final layer has no pooling with warnings.catch_warnings(): @@ -112,4 +112,4 @@ def test_model_level_monitoring_enable(): # Monitor all layers config = dynapcnn_net.make_config(device="speck2b:0", monitor_layers="all") - assert all(config.cnn_layers[i].monitor_enable == True for i in clo) + assert all(config.cnn_layers[i].monitor_enable == True for i in l2c.values()) diff --git a/tests/test_dynapcnn/test_neuron_leak.py b/tests/test_dynapcnn/test_neuron_leak.py index 4acacdbc..cf4adc49 100644 --- a/tests/test_dynapcnn/test_neuron_leak.py +++ b/tests/test_dynapcnn/test_neuron_leak.py @@ -53,9 +53,9 @@ def test_neuron_leak_config(): snn=snn, discretize=True, dvs_input=True, input_shape=(1, 64, 64) ) samna_cfg = dynapcnn.make_config(device="speck2fmodule") - chip_layers_order = dynapcnn.chip_layers_ordering + layer2core_map = dynapcnn.layer2core_map - for lyr, channel_num in zip(chip_layers_order, [2, 8, 16]): + for lyr, channel_num in zip(layer2core_map.values(), [2, 8, 16]): assert samna_cfg.cnn_layers[lyr].leak_enable is True assert len(samna_cfg.cnn_layers[lyr].biases) == channel_num @@ -124,6 +124,8 @@ def test_neuron_leak(): pre_neuron_state = neuron_states.get((c, x, y), 127) assert ( pre_neuron_state > out_ev.neuron_state + # If `pre_neuron_state` is already at minimum, it can't leak further + or pre_neuron_state == -127 ), "Neuron V_Mem doesn't decrease!" neuron_states.update({(c, x, y): out_ev.neuron_state}) print(f"c:{c}, x:{x}, y:{y}, vmem:{out_ev.neuron_state}") diff --git a/tests/test_dynapcnn/test_single_neuron_hardware.py b/tests/test_dynapcnn/test_single_neuron_hardware.py index 6d8fe6c6..a81ac2ad 100644 --- a/tests/test_dynapcnn/test_single_neuron_hardware.py +++ b/tests/test_dynapcnn/test_single_neuron_hardware.py @@ -42,9 +42,9 @@ def test_deploy_dynapcnnnetwork(): model = get_ones_network() sinabs.reset_states(model) - assert model.sequence[0].conv_layer.weight.sum() == 127 - assert model.sequence[0].spk_layer.spike_threshold == 127 - assert model.sequence[0].spk_layer.v_mem.sum() == 0 + assert model.dynapcnn_layers[0].conv_layer.weight.sum() == 127 + assert model.dynapcnn_layers[0].spk_layer.spike_threshold == 127 + assert model.dynapcnn_layers[0].spk_layer.v_mem.sum() == 0 model_output = model(torch.ones((1, 1, 1, 1))) assert model_output.sum() == 1 diff --git a/tests/test_dynapcnn/test_speck2e.py b/tests/test_dynapcnn/test_speck2e.py index f6c6a434..9939adc3 100644 --- a/tests/test_dynapcnn/test_speck2e.py +++ b/tests/test_dynapcnn/test_speck2e.py @@ -42,6 +42,5 @@ def test_speck2e_coordinates(): def test_dvs_layer_generation(): """DVSLayer should be generated is dvs input is enabled even for an empty network.""" - ann = nn.Sequential() network = DynapcnnNetwork(nn.Sequential(), input_shape=(2, 10, 10), dvs_input=True) - assert isinstance(network.sequence[0], DVSLayer) + assert isinstance(network.dvs_layer, DVSLayer) diff --git a/tests/test_dynapcnn/test_speckmini_config_making.py b/tests/test_dynapcnn/test_speckmini_config_making.py index 4ffd28c6..502fa930 100644 --- a/tests/test_dynapcnn/test_speckmini_config_making.py +++ b/tests/test_dynapcnn/test_speckmini_config_making.py @@ -70,16 +70,12 @@ def test_auto_mapping(): for test_device in devices: # test weights/kernel memory mapping - _ = SNN_KERNEL_MEM_TEST.make_config( - chip_layers_ordering="auto", device=test_device - ) - assert SNN_KERNEL_MEM_TEST.chip_layers_ordering == [0, 1, 3, 2, 4] + _ = SNN_KERNEL_MEM_TEST.make_config(layer2core_map="auto", device=test_device) + assert SNN_KERNEL_MEM_TEST.layer2core_map == {0: 0, 1: 1, 2: 3, 3: 2, 4: 4} # test neuron memory mapping - _ = SNN_NEURON_MEM_TEST.make_config( - chip_layers_ordering="auto", device=test_device - ) - assert SNN_NEURON_MEM_TEST.chip_layers_ordering == [2, 0, 1, 4, 3] + _ = SNN_NEURON_MEM_TEST.make_config(layer2core_map="auto", device=test_device) + assert SNN_NEURON_MEM_TEST.layer2core_map == {0: 2, 1: 0, 2: 1, 3: 4, 4: 3} def test_manual_mapping(): @@ -87,15 +83,15 @@ def test_manual_mapping(): for test_device in devices: # test weights/kernel memory mapping - chip_layers_order = [4, 2, 3, 1, 0] + layer2core_map = {0: 4, 1: 2, 2: 3, 3: 1, 4: 0} _ = SNN_KERNEL_MEM_TEST.make_config( - chip_layers_ordering=chip_layers_order, device=test_device + layer2core_map=layer2core_map, device=test_device ) - assert SNN_KERNEL_MEM_TEST.chip_layers_ordering == chip_layers_order + assert SNN_KERNEL_MEM_TEST.layer2core_map == layer2core_map # test neuron memory mapping - chip_layers_order = [1, 0, 2, 3, 4] + chip_layers_order = {0: 1, 1: 0, 2: 2, 3: 3, 4: 4} _ = SNN_NEURON_MEM_TEST.make_config( - chip_layers_ordering=chip_layers_order, device=test_device + layer2core_map=chip_layers_order, device=test_device ) - assert SNN_NEURON_MEM_TEST.chip_layers_ordering == chip_layers_order + assert SNN_NEURON_MEM_TEST.layer2core_map == chip_layers_order diff --git a/tests/test_dynapcnn/test_visualizer.py b/tests/test_dynapcnn/test_visualizer.py index 0873f318..dd7cce81 100644 --- a/tests/test_dynapcnn/test_visualizer.py +++ b/tests/test_dynapcnn/test_visualizer.py @@ -39,14 +39,12 @@ def get_demo_dynapcnn_network(): import torch.nn as nn import sinabs - from sinabs.backend.dynapcnn import DynapcnnCompatibleNetwork + from sinabs.backend.dynapcnn import DynapcnnNetwork ann = nn.Sequential(nn.Conv2d(2, 8, (3, 3)), nn.ReLU(), nn.AvgPool2d((2, 2))) snn = sinabs.from_model(ann, input_shape=(2, 64, 64), batch_size=1) - dynapcnn_network = DynapcnnCompatibleNetwork( - snn=snn, input_shape=(2, 64, 64), dvs_input=True - ) + dynapcnn_network = DynapcnnNetwork(snn=snn, input_shape=(2, 64, 64), dvs_input=True) return dynapcnn_network diff --git a/tests/test_dynapcnnlayer/__init__.py b/tests/test_dynapcnnlayer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_dynapcnnlayer/conftest_dynapcnnlayer.py b/tests/test_dynapcnnlayer/conftest_dynapcnnlayer.py new file mode 100644 index 00000000..6b030609 --- /dev/null +++ b/tests/test_dynapcnnlayer/conftest_dynapcnnlayer.py @@ -0,0 +1,19 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + +from .model_dummy_1 import dcnnl_map_1, expected_output_1 +from .model_dummy_2 import dcnnl_map_2, expected_output_2 +from .model_dummy_3 import dcnnl_map_3, expected_output_3 +from .model_dummy_4 import dcnnl_map_4, expected_output_4 + +# Args: dcnnl_map, discretize, expected_output +args_DynapcnnLayer = [ + (dcnnl_map_1, True, expected_output_1), + (dcnnl_map_1, False, expected_output_1), + (dcnnl_map_2, True, expected_output_2), + (dcnnl_map_2, False, expected_output_2), + (dcnnl_map_3, True, expected_output_3), + (dcnnl_map_3, False, expected_output_3), + (dcnnl_map_4, True, expected_output_4), + (dcnnl_map_4, False, expected_output_4), +] diff --git a/tests/test_dynapcnnlayer/model_dummy_1.py b/tests/test_dynapcnnlayer/model_dummy_1.py new file mode 100644 index 00000000..ffe5bac6 --- /dev/null +++ b/tests/test_dynapcnnlayer/model_dummy_1.py @@ -0,0 +1,194 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "a network with residual connections" example in https://github.com/synsense/sinabs/issues/181 + +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze + +dcnnl_map_1 = { + 0: { + "input_shape": (2, 34, 34), + "rescale_factors": set(), + "is_entry_node": True, + "conv": { + "module": nn.Conv2d(2, 10, kernel_size=(2, 2), stride=[1, 1], bias=False), + "node_id": 0, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=3, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 1, + }, + "destinations": [ + { + "pooling_ids": [2], + "pooling_modules": [nn.AvgPool2d(kernel_size=3, stride=3, padding=0)], + "destination_layer": 1, + "output_shape": (10, 11, 11), + }, + { + "pooling_ids": [3], + "pooling_modules": [ + nn.AvgPool2d(kernel_size=4, stride=4, padding=0), + ], + "destination_layer": 2, + "output_shape": (10, 8, 8), + }, + ], + }, + 1: { + "input_shape": (10, 11, 11), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(10, 10, kernel_size=(4, 4), stride=[1, 1], bias=False), + "node_id": 4, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=3, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 6, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 3, + "output_shape": (10, 7, 7), + }, + ], + }, + 2: { + "input_shape": (10, 8, 8), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(10, 1, kernel_size=(2, 2), stride=[1, 1], bias=False), + "node_id": 7, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=3, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 8, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 3, + "output_shape": (1, 7, 7), + }, + ], + }, + 3: { + "input_shape": (1, 7, 7), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Linear(in_features=49, out_features=500, bias=False), + "node_id": 9, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=3, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 10, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 4, + "output_shape": (500, 1, 1), + }, + ], + }, + 4: { + "input_shape": (500, 1, 1), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Linear(in_features=500, out_features=10, bias=False), + "node_id": 11, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=3, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 12, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": None, + } + ], + }, +} + +expected_output_1 = { + 0: { + "input_shape": (2, 34, 34), + "pool": [[3, 3], [4, 4]], + "rescale_factor": 1, + "rescale_factors": set(), + "entry_node": True, + }, + 1: { + "input_shape": (10, 11, 11), + "pool": [[1, 1]], + "rescale_factor": 1.0 / 9, + "rescale_factors": set(), # Single factor will be popped from list + "entry_node": False, + }, + 2: { + "input_shape": (10, 8, 8), + "pool": [[1, 1]], + "rescale_factor": 1.0 / 16, + "rescale_factors": set(), # Single factor will be popped from list + "entry_node": False, + }, + 3: { + "input_shape": (1, 7, 7), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + "entry_node": False, + }, + 4: { + "input_shape": (500, 1, 1), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + "entry_node": False, + }, + "entry_points": {0}, + "destination_map": { + 0: [1, 2], + 1: [3], + 2: [3], + 3: [4], + 4: [-1], + }, +} diff --git a/tests/test_dynapcnnlayer/model_dummy_2.py b/tests/test_dynapcnnlayer/model_dummy_2.py new file mode 100644 index 00000000..aa0e086e --- /dev/null +++ b/tests/test_dynapcnnlayer/model_dummy_2.py @@ -0,0 +1,255 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "a network with a merge and a split" in https://github.com/synsense/sinabs/issues/181 + +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, SumPool2d + +dcnnl_map_2 = { + 0: { + "input_shape": (2, 34, 34), + "rescale_factors": set(), + "is_entry_node": True, + "conv": { + "module": nn.Conv2d(2, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 0, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=8, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 1, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 1, + "output_shape": (4, 33, 33), + }, + ], + }, + 1: { + "input_shape": (4, 33, 33), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 2, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=8, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 3, + }, + "destinations": [ + { + "pooling_ids": [4], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False), + ], + "destination_layer": 2, + "output_shape": (4, 16, 16), + }, + { + "pooling_ids": [4], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False), + ], + "destination_layer": 3, + "output_shape": (4, 16, 16), + }, + ], + }, + 2: { + "input_shape": (4, 16, 16), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 5, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=8, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 7, + }, + "destinations": [ + { + "pooling_ids": [8], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False), + ], + "destination_layer": 4, + "output_shape": (4, 7, 7), + }, + ], + }, + 3: { + "input_shape": (4, 16, 16), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 6, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=8, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 11, + }, + "destinations": [ + { + "pooling_ids": [12], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False), + ], + "destination_layer": 6, + "output_shape": (4, 7, 7), + }, + ], + }, + 4: { + "input_shape": (4, 7, 7), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 9, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=8, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 10, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 5, + "output_shape": (4, 6, 6), + }, + ], + }, + 5: { + "input_shape": (4, 6, 6), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Linear(in_features=144, out_features=10, bias=False), + "node_id": 15, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=8, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 16, + }, + "destinations": [], + }, + 6: { + "input_shape": (4, 7, 7), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 13, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=8, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 14, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 5, + "output_shape": (4, 6, 6), + }, + ], + }, +} + +expected_output_2 = { + 0: { + "input_shape": (2, 34, 34), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 1: { + "input_shape": (4, 33, 33), + "pool": [[2, 2], [2, 2]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 2: { + "input_shape": (4, 16, 16), + "pool": [[2, 2]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 3: { + "input_shape": (4, 16, 16), + "pool": [[2, 2]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 4: { + "input_shape": (4, 7, 7), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 5: { + "input_shape": (4, 6, 6), + "pool": [], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 6: { + "input_shape": (4, 7, 7), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + "entry_points": {0}, + "destination_map": { + 0: [1], + 1: [2, 3], + 2: [4], + 3: [6], + 4: [5], + 6: [5], + 5: [], + }, +} diff --git a/tests/test_dynapcnnlayer/model_dummy_3.py b/tests/test_dynapcnnlayer/model_dummy_3.py new file mode 100644 index 00000000..80564399 --- /dev/null +++ b/tests/test_dynapcnnlayer/model_dummy_3.py @@ -0,0 +1,321 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "two networks with merging outputs" in https://github.com/synsense/sinabs/issues/181 + +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, SumPool2d + +dcnnl_map_3 = { + 0: { + "input_shape": (2, 34, 34), + "rescale_factors": set(), + "is_entry_node": True, + "conv": { + "module": nn.Conv2d(2, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 0, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 1, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 1, + "output_shape": (4, 33, 33), + }, + ], + }, + 1: { + "input_shape": (4, 33, 33), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 2, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 2, + }, + "destinations": [ + { + "pooling_ids": [4], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False) + ], + "destination_layer": 2, + "output_shape": (4, 16, 16), + }, + ], + }, + 2: { + "input_shape": (4, 16, 16), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 5, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 6, + }, + "destinations": [ + { + "pooling_ids": [7], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False) + ], + "destination_layer": 3, + "output_shape": (4, 7, 7), + }, + ], + }, + 3: { + "input_shape": (4, 7, 7), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Linear(in_features=196, out_features=100, bias=False), + "node_id": 17, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 18, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 7, + "output_shape": (100, 1, 1), + }, + ], + }, + 4: { + "input_shape": (2, 34, 34), + "rescale_factors": set(), + "is_entry_node": True, + "conv": { + "module": nn.Conv2d(2, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 8, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 9, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 5, + "output_shape": (4, 33, 33), + }, + ], + }, + 5: { + "input_shape": (4, 33, 33), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 10, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 11, + }, + "destinations": [ + { + "pooling_ids": [12], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False) + ], + "destination_layer": 6, + "output_shape": (4, 16, 16), + }, + ], + }, + 6: { + "input_shape": (4, 16, 16), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(4, 4, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 13, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 14, + }, + "destinations": [ + { + "pooling_ids": [15], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False) + ], + "destination_layer": 3, + "output_shape": (4, 7, 7), + }, + ], + }, + 7: { + "input_shape": (100, 1, 1), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Linear(in_features=100, out_features=100, bias=False), + "node_id": 19, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 20, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 8, + "output_shape": (100, 1, 1), + }, + ], + }, + 8: { + "input_shape": (100, 1, 1), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Linear(in_features=100, out_features=10, bias=False), + "node_id": 21, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 22, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": None, + } + ], + }, +} + +expected_output_3 = { + 0: { + "input_shape": (2, 34, 34), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 1: { + "input_shape": (4, 33, 33), + "pool": [[2, 2]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 2: { + "input_shape": (4, 16, 16), + "pool": [[2, 2]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 3: { + "input_shape": (4, 7, 7), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 4: { + "input_shape": (2, 34, 34), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 5: { + "input_shape": (4, 33, 33), + "pool": [[2, 2]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 6: { + "input_shape": (4, 16, 16), + "pool": [[2, 2]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 7: { + "input_shape": (100, 1, 1), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 8: { + "input_shape": (100, 1, 1), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + "entry_points": {0, 4}, + "destination_map": { + 0: [1], + 1: [2], + 2: [3], + 3: [7], + 4: [5], + 5: [6], + 6: [3], + 7: [8], + 8: [-1], + }, +} diff --git a/tests/test_dynapcnnlayer/model_dummy_4.py b/tests/test_dynapcnnlayer/model_dummy_4.py new file mode 100644 index 00000000..4dc1e223 --- /dev/null +++ b/tests/test_dynapcnnlayer/model_dummy_4.py @@ -0,0 +1,228 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "a complex network structure" example in https://github.com/synsense/sinabs/issues/181 . """ + +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, SumPool2d + +dcnnl_map_4 = { + 0: { + "input_shape": (2, 34, 34), + "rescale_factors": set(), + "is_entry_node": True, + "conv": { + "module": nn.Conv2d(2, 1, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 0, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 1, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 1, + "output_shape": (1, 33, 33), + }, + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 2, + "output_shape": (1, 33, 33), + }, + ], + }, + 1: { + "input_shape": (1, 33, 33), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(1, 1, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 2, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 4, + }, + "destinations": [ + { + "pooling_ids": [5], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False) + ], + "destination_layer": 3, + "output_shape": (1, 16, 16), + }, + ], + }, + 2: { + "input_shape": (1, 33, 33), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(1, 1, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 3, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 7, + }, + "destinations": [ + { + "pooling_ids": [8], + "pooling_modules": [ + SumPool2d(kernel_size=2, stride=2, ceil_mode=False) + ], + "destination_layer": 3, + "output_shape": (1, 16, 16), + }, + { + "pooling_ids": [9], + "pooling_modules": [ + SumPool2d(kernel_size=5, stride=5, ceil_mode=False) + ], + "destination_layer": 4, + "output_shape": (1, 6, 6), + }, + ], + }, + 3: { + "input_shape": (1, 16, 16), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(1, 1, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 11, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 12, + }, + "destinations": [ + { + "pooling_ids": [13], + "pooling_modules": [ + SumPool2d(kernel_size=3, stride=3, ceil_mode=False) + ], + "destination_layer": 5, + "output_shape": (1, 5, 5), + }, + ], + }, + 4: { + "input_shape": (1, 6, 6), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Conv2d(1, 1, kernel_size=(2, 2), stride=(1, 1), bias=False), + "node_id": 10, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 12, + }, + "destinations": [ + { + "pooling_ids": [], + "pooling_modules": [], + "destination_layer": 5, + "output_shape": (1, 5, 5), + }, + ], + }, + 5: { + "input_shape": (1, 5, 5), + "rescale_factors": set(), + "is_entry_node": False, + "conv": { + "module": nn.Linear(in_features=25, out_features=10, bias=False), + "node_id": 16, + }, + "neuron": { + "module": IAFSqueeze( + batch_size=2, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ), + "node_id": 17, + }, + "destinations": [], + }, +} + +expected_output_4 = { + 0: { + "input_shape": (2, 34, 34), + "pool": [[1, 1], [1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 1: { + "input_shape": (1, 33, 33), + "pool": [[2, 2]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 2: { + "input_shape": (1, 33, 33), + "pool": [[2, 2], [5, 5]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 3: { + "input_shape": (1, 16, 16), + "pool": [[3, 3]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 4: { + "input_shape": (1, 6, 6), + "pool": [[1, 1]], + "rescale_factor": 1, + "rescale_factors": set(), + }, + 5: { + "input_shape": (1, 5, 5), + "pool": [], + "rescale_factor": 1, + "rescale_factors": set(), + }, + "entry_points": {0}, + "destination_map": { + 0: [1, 2], + 1: [3], + 2: [3, 4], + 3: [5], + 4: [5], + 5: [], + }, +} diff --git a/tests/test_dynapcnnlayer/test_dynapcnnlayer.py b/tests/test_dynapcnnlayer/test_dynapcnnlayer.py new file mode 100644 index 00000000..9701607e --- /dev/null +++ b/tests/test_dynapcnnlayer/test_dynapcnnlayer.py @@ -0,0 +1,71 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + +import pytest + +from sinabs.backend.dynapcnn.dynapcnn_layer_utils import ( + construct_dynapcnnlayers_from_mapper, +) + +from .conftest_dynapcnnlayer import args_DynapcnnLayer + + +@pytest.mark.parametrize( + "dcnnl_map, discretize, expected_output", + args_DynapcnnLayer, +) +def test_DynapcnnLayer(dcnnl_map, discretize, expected_output): + """Tests the instantiation of a set of `DynapcnnLayer` belonging to the same SNN and the data computed + within their constructors and shared among the differntly interacting instances (according to the graph + described by `sinabs_edges`). + """ + + # create a `DynapcnnLayer` from the set of layers in `nodes_to_dcnnl_map[dpcnnl_idx]`. + dynapcnn_layers, destination_map, entry_points = ( + construct_dynapcnnlayers_from_mapper( + dcnnl_map=dcnnl_map, + discretize=discretize, + rescale_fn=None, + dvs_layer_info=None, + ) + ) + + for layer_index, dynapcnn_layer in dynapcnn_layers.items(): + + # Test layer instance + in_shape = expected_output[layer_index]["input_shape"] + pool = expected_output[layer_index]["pool"] + rescale_weights = expected_output[layer_index]["rescale_factor"] + + assert ( + tuple(dynapcnn_layer.in_shape) == in_shape + ), f"wrong 'DynapcnnLayer.in_shape': Should be {in_shape}." + assert ( + dynapcnn_layer.discretize == discretize + ), f"wrong 'DynapcnnLayer.discretize': Should be {discretize}." + in_shape = expected_output[layer_index]["input_shape"] + assert ( + dynapcnn_layer.pool == pool + ), f"wrong 'DynapcnnLayer.pool': Should be {pool}." + in_shape = expected_output[layer_index]["input_shape"] + assert ( + dynapcnn_layer.rescale_weights == rescale_weights + ), f"wrong 'DynapcnnLayer.in_shape': Should be {rescale_weights}." + + # Test entries in layer info that are not directly repeated in layer or handler instances + layer_info = dcnnl_map[layer_index] + rescale_factors = expected_output[layer_index]["rescale_factors"] + + assert ( + layer_info["rescale_factors"] == rescale_factors + ), f"wrong 'rescale_factors' entry: Should be {rescale_factors}." + + # # Convert destination lists to sets to ignore order + # destination_map = {node: set(dests) for node, dests in destination_map.items()} + # Test destination map + assert ( + destination_map == expected_output["destination_map"] + ), "wrong destination map" + + # Test entry point + assert entry_points == expected_output["entry_points"], "wrong entry points" diff --git a/tests/test_dynapcnnnetwork/__init__.py b/tests/test_dynapcnnnetwork/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_dynapcnnnetwork/conftest_dynapcnnnetwork.py b/tests/test_dynapcnnnetwork/conftest_dynapcnnnetwork.py new file mode 100644 index 00000000..7eaad023 --- /dev/null +++ b/tests/test_dynapcnnnetwork/conftest_dynapcnnnetwork.py @@ -0,0 +1,35 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + +from .model_dummy_1 import batch_size as batch_size_1 +from .model_dummy_1 import expected_output as expected_output_1 +from .model_dummy_1 import input_shape as input_shape_1 +from .model_dummy_1 import snn as snn_1 +from .model_dummy_2 import batch_size as batch_size_2 +from .model_dummy_2 import expected_output as expected_output_2 +from .model_dummy_2 import input_shape as input_shape_2 +from .model_dummy_2 import snn as snn_2 +from .model_dummy_3 import batch_size as batch_size_3 +from .model_dummy_3 import expected_output as expected_output_3 +from .model_dummy_3 import input_shape as input_shape_3 +from .model_dummy_3 import snn as snn_3 +from .model_dummy_4 import batch_size as batch_size_4 +from .model_dummy_4 import expected_output as expected_output_4 +from .model_dummy_4 import input_shape as input_shape_4 +from .model_dummy_4 import snn as snn_4 +from .model_dummy_seq import ( + expected_seq_1, + expected_seq_2, + input_shape_seq, + seq_1, + seq_2, +) + +args_DynapcnnNetworkTest = [ + (snn_1, input_shape_1, batch_size_1, expected_output_1), + (snn_2, input_shape_2, batch_size_2, expected_output_2), + (snn_3, input_shape_3, batch_size_3, expected_output_3), + (snn_4, input_shape_4, batch_size_4, expected_output_4), + (seq_1, input_shape_seq, 1, expected_seq_1), + (seq_2, input_shape_seq, 1, expected_seq_2), +] diff --git a/tests/test_dynapcnnnetwork/model_dummy_1.py b/tests/test_dynapcnnnetwork/model_dummy_1.py new file mode 100644 index 00000000..c0ad5737 --- /dev/null +++ b/tests/test_dynapcnnnetwork/model_dummy_1.py @@ -0,0 +1,119 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "a network with residual connections" example in https://github.com/synsense/sinabs/issues/181 + +import torch +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, Merge + + +class SNN(nn.Module): + def __init__(self, batch_size) -> None: + super().__init__() + + self.conv1 = nn.Conv2d(2, 10, 2, 1, bias=False) # node 0 + self.iaf1 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 1 + self.pool1 = nn.AvgPool2d(3, 3) # node 2 + self.pool1a = nn.AvgPool2d(4, 4) # node 3 + + self.conv2 = nn.Conv2d(10, 10, 4, 1, bias=False) # node 4 + self.iaf2 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 6 + + self.conv3 = nn.Conv2d(10, 1, 2, 1, bias=False) # node 8 + self.iaf3 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 9 + + self.flat = nn.Flatten() + + self.fc1 = nn.Linear(49, 500, bias=False) # node 10 + self.iaf4 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 11 + + self.fc2 = nn.Linear(500, 10, bias=False) # node 12 + self.iaf5 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 13 + + self.adder = Merge() + + def forward(self, x): + + con1_out = self.conv1(x) + iaf1_out = self.iaf1(con1_out) + pool1_out = self.pool1(iaf1_out) + pool1a_out = self.pool1a(iaf1_out) + + conv2_out = self.conv2(pool1_out) + iaf2_out = self.iaf2(conv2_out) + + conv3_out = self.conv3(self.adder(pool1a_out, iaf2_out)) + iaf3_out = self.iaf3(conv3_out) + + flat_out = self.flat(iaf3_out) + + fc1_out = self.fc1(flat_out) + iaf4_out = self.iaf4(fc1_out) + fc2_out = self.fc2(iaf4_out) + iaf5_out = self.iaf5(fc2_out) + + return iaf5_out + + +channels = 2 +height = 34 +width = 34 +batch_size = 3 +input_shape = (channels, height, width) + +snn = SNN(batch_size) + +expected_output = { + "dcnnl_edges": { + (0, 1), + (0, 2), + (1, 2), + (2, 3), + (3, 4), + ("input", 0), + }, + "node_source_map": { + 0: {"input"}, + 1: {0}, + 2: {0, 1}, + 3: {2}, + 4: {3}, + }, + "destination_map": { + 0: {1, 2}, + 1: {2}, + 2: {3}, + 3: {4}, + 4: {-1}, + }, + "entry_points": {0}, + "sorted_nodes": [0, 1, 2, 3, 4], + "output_shape": torch.Size([3, 10, 1, 1]), +} diff --git a/tests/test_dynapcnnnetwork/model_dummy_2.py b/tests/test_dynapcnnnetwork/model_dummy_2.py new file mode 100644 index 00000000..22e645cc --- /dev/null +++ b/tests/test_dynapcnnnetwork/model_dummy_2.py @@ -0,0 +1,161 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "a network with a merge and a split" in https://github.com/synsense/sinabs/issues/181 + +import torch +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, Merge, SumPool2d + + +class SNN(nn.Module): + def __init__(self, batch_size) -> None: + super().__init__() + # -- graph node A -- + self.conv_A = nn.Conv2d(2, 4, 2, 1, bias=False) + self.iaf_A = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + # -- graph node B -- + self.conv_B = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf2_B = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_B = SumPool2d(2, 2) + # -- graph node C -- + self.conv_C = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_C = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_C = SumPool2d(2, 2) + # -- graph node D -- + self.conv_D = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_D = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + # -- graph node E -- + self.conv_E = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf3_E = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_E = SumPool2d(2, 2) + # -- graph node F -- + self.conv_F = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_F = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + # -- graph node G -- + self.fc3 = nn.Linear(144, 10, bias=False) + self.iaf3_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + # -- merges -- + self.merge1 = Merge() + + # -- falts -- + self.flat_D = nn.Flatten() + self.flat_F = nn.Flatten() + + def forward(self, x): + # conv 1 - A/0 + convA_out = self.conv_A(x) # node 0 + iaf_A_out = self.iaf_A(convA_out) # node 1 + + # conv 2 - B/1 + conv_B_out = self.conv_B(iaf_A_out) # node 2 + iaf_B_out = self.iaf2_B(conv_B_out) # node 3 + pool_B_out = self.pool_B(iaf_B_out) # node 4 + + # conv 3 - C/2 + conv_C_out = self.conv_C(pool_B_out) # node 5 + iaf_C_out = self.iaf_C(conv_C_out) # node 7 + pool_C_out = self.pool_C(iaf_C_out) # node 8 + + # conv 4 - D/4 + conv_D_out = self.conv_D(pool_C_out) # node 9 + iaf_D_out = self.iaf_D(conv_D_out) # node 10 + + # fc 1 - E/3 + conv_E_out = self.conv_E(pool_B_out) # node 6 + iaf3_E_out = self.iaf3_E(conv_E_out) # node 12 + pool_E_out = self.pool_E(iaf3_E_out) # node 13 + + # fc 2 - F/6 + conv_F_out = self.conv_F(pool_E_out) # node 14 + iaf_F_out = self.iaf_F(conv_F_out) # node 15 + + # fc 2 - G/5 + flat_D_out = self.flat_D(iaf_D_out) # node 11 + flat_F_out = self.flat_F(iaf_F_out) # node 16 + + merge1_out = self.merge1(flat_D_out, flat_F_out) # node 19 + fc3_out = self.fc3(merge1_out) # node 17 + iaf3_fc_out = self.iaf3_fc(fc3_out) # node 18 + + return iaf3_fc_out + + +channels = 2 +height = 34 +width = 34 +batch_size = 8 +input_shape = (channels, height, width) + +snn = SNN(batch_size) + +expected_output = { + "dcnnl_edges": { + (0, 1), + (1, 2), + (1, 3), + (2, 4), + (3, 5), + (4, 6), + (5, 6), + ("input", 0), + }, + "node_source_map": { + 0: {"input"}, + 1: {0}, + 2: {1}, + 3: {1}, + 4: {2}, + 5: {3}, + 6: {4, 5}, + }, + "destination_map": { + 0: {1}, + 1: {2, 3}, + 2: {4}, + 3: {5}, + 4: {6}, + 5: {6}, + 6: {-1}, + }, + "sorted_nodes": [0, 1, 2, 3, 4, 5, 6], + "output_shape": torch.Size([8, 10, 1, 1]), + "entry_points": {0}, +} diff --git a/tests/test_dynapcnnnetwork/model_dummy_3.py b/tests/test_dynapcnnnetwork/model_dummy_3.py new file mode 100644 index 00000000..1decc3d6 --- /dev/null +++ b/tests/test_dynapcnnnetwork/model_dummy_3.py @@ -0,0 +1,188 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "two networks with merging outputs" in https://github.com/synsense/sinabs/issues/181 + +import torch +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, Merge, SumPool2d + + +class SNN(nn.Module): + def __init__(self, batch_size) -> None: + super().__init__() + + self.conv_A = nn.Conv2d(2, 4, 2, 1, bias=False) + self.iaf_A = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.conv_B = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_B = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_B = SumPool2d(2, 2) + + self.conv_C = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_C = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_C = SumPool2d(2, 2) + + self.conv_D = nn.Conv2d(2, 4, 2, 1, bias=False) + self.iaf_D = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.conv_E = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_E = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_E = SumPool2d(2, 2) + + self.conv_F = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_F = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_F = SumPool2d(2, 2) + + self.flat_brach1 = nn.Flatten() + self.flat_brach2 = nn.Flatten() + self.merge = Merge() + + self.fc1 = nn.Linear(196, 100, bias=False) + self.iaf1_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.fc2 = nn.Linear(100, 100, bias=False) + self.iaf2_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.fc3 = nn.Linear(100, 10, bias=False) + self.iaf3_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + def forward(self, x): + # conv 1 - A + conv_A_out = self.conv_A(x) + iaf_A_out = self.iaf_A(conv_A_out) + # conv 2 - B + conv_B_out = self.conv_B(iaf_A_out) + iaf_B_out = self.iaf_B(conv_B_out) + pool_B_out = self.pool_B(iaf_B_out) + # conv 3 - C + conv_C_out = self.conv_C(pool_B_out) + iaf_C_out = self.iaf_C(conv_C_out) + pool_C_out = self.pool_C(iaf_C_out) + + # --- + + # conv 4 - D + conv_D_out = self.conv_D(x) + iaf_D_out = self.iaf_D(conv_D_out) + # conv 5 - E + conv_E_out = self.conv_E(iaf_D_out) + iaf_E_out = self.iaf_E(conv_E_out) + pool_E_out = self.pool_E(iaf_E_out) + # conv 6 - F + conv_F_out = self.conv_F(pool_E_out) + iaf_F_out = self.iaf_F(conv_F_out) + pool_F_out = self.pool_F(iaf_F_out) + + # --- + + flat_brach1_out = self.flat_brach1(pool_C_out) + flat_brach2_out = self.flat_brach2(pool_F_out) + merge_out = self.merge(flat_brach1_out, flat_brach2_out) + + # FC 7 - G + fc1_out = self.fc1(merge_out) + iaf1_fc_out = self.iaf1_fc(fc1_out) + # FC 8 - H + fc2_out = self.fc2(iaf1_fc_out) + iaf2_fc_out = self.iaf2_fc(fc2_out) + # FC 9 - I + fc3_out = self.fc3(iaf2_fc_out) + iaf3_fc_out = self.iaf3_fc(fc3_out) + + return iaf3_fc_out + + +channels = 2 +height = 34 +width = 34 +batch_size = 2 +input_shape = (channels, height, width) + +snn = SNN(batch_size) + +expected_output = { + "dcnnl_edges": { + (0, 2), + (2, 4), + (4, 6), + (6, 7), + (1, 3), + (3, 5), + (5, 6), + (7, 8), + ("input", 0), + ("input", 1), + }, + "node_source_map": { + 0: {"input"}, + 2: {0}, + 4: {2}, + 6: {4, 5}, + 1: {"input"}, + 3: {1}, + 5: {3}, + 7: {6}, + 8: {7}, + }, + "destination_map": { + 0: {2}, + 2: {4}, + 4: {6}, + 6: {7}, + 1: {3}, + 3: {5}, + 5: {6}, + 7: {8}, + 8: {-1}, + }, + "sorted_nodes": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "output_shape": torch.Size([2, 10, 1, 1]), + "entry_points": {0, 1}, +} diff --git a/tests/test_dynapcnnnetwork/model_dummy_4.py b/tests/test_dynapcnnnetwork/model_dummy_4.py new file mode 100644 index 00000000..b47120ce --- /dev/null +++ b/tests/test_dynapcnnnetwork/model_dummy_4.py @@ -0,0 +1,186 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "a complex network structure" example in https://github.com/synsense/sinabs/issues/181 . """ + +import torch +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, Merge, SumPool2d + + +class SNN(nn.Module): + def __init__(self, batch_size) -> None: + super().__init__() + + self.conv1 = nn.Conv2d(2, 1, 2, 1, bias=False) + self.iaf1 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.conv2 = nn.Conv2d(1, 1, 2, 1, bias=False) + self.iaf2 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool2 = SumPool2d(2, 2) + + self.conv3 = nn.Conv2d(1, 1, 2, 1, bias=False) + self.iaf3 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool3 = SumPool2d(2, 2) + self.pool3a = SumPool2d(5, 5) + + self.conv4 = nn.Conv2d(1, 1, 2, 1, bias=False) + self.iaf4 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool4 = SumPool2d(3, 3) + + self.flat1 = nn.Flatten() + self.flat2 = nn.Flatten() + + self.conv5 = nn.Conv2d(1, 1, 2, 1, bias=False) + self.iaf5 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.fc2 = nn.Linear(25, 10, bias=False) + self.iaf2_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + # -- merges -- + self.merge1 = Merge() + self.merge2 = Merge() + + def forward(self, x): + # conv 1 - A/0 + con1_out = self.conv1(x) + iaf1_out = self.iaf1(con1_out) + + # conv 2 - B/1 + conv2_out = self.conv2(iaf1_out) + iaf2_out = self.iaf2(conv2_out) + pool2_out = self.pool2(iaf2_out) + + # conv 3 - C/2 + conv3_out = self.conv3(iaf1_out) + iaf3_out = self.iaf3(conv3_out) + pool3_out = self.pool3(iaf3_out) + pool3a_out = self.pool3a(iaf3_out) + + # conv 4 - D/3 + merge1_out = self.merge1(pool2_out, pool3_out) + conv4_out = self.conv4(merge1_out) + iaf4_out = self.iaf4(conv4_out) + pool4_out = self.pool4(iaf4_out) + flat1_out = self.flat1(pool4_out) + + # conv 5 - E/4 + conv5_out = self.conv5(pool3a_out) + iaf5_out = self.iaf5(conv5_out) + flat2_out = self.flat2(iaf5_out) + + # fc 2 - F/5 + merge2_out = self.merge2(flat2_out, flat1_out) + + fc2_out = self.fc2(merge2_out) + iaf2_fc_out = self.iaf2_fc(fc2_out) + + return iaf2_fc_out + + +channels = 2 +height = 34 +width = 34 +batch_size = 2 +input_shape = (channels, height, width) + +snn = SNN(batch_size) + +expected_output = { + "dcnnl_edges": { + (0, 1), + (0, 2), + (1, 3), + (2, 3), + (2, 4), + (3, 5), + (4, 5), + ("input", 0), + }, + "node_source_map": { + 0: {"input"}, + 1: {0}, + 2: {0}, + 3: {1, 2}, + 4: {2}, + 5: {3, 4}, + }, + "destination_map": { + 0: {1, 2}, + 1: {3}, + 2: {3, 4}, + 3: {5}, + 4: {5}, + 5: {-1}, + }, + "sorted_nodes": [0, 1, 2, 3, 4, 5], + "output_shape": torch.Size([2, 10, 1, 1]), + "entry_points": {0}, +} + +# Sometimes the layer that usually gets assgined ID1, gets ID2, and the +# layer with ID 3 gets ID 4. Therefore an alternative solution is defined. +# This is not a bug in sinabs itself but an issue with the test, becuase +# the IDs that the layers are assigned do not always have to be the same. +expected_output["alternative"] = { + "dcnnl_edges": { + (0, 1), + (0, 2), + (1, 3), + (1, 4), + (2, 4), + (3, 5), + (4, 5), + ("input", 0), + }, + "node_source_map": { + 0: {"input"}, + 1: {0}, + 2: {0}, + 3: {1}, + 4: {1, 2}, + 5: {3, 4}, + }, + "destination_map": { + 0: {1, 2}, + 1: {3, 4}, + 2: {4}, + 3: {5}, + 4: {5}, + 5: {-1}, + }, + "sorted_nodes": [0, 1, 2, 3, 4, 5], + "output_shape": torch.Size([2, 10, 1, 1]), + "entry_points": {0}, +} diff --git a/tests/test_dynapcnnnetwork/model_dummy_seq.py b/tests/test_dynapcnnnetwork/model_dummy_seq.py new file mode 100644 index 00000000..1da69477 --- /dev/null +++ b/tests/test_dynapcnnnetwork/model_dummy_seq.py @@ -0,0 +1,73 @@ +# implementing sequential models + +import torch +import torch.nn as nn + +from sinabs.layers import IAFSqueeze, SumPool2d + +input_shape_seq = (2, 30, 30) + +seq_1 = nn.Sequential( + nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), + IAFSqueeze(batch_size=1), + nn.Conv2d(8, 2, kernel_size=3, stride=1, bias=False), + IAFSqueeze(batch_size=1), +) + +seq_2 = nn.Sequential( + nn.Conv2d(2, 2, kernel_size=3, stride=1, bias=False), + IAFSqueeze(batch_size=1), + SumPool2d(2), + nn.AvgPool2d(2), + nn.Dropout(0.5), + nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), + IAFSqueeze(batch_size=1), + nn.Conv2d(8, 2, kernel_size=3, stride=1, bias=False), + IAFSqueeze(batch_size=1), + nn.Flatten(), + nn.Linear(3 * 3 * 2, 5), + nn.Identity(), + IAFSqueeze(batch_size=1), +) + +expected_seq_1 = { + "dcnnl_edges": { + ("input", 0), + (0, 1), + }, + "node_source_map": { + 0: {"input"}, + 1: {0}, + }, + "destination_map": { + 0: {1}, + 1: {-1}, + }, + "sorted_nodes": [0, 1], + "output_shape": torch.Size([1, 2, 26, 26]), + "entry_points": {0}, +} + +expected_seq_2 = { + "dcnnl_edges": { + (0, 1), + (1, 2), + (2, 3), + ("input", 0), + }, + "node_source_map": { + 0: {"input"}, + 1: {0}, + 2: {1}, + 3: {2}, + }, + "destination_map": { + 0: {1}, + 1: {2}, + 2: {3}, + 3: {-1}, + }, + "sorted_nodes": [0, 1, 2, 3], + "output_shape": torch.Size([1, 5, 1, 1]), + "entry_points": {0}, +} diff --git a/tests/test_dynapcnnnetwork/test_dynapcnnnetwork.py b/tests/test_dynapcnnnetwork/test_dynapcnnnetwork.py new file mode 100644 index 00000000..293ee451 --- /dev/null +++ b/tests/test_dynapcnnnetwork/test_dynapcnnnetwork.py @@ -0,0 +1,63 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + +import pytest +import torch + +from sinabs.backend.dynapcnn.dynapcnn_network import DynapcnnNetwork + +from .conftest_dynapcnnnetwork import args_DynapcnnNetworkTest + + +@pytest.mark.parametrize( + "snn, input_shape, batch_size, expected_output", args_DynapcnnNetworkTest +) +def test_DynapcnnNetwork(snn, input_shape, batch_size, expected_output): + """Tests if the correct graph representing the connections between each DynapcnnLayer within a DynapcnnNetwork + is created; if the DynapcnnLayer instances requiring input from a `Merge` are correctly flagged (along with what + their arguments should be); if the correct topological order of the DynapcnnLayers (i.e., the order in which their + forward methods should be called) is computed; if the output of the model matches what is expected. + """ + + dcnnnet = DynapcnnNetwork(snn, input_shape, batch_size) + + torch.manual_seed(0) + x = torch.randn((batch_size, *input_shape)) + output = dcnnnet(x) + + module = dcnnnet.dynapcnn_module + # For some models there are multiple possible topological sortings, + # such that the assigned node IDs are not always the same. + # To prevent the following tests from failing, alternative expected + # outputs are defined which correspond to different assigned IDs. + if ( + expected_output["dcnnl_edges"] != module._dynapcnnlayer_edges + and "alternative" in expected_output + ): + expected_output = expected_output["alternative"] + print("Using algernative node ID assignment") + assert ( + expected_output["dcnnl_edges"] == module._dynapcnnlayer_edges + ), "wrong list of edges describing DynapcnnLayer connectivity." + + # Convert source lists to sets to ignore order + source_map = { + node: set(sources) for node, sources in module._node_source_map.items() + } + assert expected_output["node_source_map"] == source_map, "wrong node source map" + + # Convert destination lists to sets to ignore order + destination_map = { + node: set(dests) for node, dests in module._destination_map.items() + } + assert ( + expected_output["destination_map"] == destination_map + ), "wrong destination map" + + assert expected_output["entry_points"] == module._entry_points, "wrong entry points" + + assert expected_output["sorted_nodes"] == module._sorted_nodes, "wrong node sorting" + + assert ( + expected_output["output_shape"] == output.shape + ), "wrong model output tensor shape." diff --git a/tests/test_dynapcnnnetwork/test_failcases.py b/tests/test_dynapcnnnetwork/test_failcases.py new file mode 100644 index 00000000..846a76ad --- /dev/null +++ b/tests/test_dynapcnnnetwork/test_failcases.py @@ -0,0 +1,98 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + +import pytest +import torch +from torch import nn + +from sinabs import layers as sl +from sinabs.backend.dynapcnn import DynapcnnNetwork +from sinabs.backend.dynapcnn.chip_factory import ChipFactory +from sinabs.backend.dynapcnn.exceptions import ( + InvalidGraphStructure, + UnsupportedLayerType, +) +from sinabs.from_torch import from_model + + +@pytest.mark.parametrize("device", tuple(ChipFactory.supported_devices.keys())) +def test_too_large(device): + + # Model that is too big to fit on any of our architectures + big_ann = nn.Sequential( + nn.Conv2d(1, 3, 5, 1, bias=False), + nn.ReLU(), + nn.AvgPool2d(2, 2), + nn.Conv2d(3, 1, 5, 1, bias=False), + nn.ReLU(), + nn.AvgPool2d(2, 2), + nn.Flatten(), + nn.Linear(16, 999999, bias=False), + ) + input_shape = (1, 28, 28) + + hardware_incompatible_model = DynapcnnNetwork( + from_model(big_ann, add_spiking_output=True, batch_size=1).cpu(), + discretize=True, + input_shape=input_shape, + ) + + assert not hardware_incompatible_model.is_compatible_with(device) + + with pytest.raises(ValueError): + hardware_incompatible_model.to(device) + + +def test_missing_spiking_layer(): + in_shape = (2, 28, 28) + snn = nn.Sequential( + nn.Conv2d(2, 8, kernel_size=3, stride=1, bias=False), + sl.IAFSqueeze(batch_size=1), + sl.SumPool2d(2), + nn.AvgPool2d(2), + nn.Conv2d(8, 16, kernel_size=3, stride=1, bias=False), + sl.IAFSqueeze(batch_size=1), + nn.Dropout2d(), + nn.Conv2d(16, 2, kernel_size=3, stride=1, bias=False), + sl.IAFSqueeze(batch_size=1), + nn.Flatten(), + nn.Linear(8, 5), + ) + + with pytest.raises(InvalidGraphStructure): + net = DynapcnnNetwork(snn, input_shape=in_shape) + + +def test_incorrect_model_start(): + in_shape = (2, 28, 28) + snn = nn.Sequential( + sl.IAFSqueeze(batch_size=1), + sl.SumPool2d(2), + nn.AvgPool2d(2), + ) + + with pytest.raises(InvalidGraphStructure): + net = DynapcnnNetwork(snn, input_shape=in_shape) + + +unsupported_layers = [ + nn.ReLU(), + nn.Sigmoid(), + nn.Tanh(), + sl.LIF(tau_mem=5), + sl.LIFSqueeze(batch_size=1, tau_mem=5), + sl.NeuromorphicReLU(), + sl.Cropping2dLayer(), +] + + +@pytest.mark.parametrize("layer", unsupported_layers) +def test_unsupported_layers(layer): + in_shape = (1, 28, 28) + ann = nn.Sequential( + nn.Conv2d(1, 3, 5, 1, bias=False), + layer, + ) + + with pytest.raises(UnsupportedLayerType): + net = DynapcnnNetwork(ann, input_shape=in_shape) diff --git a/tests/test_graph_extractor/conftest_graph_extractor.py b/tests/test_graph_extractor/conftest_graph_extractor.py new file mode 100644 index 00000000..1192e62a --- /dev/null +++ b/tests/test_graph_extractor/conftest_graph_extractor.py @@ -0,0 +1,22 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + +from model_dummy_1 import expected_output as expected_output_1 +from model_dummy_1 import input_dummy as input_dummy_1 +from model_dummy_1 import snn as snn_1 +from model_dummy_2 import expected_output as expected_output_2 +from model_dummy_2 import input_dummy as input_dummy_2 +from model_dummy_2 import snn as snn_2 +from model_dummy_3 import expected_output as expected_output_3 +from model_dummy_3 import input_dummy as input_dummy_3 +from model_dummy_3 import snn as snn_3 +from model_dummy_4 import expected_output as expected_output_4 +from model_dummy_4 import input_dummy as input_dummy_4 +from model_dummy_4 import snn as snn_4 + +args_GraphExtractor = [ + (snn_1, input_dummy_1, expected_output_1), + (snn_2, input_dummy_2, expected_output_2), + (snn_3, input_dummy_3, expected_output_3), + (snn_4, input_dummy_4, expected_output_4), +] diff --git a/tests/test_graph_extractor/model_dummy_1.py b/tests/test_graph_extractor/model_dummy_1.py new file mode 100644 index 00000000..053f21a0 --- /dev/null +++ b/tests/test_graph_extractor/model_dummy_1.py @@ -0,0 +1,152 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "a network with residual connections" example in https://github.com/synsense/sinabs/issues/181 + +import torch +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, Merge + + +class SNN(nn.Module): + def __init__(self, batch_size) -> None: + super().__init__() + + self.conv1 = nn.Conv2d(2, 10, 2, 1, bias=False) # node 0 + self.iaf1 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 1 + self.pool1 = nn.AvgPool2d(3, 3) # node 2 + self.pool1a = nn.AvgPool2d(4, 4) # node 3 + + self.conv2 = nn.Conv2d(10, 10, 4, 1, bias=False) # node 4 + self.iaf2 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 6 + + self.conv3 = nn.Conv2d(10, 1, 2, 1, bias=False) # node 8 + self.iaf3 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 9 + + self.flat = nn.Flatten() + + self.fc1 = nn.Linear(49, 500, bias=False) # node 10 + self.iaf4 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 11 + + self.fc2 = nn.Linear(500, 10, bias=False) # node 12 + self.iaf5 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) # node 13 + + self.adder = Merge() + + def forward(self, x): + + con1_out = self.conv1(x) + iaf1_out = self.iaf1(con1_out) + pool1_out = self.pool1(iaf1_out) + pool1a_out = self.pool1a(iaf1_out) + + conv2_out = self.conv2(pool1_out) + iaf2_out = self.iaf2(conv2_out) + + conv3_out = self.conv3(self.adder(pool1a_out, iaf2_out)) + iaf3_out = self.iaf3(conv3_out) + + flat_out = self.flat(iaf3_out) + + fc1_out = self.fc1(flat_out) + iaf4_out = self.iaf4(fc1_out) + fc2_out = self.fc2(iaf4_out) + iaf5_out = self.iaf5(fc2_out) + + return iaf5_out + + +channels = 2 +height = 34 +width = 34 +batch_size = 3 +input_shape = (batch_size, channels, height, width) + +torch.manual_seed(0) +input_dummy = torch.randn(input_shape) + +expected_output = { + "edges": { + (0, 1), + (1, 2), + (1, 3), + (2, 4), + (3, 5), + (4, 6), + (6, 5), + (7, 8), + (8, 9), + (9, 10), + (10, 11), + (11, 12), + (12, 13), + (5, 7), + }, + "name_2_indx_map": { + "conv1": 0, + "iaf1": 1, + "pool1": 2, + "pool1a": 3, + "conv2": 4, + "adder": 5, + "iaf2": 6, + "conv3": 7, + "iaf3": 8, + "flat": 9, + "fc1": 10, + "iaf4": 11, + "fc2": 12, + "iaf5": 13, + }, + "entry_nodes": {0}, + "nodes_io_shapes": { + 0: {"input": torch.Size([2, 34, 34]), "output": torch.Size([10, 33, 33])}, + 1: { + "input": torch.Size([10, 33, 33]), + "output": torch.Size([10, 33, 33]), + }, + 2: { + "input": torch.Size([10, 33, 33]), + "output": torch.Size([10, 11, 11]), + }, + 3: {"input": torch.Size([10, 33, 33]), "output": torch.Size([10, 8, 8])}, + 4: {"input": torch.Size([10, 11, 11]), "output": torch.Size([10, 8, 8])}, + 6: {"input": torch.Size([10, 8, 8]), "output": torch.Size([10, 8, 8])}, + 5: {"input": torch.Size([10, 8, 8]), "output": torch.Size([10, 8, 8])}, + 7: {"input": torch.Size([10, 8, 8]), "output": torch.Size([1, 7, 7])}, + 8: {"input": torch.Size([1, 7, 7]), "output": torch.Size([1, 7, 7])}, + 9: {"input": torch.Size([1, 7, 7]), "output": torch.Size([49, 1, 1])}, + 10: {"input": torch.Size([49, 1, 1]), "output": torch.Size([500, 1, 1])}, + 11: {"input": torch.Size([500, 1, 1]), "output": torch.Size([500, 1, 1])}, + 12: {"input": torch.Size([500, 1, 1]), "output": torch.Size([10, 1, 1])}, + 13: {"input": torch.Size([10, 1, 1]), "output": torch.Size([10, 1, 1])}, + }, +} + +snn = SNN(batch_size) diff --git a/tests/test_graph_extractor/model_dummy_2.py b/tests/test_graph_extractor/model_dummy_2.py new file mode 100644 index 00000000..f4d9bb77 --- /dev/null +++ b/tests/test_graph_extractor/model_dummy_2.py @@ -0,0 +1,200 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "a network with a merge and a split" in https://github.com/synsense/sinabs/issues/181 + +import torch +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, Merge, SumPool2d + + +class SNN(nn.Module): + def __init__(self, batch_size) -> None: + super().__init__() + # -- graph node A -- + self.conv_A = nn.Conv2d(2, 4, 2, 1, bias=False) + self.iaf_A = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + # -- graph node B -- + self.conv_B = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf2_B = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_B = SumPool2d(2, 2) + # -- graph node C -- + self.conv_C = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_C = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_C = SumPool2d(2, 2) + # -- graph node D -- + self.conv_D = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_D = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + # -- graph node E -- + self.conv_E = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf3_E = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_E = SumPool2d(2, 2) + # -- graph node F -- + self.conv_F = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_F = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + # -- graph node G -- + self.fc3 = nn.Linear(144, 10, bias=False) + self.iaf3_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + # -- merges -- + self.merge1 = Merge() + + # -- falts -- + self.flat_D = nn.Flatten() + self.flat_F = nn.Flatten() + + def forward(self, x): + # conv 1 - A/0 + convA_out = self.conv_A(x) # node 0 + iaf_A_out = self.iaf_A(convA_out) # node 1 + + # conv 2 - B/1 + conv_B_out = self.conv_B(iaf_A_out) # node 2 + iaf_B_out = self.iaf2_B(conv_B_out) # node 3 + pool_B_out = self.pool_B(iaf_B_out) # node 4 + + # conv 3 - C/2 + conv_C_out = self.conv_C(pool_B_out) # node 5 + iaf_C_out = self.iaf_C(conv_C_out) # node 7 + pool_C_out = self.pool_C(iaf_C_out) # node 8 + + # conv 4 - D/4 + conv_D_out = self.conv_D(pool_C_out) # node 9 + iaf_D_out = self.iaf_D(conv_D_out) # node 10 + + # fc 1 - E/3 + conv_E_out = self.conv_E(pool_B_out) # node 6 + iaf3_E_out = self.iaf3_E(conv_E_out) # node 12 + pool_E_out = self.pool_E(iaf3_E_out) # node 13 + + # fc 2 - F/6 + conv_F_out = self.conv_F(pool_E_out) # node 14 + iaf_F_out = self.iaf_F(conv_F_out) # node 15 + + # fc 2 - G/5 + flat_D_out = self.flat_D(iaf_D_out) # node 11 + flat_F_out = self.flat_F(iaf_F_out) # node 16 + + merge1_out = self.merge1(flat_D_out, flat_F_out) # node 19 + fc3_out = self.fc3(merge1_out) # node 17 + iaf3_fc_out = self.iaf3_fc(fc3_out) # node 18 + + return iaf3_fc_out + + +channels = 2 +height = 34 +width = 34 +batch_size = 8 +input_shape = (batch_size, channels, height, width) + +torch.manual_seed(0) +input_dummy = torch.randn(input_shape) + +expected_output = { + "edges": { + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (4, 6), + (5, 7), + (7, 8), + (8, 9), + (9, 10), + (10, 11), + (6, 12), + (12, 13), + (13, 14), + (14, 15), + (15, 16), + (17, 18), + (19, 17), + (11, 19), + (16, 19), + }, + "name_2_indx_map": { + "conv_A": 0, + "iaf_A": 1, + "conv_B": 2, + "iaf2_B": 3, + "pool_B": 4, + "conv_C": 5, + "conv_E": 6, + "iaf_C": 7, + "pool_C": 8, + "conv_D": 9, + "iaf_D": 10, + "flat_D": 11, + "iaf3_E": 12, + "pool_E": 13, + "conv_F": 14, + "iaf_F": 15, + "flat_F": 16, + "fc3": 17, + "iaf3_fc": 18, + "merge1": 19, + }, + "entry_nodes": {0}, + "nodes_io_shapes": { + 0: {"input": torch.Size([2, 34, 34]), "output": torch.Size([4, 33, 33])}, + 1: {"input": torch.Size([4, 33, 33]), "output": torch.Size([4, 33, 33])}, + 2: {"input": torch.Size([4, 33, 33]), "output": torch.Size([4, 32, 32])}, + 3: {"input": torch.Size([4, 32, 32]), "output": torch.Size([4, 32, 32])}, + 4: {"input": torch.Size([4, 32, 32]), "output": torch.Size([4, 16, 16])}, + 5: {"input": torch.Size([4, 16, 16]), "output": torch.Size([4, 15, 15])}, + 6: {"input": torch.Size([4, 16, 16]), "output": torch.Size([4, 15, 15])}, + 7: {"input": torch.Size([4, 15, 15]), "output": torch.Size([4, 15, 15])}, + 12: {"input": torch.Size([4, 15, 15]), "output": torch.Size([4, 15, 15])}, + 8: {"input": torch.Size([4, 15, 15]), "output": torch.Size([4, 7, 7])}, + 13: {"input": torch.Size([4, 15, 15]), "output": torch.Size([4, 7, 7])}, + 9: {"input": torch.Size([4, 7, 7]), "output": torch.Size([4, 6, 6])}, + 14: {"input": torch.Size([4, 7, 7]), "output": torch.Size([4, 6, 6])}, + 10: {"input": torch.Size([4, 6, 6]), "output": torch.Size([4, 6, 6])}, + 15: {"input": torch.Size([4, 6, 6]), "output": torch.Size([4, 6, 6])}, + 11: {"input": torch.Size([4, 6, 6]), "output": torch.Size([144, 1, 1])}, + 16: {"input": torch.Size([4, 6, 6]), "output": torch.Size([144, 1, 1])}, + 19: {"input": torch.Size([144, 1, 1]), "output": torch.Size([144, 1, 1])}, + 17: {"input": torch.Size([144, 1, 1]), "output": torch.Size([10, 1, 1])}, + 18: {"input": torch.Size([10, 1, 1]), "output": torch.Size([10, 1, 1])}, + }, +} + +snn = SNN(batch_size) diff --git a/tests/test_graph_extractor/model_dummy_3.py b/tests/test_graph_extractor/model_dummy_3.py new file mode 100644 index 00000000..7ee4c5a5 --- /dev/null +++ b/tests/test_graph_extractor/model_dummy_3.py @@ -0,0 +1,235 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "two networks with merging outputs" in https://github.com/synsense/sinabs/issues/181 + +import torch +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, Merge, SumPool2d + + +class SNN(nn.Module): + def __init__(self, batch_size) -> None: + super().__init__() + + self.conv_A = nn.Conv2d(2, 4, 2, 1, bias=False) + self.iaf_A = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.conv_B = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_B = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_B = SumPool2d(2, 2) + + self.conv_C = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_C = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_C = SumPool2d(2, 2) + + self.conv_D = nn.Conv2d(2, 4, 2, 1, bias=False) + self.iaf_D = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.conv_E = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_E = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_E = SumPool2d(2, 2) + + self.conv_F = nn.Conv2d(4, 4, 2, 1, bias=False) + self.iaf_F = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool_F = SumPool2d(2, 2) + + self.flat_brach1 = nn.Flatten() + self.flat_brach2 = nn.Flatten() + self.merge = Merge() + + self.fc1 = nn.Linear(196, 100, bias=False) + self.iaf1_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.fc2 = nn.Linear(100, 100, bias=False) + self.iaf2_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.fc3 = nn.Linear(100, 10, bias=False) + self.iaf3_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + def forward(self, x): + # conv 1 - A + conv_A_out = self.conv_A(x) + iaf_A_out = self.iaf_A(conv_A_out) + # conv 2 - B + conv_B_out = self.conv_B(iaf_A_out) + iaf_B_out = self.iaf_B(conv_B_out) + pool_B_out = self.pool_B(iaf_B_out) + # conv 3 - C + conv_C_out = self.conv_C(pool_B_out) + iaf_C_out = self.iaf_C(conv_C_out) + pool_C_out = self.pool_C(iaf_C_out) + + # --- + + # conv 4 - D + conv_D_out = self.conv_D(x) + iaf_D_out = self.iaf_D(conv_D_out) + # conv 5 - E + conv_E_out = self.conv_E(iaf_D_out) + iaf_E_out = self.iaf_E(conv_E_out) + pool_E_out = self.pool_E(iaf_E_out) + # conv 6 - F + conv_F_out = self.conv_F(pool_E_out) + iaf_F_out = self.iaf_F(conv_F_out) + pool_F_out = self.pool_F(iaf_F_out) + + # --- + + flat_brach1_out = self.flat_brach1(pool_C_out) + flat_brach2_out = self.flat_brach2(pool_F_out) + merge_out = self.merge(flat_brach1_out, flat_brach2_out) + + # FC 7 - G + fc1_out = self.fc1(merge_out) + iaf1_fc_out = self.iaf1_fc(fc1_out) + # FC 8 - H + fc2_out = self.fc2(iaf1_fc_out) + iaf2_fc_out = self.iaf2_fc(fc2_out) + # FC 9 - I + fc3_out = self.fc3(iaf2_fc_out) + iaf3_fc_out = self.iaf3_fc(fc3_out) + + return iaf3_fc_out + + +channels = 2 +height = 34 +width = 34 +batch_size = 2 +input_shape = (batch_size, channels, height, width) + +torch.manual_seed(0) +input_dummy = torch.randn(input_shape) + +expected_output = { + "edges": { + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (5, 6), + (6, 7), + (7, 8), + (9, 10), + (10, 11), + (11, 12), + (12, 13), + (13, 14), + (14, 15), + (15, 16), + (16, 17), + (8, 18), + (17, 18), + (18, 19), + (19, 20), + (20, 21), + (21, 22), + (22, 23), + (23, 24), + }, + "name_2_indx_map": { + "conv_A": 0, + "iaf_A": 1, + "conv_B": 2, + "iaf_B": 3, + "pool_B": 4, + "conv_C": 5, + "iaf_C": 6, + "pool_C": 7, + "flat_brach1": 8, + "conv_D": 9, + "iaf_D": 10, + "conv_E": 11, + "iaf_E": 12, + "pool_E": 13, + "conv_F": 14, + "iaf_F": 15, + "pool_F": 16, + "flat_brach2": 17, + "merge": 18, + "fc1": 19, + "iaf1_fc": 20, + "fc2": 21, + "iaf2_fc": 22, + "fc3": 23, + "iaf3_fc": 24, + }, + "entry_nodes": {0, 9}, + "nodes_io_shapes": { + 0: {"input": torch.Size([2, 34, 34]), "output": torch.Size([4, 33, 33])}, + 9: {"input": torch.Size([2, 34, 34]), "output": torch.Size([4, 33, 33])}, + 1: {"input": torch.Size([4, 33, 33]), "output": torch.Size([4, 33, 33])}, + 10: {"input": torch.Size([4, 33, 33]), "output": torch.Size([4, 33, 33])}, + 2: {"input": torch.Size([4, 33, 33]), "output": torch.Size([4, 32, 32])}, + 11: {"input": torch.Size([4, 33, 33]), "output": torch.Size([4, 32, 32])}, + 3: {"input": torch.Size([4, 32, 32]), "output": torch.Size([4, 32, 32])}, + 12: {"input": torch.Size([4, 32, 32]), "output": torch.Size([4, 32, 32])}, + 4: {"input": torch.Size([4, 32, 32]), "output": torch.Size([4, 16, 16])}, + 13: {"input": torch.Size([4, 32, 32]), "output": torch.Size([4, 16, 16])}, + 5: {"input": torch.Size([4, 16, 16]), "output": torch.Size([4, 15, 15])}, + 14: {"input": torch.Size([4, 16, 16]), "output": torch.Size([4, 15, 15])}, + 6: {"input": torch.Size([4, 15, 15]), "output": torch.Size([4, 15, 15])}, + 15: {"input": torch.Size([4, 15, 15]), "output": torch.Size([4, 15, 15])}, + 7: {"input": torch.Size([4, 15, 15]), "output": torch.Size([4, 7, 7])}, + 16: {"input": torch.Size([4, 15, 15]), "output": torch.Size([4, 7, 7])}, + 8: {"input": torch.Size([4, 7, 7]), "output": torch.Size([196, 1, 1])}, + 17: {"input": torch.Size([4, 7, 7]), "output": torch.Size([196, 1, 1])}, + 18: {"input": torch.Size([196, 1, 1]), "output": torch.Size([196, 1, 1])}, + 19: {"input": torch.Size([196, 1, 1]), "output": torch.Size([100, 1, 1])}, + 20: {"input": torch.Size([100, 1, 1]), "output": torch.Size([100, 1, 1])}, + 21: {"input": torch.Size([100, 1, 1]), "output": torch.Size([100, 1, 1])}, + 22: {"input": torch.Size([100, 1, 1]), "output": torch.Size([100, 1, 1])}, + 23: {"input": torch.Size([100, 1, 1]), "output": torch.Size([10, 1, 1])}, + 24: {"input": torch.Size([10, 1, 1]), "output": torch.Size([10, 1, 1])}, + }, +} + +snn = SNN(batch_size) diff --git a/tests/test_graph_extractor/model_dummy_4.py b/tests/test_graph_extractor/model_dummy_4.py new file mode 100644 index 00000000..1e4d4da6 --- /dev/null +++ b/tests/test_graph_extractor/model_dummy_4.py @@ -0,0 +1,192 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com +# implementing "a complex network structure" example in https://github.com/synsense/sinabs/issues/181 . """ + +import torch +import torch.nn as nn + +from sinabs.activation.surrogate_gradient_fn import PeriodicExponential +from sinabs.layers import IAFSqueeze, Merge, SumPool2d + + +class SNN(nn.Module): + def __init__(self, batch_size) -> None: + super().__init__() + + self.conv1 = nn.Conv2d(2, 1, 2, 1, bias=False) + self.iaf1 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.conv2 = nn.Conv2d(1, 1, 2, 1, bias=False) + self.iaf2 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool2 = SumPool2d(2, 2) + + self.conv3 = nn.Conv2d(1, 1, 2, 1, bias=False) + self.iaf3 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool3 = SumPool2d(2, 2) + self.pool3a = SumPool2d(5, 5) + + self.conv4 = nn.Conv2d(1, 1, 2, 1, bias=False) + self.iaf4 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + self.pool4 = SumPool2d(3, 3) + + self.flat1 = nn.Flatten() + self.flat2 = nn.Flatten() + + self.conv5 = nn.Conv2d(1, 1, 2, 1, bias=False) + self.iaf5 = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + self.fc2 = nn.Linear(25, 10, bias=False) + self.iaf2_fc = IAFSqueeze( + batch_size=batch_size, + min_v_mem=-1.0, + spike_threshold=1.0, + surrogate_grad_fn=PeriodicExponential(), + ) + + # -- merges -- + self.merge1 = Merge() + self.merge2 = Merge() + + def forward(self, x): + # conv 1 - A/0 + con1_out = self.conv1(x) + iaf1_out = self.iaf1(con1_out) + + # conv 2 - B/1 + conv2_out = self.conv2(iaf1_out) + iaf2_out = self.iaf2(conv2_out) + pool2_out = self.pool2(iaf2_out) + + # conv 3 - C/2 + conv3_out = self.conv3(iaf1_out) + iaf3_out = self.iaf3(conv3_out) + pool3_out = self.pool3(iaf3_out) + pool3a_out = self.pool3a(iaf3_out) + + # conv 4 - D/3 + merge1_out = self.merge1(pool2_out, pool3_out) + conv4_out = self.conv4(merge1_out) + iaf4_out = self.iaf4(conv4_out) + pool4_out = self.pool4(iaf4_out) + flat1_out = self.flat1(pool4_out) + + # conv 5 - E/4 + conv5_out = self.conv5(pool3a_out) + iaf5_out = self.iaf5(conv5_out) + flat2_out = self.flat2(iaf5_out) + + # fc 2 - F/5 + merge2_out = self.merge2(flat2_out, flat1_out) + + fc2_out = self.fc2(merge2_out) + iaf2_fc_out = self.iaf2_fc(fc2_out) + + return iaf2_fc_out + + +channels = 2 +height = 34 +width = 34 +batch_size = 2 +input_shape = (batch_size, channels, height, width) + +torch.manual_seed(0) +input_dummy = torch.randn(input_shape) + +expected_output = { + "edges": { + (0, 1), + (1, 2), + (1, 3), + (2, 4), + (4, 5), + (5, 6), + (3, 7), + (7, 8), + (7, 9), + (8, 6), + (9, 10), + (11, 12), + (12, 13), + (13, 14), + (14, 15), + (16, 15), + (10, 17), + (17, 16), + (18, 19), + (6, 11), + (15, 18), + }, + "name_2_indx_map": { + "conv1": 0, + "iaf1": 1, + "conv2": 2, + "conv3": 3, + "iaf2": 4, + "pool2": 5, + "merge1": 6, + "iaf3": 7, + "pool3": 8, + "pool3a": 9, + "conv5": 10, + "conv4": 11, + "iaf4": 12, + "pool4": 13, + "flat1": 14, + "merge2": 15, + "flat2": 16, + "iaf5": 17, + "fc2": 18, + "iaf2_fc": 19, + }, + "entry_nodes": {0}, + "nodes_io_shapes": { + 0: {"input": torch.Size([2, 34, 34]), "output": torch.Size([1, 33, 33])}, + 1: {"input": torch.Size([1, 33, 33]), "output": torch.Size([1, 33, 33])}, + 2: {"input": torch.Size([1, 33, 33]), "output": torch.Size([1, 32, 32])}, + 3: {"input": torch.Size([1, 33, 33]), "output": torch.Size([1, 32, 32])}, + 4: {"input": torch.Size([1, 32, 32]), "output": torch.Size([1, 32, 32])}, + 7: {"input": torch.Size([1, 32, 32]), "output": torch.Size([1, 32, 32])}, + 5: {"input": torch.Size([1, 32, 32]), "output": torch.Size([1, 16, 16])}, + 8: {"input": torch.Size([1, 32, 32]), "output": torch.Size([1, 16, 16])}, + 9: {"input": torch.Size([1, 32, 32]), "output": torch.Size([1, 6, 6])}, + 6: {"input": torch.Size([1, 16, 16]), "output": torch.Size([1, 16, 16])}, + 10: {"input": torch.Size([1, 6, 6]), "output": torch.Size([1, 5, 5])}, + 11: {"input": torch.Size([1, 16, 16]), "output": torch.Size([1, 15, 15])}, + 17: {"input": torch.Size([1, 5, 5]), "output": torch.Size([1, 5, 5])}, + 12: {"input": torch.Size([1, 15, 15]), "output": torch.Size([1, 15, 15])}, + 16: {"input": torch.Size([1, 5, 5]), "output": torch.Size([25, 1, 1])}, + 13: {"input": torch.Size([1, 15, 15]), "output": torch.Size([1, 5, 5])}, + 14: {"input": torch.Size([1, 5, 5]), "output": torch.Size([25, 1, 1])}, + 15: {"input": torch.Size([25, 1, 1]), "output": torch.Size([25, 1, 1])}, + 18: {"input": torch.Size([25, 1, 1]), "output": torch.Size([10, 1, 1])}, + 19: {"input": torch.Size([10, 1, 1]), "output": torch.Size([10, 1, 1])}, + }, +} + +snn = SNN(batch_size) diff --git a/tests/test_graph_extractor/test_graph_extractor.py b/tests/test_graph_extractor/test_graph_extractor.py new file mode 100644 index 00000000..20acdd16 --- /dev/null +++ b/tests/test_graph_extractor/test_graph_extractor.py @@ -0,0 +1,67 @@ +# author : Willian Soares Girao +# contact : wsoaresgirao@gmail.com + +import pytest +from conftest_graph_extractor import args_GraphExtractor + +from sinabs.backend.dynapcnn.nir_graph_extractor import GraphExtractor + + +def fix_node_ids(expected_output, graph_extractor): + """Match node IDs between graph extractor and expected output + + Node IDs can be assigned in many ways. This function prevents test + errors from generated IDs not matching expected output + + Parameters + ---------- + expected_output: Dict with expected output + graph_extractor: GraphExtractor instance + + Returns + ------- + Expected outputs with remapped node IDs + """ + idx_map = { + expected_idx: graph_extractor.name_2_indx_map[name] + for name, expected_idx in expected_output["name_2_indx_map"].items() + } + edges = {(idx_map[src], idx_map[tgt]) for src, tgt in expected_output["edges"]} + name_2_indx_map = { + name: idx_map[idx] for name, idx in expected_output["name_2_indx_map"].items() + } + entry_nodes = {idx_map[node] for node in expected_output["entry_nodes"]} + nodes_io_shapes = { + idx_map[node]: shape + for node, shape in expected_output["nodes_io_shapes"].items() + } + return { + "edges": edges, + "name_2_indx_map": name_2_indx_map, + "entry_nodes": entry_nodes, + "nodes_io_shapes": nodes_io_shapes, + } + + +@pytest.mark.parametrize("snn, input_dummy, expected_output", args_GraphExtractor) +def test_GraphExtractor(snn, input_dummy, expected_output): + """Tests the graph extraction from the original SNN being turned into a `DynapcnnNetwork`. These tests + verify the correct functionality of the `GraphExtractor` class, which implements the first pre-processing + step on the conversion of the SNN into a DynapcnnNetwork. + """ + + graph_tracer = GraphExtractor(snn, input_dummy) + expected_output = fix_node_ids(expected_output, graph_tracer) + + assert ( + expected_output["edges"] == graph_tracer.edges + ), "wrong list of edges extracted from the SNN." + assert ( + expected_output["name_2_indx_map"] == graph_tracer.name_2_indx_map + ), "wrong mapping from layer variable name to node ID." + assert ( + expected_output["entry_nodes"] == graph_tracer.entry_nodes + ), "wrong list with entry node's IDs (i.e., layers serving as input to the SNN)." + assert ( + expected_output["nodes_io_shapes"] == graph_tracer.nodes_io_shapes + ), "wrong I/O shapes computed for one or more nodes." diff --git a/tests/test_hooks.py b/tests/test_hooks.py index c45da42c..d213a0a9 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -139,4 +139,4 @@ def test_input_diff_hook(): layer.register_forward_hook(hooks.input_diff_hook) model(inp) for idx, correct in correct_values.items(): - assert (model[idx].hook_data["diff_output"] == correct).all() + assert torch.allclose(model[idx].hook_data["diff_output"], correct)