Skip to content

Commit

Permalink
init commit: add position id in meshgrid
Browse files Browse the repository at this point in the history
  • Loading branch information
ameroyer committed Dec 23, 2024
1 parent 62ced44 commit bfa3789
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion candle-transformers/src/models/pixtral/vision_model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle::{DType, Module, Result, Tensor, D};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};

fn default_act() -> candle_nn::Activation {
Expand Down Expand Up @@ -286,6 +286,7 @@ pub struct Model {
ln_pre: RmsNorm,
transformer: Transformer,
patch_positional_embedding: RotaryEmbedding,
max_image_width: u32,
}

impl Model {
Expand All @@ -305,18 +306,38 @@ impl Model {
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
let patch_positional_embedding =
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
Ok(Self {
patch_conv,
ln_pre,
transformer,
patch_positional_embedding,
max_image_width,
})
}

pub fn position_ids_in_meshgrid(
&self,
num_patches_h: usize,
num_patches_w: usize,
device: &Device,
) -> Result<Tensor> {
let idx = Tensor::arange(0, num_patches_h as u32, device)?;
let idy = Tensor::arange(0, num_patches_w as u32, device)?;
let mesh = Tensor::meshgrid(&[idx, idy], false)?;
let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?;
Ok(ids.flatten_all()?)
}
}

impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let patch_embeds = xs.apply(&self.patch_conv)?;
let susampled_positions = self.position_ids_in_meshgrid(
patch_embeds.dim(2)?,
patch_embeds.dim(3)?,
patch_embeds.device(),
)?;
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
self.transformer
.forward(&patch_embeds, &self.patch_positional_embedding, None)
Expand Down

0 comments on commit bfa3789

Please sign in to comment.