Skip to content

Commit

Permalink
Simplify DynapcnnLayer by removing _pool_layers attribute.
Browse files Browse the repository at this point in the history
  • Loading branch information
bauerfe committed Sep 17, 2024
1 parent 72c8e6e commit f9a797d
Showing 1 changed file with 20 additions and 50 deletions.
70 changes: 20 additions & 50 deletions sinabs/backend/dynapcnn/dynapcnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
# contact : [email protected]

from copy import deepcopy
from typing import Dict, Callable, Tuple, Union, List
from functools import partial
from typing import Tuple, List

import numpy as np
import torch
from torch import nn

import sinabs.activation
import sinabs.layers as sl

from .discretize import discretize_conv_spike_


# Define sum pooling functional as power-average pooling with power 1
sum_pool2d = partial(nn.functional.lp_pool2d, norm_type=1)


class DynapcnnLayer(nn.Module):
"""Create a DynapcnnLayer object representing a layer on DynapCNN or Speck.
Expand Down Expand Up @@ -74,15 +79,17 @@ def __init__(
if self._discretize:
conv, spk = discretize_conv_spike_(conv, spk, to_int=False)

if discretize:
# int conversion is done while writing the config.
conv, spk = discretize_conv_spike_(conv, spk, to_int=False)
self._conv = conv
self._spk = spk

@property
def conv(self):
return self._conv

self.conv = conv
self.spk = spk
@property
def spk(self):
return self._spk

self._pool_lyrs = self._make_pool_layers() # creates SumPool2d layers from `pool`.

@property
def pool(self):
return self._pool
Expand All @@ -101,7 +108,7 @@ def conv_out_shape(self):

####################################################### Public Methods #######################################################

def forward(self, x):
def forward(self, x) -> List[torch.Tensor]:
"""Torch forward pass.
...
Expand All @@ -113,12 +120,13 @@ def forward(self, x):
x = self.spk(x)

for pool in self._pool:

if pool == 1:
# no pooling is applied.
returns.append(x)
else:
# sum pooling of `(pool, pool)` is applied.
pool_out = self._pool_lyrs[pool](x)
pool_out = sum_pool2d(x, kernel_size=pool)
returns.append(pool_out)

return tuple(returns)
Expand Down Expand Up @@ -212,41 +220,6 @@ def _convert_linear_to_conv(self, lin: nn.Linear, layer_data: dict) -> Tuple[nn.
)

return layer, input_shape

def _make_pool_layers(self) -> Dict[int, sl.SumPool2d]:
""" Creates a `sl.SumPool2d` for each entry in `self._pool` greater than one.
Note: the "kernel size" (values > 1) in self._pool is by default used to set the stride of the pooling layer.
Returns
-------
- pool_lyrs (dict): the `key` is a value grather than 1 in `self._pool`, with the `value` being the `sl.SumPool2d` it represents.
"""

pool_lyrs = {}

# validating if pool are integers
for item in self._pool:
if not isinstance(item, int):
raise ValueError(f"Item '{item}' in `pool` is not an integer.")

# create layers form pool list.
for kernel_s in self._pool:

if kernel_s != 1:

pooling = (kernel_s, kernel_s)

# compute cumulative pooling.
cumulative_pooling = (
cumulative_pooling[0] * pooling[0],
cumulative_pooling[1] * pooling[1],
)

# create SumPool2d layer.
pool_lyrs[kernel_s] = sl.SumPool2d(cumulative_pooling)

return pool_lyrs

def _get_conv_output_shape(self) -> Tuple[int, int, int]:
""" Computes the output dimensions of `conv_layer`.
Expand All @@ -257,10 +230,7 @@ def _get_conv_output_shape(self) -> Tuple[int, int, int]:
"""
# get the layer's parameters.

spk = deepcopy()
out_channels = self.conv.out_channels

spk = deepcopy()
kernel_size = self.conv.kernel_size
stride = self.conv.stride
padding = self.conv.padding
Expand All @@ -270,4 +240,4 @@ def _get_conv_output_shape(self) -> Tuple[int, int, int]:
out_height = ((self.in_shape[1] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0]) + 1
out_width = ((self.in_shape[2] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1]) + 1

return (out_channels, out_height, out_width)
return (out_channels, out_height, out_width)

0 comments on commit f9a797d

Please sign in to comment.