Skip to content

Commit

Permalink
feature(models): Add unet_2d_blocks (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcvz authored Apr 8, 2024
1 parent 5aaacc7 commit a673888
Show file tree
Hide file tree
Showing 8 changed files with 974 additions and 25 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ torch = ["burn/tch"]
wgpu = ["burn/wgpu"]

[dependencies]
burn = { version = "0.12.1", default-features = false }
burn = { version = "0.13.0", default-features = false }
num-traits = { version = "0.2.17", default-features = false }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] }

[patch.crates-io]
burn = { git = "https://github.com/tracel-ai/burn" }
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub mod pipelines;
pub mod transformers;
pub mod utils;

#[cfg(all(test, not(feature = "wgpu"), not(feature = "torch")))]
#[cfg(all(test, feature = "ndarray"))]
use burn::backend::ndarray;

#[cfg(all(test, feature = "torch"))]
Expand All @@ -20,7 +20,7 @@ use burn::backend::wgpu;

extern crate alloc;

#[cfg(all(test, not(feature = "wgpu"), not(feature = "torch")))]
#[cfg(all(test, feature = "ndarray"))]
pub type TestBackend = ndarray::NdArray<f32>;

#[cfg(all(test, feature = "torch"))]
Expand Down
50 changes: 32 additions & 18 deletions src/models/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,36 +297,50 @@ pub struct SpatialTransformerConfig {
pub n_groups: usize,
pub d_context: Option<usize>,
pub sliced_attn_size: Option<usize>,
// #[config(default = false)]
// pub use_linear_projection: bool,
#[config(default = false)]
pub use_linear_projection: bool,
pub in_channels: usize,
pub n_heads: usize,
pub d_head: usize,
}

//#[derive(Config, Debug)]
//enum Proj<B: Backend> {
// Conv2d(nn::conv::Conv2d<B>),
// Linear(nn::Linear<B>)
//}
#[derive(Module, Debug)]
enum Proj<B: Backend> {
Conv2d(nn::conv::Conv2d<B>),
Linear(nn::Linear<B>),
}

impl<B: Backend> Proj<B> {
fn forward(&self, xs: Tensor<B, 4>) -> Tensor<B, 4> {
match self {
Proj::Conv2d(conv) => conv.forward(xs),
Proj::Linear(linear) => linear.forward(xs),
}
}
}

/// Aka Transformer2DModel
#[derive(Module, Debug)]
pub struct SpatialTransformer<B: Backend> {
norm: GroupNorm<B>,
proj_in: nn::conv::Conv2d<B>,
proj_in: Proj<B>,
transformer_blocks: Vec<BasicTransformerBlock<B>>,
proj_out: nn::conv::Conv2d<B>,
}

impl SpatialTransformerConfig {
fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> SpatialTransformer<B> {
let d_inner = self.n_heads * self.d_head;
let norm = GroupNormConfig::new(self.n_groups, self.in_channels)
.with_epsilon(1e-6)
.init(device);
// let proj_in = if config.use_linear_projection {
let proj_in = nn::conv::Conv2dConfig::new([self.in_channels, d_inner], [1, 1]).init(device);
let proj_in = if self.use_linear_projection {
Proj::Linear(nn::LinearConfig::new(self.in_channels, d_inner).init(device))
} else {
Proj::Conv2d(
nn::conv::Conv2dConfig::new([self.in_channels, d_inner], [1, 1]).init(device),
)
};

let mut transformer_blocks = vec![];
for _index in 0..self.depth {
Expand All @@ -351,7 +365,7 @@ impl SpatialTransformerConfig {
}

impl<B: Backend> SpatialTransformer<B> {
fn forward(&self, xs: Tensor<B, 4>, context: Option<Tensor<B, 3>>) -> Tensor<B, 4> {
pub fn forward(&self, xs: Tensor<B, 4>, context: Option<Tensor<B, 3>>) -> Tensor<B, 4> {
let [n_batch, _n_channel, height, weight] = xs.dims();

let residual = xs.clone();
Expand Down Expand Up @@ -402,7 +416,7 @@ pub struct AttentionBlock<B: Backend> {
}

impl AttentionBlockConfig {
fn init<B: Backend>(&self, device: &B::Device) -> AttentionBlock<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> AttentionBlock<B> {
let n_head_channels = self.n_head_channels.unwrap_or(self.channels);
let n_heads = self.channels / n_head_channels;
let group_norm = GroupNormConfig::new(self.n_groups, self.channels)
Expand Down Expand Up @@ -433,7 +447,7 @@ impl<B: Backend> AttentionBlock<B> {
.swap_dims(1, 2)
}

fn forward(&self, xs: Tensor<B, 4>) -> Tensor<B, 4> {
pub fn forward(&self, xs: Tensor<B, 4>) -> Tensor<B, 4> {
let residual = xs.clone();
let [n_batch, channel, height, width] = xs.dims();
let xs = self
Expand Down Expand Up @@ -506,8 +520,8 @@ mod tests {

let geglu = GeGlu {
proj: nn::Linear {
weight: Param::new(ParamId::new(), weight),
bias: Some(Param::new(ParamId::new(), bias)),
weight: Param::initialized(ParamId::new(), weight),
bias: Some(Param::initialized(ParamId::new(), bias)),
},
};

Expand Down Expand Up @@ -562,8 +576,8 @@ mod tests {

let geglu = GeGlu {
proj: nn::Linear {
weight: Param::new(ParamId::new(), weight),
bias: Some(Param::new(ParamId::new(), bias)),
weight: Param::initialized(ParamId::new(), weight),
bias: Some(Param::initialized(ParamId::new(), bias)),
},
};

Expand Down
3 changes: 2 additions & 1 deletion src/models/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ impl<B: Backend> Timesteps<B> {
Self {
num_channels,
flip_sin_to_cos,

downscale_freq_shift,
_backend: PhantomData,
}
}

pub fn forward<const D1: usize, const D2: usize>(&self, xs: Tensor<B, D1>) -> Tensor<B, D2> {
let half_dim = self.num_channels / 2;
let exponent = Tensor::arange(0..half_dim, &xs.device()).float() * -f64::ln(10000.);
let exponent = Tensor::arange(0..half_dim as i64, &xs.device()).float() * -f64::ln(10000.);
let exponent = exponent / (half_dim as f64 - self.downscale_freq_shift);
let emb = exponent.exp();
// emb = timesteps[:, None].float() * emb[None, :]
Expand Down
1 change: 1 addition & 0 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
pub mod attention;
pub mod embeddings;
pub mod resnet;
pub mod unet_2d_blocks;
6 changes: 4 additions & 2 deletions src/models/resnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl<B: Backend> ResnetBlock2D<B> {

let xs = self.norm1.forward(xs.clone());
let xs = self.conv1.forward(silu(xs));
match (temb, &self.time_emb_proj) {
let xs = match (temb, &self.time_emb_proj) {
(Some(temb), Some(time_emb_proj)) => {
time_emb_proj
.forward(silu(temb))
Expand All @@ -109,7 +109,9 @@ impl<B: Backend> ResnetBlock2D<B> {
}
_ => xs.clone(),
};
let xs = self.conv2.forward(silu(self.norm2.forward(xs)));
let xs = self.norm2.forward(xs);
let xs = silu(xs);
let xs = self.conv2.forward(xs);
(shortcut_xs + xs) / self.output_scale_factor
}
}
Expand Down
Loading

0 comments on commit a673888

Please sign in to comment.