A compilation of implementations of various ML papers, especially in computer vision. This contains some self-implementations and unofficial & official implementations. More to be added.
$ pip install torch-modules-compilation
-
- Bottleneck Residual Block
- Depthwise Seperable Convolution
- SAGAN self-attention module
- Global-Local Attention Module
- Global Context Module
- LFSA Tokenizer and Refinement Block
- Parameter-Free Channel Attention (PFCA)
- Patch Merger
- ResBlock
- Up/Down sample ResBlock
- Residual MLP Block
- Residual MLP Downsampling Block
- Transformer Encoder Layer
- UNet Encoder and Decoder
- Squeeze-Excitation Module
- Token Learner
- Triplet Attention
Your basic bottleneck residual block in ResNets. Image from the paper "Deep Residual Learning for Image Recognition"
in_channels
(int): number of input channels
bottleneck_channels
(int): number of bottleneck channels; usually less than the number of bottleneck channels
dropout
(float): dropout rate; performed after every convolution
from torch_modules_compilation import modules
x = torch.randn(32, 256, 16, 16) # (batch_size, channels, height, width)
block = modules.BottleneckResBlock(in_channels=256, bottleneck_channels=64)
block(x).shape # (32, 256, 16, 16)
A depthwise seperable convolution; consists of a depthwise convolution and a pointwise convolution. Used in MobileNets and used in the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications". Image also from this paper.
in_channels
(int): Number of input channels
out_channels
(int): Number of output channels
kernel_size
(int): Size of depthwise convolution kernel
stride
(int): Stride of depthwise convolution
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.DepthwiseSepConv(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
block(x).shape # (32, 128, 16, 16)
A feature map self-attention module used in SAGAN; "Self-Attention Generative Adversarial Networks". Image also from this paper. This code implementation was copied and modified from https://github.com/rosinality/sagan-pytorch/blob/master/model.py#L82 under Apache 2.0 License. Modification removes spectral initalization.
in_channels
(int): Number of input channels
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.FeatureMapSelfAttention(in_channels=64)
block(x).shape # (32, 64, 16, 16)
An convolutional attention module introduced in the paper "All the attention you need: Global-local, spatial-channel attention for image retrieval.". Image also from this paper.
in_channels
(int): number of channels of the input feature map
num_reduced_channels
(int): number of channels that the local and global spatial attention modules will reduce the input feature map. Refer to figures 3 and 5 in the paper.
feaure_map_size
(int): height/width of the feature map. The height/width of the input feature maps must be at least 7, due to the 7x7 convolution (3x3 dilated conv) in the module.
kernel_size
(int): scope of the inter-channel attention
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.GLAM(in_channels=64, num_reduced_channels=48, feature_map_size=16, kernel_size=5)
# height and width is equal to feature_map_size
block(x).shape # (32, 64, 16, 16)
A sort of self-attention (non-local) block on feature maps. Implementation of "GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond".
input_channels
(int): Number of input channels
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.GlobalContextModule(input_channels=64)
block(x).shape # (32, 64, 16, 16)
Implementation of the tokenizer in "Learning Token-Based Representation for Image Retrieval" This are two modules: The tokenizer module that converts feature maps from a CNN (in the paper's case, feature maps from a local-feature-self-attention module) and tokenizes them "into L visual tokens". This is used prior to the refinement block as described in the paper. The refinement block "enhance[s] the obtained visual tokens with self-attention and cross-attention."
LFSA Tokenizer
in_channels
(int): number of input channels
num_att_maps
(int): number of tokens to tokenize the input into; also the number of channels used by the spatial attention
Refinement Block
d_model
(int): dimensionality/channels of input
nhead
(int): number of attention heads in the transformer
dim_feedforward
(int): number of hidden dimensions in the feedforward layers
dropout
(int): dropout rate
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
tokenizer = modules.LFSATokenizer(in_channels=64, num_att_maps=48)
refinement_block = modules.RefinementBlock(d_model=64, nhead=2, dim_feedforward=48*4, dropout=0.1)
visual_tokens, cnn_output = tokenizer(x)
print(visual_tokens.shape) # (32, 48, 64)
print(cnn_output.shape) # (32, 16*16, 64)
output = refinement_block(visual_tokens, cnn_output)
print(output.shape) # (32, 48, 64)
A channel attention module for convolutional feature maps without any trainable parameters. Used in and image from the paper "PARAMETER-FREE CHANNEL ATTENTION FOR IMAGE CLASSIFICATION AND SUPER-RESOLUTION".
feature_map_size
(int): Length/width of the input feature map
_lambda
(float): A hyperparameter that is added to the variance (default: 1e-4)
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.ParameterFreeChannelAttention(feature_map_size=16)
block(x).shape # (32, 64, 16, 16)
Merges N tokens into M tokens in transformer models. Typically added in-between transformer layers. Introduced in the paper "LEARNING TO MERGE TOKENS IN VISION TRANSFORMERS". Image from this paper. Copied from lucidrains' repo under the MIT license.
dim
(int): dimensionality/channels of the tokens
output_tokens
(int): number of output merged tokens
norm
(bool): normalize the input before merging
scale
(bool): scale the attention matrix by the square root of dim (for numerical stability)
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16) # (batch_size, seq_length, channels)
block = modules.PatchMerger(dim=16, output_tokens=48, scale=True)
block(x).shape # (32, 48, 16)
Your basic residual block. Used in ResNets. Image from original paper "Deep Residual Learning for Image Recognition"
in_channels
(int): number of input channels
kernel_size
(int): kernel size
dropout
(float): dropout rate
from torch_modules_compilation import modules
x = torch.randn(32, 64, 16, 16) # (batch_size, seq_length, channels)
block = modules.ResBlock(in_channels=64, kernel_size=3, dropout=0.2)
block(x).shape # (32, 64, 16, 16)
Composed of several residual blocks and a down/up sampling at the end; adapted from Stable Diffusion's ResnetBlock.
in_channels
(int): number of input channels
out_channels
(int): number of output channels
num_groups
(int): number of groups for Group Normalization
num_layers
(int): number of residual blocks
dropout
(float): dropout rate
sample
(str): One of "down", "up", or "none". For downsampling 2x, use "down". For upsampling 2x, use "up". Use "none" for no down/up sampling.
from torch_modules_compilation import modules
x = torch.randn(32, 64, 96, 96) # (batch_size, channels, height, width)
block = modules.ResBlockUpDownSample(
in_channels=64,
out_channels=128,
num_groups=8,
num_layers=2,
dropout=0.1,
sample='down'
)
block(x).shape # (32, 128, 48, 48)
An improvement of standard MLPs along with residual connections. From "Generalizing MLPs With Dropouts, Batch Normalization, and Skip Connections". This implements the residual MLP block (eq. 5 in the paper).
dim
(int): number of input dimensions
ic_first
(bool): normalize and dropout at the start
dropout
(float): dropout rate
from torch_modules_compilation import modules
x = torch.randn(32, 96) # (batch_size, dim)
block = modules.ResidualMLP_block(dim=96, ic_first=True, dropout=0.1)
block(x).shape # (32, 96)
An improvement of standard MLPs along with residual connections. From "Generalizing MLPs With Dropouts, Batch Normalization, and Skip Connections". This implements the residual MLP block (eq. 6 in the paper).
dim
(int): number of input dimensions
downsample_dim
(int): number of output dimensions
dropout
(float): dropout rate
from torch_modules_compilation import modules
x = torch.randn(32, 96) # (batch_size, dim)
block = modules.ResidualMLP_downsample(dim=96, downsample_dim=48, dropout=0.1)
block(x).shape # (32, 48)
Standard transformer encoder layer with queries, keys, and values as inputs.
d_model
(int): model dimensionality
nhead
(int): number of attention heads
dim_feedforward
(int): number of hidden dimensions in the feedforward layers
dropout
(float): dropout rate
kdim
(int, optional): dimensions of the keys
vdim
(int, optional): dimensions of the values
from torch_modules_compilation import modules
queries = torch.randn(32, 20, 64) # (batch_size, seq_length, dim)
keys = torch.randn(32, 19, 48) # (batch_size, seq_length, dim)
values = torch.randn(32, 19, 96) # (batch_size, seq_length, dim)
block = modules.TransformerEncoderLayer(
d_model=64,
nhead=8,
dim_feedforward=256,
dropout=0.2,
kdim=48,
vdim=96
)
block(queries, keys, values).shape # (32, 20, 64)
Standard UNet implementation. From the paper U-Net: Convolutional Networks for Biomedical Image Segmentation.
UNet Encoder
channels
(list): A list containing the number of channels in the encoder. E.g [3, 64, 128, 256]
dropout
(float): dropout rate
UNet Decoder
channels
(list of ints): A list containing the number of channels in the encoder. E.g. [256, 128, 64, 3]
dropout
(float): dropout rate
from torch_modules_compilation import modules
images = torch.randn(16, 3, 224, 224) # (batch_size, channels, height, width)
unet_encoder = modules.UnetEncoder(channels=[3,64,128,256], dropout=0.1)
unet_decoder = modules.UnetDecoder(channels=[256,128,64,3], dropout=0.1)
encoder_features = unet_encoder(images)
output = unet_decoder(encoder_features)
print(output.shape) # (16, 64, 224, 224)
Module that computes channel-wise interactions in a feature map. From Squeeze-and-Excitation Networks.
in_channels
(int): Number of input channels
reduced_channels
(int): Number of channels to reduce to in the "squeeze" part of the module
feature_map_size
(int): height/width of the feature map
from torch_modules_compilation import modules
feature_maps = torch.randn(16, 128, 64, 64) # (batch_size, channels, height, width)
se_module = modules.SEModule(in_channels=128, reduced_channels=32, feature_map_size=64)
se_module(feature_maps) # shape (16, 128, 64, 64); same as input
Module designed for reducing and generating visual tokens given a feature map. From TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?
in_channels
(int): Number of input channels
num_tokens
(int): Number of tokens to reduce to
from torch_modules_compilation import modules
feature_maps = torch.randn(2, 16, 10, 10) # (batch_size, channels, height, width)
token_learner = modules.TokenLearner(in_channels=16, num_tokens=50) # reduce tokens from 10*10 to 50
token_learner(feature_maps) # shape (2, 50, 16)
Computes attention in a feature map across all three dimensions (channel and both spatial dims). From Rotate to Attend: Convolutional Triplet Attention Module.
in_channels
(int): Number of input channels
height
(int): height of feature map
width
(int): width of feature map
kernel_size
(int): kernel size of the convolutions. Default: 7
from torch_modules_compilation import modules
feature_maps = torch.randn(2, 16, 10, 10) # (batch_size, channels, height, width)
triplet_attention = modules.TripletAttention(in_channels=16, height=10, width=10)
triplet_attention(feature_maps) # shape (2, 16, 10, 10); same as input
Unless specified, some of these modules are licensed under various licenses and/or copied from other repositories, such as MIT and Apache. Take note of these licenses when using these code in your work. The rest are of my own implementation, which is under the MIT license. See this repo's license file