Skip to content

Commit

Permalink
Add initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
JoseCarlosGarcia95 committed Dec 31, 2024
1 parent 4fee75a commit 4c669dc
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 56 deletions.
89 changes: 51 additions & 38 deletions candle-examples/examples/quantized-bitnet/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,35 @@ enum Prompt {

#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "falcon3-1b-1.58")]
Falcon3_1b1_58,
#[value(name = "falcon3-1b-instruct-1.58")]
Falcon3_1bInstruct1_58,
#[value(name = "falcon3-3b-instruct-1.58")]
Falcon3_3bInstruct1_58,
#[value(name = "falcon3-3b-1.58")]
Falcon3_3b1_58,
#[value(name = "falcon3-7b-instruct-1.58")]
Falcon3_7bInstruct1_58,
#[value(name = "falcon3-7b-1.58")]
Falcon3_7b1_58,
#[value(name = "falcon3-10b-instruct-1.58")]
Falcon3_10bInstruct1_58,
#[value(name = "falcon3-10b-1.58")]
Falcon3_10b1_58,
#[value(name = "llama3-8b-1.58")]
Llama3_8b1_58,
}

impl Which {
fn is_falcon(&self) -> bool {
matches!(self, Self::Falcon3_1b1_58 | Self::Falcon3_3b1_58 | Self::Falcon3_7b1_58 | Self::Falcon3_10b1_58)
}

fn is_llama(&self) -> bool {
matches!(self, Self::Llama3_8b1_58)
}

impl Which {
fn tokenizer_repo(&self) -> &'static str {
match self {
Self::Falcon3_1b1_58 => "tiiuae/Falcon3-1B-Instruct-1.58bit",
Self::Falcon3_3b1_58 => "tiiuae/Falcon3-3B-Instruct-1.58bit",
Self::Llama3_8b1_58 => "HF1BitLLM/Llama3-8B-1.58-100B-tokens",
Self::Falcon3_10b1_58 => "tiiuae/Falcon3-10B-Base-1.58bit",
Self::Falcon3_7b1_58 => "tiiuae/Falcon3-7B-Instruct-1.58bit",
Self::Falcon3_1bInstruct1_58 => "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF",
Self::Falcon3_3bInstruct1_58 => "nebuxcloud/Falcon3-3B-Instruct-1.58bit-GGUF",
Self::Falcon3_3b1_58 => "nebuxcloud/Falcon3-3B-Base-1.58bit-GGUF",
Self::Falcon3_7bInstruct1_58 => "nebuxcloud/Falcon3-7B-Instruct-1.58bit-GGUF",
Self::Falcon3_10b1_58 => "nebuxcloud/Falcon3-10B-Base-1.58bit-GGUF",
Self::Falcon3_10bInstruct1_58 => "nebuxcloud/Falcon3-10B-Instruct-1.58bit-GGUF",
Self::Falcon3_7b1_58 => "nebuxcloud/Falcon3-7B-Base-1.58bit-GGUF",
Self::Llama3_8b1_58 => "nebuxcloud/Llama3-8B-1.58-100B-tokens-GGUF",
}
}
}
Expand Down Expand Up @@ -123,7 +124,7 @@ struct Args {
repeat_last_n: usize,

/// The model size to use.
#[arg(long, default_value = "falcon3-1b-1.58")]
#[arg(long, default_value = "falcon3-1b-instruct-1.58")]
which: Which,

/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
Expand Down Expand Up @@ -154,25 +155,37 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename) = match self.which {
Which::Falcon3_1b1_58 => (
"tiiuae/Falcon3-1B-Instruct-1.58bit",
"Falcon3-1B-Instruct-1.58bit.gguf",
Which::Falcon3_1bInstruct1_58 => (
"nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF",
"Falcon3-1B-Instruct-1.58bit-q2b0.gguf",
),
Which::Falcon3_3bInstruct1_58 => (
"nebuxcloud/Falcon3-3B-Instruct-1.58bit-GGUF",
"Falcon3-3B-Instruct-1.58bit-q2b0.gguf",
),
Which::Falcon3_3b1_58 => (
"tiiuae/Falcon3-3B-Instruct-1.58bit",
"Falcon3-3B-Instruct-1.58bit.gguf",
"nebuxcloud/Falcon3-3B-Base-1.58bit-GGUF",
"Falcon3-3B-Base-1.58bit-q2b0.gguf",
),
Which::Falcon3_10b1_58 => (
"tiiuae/Falcon3-10B-Instruct-1.58bit",
"Falcon3-10B-Instruct-1.58bit.gguf",
Which::Falcon3_7bInstruct1_58 => (
"nebuxcloud/Falcon3-7B-Instruct-1.58bit-GGUF",
"Falcon3-7B-Instruct-1.58bit-q2b0.gguf",
),
Which::Falcon3_7b1_58 => (
"tiiuae/Falcon3-7B-Instruct-1.58bit",
"Falcon3-7B-Instruct-1.58bit.gguf",
"nebuxcloud/Falcon3-7B-Base-1.58bit-GGUF",
"Falcon3-7B-Base-1.58bit-q2b0.gguf",
),
Which::Falcon3_10b1_58 => (
"nebuxcloud/Falcon3-10B-Base-1.58bit-GGUF",
"Falcon3-10B-Base-1.58bit-q2b0.gguf",
),
Which::Falcon3_10bInstruct1_58 => (
"nebuxcloud/Falcon3-10B-Instruct-1.58bit-GGUF",
"Falcon3-10B-Instruct-1.58bit-q2b0.gguf",
),
Which::Llama3_8b1_58 => (
"HF1BitLLM/Llama3-8B-1.58-100B-tokens",
"Llama3-8B-1.58bit.gguf",
"nebuxcloud/Llama3-8B-1.58-100B-tokens-GGUF",
"Llama3-8B-1.58-100B-tokens-q2b0.gguf",
),
};
let revision = "main";
Expand Down Expand Up @@ -306,13 +319,7 @@ fn main() -> anyhow::Result<()> {
}
}

if args.which.is_llama() {
format!(
"<|start_header_id|>user<|end_header_id|>{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
)
} else {
prompt
}
prompt
}
};

