From a6738888093a4a3a0f9f5e9217d6b002ebaeed39 Mon Sep 17 00:00:00 2001 From: David Chavez Date: Mon, 8 Apr 2024 18:08:49 +0200 Subject: [PATCH] feature(models): Add unet_2d_blocks (#18) --- Cargo.toml | 5 +- src/lib.rs | 4 +- src/models/attention.rs | 50 +- src/models/embeddings.rs | 3 +- src/models/mod.rs | 1 + src/models/resnet.rs | 6 +- src/models/unet_2d_blocks.rs | 927 +++++++++++++++++++++++++++++++++++ src/transformers/clip.rs | 3 +- 8 files changed, 974 insertions(+), 25 deletions(-) create mode 100644 src/models/unet_2d_blocks.rs diff --git a/Cargo.toml b/Cargo.toml index fa74e4c..3c17d05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,9 +16,12 @@ torch = ["burn/tch"] wgpu = ["burn/wgpu"] [dependencies] -burn = { version = "0.12.1", default-features = false } +burn = { version = "0.13.0", default-features = false } num-traits = { version = "0.2.17", default-features = false } serde = { version = "1.0.192", default-features = false, features = [ "derive", "alloc", ] } + +[patch.crates-io] +burn = { git = "https://github.com/tracel-ai/burn" } diff --git a/src/lib.rs b/src/lib.rs index 4720980..679284c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ pub mod pipelines; pub mod transformers; pub mod utils; -#[cfg(all(test, not(feature = "wgpu"), not(feature = "torch")))] +#[cfg(all(test, feature = "ndarray"))] use burn::backend::ndarray; #[cfg(all(test, feature = "torch"))] @@ -20,7 +20,7 @@ use burn::backend::wgpu; extern crate alloc; -#[cfg(all(test, not(feature = "wgpu"), not(feature = "torch")))] +#[cfg(all(test, feature = "ndarray"))] pub type TestBackend = ndarray::NdArray; #[cfg(all(test, feature = "torch"))] diff --git a/src/models/attention.rs b/src/models/attention.rs index c3c5bff..70372f1 100644 --- a/src/models/attention.rs +++ b/src/models/attention.rs @@ -297,36 +297,50 @@ pub struct SpatialTransformerConfig { pub n_groups: usize, pub d_context: Option, pub sliced_attn_size: Option, - // #[config(default = false)] - // pub use_linear_projection: bool, + #[config(default = false)] + pub use_linear_projection: bool, pub in_channels: usize, pub n_heads: usize, pub d_head: usize, } -//#[derive(Config, Debug)] -//enum Proj { -// Conv2d(nn::conv::Conv2d), -// Linear(nn::Linear) -//} +#[derive(Module, Debug)] +enum Proj { + Conv2d(nn::conv::Conv2d), + Linear(nn::Linear), +} + +impl Proj { + fn forward(&self, xs: Tensor) -> Tensor { + match self { + Proj::Conv2d(conv) => conv.forward(xs), + Proj::Linear(linear) => linear.forward(xs), + } + } +} /// Aka Transformer2DModel #[derive(Module, Debug)] pub struct SpatialTransformer { norm: GroupNorm, - proj_in: nn::conv::Conv2d, + proj_in: Proj, transformer_blocks: Vec>, proj_out: nn::conv::Conv2d, } impl SpatialTransformerConfig { - fn init(&self, device: &B::Device) -> SpatialTransformer { + pub fn init(&self, device: &B::Device) -> SpatialTransformer { let d_inner = self.n_heads * self.d_head; let norm = GroupNormConfig::new(self.n_groups, self.in_channels) .with_epsilon(1e-6) .init(device); - // let proj_in = if config.use_linear_projection { - let proj_in = nn::conv::Conv2dConfig::new([self.in_channels, d_inner], [1, 1]).init(device); + let proj_in = if self.use_linear_projection { + Proj::Linear(nn::LinearConfig::new(self.in_channels, d_inner).init(device)) + } else { + Proj::Conv2d( + nn::conv::Conv2dConfig::new([self.in_channels, d_inner], [1, 1]).init(device), + ) + }; let mut transformer_blocks = vec![]; for _index in 0..self.depth { @@ -351,7 +365,7 @@ impl SpatialTransformerConfig { } impl SpatialTransformer { - fn forward(&self, xs: Tensor, context: Option>) -> Tensor { + pub fn forward(&self, xs: Tensor, context: Option>) -> Tensor { let [n_batch, _n_channel, height, weight] = xs.dims(); let residual = xs.clone(); @@ -402,7 +416,7 @@ pub struct AttentionBlock { } impl AttentionBlockConfig { - fn init(&self, device: &B::Device) -> AttentionBlock { + pub fn init(&self, device: &B::Device) -> AttentionBlock { let n_head_channels = self.n_head_channels.unwrap_or(self.channels); let n_heads = self.channels / n_head_channels; let group_norm = GroupNormConfig::new(self.n_groups, self.channels) @@ -433,7 +447,7 @@ impl AttentionBlock { .swap_dims(1, 2) } - fn forward(&self, xs: Tensor) -> Tensor { + pub fn forward(&self, xs: Tensor) -> Tensor { let residual = xs.clone(); let [n_batch, channel, height, width] = xs.dims(); let xs = self @@ -506,8 +520,8 @@ mod tests { let geglu = GeGlu { proj: nn::Linear { - weight: Param::new(ParamId::new(), weight), - bias: Some(Param::new(ParamId::new(), bias)), + weight: Param::initialized(ParamId::new(), weight), + bias: Some(Param::initialized(ParamId::new(), bias)), }, }; @@ -562,8 +576,8 @@ mod tests { let geglu = GeGlu { proj: nn::Linear { - weight: Param::new(ParamId::new(), weight), - bias: Some(Param::new(ParamId::new(), bias)), + weight: Param::initialized(ParamId::new(), weight), + bias: Some(Param::initialized(ParamId::new(), bias)), }, }; diff --git a/src/models/embeddings.rs b/src/models/embeddings.rs index 8338e8e..8a13124 100644 --- a/src/models/embeddings.rs +++ b/src/models/embeddings.rs @@ -54,6 +54,7 @@ impl Timesteps { Self { num_channels, flip_sin_to_cos, + downscale_freq_shift, _backend: PhantomData, } @@ -61,7 +62,7 @@ impl Timesteps { pub fn forward(&self, xs: Tensor) -> Tensor { let half_dim = self.num_channels / 2; - let exponent = Tensor::arange(0..half_dim, &xs.device()).float() * -f64::ln(10000.); + let exponent = Tensor::arange(0..half_dim as i64, &xs.device()).float() * -f64::ln(10000.); let exponent = exponent / (half_dim as f64 - self.downscale_freq_shift); let emb = exponent.exp(); // emb = timesteps[:, None].float() * emb[None, :] diff --git a/src/models/mod.rs b/src/models/mod.rs index 845a097..5d7234d 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -5,3 +5,4 @@ pub mod attention; pub mod embeddings; pub mod resnet; +pub mod unet_2d_blocks; diff --git a/src/models/resnet.rs b/src/models/resnet.rs index 972d43c..57c1e31 100644 --- a/src/models/resnet.rs +++ b/src/models/resnet.rs @@ -99,7 +99,7 @@ impl ResnetBlock2D { let xs = self.norm1.forward(xs.clone()); let xs = self.conv1.forward(silu(xs)); - match (temb, &self.time_emb_proj) { + let xs = match (temb, &self.time_emb_proj) { (Some(temb), Some(time_emb_proj)) => { time_emb_proj .forward(silu(temb)) @@ -109,7 +109,9 @@ impl ResnetBlock2D { } _ => xs.clone(), }; - let xs = self.conv2.forward(silu(self.norm2.forward(xs))); + let xs = self.norm2.forward(xs); + let xs = silu(xs); + let xs = self.conv2.forward(xs); (shortcut_xs + xs) / self.output_scale_factor } } diff --git a/src/models/unet_2d_blocks.rs b/src/models/unet_2d_blocks.rs new file mode 100644 index 0000000..3a856df --- /dev/null +++ b/src/models/unet_2d_blocks.rs @@ -0,0 +1,927 @@ +//! 2D UNet Building Blocks +//! + +use burn::{ + config::Config, + module::Module, + nn, + tensor::{ + backend::Backend, + module::{avg_pool2d, interpolate}, + ops::{InterpolateMode, InterpolateOptions}, + Tensor, + }, +}; + +use crate::utils::pad_with_zeros; + +use super::{ + attention::{ + AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, + }, + resnet::{ResnetBlock2D, ResnetBlock2DConfig}, +}; + +use alloc::vec; +use alloc::vec::Vec; + +#[derive(Config)] +struct Downsample2DConfig { + in_channels: usize, + use_conv: bool, + out_channels: usize, + padding: usize, +} + +#[derive(Module, Debug)] +struct Downsample2D { + conv: Option>, + padding: usize, +} + +impl Downsample2DConfig { + fn init(&self, device: &B::Device) -> Downsample2D { + let conv = if self.use_conv { + let conv = nn::conv::Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]) + .with_stride([2, 2]) + .with_padding(nn::PaddingConfig2d::Explicit(self.padding, self.padding)) + .init(device); + + Some(conv) + } else { + None + }; + + Downsample2D { + conv, + padding: self.padding, + } + } +} + +impl Downsample2D { + fn pad_tensor(xs: Tensor, padding: usize) -> Tensor { + if padding == 0 { + let xs = pad_with_zeros(xs, 4 - 1, 0, 1); + return pad_with_zeros(xs, 4 - 2, 0, 1); + } + + return xs; + } + + fn forward(&self, xs: Tensor) -> Tensor { + match &self.conv { + None => avg_pool2d(xs, [2, 2], [2, 2], [0, 0], true), + Some(conv) => conv.forward(Self::pad_tensor(xs, self.padding)), + } + } +} + +#[derive(Config)] +struct Upsample2DConfig { + in_channels: usize, + out_channels: usize, +} + +// This does not support the conv-transpose mode. +#[derive(Module, Debug)] +struct Upsample2D { + conv: nn::conv::Conv2d, +} + +impl Upsample2DConfig { + fn init(&self, device: &B::Device) -> Upsample2D { + let conv = nn::conv::Conv2dConfig::new([self.in_channels, self.out_channels], [3, 3]) + .with_padding(nn::PaddingConfig2d::Explicit(1, 1)) + .init(device); + + Upsample2D { conv } + } +} + +impl Upsample2D { + fn forward(&self, xs: Tensor, size: Option<(usize, usize)>) -> Tensor { + let xs = match size { + None => { + let [_bsize, _channels, height, width] = xs.dims(); + interpolate( + xs, + [2 * height, 2 * width], + InterpolateOptions::new(InterpolateMode::Nearest), + ) + } + Some((h, w)) => interpolate( + xs, + [h, w], + InterpolateOptions::new(InterpolateMode::Nearest), + ), + }; + + self.conv.forward(xs) + } +} + +#[derive(Config, Debug)] +pub struct DownEncoderBlock2DConfig { + pub in_channels: usize, + pub out_channels: usize, + #[config(default = 1)] + pub n_layers: usize, + #[config(default = 1e-6)] + pub resnet_eps: f64, + #[config(default = 32)] + pub resnet_groups: usize, + #[config(default = 1.)] + pub output_scale_factor: f64, + #[config(default = true)] + pub add_downsample: bool, + #[config(default = 1)] + pub downsample_padding: usize, +} + +#[derive(Module, Debug)] +pub struct DownEncoderBlock2D { + resnets: Vec>, + downsampler: Option>, +} + +impl DownEncoderBlock2DConfig { + pub fn init(&self, device: &B::Device) -> DownEncoderBlock2D { + let resnets: Vec<_> = { + (0..(self.n_layers)) + .map(|i| { + let in_channels = if i == 0 { + self.in_channels + } else { + self.out_channels + }; + + let conv_cfg = ResnetBlock2DConfig::new(in_channels) + .with_out_channels(Some(self.out_channels)) + .with_groups(self.resnet_groups) + .with_eps(self.resnet_eps) + .with_output_scale_factor(self.output_scale_factor); + + conv_cfg.init(device) + }) + .collect() + }; + + let downsampler = if self.add_downsample { + let downsample_cfg = Downsample2DConfig { + in_channels: self.out_channels, + use_conv: true, + out_channels: self.out_channels, + padding: self.downsample_padding, + }; + Some(downsample_cfg.init(device)) + } else { + None + }; + + DownEncoderBlock2D { + resnets, + downsampler, + } + } +} + +impl DownEncoderBlock2D { + fn forward(&self, xs: Tensor) -> Tensor { + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(xs, None) + } + match &self.downsampler { + Some(downsampler) => downsampler.forward(xs), + None => xs, + } + } +} + +#[derive(Config, Debug)] +pub struct UpDecoderBlock2DConfig { + pub in_channels: usize, + pub out_channels: usize, + #[config(default = 1)] + pub n_layers: usize, + #[config(default = 1e-6)] + pub resnet_eps: f64, + #[config(default = 32)] + pub resnet_groups: usize, + #[config(default = 1.)] + pub output_scale_factor: f64, + #[config(default = true)] + pub add_upsample: bool, +} + +#[derive(Module, Debug)] +pub struct UpDecoderBlock2D { + resnets: Vec>, + upsampler: Option>, +} + +impl UpDecoderBlock2DConfig { + pub fn init(&self, device: &B::Device) -> UpDecoderBlock2D { + let resnets: Vec<_> = { + (0..(self.n_layers)) + .map(|i| { + let in_channels = if i == 0 { + self.in_channels + } else { + self.out_channels + }; + + let conv_cfg = ResnetBlock2DConfig::new(in_channels) + .with_out_channels(Some(self.out_channels)) + .with_groups(self.resnet_groups) + .with_eps(self.resnet_eps) + .with_output_scale_factor(self.output_scale_factor); + + conv_cfg.init(device) + }) + .collect() + }; + + let upsampler = if self.add_upsample { + let upsample = Upsample2DConfig { + in_channels: self.out_channels, + out_channels: self.out_channels, + }; + + Some(upsample.init(device)) + } else { + None + }; + + UpDecoderBlock2D { resnets, upsampler } + } +} + +impl UpDecoderBlock2D { + fn forward(&self, xs: Tensor) -> Tensor { + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(xs, None) + } + match &self.upsampler { + Some(upsampler) => upsampler.forward(xs, None), + None => xs, + } + } +} + +#[derive(Config, Debug)] +pub struct UNetMidBlock2DConfig { + in_channels: usize, + temb_channels: Option, + #[config(default = 1)] + pub n_layers: usize, + #[config(default = 1e-6)] + pub resnet_eps: f64, + pub resnet_groups: Option, + pub attn_num_head_channels: Option, + // attention_type "default" + #[config(default = 1.)] + pub output_scale_factor: f64, +} + +#[derive(Module, Debug)] +struct AttentionResnetBlock2D { + attention_block: AttentionBlock, + resnet_block: ResnetBlock2D, +} + +#[derive(Module, Debug)] +pub struct UNetMidBlock2D { + resnet: ResnetBlock2D, + attn_resnets: Vec>, +} + +impl UNetMidBlock2DConfig { + pub fn init(&self, device: &B::Device) -> UNetMidBlock2D { + let resnet_groups = self + .resnet_groups + .unwrap_or_else(|| usize::min(self.in_channels / 4, 32)); + + let resnet = ResnetBlock2DConfig::new(self.in_channels) + .with_eps(self.resnet_eps) + .with_groups(resnet_groups) + .with_output_scale_factor(self.output_scale_factor) + .with_temb_channels(self.temb_channels) + .init(device); + + let mut attn_resnets = vec![]; + for _index in 0..self.n_layers { + let attention_block = AttentionBlockConfig::new(self.in_channels) + .with_n_head_channels(self.attn_num_head_channels) + .with_n_groups(resnet_groups) + .with_rescale_output_factor(self.output_scale_factor) + .with_eps(self.resnet_eps) + .init(device); + + let resnet_block = ResnetBlock2DConfig::new(self.in_channels) + .with_eps(self.resnet_eps) + .with_groups(resnet_groups) + .with_output_scale_factor(self.output_scale_factor) + .with_temb_channels(self.temb_channels) + .init(device); + + attn_resnets.push(AttentionResnetBlock2D { + attention_block, + resnet_block, + }) + } + + UNetMidBlock2D { + resnet, + attn_resnets, + } + } +} + +impl UNetMidBlock2D { + pub fn forward(&self, xs: Tensor, temb: Option>) -> Tensor { + let mut xs = self.resnet.forward(xs, temb.clone()); + for block in self.attn_resnets.iter() { + xs = block + .resnet_block + .forward(block.attention_block.forward(xs), temb.clone()) + } + + xs + } +} + +#[derive(Config)] +pub struct UNetMidBlock2DCrossAttnConfig { + in_channels: usize, + temb_channels: Option, + #[config(default = 1)] + pub n_layers: usize, + #[config(default = 1e-6)] + pub resnet_eps: f64, + // Note: Should default to 32 + pub resnet_groups: Option, + #[config(default = 1)] + pub attn_num_head_channels: usize, + // attention_type "default" + #[config(default = 1.)] + pub output_scale_factor: f64, + #[config(default = 1280)] + pub cross_attn_dim: usize, + pub sliced_attention_size: Option, + #[config(default = false)] + pub use_linear_projection: bool, + #[config(default = 1)] + pub transformer_layers_per_block: usize, +} + +#[derive(Module, Debug)] +struct SpatialTransformerResnetBlock2D { + spatial_transformer: SpatialTransformer, + resnet_block: ResnetBlock2D, +} + +#[derive(Module, Debug)] +pub struct UNetMidBlock2DCrossAttn { + resnet: ResnetBlock2D, + attn_resnets: Vec>, +} + +impl UNetMidBlock2DCrossAttnConfig { + pub fn init(&self, device: &B::Device) -> UNetMidBlock2DCrossAttn { + let resnet_groups = self + .resnet_groups + .unwrap_or_else(|| usize::min(self.in_channels / 4, 32)); + let resnet = ResnetBlock2DConfig::new(self.in_channels) + .with_eps(self.resnet_eps) + .with_groups(resnet_groups) + .with_output_scale_factor(self.output_scale_factor) + .with_temb_channels(self.temb_channels) + .init(device); + + let mut attn_resnets = vec![]; + for _index in 0..self.n_layers { + let spatial_transformer = SpatialTransformerConfig::new( + self.in_channels, + self.attn_num_head_channels, + self.in_channels / self.attn_num_head_channels, + ) + .with_depth(1) + .with_n_groups(resnet_groups) + .with_d_context(Some(self.cross_attn_dim)) + .with_sliced_attn_size(self.sliced_attention_size) + .with_use_linear_projection(self.use_linear_projection) + .init(device); + + let resnet_block = ResnetBlock2DConfig::new(self.in_channels) + .with_eps(self.resnet_eps) + .with_groups(resnet_groups) + .with_output_scale_factor(self.output_scale_factor) + .with_temb_channels(self.temb_channels) + .init(device); + + attn_resnets.push(SpatialTransformerResnetBlock2D { + spatial_transformer, + resnet_block, + }) + } + + UNetMidBlock2DCrossAttn { + resnet, + attn_resnets, + } + } +} + +impl UNetMidBlock2DCrossAttn { + pub fn forward( + &self, + xs: Tensor, + temb: Option>, + encoder_hidden_states: Option>, + ) -> Tensor { + let mut xs = self.resnet.forward(xs, temb.clone()); + for block in self.attn_resnets.iter() { + let trans = block + .spatial_transformer + .forward(xs, encoder_hidden_states.clone()); + xs = self.resnet.forward(trans, temb.clone()); + } + + xs + } +} + +#[derive(Config, Copy)] +pub struct DownBlock2DConfig { + in_channels: usize, + out_channels: usize, + temb_channels: Option, + #[config(default = 1)] + pub n_layers: usize, + #[config(default = 1e-6)] + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + #[config(default = 32)] + pub resnet_groups: usize, + #[config(default = 1.)] + pub output_scale_factor: f64, + #[config(default = true)] + pub add_downsample: bool, + #[config(default = 1)] + pub downsample_padding: usize, +} + +#[derive(Module, Debug)] +pub struct DownBlock2D { + resnets: Vec>, + downsampler: Option>, +} + +impl DownBlock2DConfig { + pub fn init(&self, device: &B::Device) -> DownBlock2D { + let resnets = (0..self.n_layers) + .map(|_| { + ResnetBlock2DConfig::new(self.out_channels) + .with_eps(self.resnet_eps) + .with_groups(self.resnet_groups) + .with_output_scale_factor(self.output_scale_factor) + .with_temb_channels(self.temb_channels) + .init(device) + }) + .collect(); + + let downsampler = if self.add_downsample { + Some( + Downsample2DConfig::new( + self.out_channels, + true, + self.out_channels, + self.downsample_padding, + ) + .init(device), + ) + } else { + None + }; + + DownBlock2D { + resnets, + downsampler, + } + } +} + +impl DownBlock2D { + pub fn forward( + &self, + xs: Tensor, + temb: Option>, + ) -> (Tensor, Vec>) { + let mut xs = xs; + let mut output_states = vec![]; + for resnet in self.resnets.iter() { + xs = resnet.forward(xs, temb.clone()); + output_states.push(xs.clone()); + } + + if let Some(downsampler) = &self.downsampler { + xs = downsampler.forward(xs); + output_states.push(xs.clone()); + } + + (xs, output_states) + } +} + +#[derive(Config)] +pub struct CrossAttnDownBlock2DConfig { + in_channels: usize, + out_channels: usize, + temb_channels: Option, + pub downblock: DownBlock2DConfig, + #[config(default = 1)] + pub attn_num_head_channels: usize, + #[config(default = 1280)] + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option, + #[config(default = false)] + pub use_linear_projection: bool, +} + +#[derive(Module, Debug)] +pub struct CrossAttnDownBlock2D { + downblock: DownBlock2D, + attentions: Vec>, +} + +impl CrossAttnDownBlock2DConfig { + pub fn init(&self, device: &B::Device) -> CrossAttnDownBlock2D { + let mut downblock = self.downblock; + downblock.in_channels = self.in_channels; + downblock.out_channels = self.out_channels; + downblock.temb_channels = self.temb_channels; + let downblock = self.downblock.init(device); + + let attentions = (0..self.downblock.n_layers) + .map(|_| { + SpatialTransformerConfig::new( + self.out_channels, + self.attn_num_head_channels, + self.out_channels / self.attn_num_head_channels, + ) + .with_depth(1) + .with_d_context(Some(self.cross_attention_dim)) + .with_n_groups(self.downblock.resnet_groups) + .with_sliced_attn_size(self.sliced_attention_size) + .with_use_linear_projection(self.use_linear_projection) + .init(device) + }) + .collect(); + + CrossAttnDownBlock2D { + downblock, + attentions, + } + } +} + +impl CrossAttnDownBlock2D { + pub fn forward( + &self, + xs: Tensor, + temb: Option>, + encoder_hidden_states: Option>, + ) -> (Tensor, Vec>) { + let mut xs = xs; + let mut output_states = vec![]; + for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) { + xs = resnet.forward(xs, temb.clone()); + xs = attn.forward(xs, encoder_hidden_states.clone()); + output_states.push(xs.clone()); + } + + if let Some(downsampler) = &self.downblock.downsampler { + xs = downsampler.forward(xs); + output_states.push(xs.clone()); + } + + (xs, output_states) + } +} + +#[derive(Config)] +pub struct UpBlock2DConfig { + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option, + #[config(default = 1)] + pub n_layers: usize, + #[config(default = 1e-6)] + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + #[config(default = 32)] + pub resnet_groups: usize, + #[config(default = 1.)] + pub output_scale_factor: f64, + #[config(default = true)] + pub add_upsample: bool, +} + +#[derive(Module, Debug)] +pub struct UpBlock2D { + resnets: Vec>, + upsampler: Option>, +} + +impl UpBlock2DConfig { + pub fn init(&self, device: &B::Device) -> UpBlock2D { + let resnets = (0..self.n_layers) + .map(|i| { + let res_skip_channels = if i == self.n_layers - 1 { + self.in_channels + } else { + self.out_channels + }; + + let resnet_in_channels = if i == 0 { + self.prev_output_channels + } else { + self.out_channels + }; + + let in_channels = resnet_in_channels + res_skip_channels; + + ResnetBlock2DConfig::new(self.in_channels) + .with_out_channels(Some(self.out_channels)) + .with_temb_channels(self.temb_channels) + .with_eps(self.resnet_eps) + .with_output_scale_factor(self.output_scale_factor) + .init(device) + }) + .collect(); + + let upsampler = if self.add_upsample { + let upsampler = + Upsample2DConfig::new(self.out_channels, self.out_channels).init(device); + Some(upsampler) + } else { + None + }; + + UpBlock2D { resnets, upsampler } + } +} + +impl UpBlock2D { + pub fn forward( + &self, + xs: Tensor, + res_xs: &[Tensor], + temb: Option>, + upsample_size: Option<(usize, usize)>, + ) -> Tensor { + let mut xs = xs; + for (index, resnet) in self.resnets.iter().enumerate() { + xs = Tensor::cat( + vec![xs.clone(), res_xs[res_xs.len() - index - 1].clone()], + 1, + ); + xs = resnet.forward(xs, temb.clone()); + } + + match &self.upsampler { + Some(upsampler) => upsampler.forward(xs, upsample_size), + None => xs, + } + } +} + +#[derive(Config)] +pub struct CrossAttnUpBlock2DConfig { + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option, + pub upblock: UpBlock2DConfig, + #[config(default = 1)] + pub attn_num_head_channels: usize, + #[config(default = 1280)] + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option, + #[config(default = false)] + pub use_linear_projection: bool, +} + +#[derive(Module, Debug)] +pub struct CrossAttnUpBlock2D { + pub upblock: UpBlock2D, + pub attentions: Vec>, +} + +impl CrossAttnUpBlock2DConfig { + pub fn init(&self, device: &B::Device) -> CrossAttnUpBlock2D { + let mut upblock_config = self.upblock.clone(); + upblock_config.in_channels = self.in_channels; + upblock_config.prev_output_channels = self.prev_output_channels; + upblock_config.out_channels = self.out_channels; + upblock_config.temb_channels = self.temb_channels; + let upblock = upblock_config.init(device); + + let attentions = (0..self.upblock.n_layers) + .map(|_| { + SpatialTransformerConfig::new( + self.out_channels, + self.attn_num_head_channels, + self.out_channels / self.attn_num_head_channels, + ) + .with_depth(1) + .with_d_context(Some(self.cross_attention_dim)) + .with_n_groups(self.upblock.resnet_groups) + .with_sliced_attn_size(self.sliced_attention_size) + .with_use_linear_projection(self.use_linear_projection) + .init(device) + }) + .collect(); + + CrossAttnUpBlock2D { + upblock, + attentions, + } + } +} + +impl CrossAttnUpBlock2D { + pub fn forward( + &self, + xs: Tensor, + res_xs: &[Tensor], + temb: Option>, + upsample_size: Option<(usize, usize)>, + encoder_hidden_states: Option>, + ) -> Tensor { + let mut xs = xs; + for (index, resnet) in self.upblock.resnets.iter().enumerate() { + xs = Tensor::cat( + vec![xs.clone(), res_xs[res_xs.len() - index - 1].clone()], + 1, + ); + xs = resnet.forward(xs, temb.clone()); + xs = self.attentions[index].forward(xs, encoder_hidden_states.clone()); + } + + match &self.upblock.upsampler { + Some(upsampler) => upsampler.forward(xs, upsample_size), + None => xs, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn::tensor::{Data, Distribution, Shape}; + + #[test] + fn test_downsample_2d_no_conv() { + let device = Default::default(); + let tensor: Tensor = Tensor::from_data( + Data::from([ + [ + [[0.0351, 0.4179], [0.0137, 0.6947]], + [[0.9526, 0.5386], [0.2856, 0.1839]], + [[0.3215, 0.4595], [0.6777, 0.3946]], + [[0.5221, 0.4230], [0.2774, 0.1069]], + ], + [ + [[0.8941, 0.8696], [0.5735, 0.8750]], + [[0.6718, 0.4144], [0.1038, 0.2629]], + [[0.7467, 0.9415], [0.5005, 0.6309]], + [[0.6534, 0.2019], [0.3670, 0.8074]], + ], + ]), + &device, + ); + + let downsample_2d = Downsample2DConfig::new(4, false, 4, 0).init(&device); + let output = downsample_2d.forward(tensor); + + output.into_data().assert_approx_eq( + &Data::from([ + [[[0.2904]], [[0.4902]], [[0.4633]], [[0.3323]]], + [[[0.8031]], [[0.3632]], [[0.7049]], [[0.5074]]], + ]), + 3, + ); + } + + #[test] + fn test_pad_tensor_0() { + let device = Default::default(); + let tensor: Tensor = Tensor::from_data( + Data::from([ + [ + [[0.8600, 0.9473], [0.2543, 0.6181]], + [[0.3889, 0.7722], [0.6736, 0.0454]], + [[0.2809, 0.4672], [0.1632, 0.3959]], + [[0.5317, 0.0831], [0.8353, 0.3654]], + ], + [ + [[0.6106, 0.4130], [0.7932, 0.8800]], + [[0.8750, 0.1991], [0.7018, 0.7865]], + [[0.7470, 0.2071], [0.2699, 0.4425]], + [[0.7763, 0.0227], [0.6210, 0.0730]], + ], + ]), + &device, + ); + + let output = Downsample2D::pad_tensor(tensor, 0); + + output.into_data().assert_approx_eq( + &Data::from([ + [ + [ + [0.8600, 0.9473, 0.0000], + [0.2543, 0.6181, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + [ + [0.3889, 0.7722, 0.0000], + [0.6736, 0.0454, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + [ + [0.2809, 0.4672, 0.0000], + [0.1632, 0.3959, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + [ + [0.5317, 0.0831, 0.0000], + [0.8353, 0.3654, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + ], + [ + [ + [0.6106, 0.4130, 0.0000], + [0.7932, 0.8800, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + [ + [0.8750, 0.1991, 0.0000], + [0.7018, 0.7865, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + [ + [0.7470, 0.2071, 0.0000], + [0.2699, 0.4425, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + [ + [0.7763, 0.0227, 0.0000], + [0.6210, 0.0730, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + ], + ]), + 3, + ); + } + + #[test] + fn test_down_encoder_block2d() { + TestBackend::seed(0); + + let device = Default::default(); + let block = DownEncoderBlock2DConfig::new(32, 32).init::(&device); + + let tensor: Tensor = + Tensor::random([4, 32, 32, 32], Distribution::Default, &device); + let output = block.forward(tensor.clone()); + + assert_eq!(output.shape(), Shape::new([4, 32, 16, 16])); + } + + #[test] + fn test_up_decoder_block2d() { + TestBackend::seed(0); + + let device = Default::default(); + let block = UpDecoderBlock2DConfig::new(32, 32).init::(&device); + + let tensor: Tensor = + Tensor::random([4, 32, 32, 32], Distribution::Default, &device); + let output = block.forward(tensor.clone()); + + assert_eq!(output.shape(), Shape::new([4, 32, 64, 64])); + } +} diff --git a/src/transformers/clip.rs b/src/transformers/clip.rs index 4798d98..ca1ea0c 100644 --- a/src/transformers/clip.rs +++ b/src/transformers/clip.rs @@ -124,7 +124,8 @@ impl ClipConfig { nn::EmbeddingConfig::new(self.vocab_size, self.embed_dim).init(device); let position_embedding = nn::EmbeddingConfig::new(self.max_position_embeddings, self.embed_dim).init(device); - let position_ids = Tensor::arange(0..self.max_position_embeddings, device).unsqueeze(); + let position_ids = + Tensor::arange(0..self.max_position_embeddings as i64, device).unsqueeze(); ClipTextEmbeddings { token_embedding,