Skip to content

Commit

Permalink
Add support for TrOCR Model (#1303)
Browse files Browse the repository at this point in the history
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting

* add trocr model

* fix formatting

* commit the actual model lol

* more formatting

* remove tokenizer config
  • Loading branch information
ToluClassics authored Nov 9, 2023
1 parent e669747 commit 6958384
Show file tree
Hide file tree
Showing 7 changed files with 767 additions and 15 deletions.
Binary file added candle-examples/examples/trocr/assets/trocr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
154 changes: 154 additions & 0 deletions candle-examples/examples/trocr/image_processor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
use image::{DynamicImage, ImageBuffer};
use serde::Deserialize;
use std::collections::HashMap;

use candle::{DType, Device, Result, Tensor};

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct ProcessorConfig {
do_resize: bool,
height: u32,
width: u32,
do_rescale: bool,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}

impl Default for ProcessorConfig {
fn default() -> Self {
Self {
do_resize: true,
height: 384,
width: 384,
do_rescale: true,
do_normalize: true,
image_mean: vec![0.5, 0.5, 0.5],
image_std: vec![0.5, 0.5, 0.5],
}
}
}

pub struct ViTImageProcessor {
do_resize: bool,
height: u32,
width: u32,
do_normalize: bool,
image_mean: Vec<f32>,
image_std: Vec<f32>,
}

impl ViTImageProcessor {
pub fn new(config: &ProcessorConfig) -> Self {
Self {
do_resize: config.do_resize,
height: config.height,
width: config.width,
do_normalize: config.do_normalize,
image_mean: config.image_mean.clone(),
image_std: config.image_std.clone(),
}
}

pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> {
let height = self.height as usize;
let width = self.width as usize;
let channels = 3;

let images = self.load_images(images)?;

let resized_images: Vec<DynamicImage> = if self.do_resize {
images
.iter()
.map(|image| self.resize(image.clone(), None).unwrap())
.collect()
} else {
images
};

let normalized_images: Vec<Tensor> = if self.do_normalize {
resized_images
.iter()
.map(|image| self.normalize(image.clone(), None, None).unwrap())
.collect()
} else {
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> =
resized_images.iter().map(|image| image.to_rgb8()).collect();
let data = resized_images
.into_iter()
.map(|image| image.into_raw())
.collect::<Vec<Vec<u8>>>();

data.iter()
.map(|image| {
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu)
.unwrap()
.permute((2, 0, 1))
.unwrap()
})
.collect::<Vec<Tensor>>()
};

Tensor::stack(&normalized_images, 0)
}

fn resize(
&self,
image: image::DynamicImage,
size: Option<HashMap<String, u32>>,
) -> Result<image::DynamicImage> {
let (height, width) = match &size {
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()),
None => (&self.height, &self.width),
};

let resized_image =
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle);

Ok(resized_image)
}

fn normalize(
&self,
image: image::DynamicImage,
mean: Option<Vec<f32>>,
std: Option<Vec<f32>>,
) -> Result<Tensor> {
let mean = match mean {
Some(mean) => mean,
None => self.image_mean.clone(),
};

let std = match std {
Some(std) => std,
None => self.image_std.clone(),
};

let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?;
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?;

let image = image.to_rgb8();
let data = image.into_raw();

let height = self.height as usize;
let width = self.width as usize;
let channels = 3;

let data =
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?;

(data.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}

pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
let mut images: Vec<image::DynamicImage> = Vec::new();
for path in image_path {
let img = image::io::Reader::open(path)?.decode().unwrap();
images.push(img);
}

Ok(images)
}
}
132 changes: 132 additions & 0 deletions candle-examples/examples/trocr/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use anyhow::Error as E;
use clap::{Parser, ValueEnum};

use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::trocr;

use tokenizers::Tokenizer;
mod image_processor;

#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Large,
}

#[derive(Parser, Debug)]
struct Args {
#[arg(long)]
model: Option<String>,

/// Choose the variant of the model to run.
#[arg(long, default_value = "base")]
which: Which,

/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Text to be translated
#[arg(long)]
image: String,
}

pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();

let tokenizer_dec = {
let tokenizer = Api::new()?
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?;

Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};

let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);

let device = candle_examples::device(args.cpu)?;

let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-base-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
))
.get("model.safetensors")?,
Which::Large => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-large-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/6".to_string(),
))
.get("model.safetensors")?,
},
};
println!("model: {:?}", model);
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
};

let encoder_config = match args.which {
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
Which::Large => {
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
}
};

let decoder_config = trocr::TrOCRConfig::default();
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;

let config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&config);

let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;

let encoder_xs = model.encoder().forward(&image)?;

let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);

let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;

let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;

let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);

if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == decoder_config.eos_token_id {
break;
}
}

if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();

Ok(())
}
16 changes: 16 additions & 0 deletions candle-examples/examples/trocr/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# candle-trocr

`TrOCR` is a transformer OCR Model. In this example it is used to
transcribe image text. See the associated [model
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
the model itself.

## Running an example

```bash
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
```

```
<s> industry , Mr. Brown commented icily . " Let us have a</s>
```
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod t5;
pub mod trocr;
pub mod vgg;
pub mod vit;
pub mod whisper;
Expand Down
Loading

0 comments on commit 6958384

Please sign in to comment.