From 17591f1d7d40966fb5af366417bf14f5137a37ec Mon Sep 17 00:00:00 2001 From: ryankopf Date: Thu, 16 Nov 2023 16:36:15 -0600 Subject: [PATCH] Initial work on improving the book. --- candle-book/src/inference/hub.md | 27 +++++++++++++++++++-------- candle-book/src/inference/mod.rs | 25 +++++++++++++++++++++++++ candle-book/src/lib.rs | 18 +----------------- 3 files changed, 45 insertions(+), 25 deletions(-) create mode 100644 candle-book/src/inference/mod.rs diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md index e8d8b267db..1980102562 100644 --- a/candle-book/src/inference/hub.md +++ b/candle-book/src/inference/hub.md @@ -10,17 +10,23 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas ```rust -# extern crate candle_core; # extern crate hf_hub; +# extern crate candle_core; use hf_hub::api::sync::Api; use candle_core::Device; -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); +fn main() { + let api = Api::new().unwrap(); + let repo = api.model("bert-base-uncased".to_string()); -let weights = repo.get("model.safetensors").unwrap(); + let weights = repo.get("model.safetensors").unwrap(); -let weights = candle_core::safetensors::load(weights, &Device::Cpu); + let weights = candle_core::safetensors::load(weights, &Device::Cpu); + + for (name, tensor) in weights.iter() { + println!("Weight: {}, Dimensions: {:?}", name, tensor); + } +} ``` We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file. @@ -38,8 +44,13 @@ cargo add hf-hub --features tokio ```rust,ignore # This is tested directly in examples crate because it needs external dependencies unfortunately: -# See [this](https://github.com/rust-lang/mdBook/issues/706) -{{#include ../lib.rs:book_hub_1}} +# See [this mdBook issue from 2019](https://github.com/rust-lang/mdBook/issues/706) +{{#include ./mod.rs:book_hub_1_top}} + +#[tokio::main] +async fn main() { + {{#include ./mod.rs:book_hub_1}} +} ``` @@ -81,7 +92,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2 and will definitely be slower on network mounted disk, because it will issue more read calls. ```rust,ignore -{{#include ../lib.rs:book_hub_2}} +{{#include ./mod.rs:book_hub_2}} ``` **Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety). diff --git a/candle-book/src/inference/mod.rs b/candle-book/src/inference/mod.rs new file mode 100644 index 0000000000..3ae22ecc90 --- /dev/null +++ b/candle-book/src/inference/mod.rs @@ -0,0 +1,25 @@ +#[cfg(test)] +mod tests { + + #[rustfmt::skip] + #[tokio::test] + async fn book_hub_1() { +// ANCHOR: book_hub_1_top +use candle::Device; +use hf_hub::api::tokio::Api; +// ANCHOR_END: book_hub_1_top +// ANCHOR: book_hub_1 + let api = Api::new().unwrap(); + let repo = api.model("bert-base-uncased".to_string()); + + let weights_filename = repo.get("model.safetensors").await.unwrap(); + + let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap(); + for (name, tensor) in weights.iter() { + println!("Weight: {}, Dimensions: {:?}", name, tensor); + } +// ANCHOR_END: book_hub_1 + assert_eq!(weights.len(), 206); + } + +} \ No newline at end of file diff --git a/candle-book/src/lib.rs b/candle-book/src/lib.rs index a1ec1e94b6..487e095460 100644 --- a/candle-book/src/lib.rs +++ b/candle-book/src/lib.rs @@ -1,5 +1,6 @@ #[cfg(test)] pub mod simplified; +pub mod inference; #[cfg(test)] mod tests { @@ -8,23 +9,6 @@ mod tests { use parquet::file::reader::SerializedFileReader; // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 - #[rustfmt::skip] - #[tokio::test] - async fn book_hub_1() { -// ANCHOR: book_hub_1 -use candle::Device; -use hf_hub::api::tokio::Api; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); - -let weights_filename = repo.get("model.safetensors").await.unwrap(); - -let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap(); -// ANCHOR_END: book_hub_1 - assert_eq!(weights.len(), 206); - } - #[rustfmt::skip] #[test] fn book_hub_2() {