Expand Down Expand Up @@ -376,11 +383,17 @@ fn main() -> anyhow::Result<()> {
}

let eos_tokens = match args.which {
Which::Falcon3_10b1_58 | Which::Falcon3_7b1_58 | Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => {
Which::Falcon3_10b1_58 |
Which::Falcon3_10bInstruct1_58 |
Which::Falcon3_7bInstruct1_58 |
Which::Falcon3_7b1_58 |
Which::Falcon3_3bInstruct1_58 |
Which::Falcon3_3b1_58 |
Which::Falcon3_1bInstruct1_58 => {
vec!["<|endoftext|>"]
}
Which::Llama3_8b1_58 => {
vec!["<|eot_id|>"]
vec!["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>"]
}
};

Expand Down
7 changes: 6 additions & 1 deletion candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2210,7 +2210,6 @@ pub fn call_quantized_matmul_mv_t(
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2b0
| GgmlDType::Q8_1 => {
let nth0 = 8;
let nth1 = 8;
Expand All @@ -2231,6 +2230,12 @@ pub fn call_quantized_matmul_mv_t(
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q2b0 => {
let nth0 = 8;
let nth1 = 8;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::Q3K | GgmlDType::Q5K => {
let nth0 = 2;
let nth1 = 32;
Expand Down
9 changes: 5 additions & 4 deletions candle-metal-kernels/src/quantized.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3544,10 +3544,11 @@ void kernel_mul_mv_q2b0_f32_impl(
int bit = startBit + iBit;
int bByte = bit >> 3;
int bMask = 1 << (bit & 7);
int isPos = ((bx->qs[bByte] & bMask) != 0) ? 1 : 0;
int isNeg = ((bx->qd[bByte] & bMask) != 0) ? 1 : 0;

sumq += float(isPos - isNeg) * yl[iBit];
if ((bx->qs[bByte] & bMask) != 0) {
sumq += yl[iBit];
} else if ((bx->qd[bByte] & bMask) != 0) {
sumq -= yl[iBit];
}
}

sumf[row] += sumq;
Expand Down
19 changes: 6 additions & 13 deletions candle-transformers/src/models/quantized_llama_bitnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,13 @@ impl BitQMatMul {
Ok(Self { inner, span, weight_scale })
}

pub fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> {
let target_dim = x.rank().saturating_sub(1);

let max_abs = x.abs()?.max_keepdim(target_dim)?;

let scale = (127.0/ &max_abs)?;

let scaled_rounded = x
.broadcast_mul(&scale)?
.round()?
.clamp(-128f32, 127f32)?;

fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> {
let scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(1e-5, f32::INFINITY)?;
let scale = (127.0 / scale)?;

let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?;

Ok((scaled_rounded, scale))
Ok((y, scale))
}

fn forward(&self, x: &Tensor) -> Result<Tensor> {
Expand Down

0 comments on commit 4c669dc

Please sign in to comment.