Skip to content

Commit

Permalink
Initial work on improving the book.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryankopf committed Nov 16, 2023
1 parent 9ab3f97 commit 17591f1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 25 deletions.
27 changes: 19 additions & 8 deletions candle-book/src/inference/hub.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas


```rust
# extern crate candle_core;
# extern crate hf_hub;
# extern crate candle_core;
use hf_hub::api::sync::Api;
use candle_core::Device;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
fn main() {
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());

let weights = repo.get("model.safetensors").unwrap();
let weights = repo.get("model.safetensors").unwrap();

let weights = candle_core::safetensors::load(weights, &Device::Cpu);
let weights = candle_core::safetensors::load(weights, &Device::Cpu);

for (name, tensor) in weights.iter() {
println!("Weight: {}, Dimensions: {:?}", name, tensor);
}
}
```

We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
Expand All @@ -38,8 +44,13 @@ cargo add hf-hub --features tokio

```rust,ignore
# This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706)
{{#include ../lib.rs:book_hub_1}}
# See [this mdBook issue from 2019](https://github.com/rust-lang/mdBook/issues/706)
{{#include ./mod.rs:book_hub_1_top}}
#[tokio::main]
async fn main() {
{{#include ./mod.rs:book_hub_1}}
}
```


Expand Down Expand Up @@ -81,7 +92,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
and will definitely be slower on network mounted disk, because it will issue more read calls.

```rust,ignore
{{#include ../lib.rs:book_hub_2}}
{{#include ./mod.rs:book_hub_2}}
```

**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
Expand Down
25 changes: 25 additions & 0 deletions candle-book/src/inference/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#[cfg(test)]
mod tests {

#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1_top
use candle::Device;
use hf_hub::api::tokio::Api;
// ANCHOR_END: book_hub_1_top
// ANCHOR: book_hub_1
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());

let weights_filename = repo.get("model.safetensors").await.unwrap();

let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
for (name, tensor) in weights.iter() {
println!("Weight: {}, Dimensions: {:?}", name, tensor);
}
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}

}
18 changes: 1 addition & 17 deletions candle-book/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#[cfg(test)]
pub mod simplified;
pub mod inference;

#[cfg(test)]
mod tests {
Expand All @@ -8,23 +9,6 @@ mod tests {
use parquet::file::reader::SerializedFileReader;

// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1
use candle::Device;
use hf_hub::api::tokio::Api;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());

let weights_filename = repo.get("model.safetensors").await.unwrap();

let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}

#[rustfmt::skip]
#[test]
fn book_hub_2() {
Expand Down

0 comments on commit 17591f1

Please sign in to comment.