Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for TrOCR Model #1303

Merged
merged 12 commits into from
Nov 9, 2023
Merged

Conversation

ToluClassics
Copy link
Contributor

Copy link
Collaborator

@LaurentMazare LaurentMazare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a great addition, thanks! Just one small nitpick.

candle-examples/examples/trocr/readme.md Outdated Show resolved Hide resolved
@LaurentMazare LaurentMazare merged commit 6958384 into huggingface:main Nov 9, 2023
10 of 12 checks passed
@LaurentMazare
Copy link
Collaborator

Merged, thanks a lot! Would be great to advertise this new model a bit if you felt like posting on twitter / reddit / discord so that potential users are aware of its existance.

@katopz
Copy link
Contributor

katopz commented Jan 31, 2024

Thanks for nice addition! Anyway I have some issue for microsoft/trocr-large-printed when I try it with

    let vb = {
        let model = match args.model {
            Some(model) => std::path::PathBuf::from(model),
            None => match args.which {
                // Somehow this one is 404 
                Which::Base => Api::new()?
                    .repo(hf_hub::Repo::with_revision(
                        "microsoft/trocr-base-printed".to_string(),
                        hf_hub::RepoType::Model,
                        "main".to_string(),
                    ))
                    .get("model.safetensors")?,
                Which::Large => Api::new()?
                    .repo(hf_hub::Repo::with_revision(
                        "microsoft/trocr-large-printed".to_string(),
                        hf_hub::RepoType::Model,
                        "main".to_string(),
                    ))
                    .get("model.safetensors")?,
            },
        };
        println!("model: {:?}", model);
        unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
    };

and

cargo run --example trocr --release --  --which large --cpu --image candle-examples/examples/trocr/assets/trocr.png

i got an error

Error: shape mismatch for encoder.embeddings.cls_token, expected: [1, 1, 768], got: [1, 1, 1024]

Not sure what I miss?

@ToluClassics
Copy link
Contributor Author

Hi @katopz , I think we might be missing the encoder configurations for the large model.
Here's where the encoder config is for base

pub fn microsoft_trocr_base_handwritten() -> Self {

@katopz
Copy link
Contributor

katopz commented Jan 31, 2024

Hi @katopz , I think we might be missing the encoder configurations for the large model. Here's where the encoder config is for base

pub fn microsoft_trocr_base_handwritten() -> Self {

Oh, right! Anyway I hit another error after try to change hiden_size to 1024. 🤔

Error: shape mismatch for encoder.encoder.layer.0.attention.attention.query.weight, expected: [1020, 1024], got: [1024, 1024]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants