-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting * add trocr model * fix formatting * commit the actual model lol * more formatting * remove tokenizer config
- Loading branch information
1 parent
e669747
commit 6958384
Showing
7 changed files
with
767 additions
and
15 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
use image::{DynamicImage, ImageBuffer}; | ||
use serde::Deserialize; | ||
use std::collections::HashMap; | ||
|
||
use candle::{DType, Device, Result, Tensor}; | ||
|
||
#[derive(Debug, Clone, PartialEq, Deserialize)] | ||
pub struct ProcessorConfig { | ||
do_resize: bool, | ||
height: u32, | ||
width: u32, | ||
do_rescale: bool, | ||
do_normalize: bool, | ||
image_mean: Vec<f32>, | ||
image_std: Vec<f32>, | ||
} | ||
|
||
impl Default for ProcessorConfig { | ||
fn default() -> Self { | ||
Self { | ||
do_resize: true, | ||
height: 384, | ||
width: 384, | ||
do_rescale: true, | ||
do_normalize: true, | ||
image_mean: vec![0.5, 0.5, 0.5], | ||
image_std: vec![0.5, 0.5, 0.5], | ||
} | ||
} | ||
} | ||
|
||
pub struct ViTImageProcessor { | ||
do_resize: bool, | ||
height: u32, | ||
width: u32, | ||
do_normalize: bool, | ||
image_mean: Vec<f32>, | ||
image_std: Vec<f32>, | ||
} | ||
|
||
impl ViTImageProcessor { | ||
pub fn new(config: &ProcessorConfig) -> Self { | ||
Self { | ||
do_resize: config.do_resize, | ||
height: config.height, | ||
width: config.width, | ||
do_normalize: config.do_normalize, | ||
image_mean: config.image_mean.clone(), | ||
image_std: config.image_std.clone(), | ||
} | ||
} | ||
|
||
pub fn preprocess(&self, images: Vec<&str>) -> Result<Tensor> { | ||
let height = self.height as usize; | ||
let width = self.width as usize; | ||
let channels = 3; | ||
|
||
let images = self.load_images(images)?; | ||
|
||
let resized_images: Vec<DynamicImage> = if self.do_resize { | ||
images | ||
.iter() | ||
.map(|image| self.resize(image.clone(), None).unwrap()) | ||
.collect() | ||
} else { | ||
images | ||
}; | ||
|
||
let normalized_images: Vec<Tensor> = if self.do_normalize { | ||
resized_images | ||
.iter() | ||
.map(|image| self.normalize(image.clone(), None, None).unwrap()) | ||
.collect() | ||
} else { | ||
let resized_images: Vec<ImageBuffer<image::Rgb<u8>, Vec<u8>>> = | ||
resized_images.iter().map(|image| image.to_rgb8()).collect(); | ||
let data = resized_images | ||
.into_iter() | ||
.map(|image| image.into_raw()) | ||
.collect::<Vec<Vec<u8>>>(); | ||
|
||
data.iter() | ||
.map(|image| { | ||
Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu) | ||
.unwrap() | ||
.permute((2, 0, 1)) | ||
.unwrap() | ||
}) | ||
.collect::<Vec<Tensor>>() | ||
}; | ||
|
||
Tensor::stack(&normalized_images, 0) | ||
} | ||
|
||
fn resize( | ||
&self, | ||
image: image::DynamicImage, | ||
size: Option<HashMap<String, u32>>, | ||
) -> Result<image::DynamicImage> { | ||
let (height, width) = match &size { | ||
Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()), | ||
None => (&self.height, &self.width), | ||
}; | ||
|
||
let resized_image = | ||
image.resize_exact(*width, *height, image::imageops::FilterType::Triangle); | ||
|
||
Ok(resized_image) | ||
} | ||
|
||
fn normalize( | ||
&self, | ||
image: image::DynamicImage, | ||
mean: Option<Vec<f32>>, | ||
std: Option<Vec<f32>>, | ||
) -> Result<Tensor> { | ||
let mean = match mean { | ||
Some(mean) => mean, | ||
None => self.image_mean.clone(), | ||
}; | ||
|
||
let std = match std { | ||
Some(std) => std, | ||
None => self.image_std.clone(), | ||
}; | ||
|
||
let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?; | ||
let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?; | ||
|
||
let image = image.to_rgb8(); | ||
let data = image.into_raw(); | ||
|
||
let height = self.height as usize; | ||
let width = self.width as usize; | ||
let channels = 3; | ||
|
||
let data = | ||
Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?; | ||
|
||
(data.to_dtype(DType::F32)? / 255.)? | ||
.broadcast_sub(&mean)? | ||
.broadcast_div(&std) | ||
} | ||
|
||
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> { | ||
let mut images: Vec<image::DynamicImage> = Vec::new(); | ||
for path in image_path { | ||
let img = image::io::Reader::open(path)?.decode().unwrap(); | ||
images.push(img); | ||
} | ||
|
||
Ok(images) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
|
||
use anyhow::Error as E; | ||
use clap::{Parser, ValueEnum}; | ||
|
||
use candle::{DType, Tensor}; | ||
use candle_examples::token_output_stream::TokenOutputStream; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::models::trocr; | ||
|
||
use tokenizers::Tokenizer; | ||
mod image_processor; | ||
|
||
#[derive(Clone, Debug, Copy, ValueEnum)] | ||
enum Which { | ||
Base, | ||
Large, | ||
} | ||
|
||
#[derive(Parser, Debug)] | ||
struct Args { | ||
#[arg(long)] | ||
model: Option<String>, | ||
|
||
/// Choose the variant of the model to run. | ||
#[arg(long, default_value = "base")] | ||
which: Which, | ||
|
||
/// Run on CPU rather than on GPU. | ||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// Text to be translated | ||
#[arg(long)] | ||
image: String, | ||
} | ||
|
||
pub fn main() -> anyhow::Result<()> { | ||
use hf_hub::api::sync::Api; | ||
let args = Args::parse(); | ||
|
||
let tokenizer_dec = { | ||
let tokenizer = Api::new()? | ||
.model(String::from("ToluClassics/candle-trocr-tokenizer")) | ||
.get("tokenizer.json")?; | ||
|
||
Tokenizer::from_file(&tokenizer).map_err(E::msg)? | ||
}; | ||
|
||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec); | ||
|
||
let device = candle_examples::device(args.cpu)?; | ||
|
||
let vb = { | ||
let model = match args.model { | ||
Some(model) => std::path::PathBuf::from(model), | ||
None => match args.which { | ||
Which::Base => Api::new()? | ||
.repo(hf_hub::Repo::with_revision( | ||
"microsoft/trocr-base-handwritten".to_string(), | ||
hf_hub::RepoType::Model, | ||
"refs/pr/3".to_string(), | ||
)) | ||
.get("model.safetensors")?, | ||
Which::Large => Api::new()? | ||
.repo(hf_hub::Repo::with_revision( | ||
"microsoft/trocr-large-handwritten".to_string(), | ||
hf_hub::RepoType::Model, | ||
"refs/pr/6".to_string(), | ||
)) | ||
.get("model.safetensors")?, | ||
}, | ||
}; | ||
println!("model: {:?}", model); | ||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? } | ||
}; | ||
|
||
let encoder_config = match args.which { | ||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(), | ||
Which::Large => { | ||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten() | ||
} | ||
}; | ||
|
||
let decoder_config = trocr::TrOCRConfig::default(); | ||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?; | ||
|
||
let config = image_processor::ProcessorConfig::default(); | ||
let processor = image_processor::ViTImageProcessor::new(&config); | ||
|
||
let image = vec![args.image.as_str()]; | ||
let image = processor.preprocess(image)?; | ||
|
||
let encoder_xs = model.encoder().forward(&image)?; | ||
|
||
let mut logits_processor = | ||
candle_transformers::generation::LogitsProcessor::new(1337, None, None); | ||
|
||
let mut token_ids: Vec<u32> = vec![decoder_config.decoder_start_token_id]; | ||
for index in 0..1000 { | ||
let context_size = if index >= 1 { 1 } else { token_ids.len() }; | ||
let start_pos = token_ids.len().saturating_sub(context_size); | ||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; | ||
|
||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?; | ||
|
||
let logits = logits.squeeze(0)?; | ||
let logits = logits.get(logits.dim(0)? - 1)?; | ||
let token = logits_processor.sample(&logits)?; | ||
token_ids.push(token); | ||
|
||
if let Some(t) = tokenizer_dec.next_token(token)? { | ||
use std::io::Write; | ||
print!("{t}"); | ||
std::io::stdout().flush()?; | ||
} | ||
if token == decoder_config.eos_token_id { | ||
break; | ||
} | ||
} | ||
|
||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? { | ||
print!("{rest}"); | ||
} | ||
println!(); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# candle-trocr | ||
|
||
`TrOCR` is a transformer OCR Model. In this example it is used to | ||
transcribe image text. See the associated [model | ||
card](https://huggingface.co/microsoft/trocr-base-printed) for details on | ||
the model itself. | ||
|
||
## Running an example | ||
|
||
```bash | ||
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png | ||
``` | ||
|
||
``` | ||
<s> industry , Mr. Brown commented icily . " Let us have a</s> | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.