From bfa37894f160d97f5e0c2911ab04888097b11d24 Mon Sep 17 00:00:00 2001 From: amelie Date: Mon, 23 Dec 2024 10:42:57 +0100 Subject: [PATCH] init commit: add position id in meshgrid --- .../src/models/pixtral/vision_model.rs | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/models/pixtral/vision_model.rs b/candle-transformers/src/models/pixtral/vision_model.rs index 20d8f08231..65719c281b 100644 --- a/candle-transformers/src/models/pixtral/vision_model.rs +++ b/candle-transformers/src/models/pixtral/vision_model.rs @@ -1,4 +1,4 @@ -use candle::{DType, Module, Result, Tensor, D}; +use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder}; fn default_act() -> candle_nn::Activation { @@ -286,6 +286,7 @@ pub struct Model { ln_pre: RmsNorm, transformer: Transformer, patch_positional_embedding: RotaryEmbedding, + max_image_width: u32, } impl Model { @@ -305,18 +306,38 @@ impl Model { let transformer = Transformer::new(cfg, vb.pp("transformer"))?; let patch_positional_embedding = RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?; + let max_image_width = (cfg.image_size / cfg.patch_size) as u32; Ok(Self { patch_conv, ln_pre, transformer, patch_positional_embedding, + max_image_width, }) } + + pub fn position_ids_in_meshgrid( + &self, + num_patches_h: usize, + num_patches_w: usize, + device: &Device, + ) -> Result { + let idx = Tensor::arange(0, num_patches_h as u32, device)?; + let idy = Tensor::arange(0, num_patches_w as u32, device)?; + let mesh = Tensor::meshgrid(&[idx, idy], false)?; + let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?; + Ok(ids.flatten_all()?) + } } impl Module for Model { fn forward(&self, xs: &Tensor) -> Result { let patch_embeds = xs.apply(&self.patch_conv)?; + let susampled_positions = self.position_ids_in_meshgrid( + patch_embeds.dim(2)?, + patch_embeds.dim(3)?, + patch_embeds.device(), + )?; let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?; self.transformer .forward(&patch_embeds, &self.patch_positional_embedding, None)