Skip to content

Commit

Permalink
Fix: fmt + clippy + stub.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Dec 28, 2023
1 parent a825eb6 commit db3660b
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 42 deletions.
2 changes: 1 addition & 1 deletion candle-core/examples/tensor-tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ fn run_ls(
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, &device)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
Expand Down
11 changes: 8 additions & 3 deletions candle-core/src/quantized/gguf_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::{Result, Device};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;

Expand Down Expand Up @@ -72,7 +72,12 @@ impl TensorInfo {
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec(), device)
super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
}
}

Expand Down Expand Up @@ -461,7 +466,7 @@ impl Content {
&self,
reader: &mut R,
name: &str,
device: &Device
device: &Device,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
Expand Down
14 changes: 4 additions & 10 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ pub mod avx;
pub mod ggml_file;
pub mod gguf_file;
pub mod k_quants;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
#[cfg(feature = "metal")]
pub mod metal;
pub mod utils;

pub use k_quants::GgmlType;
Expand Down Expand Up @@ -174,16 +174,10 @@ fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
}

impl QTensor {
pub fn new<S: Into<Shape>>(
data: Box<dyn QuantizedType>,
shape: S,
) -> Result<Self> {
pub fn new<S: Into<Shape>>(data: Box<dyn QuantizedType>, shape: S) -> Result<Self> {
let shape = shape.into();
check_shape(&shape, data.block_size())?;
Ok(Self {
data,
shape,
})
Ok(Self { data, shape })
}

pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
Expand Down
4 changes: 2 additions & 2 deletions candle-examples/examples/blip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {

let config = blip::Config::image_captioning_large();

let device = candle_examples::device(args.cpu)?;
let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");

let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model))
} else {
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");

Expand Down
8 changes: 4 additions & 4 deletions candle-examples/examples/llama2-c/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
Expand All @@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_real",
)?
.dequantize(&candle::Device::Cpu)?;
.dequantize(&device)?;
let freq_cis_imag = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag",
)?
.dequantize(&candle::Device::Cpu)?;
.dequantize(&device)?;

let fake_vb = candle_nn::VarBuilder::from_tensors(
[
Expand All @@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.into_iter()
.collect(),
candle::DType::F32,
&candle::Device::Cpu,
&device,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
Expand Down
5 changes: 3 additions & 2 deletions candle-examples/examples/mistral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,14 @@ fn main() -> Result<()> {

let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
Expand Down
5 changes: 3 additions & 2 deletions candle-examples/examples/stable-lm/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,14 @@ fn main() -> Result<()> {

let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
Expand Down
6 changes: 4 additions & 2 deletions candle-examples/examples/whisper/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,10 @@ fn main() -> Result<()> {
println!("loaded mel: {:?}", mel.dims());

let mut model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&weights_filename,
&device,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb =
Expand Down
8 changes: 6 additions & 2 deletions candle-pyo3/py_src/candle/utils/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ def has_mkl() -> bool:
pass

@staticmethod
def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
def load_ggml(
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
"""
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
"""
pass

@staticmethod
def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
def load_gguf(
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
"""
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values.
Expand Down
26 changes: 13 additions & 13 deletions candle-transformers/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,34 +394,34 @@ impl ModelWeights {
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), &device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), &device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), &device)?;
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), &device)?;
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp_or_moe = if n_expert <= 1 {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), &device)?;
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), &device)?;
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), &device)?;
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
} else {
let feed_forward_gate_inp =
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), &device)?;
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
let mut experts = Vec::with_capacity(n_expert);
for i in 0..n_expert {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), &device)?;
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), &device)?;
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), &device)?;
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
experts.push(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
Expand All @@ -435,8 +435,8 @@ impl ModelWeights {
}
};
let attention_norm =
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), &device)?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), &device)?;
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/quantized_var_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl VarBuilder {
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut cursor, tensor_name, device)?;
let tensor = content.tensor(&mut cursor, tensor_name, device)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
Expand Down

0 comments on commit db3660b

Please sign in to comment.