From af7939d4a7367c14d15c736bfd18fb7f4d839912 Mon Sep 17 00:00:00 2001 From: edgar Date: Fri, 20 Dec 2024 21:32:05 +0100 Subject: [PATCH] make DepthAnythingV2 more reusable --- .../examples/depth_anything_v2/main.rs | 6 +-- .../src/models/depth_anything_v2.rs | 44 +++++++++++-------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/candle-examples/examples/depth_anything_v2/main.rs b/candle-examples/examples/depth_anything_v2/main.rs index ef337ebab4..2608b40d38 100644 --- a/candle-examples/examples/depth_anything_v2/main.rs +++ b/candle-examples/examples/depth_anything_v2/main.rs @@ -6,10 +6,8 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -use std::ffi::OsString; -use std::path::PathBuf; - use clap::Parser; +use std::{ffi::OsString, path::PathBuf, sync::Arc}; use candle::DType::{F32, U8}; use candle::{DType, Device, Module, Result, Tensor}; @@ -82,7 +80,7 @@ pub fn main() -> anyhow::Result<()> { }; let config = DepthAnythingV2Config::vit_small(); - let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?; + let depth_anything = DepthAnythingV2::new(Arc::new(dinov2), config, vb)?; let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?; diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs index 8eddbf2af5..f12047574c 100644 --- a/candle-transformers/src/models/depth_anything_v2.rs +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -4,6 +4,8 @@ //! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything) //! +use std::sync::Arc; + use candle::D::Minus1; use candle::{Module, Result, Tensor}; use candle_nn::ops::Identity; @@ -365,16 +367,18 @@ impl Scratch { const NUM_CHANNELS: usize = 4; -pub struct DPTHead<'a> { - conf: &'a DepthAnythingV2Config, +pub struct DPTHead { projections: Vec, resize_layers: Vec>, readout_projections: Vec, scratch: Scratch, + use_class_token: bool, + input_image_size: usize, + target_patch_size: usize, } -impl<'a> DPTHead<'a> { - pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result { +impl DPTHead { + pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result { let mut projections: Vec = Vec::with_capacity(conf.out_channel_sizes.len()); for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() { projections.push(conv2d( @@ -445,20 +449,22 @@ impl<'a> DPTHead<'a> { let scratch = Scratch::new(conf, vb.pp("scratch"))?; Ok(Self { - conf, projections, resize_layers, readout_projections, scratch, + use_class_token: conf.use_class_token, + input_image_size: conf.input_image_size, + target_patch_size: conf.target_patch_size, }) } } -impl Module for DPTHead<'_> { +impl Module for DPTHead { fn forward(&self, xs: &Tensor) -> Result { let mut out: Vec = Vec::with_capacity(NUM_CHANNELS); for i in 0..NUM_CHANNELS { - let x = if self.conf.use_class_token { + let x = if self.use_class_token { let x = xs.get(i)?.get(0)?; let class_token = xs.get(i)?.get(1)?; let readout = class_token.unsqueeze(1)?.expand(x.shape())?; @@ -473,8 +479,8 @@ impl Module for DPTHead<'_> { let x = x.permute((0, 2, 1))?.reshape(( x_dims[0], x_dims[x_dims.len() - 1], - self.conf.target_patch_size, - self.conf.target_patch_size, + self.target_patch_size, + self.target_patch_size, ))?; let x = self.projections[i].forward(&x)?; @@ -515,25 +521,25 @@ impl Module for DPTHead<'_> { let out = self.scratch.output_conv1.forward(&path1)?; - let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?; + let out = out.interpolate2d(self.input_image_size, self.input_image_size)?; self.scratch.output_conv2.forward(&out) } } -pub struct DepthAnythingV2<'a> { - pretrained: &'a DinoVisionTransformer, - depth_head: DPTHead<'a>, - conf: &'a DepthAnythingV2Config, +pub struct DepthAnythingV2 { + pretrained: Arc, + depth_head: DPTHead, + conf: DepthAnythingV2Config, } -impl<'a> DepthAnythingV2<'a> { +impl<'a> DepthAnythingV2 { pub fn new( - pretrained: &'a DinoVisionTransformer, - conf: &'a DepthAnythingV2Config, + pretrained: Arc, + conf: DepthAnythingV2Config, vb: VarBuilder, ) -> Result { - let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?; + let depth_head = DPTHead::new(&conf, vb.pp("depth_head"))?; Ok(Self { pretrained, @@ -543,7 +549,7 @@ impl<'a> DepthAnythingV2<'a> { } } -impl Module for DepthAnythingV2<'_> { +impl Module for DepthAnythingV2 { fn forward(&self, xs: &Tensor) -> Result { let features = self.pretrained.get_intermediate_layers( xs,