diff --git a/qwen.cpp b/qwen.cpp index ad7a97c..7e1756a 100644 --- a/qwen.cpp +++ b/qwen.cpp @@ -357,12 +357,64 @@ QwenTokenizer::QwenTokenizer(const std::string & tiktoken_path, const QwenConfig } } + // qwen + // std::cout<< "init qwen tokenizer" << std::endl; std::vector special_tokens_s{"<|endoftext|>", "<|im_start|>", "<|im_end|>"}; char buffer[14]; for (size_t i = 0; i < 205; i++) { // 205 for extra control token snprintf(buffer, 14, "<|extra_%zu|>", i); special_tokens_s.push_back(buffer); } + + size_t encoder_size = encoder.size(); + std::unordered_map special_tokens; + special_tokens.reserve(special_tokens_s.size()); + for (size_t i = 0; i < special_tokens_s.size(); i++) { + special_tokens[special_tokens_s[i]] = encoder_size + i; + } + + tokenizer = tiktoken::tiktoken(std::move(encoder), special_tokens, PAT_STR); + eos_token_id = config.eos_token_id; + im_start_id = config.im_start_id; + im_end_id = config.im_end_id; +} + +LlamaTokenizer::LlamaTokenizer(const std::string & tiktoken_path, const QwenConfig &config) : QwenTokenizer(tiktoken_path, config) { + std::ifstream file(tiktoken_path); + if (!file) { + throw std::runtime_error("failed to open encoder file: " + tiktoken_path); + } + + std::unordered_map encoder; + std::string line; + while (std::getline(file, line)) { + auto [token, rank] = _parse(line); + + if (!encoder.emplace(std::move(token), rank).second) { + throw std::runtime_error("duplicate item: " + line); + } + } + + //llama3 + // std::cout<< "init llama3 tokenizer" << std::endl; + std::vector special_tokens_s{ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", // end of turn + }; + char buffer[14]; + for (size_t i = 5; i < 250; i++) { + snprintf(buffer, 14, "<|reserved_special_token_%zu|>", i); + special_tokens_s.push_back(buffer); + } + size_t encoder_size = encoder.size(); std::unordered_map special_tokens; special_tokens.reserve(special_tokens_s.size()); @@ -408,6 +460,15 @@ std::string QwenTokenizer::build_prompt(const std::vector &messages // check_chat_messages(messages); std::ostringstream oss_prompt; + + // chatml: + // <|im_start|>system + // You are a helpful assistant.<|im_end|> + // <|im_start|>user + // 帮我介绍一下上海迪士尼的游玩攻略<|im_end|> + // <|im_start|>assistant + + oss_prompt << "<|im_start|>system\n" << messages.front().content << "<|im_end|>"; // apply system for (size_t i = 1; i < messages.size(); i += 1) { // get all history messages @@ -418,8 +479,34 @@ std::string QwenTokenizer::build_prompt(const std::vector &messages } } - // oss_prompt << "\n<|im_start|>user\n" << messages.back().content << "<|im_end|>\n<|im_start|>assistant\n"; // old way: last message + // std::cout << oss_prompt.str()< &messages) { + + // check_chat_messages(messages); + + std::ostringstream oss_prompt; + + // llama3 + // <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + // You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> + + // 帮我介绍一下上海迪士尼的游玩攻略<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + oss_prompt << "<|begin_of_text|><|start_header_id|>system<|end_header_id|>>\n\n" << messages.front().content << "<|eot_id|>"; + + for (size_t i = 1; i < messages.size(); i += 1) { // get all history messages + if(messages[i].role == ChatMessage::ROLE_USER){ // user + oss_prompt << "<|start_header_id|>user<|end_header_id|>\n\n"<< messages[i].content<<"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; + } + else{ // assitant + oss_prompt << messages[i].content << ""; + } + } // std::cout << oss_prompt.str()< QwenTokenizer::encode_messages(const std::vector & return input_ids; } +std::vector LlamaTokenizer::encode_messages(const std::vector &messages, int max_length) const { + std::string prompt = build_prompt(messages); + std::vector input_ids = encode(prompt, max_length); + return input_ids; +} + auto QwenTokenizer::is_special_id(int id) const -> bool { return id == eos_token_id || id == im_start_id || id == im_end_id; } @@ -438,9 +531,20 @@ auto QwenTokenizer::is_special_id(int id) const -> bool { QwenAttention::QwenAttention(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length) : num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), - q_proj(ctx, hidden_size, hidden_size), // head_dim * num_attention_heads - k_proj(ctx, hidden_size, hidden_size/(num_attention_heads/num_kv_heads)), // head_dim * num_kv_heads. Qwen1.5 32B is GQA model - v_proj(ctx, hidden_size, hidden_size/(num_attention_heads/num_kv_heads)), // head_dim * num_kv_heads + q_proj(ctx, hidden_size, hidden_size, true), // head_dim * num_attention_heads + k_proj(ctx, hidden_size, hidden_size/(num_attention_heads/num_kv_heads),true), // head_dim * num_kv_heads. Qwen1.5 32B is GQA model + v_proj(ctx, hidden_size, hidden_size/(num_attention_heads/num_kv_heads),true), // head_dim * num_kv_heads + o_proj(ctx, hidden_size, hidden_size, false), + k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads, max_length, + num_kv_heads)), + v_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, max_length, hidden_size / num_attention_heads, + num_kv_heads)) {} + +LlamaAttention::LlamaAttention(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length) + : num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), + q_proj(ctx, hidden_size, hidden_size, false), // head_dim * num_attention_heads + k_proj(ctx, hidden_size, hidden_size/(num_attention_heads/num_kv_heads),false), // head_dim * num_kv_heads. Qwen1.5 32B is GQA model + v_proj(ctx, hidden_size, hidden_size/(num_attention_heads/num_kv_heads), false), // head_dim * num_kv_heads o_proj(ctx, hidden_size, hidden_size, false), k_cache(ggml_new_tensor_3d(ctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads, max_length, num_kv_heads)), @@ -467,6 +571,7 @@ auto QwenAttention::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_ ggml_tensor *v = v_proj.forward(ctx, hidden_states); // [qlen, kv_hidden] ggml_tensor *value_layer = ggml_view_3d(gctx, v, head_size, num_kv_heads, qlen, head_size * ggml_element_size(v), v->nb[1], 0); + // std::cout << "attn debug 1: qkv forward pass" << std::endl; #ifdef GGML_USE_CUBLAS if (!ggml_is_contiguous(query_layer)) { query_layer = tensor_assign_buffers(ggml_cont(gctx, query_layer)); @@ -525,6 +630,87 @@ auto QwenAttention::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_ return attn_output; } +auto LlamaAttention::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_past, int n_ctx) const -> ggml_tensor * { + // std::cout << "debug llamaattention" << std::endl; + ggml_context *gctx = ctx->ctx_b.get(); + + const int hidden_size = hidden_states->ne[0]; + const int qlen = hidden_states->ne[1]; + const int head_size = hidden_size / num_attention_heads; + const int rope_dim = head_size; + // const int n_past = static_cast(KQ_pos->data)[0]; + + ggml_tensor *q = q_proj.forward(ctx, hidden_states); // [qlen, hidden] + // [qlen, heads, head_size] + ggml_tensor *query_layer = ggml_view_3d(gctx, q, head_size, num_attention_heads, qlen, head_size * ggml_element_size(q), q->nb[1], 0); + + ggml_tensor *k = k_proj.forward(ctx, hidden_states); // [qlen, kv_hidden] + // [qlen, kv_heads, head_size] + ggml_tensor *key_layer = ggml_view_3d(gctx, k, head_size, num_kv_heads, qlen, head_size * ggml_element_size(k), k->nb[1], 0); + + ggml_tensor *v = v_proj.forward(ctx, hidden_states); // [qlen, kv_hidden] + ggml_tensor *value_layer = ggml_view_3d(gctx, v, head_size, num_kv_heads, qlen, head_size * ggml_element_size(v), v->nb[1], 0); + + // std::cout << "attn debug 1: qkv forward pass" << std::endl; +#ifdef GGML_USE_CUBLAS + if (!ggml_is_contiguous(query_layer)) { + query_layer = tensor_assign_buffers(ggml_cont(gctx, query_layer)); + } +#endif + query_layer = tensor_assign_buffers(ggml_rope_inplace(gctx, query_layer, KQ_pos, rope_dim, 2, n_ctx)); + query_layer = tensor_assign_buffers(ggml_permute(gctx, query_layer, 0, 2, 1, 3)); // [heads, qlen, head_size] + +#ifdef GGML_USE_CUBLAS + if (!ggml_is_contiguous(key_layer)) { + key_layer = tensor_assign_buffers(ggml_cont(gctx, key_layer)); + } +#endif + key_layer = tensor_assign_buffers(ggml_rope_inplace(gctx, key_layer, KQ_pos, rope_dim, 2, n_ctx)); + key_layer = tensor_assign_buffers(ggml_permute(gctx, key_layer, 0, 2, 1, 3)); // [kv_heads, qlen, head_size] + value_layer = tensor_assign_buffers(ggml_permute(gctx, value_layer, 1, 2, 0, 3)); // [kv_heads, head_size, qlen] + + // store key & value to cache + ggml_tensor *k_cache_view = tensor_assign_buffers( + ggml_view_3d(gctx, k_cache, head_size, qlen, num_kv_heads, k_cache->nb[1], k_cache->nb[2], + n_past * head_size * ggml_element_size(k_cache))); // [kv_heads, qlen, head_size] + ggml_build_forward_expand(ctx->gf, ggml_cpy(gctx, key_layer, k_cache_view)); + ggml_tensor *v_cache_view = tensor_assign_buffers( + ggml_view_3d(gctx, v_cache, qlen, head_size, num_kv_heads, v_cache->nb[1], v_cache->nb[2], + n_past * ggml_element_size(v_cache))); // [kv_heads, head_size, qlen] + ggml_build_forward_expand(ctx->gf, ggml_cpy(gctx, value_layer, v_cache_view)); + + // concat key & value with past kv + key_layer = tensor_assign_buffers( + ggml_view_3d(gctx, k_cache, head_size, n_past + qlen, num_kv_heads, + k_cache->nb[1], k_cache->nb[2], 0)); // [kv_heads, klen, head_size] + value_layer = tensor_assign_buffers( + ggml_view_3d(gctx, v_cache, n_past + qlen, head_size, num_kv_heads, + v_cache->nb[1], v_cache->nb[2], 0)); // [kv_heads, head_size, klen] + + // attention + ggml_tensor *attn_scores = + tensor_assign_buffers(ggml_mul_mat(gctx, key_layer, query_layer)); // [kv_heads, mqa_scale * qlen, klen] + attn_scores = tensor_assign_buffers( + ggml_scale_inplace(gctx, attn_scores, 1.f / std::sqrt(head_size))); + if (n_past == 0) { + // build attention mask for context input + attn_scores = tensor_assign_buffers(ggml_diag_mask_inf_inplace(gctx, attn_scores, n_past)); + } + ggml_tensor *attn_probs = + tensor_assign_buffers(ggml_soft_max_inplace(gctx, attn_scores)); // [kv_heads, mqa_scale * qlen, klen] + + ggml_tensor *context_layer = tensor_assign_buffers( + ggml_mul_mat(gctx, value_layer, attn_probs)); // [kv_heads, mqa_scale * qlen, head_size] + context_layer = tensor_assign_buffers( + ggml_cont(gctx, ggml_permute(gctx, context_layer, 0, 2, 1, 3))); // [qlen, heads, head_size] + context_layer = tensor_assign_buffers( + ggml_reshape_2d(gctx, context_layer, hidden_size, qlen)); // [qlen, hidden] + + ggml_tensor *attn_output = o_proj.forward(ctx, context_layer); + return attn_output; +} + + auto QwenMLP::forward(ModelContext *ctx, ggml_tensor *hidden_states) const -> ggml_tensor * { ggml_context *gctx = ctx->ctx_b.get(); @@ -566,6 +752,9 @@ auto Qwen2MoeSparseMoeBlock::forward(ModelContext *ctx, ggml_tensor *hidden_stat for(int i = 0; i < num_experts_per_tok; i++) { ggml_tensor * cur_expert; + // 最新版的ggml为了适配qwenmoe这种模型,修改了这个ggml_mul_mat_id, 这部分应该要做多的修改才能适配 + // https://github.com/ggerganov/llama.cpp/pull/6074/commits/56a38c46f5343596f69eed9803f3ec06010aecea#diff-150dc86746a90bad4fc2c3334aeb9b5887b3adad3cc1459446717638605348ef + // https://github.com/ggerganov/llama.cpp/commit/f4dea7da1841a92d2788b0535063abf2f0e28461#diff-150dc86746a90bad4fc2c3334aeb9b5887b3adad3cc1459446717638605348ef ggml_tensor * cur_up = ggml_mul_mat_id(gctx, expert_ups.data(), n_expert, selected_experts, i, hidden_states); ggml_tensor * cur_gate = ggml_mul_mat_id(gctx, expert_gates.data(), n_expert, selected_experts, i, hidden_states); @@ -627,6 +816,22 @@ auto QwenMoeBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_t return hidden_states; } +auto LlamaBlock::forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_past, int n_ctx) const -> ggml_tensor * { + // std::cout << "llamablock forward" << std::endl; + ggml_context *gctx = ctx->ctx_b.get(); + + ggml_tensor *residual = hidden_states; + hidden_states = input_layernorm.forward(ctx, hidden_states, 1e-6f); + hidden_states = attn.forward(ctx, hidden_states, KQ_pos, n_past, n_ctx); // FAILE HERE a->ne[2] == b->ne[0] + hidden_states = tensor_assign_buffers(ggml_add_inplace(gctx, hidden_states, residual)); + + residual = hidden_states; + hidden_states = post_attention_layernorm.forward(ctx, hidden_states, 1e-6f); + hidden_states = mlp.forward(ctx, hidden_states); + hidden_states = tensor_assign_buffers(ggml_add_inplace(gctx, hidden_states, residual)); + return hidden_states; +} + QwenModel::QwenModel(ModelContext *ctx, const QwenConfig &config) : embed_tokens(ctx, config.vocab_size, config.hidden_size), norm(ctx, config.hidden_size) { layers.reserve(config.num_hidden_layers); @@ -636,6 +841,8 @@ QwenModel::QwenModel(ModelContext *ctx, const QwenConfig &config) } } + + auto QwenModel::forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx) const -> ggml_tensor * { ggml_context *gctx = ctx->ctx_b.get(); ggml_tensor *KQ_pos = pos_ids_gen_(gctx, input_ids->ne[0], n_past, n_ctx); @@ -643,16 +850,19 @@ auto QwenModel::forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, i tensor_to_device(KQ_pos); } ggml_tensor *hidden_states = embed_tokens.forward(ctx, input_ids); + // std::cout<< "embed_token forward"<scratch); hidden_states = layer.forward(ctx, hidden_states, KQ_pos,n_past, n_ctx); } + // std::cout << "layers forward"< ggml_tensor * { ggml_context *gctx = ctx->ctx_b.get(); ggml_tensor *KQ_pos = pos_ids_gen_(gctx, input_ids->ne[0], n_past, n_ctx); @@ -686,6 +907,29 @@ auto QwenMoeModel::forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past return hidden_states; } +auto LlamaModel::forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx) const -> ggml_tensor * { + ggml_context *gctx = ctx->ctx_b.get(); + ggml_tensor *KQ_pos = pos_ids_gen_(gctx, input_ids->ne[0], n_past, n_ctx); + if(KQ_pos){ + tensor_to_device(KQ_pos); + } + ggml_tensor *hidden_states = embed_tokens.forward(ctx, input_ids); + for (const auto &layer : layers) { + ggml_set_scratch(gctx, ctx->scratch); + hidden_states = layer.forward(ctx, hidden_states, KQ_pos, n_past, n_ctx); + } + if(KQ_pos){ + tensor_to_cpu(KQ_pos); + } + + ggml_scratch empty_scratch = {0, 0, nullptr}; + ggml_set_scratch(gctx, empty_scratch); + hidden_states = norm.forward(ctx, hidden_states, 1e-6f); + return hidden_states; +} + + + QwenForCausalLM::QwenForCausalLM(const QwenConfig &config) : config(config) { ctx_.compute_buffer.resize(MEM_SIZE); @@ -814,6 +1058,69 @@ QwenMoeForCausalLM::QwenMoeForCausalLM(const QwenMoeConfig &config) state_dict_.emplace_back("lm_head.weight", lm_head.weight); } +Llama3ForCausalLM::Llama3ForCausalLM(const Llama3Config &config) + : QwenForCausalLM(config), config(config) { + + // std::cout << "hello here" << std::endl; + + ctx_.compute_buffer.resize(MEM_SIZE); + ctx_.scratch_buffer.resize(SCRATCH_SIZE); + ctx_.scratch = {0, ctx_.scratch_buffer.size(), ctx_.scratch_buffer.data()}; + #ifdef GGML_USE_CUBLAS + ggml_cuda_set_scratch_size(SCRATCH_SIZE); + #endif + constexpr size_t tensor_ovhd = GGML_TENSOR_SIZE + GGML_OBJECT_SIZE; + const size_t ctx_w_size = (3 + config.num_hidden_layers * 9) * tensor_ovhd; + const size_t ctx_kv_size = 2 * config.num_hidden_layers * + (config.max_length * config.hidden_size / config.num_attention_heads * config.num_kv_heads * ggml_type_size(GGML_TYPE_F16) + tensor_ovhd); + ctx_.dtype = config.dtype; + ctx_.ctx_w = make_unique_ggml_context(ctx_w_size, nullptr, true); + ctx_.ctx_kv = make_unique_ggml_context(ctx_kv_size + 1 * MB, nullptr, false); // 1MB extra for MPS + + + transformer = LlamaModel(&ctx_, config); // failed here + // std::cout << "hello here2" << std::endl; + lm_head = Linear(&ctx_, config.hidden_size, config.vocab_size, false); + + + QWEN_CHECK(ggml_used_mem(ctx_.ctx_w.get()) == ggml_get_mem_size(ctx_.ctx_w.get())) << "corrupted model weights"; + QWEN_CHECK(ggml_used_mem(ctx_.ctx_kv.get()) == ctx_kv_size) << "corrupted kv cache"; + + // build state_dict + state_dict_.reserve(3 + config.num_hidden_layers * 9); + state_dict_.emplace_back("model.embed_tokens.weight", transformer.embed_tokens.weight); + for (int i = 0; i < config.num_hidden_layers; i++) { + std::string layer_prefix = "model.layers." + std::to_string(i) + '.'; + state_dict_.emplace_back(layer_prefix + "input_layernorm.weight", transformer.layers[i].input_layernorm.weight); + + state_dict_.emplace_back(layer_prefix + "self_attn.q_proj.weight", + transformer.layers[i].attn.q_proj.weight); + + state_dict_.emplace_back(layer_prefix + "self_attn.k_proj.weight", + transformer.layers[i].attn.k_proj.weight); + + state_dict_.emplace_back(layer_prefix + "self_attn.v_proj.weight", + transformer.layers[i].attn.v_proj.weight); + + state_dict_.emplace_back(layer_prefix + "self_attn.o_proj.weight", + transformer.layers[i].attn.o_proj.weight); + + state_dict_.emplace_back(layer_prefix + "post_attention_layernorm.weight", + transformer.layers[i].post_attention_layernorm.weight); + state_dict_.emplace_back(layer_prefix + "mlp.gate_proj.weight", + transformer.layers[i].mlp.gate_proj.weight); + state_dict_.emplace_back(layer_prefix + "mlp.up_proj.weight", + transformer.layers[i].mlp.up_proj.weight); + state_dict_.emplace_back(layer_prefix + "mlp.down_proj.weight", + transformer.layers[i].mlp.down_proj.weight); + } + state_dict_.emplace_back("model.norm.weight", transformer.norm.weight); + state_dict_.emplace_back("lm_head.weight", lm_head.weight); + +} + + + QwenForCausalLM::~QwenForCausalLM() { for (auto &item : state_dict_) { tensor_to_cpu(item.second); @@ -836,12 +1143,23 @@ QwenMoeForCausalLM::~QwenMoeForCausalLM() { } } +Llama3ForCausalLM::~Llama3ForCausalLM() { + for (auto &item : state_dict_) { + tensor_to_cpu(item.second); + } + + for (auto &layer : transformer.layers) { + tensor_to_cpu(layer.attn.k_cache); + tensor_to_cpu(layer.attn.v_cache); + } +} + auto QwenForCausalLM::forward_graph_compute(const std::vector &input_ids, int n_past, int n_ctx,int n_threads, bool is_decoding) -> ggml_tensor*{ + ctx_.ctx_b = make_unique_ggml_context(ctx_.compute_buffer.size(), ctx_.compute_buffer.data(), false); - // ctx_.gf = ggml_new_graph(ctx_.ctx_b.get()); // dafault graph size didn't fit larger model like 32b, moe + // ctx_.gf = ggml_new_graph(ctx_.ctx_b.get()); // dafault graph size didn't fit larger model size_t GRAPH_SIZE = 4096 * 2; ctx_.gf = ggml_new_graph_custom(ctx_.ctx_b.get(), GRAPH_SIZE, false); - // int n_threads = gen_config.num_threads; // user defined if (n_threads <= 0) { n_threads = get_default_num_threads(); // default thread num @@ -1026,7 +1344,9 @@ auto QwenForCausalLM::generate( streamer->put({next_token_id}); } - if (next_token_id == config.eos_token_id || next_token_id == config.im_start_id || next_token_id == config.im_end_id) { + // if (next_token_id == config.eos_token_id || next_token_id == config.im_start_id || next_token_id == config.im_end_id ) { + // tmp just test llama3 + if (next_token_id == config.eos_token_id || next_token_id == config.im_start_id || next_token_id == config.im_end_id | next_token_id == config.pad_token_id) { break; } } @@ -1077,6 +1397,26 @@ auto QwenMoeForCausalLM::load(ModelLoader &loader) -> void { ctx_.init_device_context(); } +auto Llama3ForCausalLM::load(ModelLoader &loader) -> void { + // std::cout << "Debug: state_dict size: " << state_dict_.size() < ggml_tensor * { + ggml_tensor *transformer_outputs = transformer.forward(ctx, input_ids, n_past, n_ctx); + // NOTE: only compute next_token_logits for the last token + if (is_decoding && input_ids->ne[0] > 1) { + transformer_outputs = tensor_assign_buffers( + ggml_view_1d(ctx->ctx_b.get(), transformer_outputs, config.hidden_size, + (input_ids->ne[0] - 1) * config.hidden_size * ggml_element_size(transformer_outputs))); + } + ggml_tensor *lm_logits = lm_head.forward(ctx, transformer_outputs); + return lm_logits; +} + + + // ===== pipeline ===== Pipeline::Pipeline(const std::string &path, const std::string &tiktoken_path, int max_length) { @@ -1150,13 +1510,24 @@ Pipeline::Pipeline(const std::string &path, const std::string &tiktoken_path, in QwenMoeConfig* moe_config = static_cast(config.get()); model = std::make_unique(*moe_config); } + else if (model_type == ModelType::LLAMA3) + { + config = std::make_unique(loader.read_basic()); + _update_config_max_length(*config, max_length); + Llama3Config* llama3_config = static_cast(config.get()); + model = std::make_unique(*llama3_config); + } // load model model->load(loader); - // std::cout<< "weight load complish"<(tiktoken_path, *config); + if (model_type == ModelType::LLAMA3){ + tokenizer = std::make_unique(tiktoken_path, *config); + }else{ + tokenizer = std::make_unique(tiktoken_path, *config); + } } auto Pipeline::generate( diff --git a/qwen.h b/qwen.h index cf812e6..b04be9e 100644 --- a/qwen.h +++ b/qwen.h @@ -117,7 +117,7 @@ class Embedding { class Linear { public: Linear() : weight(nullptr), bias(nullptr) {} - Linear(ModelContext *ctx, int in_features, int out_features, bool use_bias = true) + Linear(ModelContext *ctx, int in_features, int out_features, bool use_bias = false) : weight(ggml_new_tensor_2d(ctx->ctx_w.get(), ctx->dtype, in_features, out_features)), bias(use_bias ? ggml_new_tensor_1d(ctx->ctx_w.get(), GGML_TYPE_F32, out_features) : nullptr) {} @@ -278,6 +278,8 @@ enum class ModelType { QWEN1 = 1, // abort QWEN2 = 2, QWEN2MOE = 3, + CODEQWEN = 4, + LLAMA3 = 5 }; struct QwenConfig { @@ -306,6 +308,11 @@ struct QwenMoeConfig : QwenConfig { int norm_topk_prob; }; +struct Llama3Config : QwenConfig{ + // float rope_theta; +}; + + struct ChatMessage{ std::string role; std::string content; @@ -333,7 +340,7 @@ class QwenTokenizer { auto decode(const std::vector &ids) const -> std::string; - std::vector encode_messages(const std::vector &messages, int max_length) const; + virtual std::vector encode_messages(const std::vector &messages, int max_length) const; ChatMessage decode_message(const std::vector &ids) const{ return {ChatMessage::ROLE_ASSISTANT, decode(ids)}; @@ -352,6 +359,16 @@ class QwenTokenizer { static void check_chat_messages(const std::vector &messages); }; +class LlamaTokenizer : public QwenTokenizer{ + public: + LlamaTokenizer(const std::string & tiktoken_path, const QwenConfig &config); + + std::vector encode_messages(const std::vector &messages, int max_length) const; + + static std::string build_prompt(const std::vector &messages); + +}; + class QwenAttention { public: QwenAttention() : num_attention_heads(0), num_kv_heads(0) {} @@ -369,6 +386,23 @@ class QwenAttention { ggml_tensor *v_cache; // [n_head, head_size, maxlen] }; +class LlamaAttention: public QwenAttention{ + public: + LlamaAttention() : num_attention_heads(0), num_kv_heads(0) {} + LlamaAttention(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length); + + auto forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos, int n_past, int n_ctx) const -> ggml_tensor *; + + int num_attention_heads; + int num_kv_heads; + Linear q_proj; + Linear k_proj; + Linear v_proj; + Linear o_proj; + ggml_tensor *k_cache; // [n_head, maxlen, head_size] + ggml_tensor *v_cache; // [n_head, head_size, maxlen] +}; + class QwenMLP { public: QwenMLP() = default; @@ -434,6 +468,24 @@ class QwenBlock { QwenMLP mlp; }; +class LlamaBlock:public QwenBlock{ + public: + LlamaBlock() = default; + LlamaBlock(ModelContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size, int max_length) + : input_layernorm(ctx, hidden_size, false), + attn(ctx, hidden_size, num_attention_heads, num_kv_heads, max_length), + post_attention_layernorm(ctx, hidden_size, false), + mlp(ctx, hidden_size, intermediate_size) {} + + auto forward(ModelContext *ctx, ggml_tensor *hidden_states, ggml_tensor *KQ_pos,int n_past, int n_ctx) const -> ggml_tensor *; + + RMSNorm input_layernorm; + LlamaAttention attn; + RMSNorm post_attention_layernorm; + QwenMLP mlp; + +}; + struct BasicPositionIdsGenerator { ggml_tensor *operator()(ggml_context *ctx, int qlen, int n_past, int n_ctx) const { ggml_tensor *position_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, qlen); @@ -488,9 +540,23 @@ class QwenMoeModel { BasicPositionIdsGenerator pos_ids_gen_; }; +class LlamaModel: public QwenModel{ + public: + LlamaModel() = default; + LlamaModel(ModelContext *ctx, const Llama3Config &config); + + auto forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx) const -> ggml_tensor *; + + Embedding embed_tokens; + std::vector layers; + RMSNorm norm; + BasicPositionIdsGenerator pos_ids_gen_; +}; + class QwenForCausalLM { public: QwenForCausalLM(const QwenConfig &config); + QwenForCausalLM(int test){}; ~QwenForCausalLM(); auto generate_next_token( @@ -553,6 +619,25 @@ class QwenMoeForCausalLM : public QwenForCausalLM { std::vector> state_dict_; }; +class Llama3ForCausalLM : public QwenForCausalLM{ + public: + Llama3ForCausalLM(const Llama3Config &config); // Declaration + ~Llama3ForCausalLM(); + // Override methods here if needed + + auto load(ModelLoader &loader) -> void override; + auto forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx, bool is_decoding) const -> ggml_tensor * override; + + static constexpr size_t MEM_SIZE = 2000 * MB; // 2k context + static constexpr size_t SCRATCH_SIZE = 2000 * MB; // 2k context + Llama3Config config; + LlamaModel transformer; + + private: + ModelContext ctx_; + std::vector> state_dict_; +}; + // ===== pipeline ===== class Pipeline { diff --git a/qwen_cpp/convert.py b/qwen_cpp/convert.py index 93654b4..bfda059 100644 --- a/qwen_cpp/convert.py +++ b/qwen_cpp/convert.py @@ -42,6 +42,7 @@ class ModelType(Enum): QWEN2 = 2 QWEN2MOE = 3 CODEQWEN = 4 # not use now + Llama3 = 5 def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: @@ -226,6 +227,60 @@ def dump_model(f, model, ggml_type): ] dump_state_dict(f, weight_names, model.state_dict(), ggml_type) +class Llama3Converter: + """ + 主要是tokenizer的区别。 + 还有qkv没有bias + """ + MODEL_TYPE = ModelType.Llama3 + @classmethod + def convert(cls, f, model, tokenizer, ggml_type): + f.write(b"ggml") # magic + f.write(struct.pack("ii", cls.MODEL_TYPE.value, 1)) # model type & version + cls.dump_config(f, model.config, model.generation_config, tokenizer, ggml_type) # generation_config is not use now. + cls.dump_model(f, model, ggml_type) + + @staticmethod + def dump_config(f, config, generation_config, tokenizer, ggml_type): + config_values = [ + ggml_type.value, + config.vocab_size, + config.hidden_size, + config.num_attention_heads, + config.num_key_value_heads, + config.num_hidden_layers, + config.intermediate_size, + config.max_position_embeddings, + config.eos_token_id, # eos 151645 + 128001, # pad + 128000, # <|begin_of_text|> + 128009, # "<|end_of_text|>" + ] + f.write(struct.pack("i" * len(config_values), *config_values)) + + @staticmethod + def dump_model(f, model, ggml_type): + weight_names = ["model.embed_tokens.weight"] + for i in range(model.config.num_hidden_layers): + weight_names += [ + f"model.layers.{i}.input_layernorm.weight", + f"model.layers.{i}.self_attn.q_proj.weight", + f"model.layers.{i}.self_attn.k_proj.weight", + f"model.layers.{i}.self_attn.v_proj.weight", + f"model.layers.{i}.self_attn.o_proj.weight", + f"model.layers.{i}.post_attention_layernorm.weight", + f"model.layers.{i}.mlp.gate_proj.weight", + f"model.layers.{i}.mlp.up_proj.weight", + f"model.layers.{i}.mlp.down_proj.weight", + ] + + weight_names += [ + "model.norm.weight", + "lm_head.weight", + ] + print(len(weight_names)) + dump_state_dict(f, weight_names, model.state_dict(), ggml_type) + class Qwen2Converter: MODEL_TYPE = ModelType.QWEN2 @classmethod @@ -406,7 +461,7 @@ def convert(f: BinaryIO, model_name_or_path: str, dtype: str = "q4_0"): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - # print(model) + print(model) # state_dict = model.state_dict() # keys = state_dict.keys() # print(keys) @@ -429,6 +484,10 @@ def convert(f: BinaryIO, model_name_or_path: str, dtype: str = "q4_0"): elif model.config.architectures[0]=="Qwen2MoeForCausalLM": print('Convert Qwen1.5-Moe') Qwen2MOEConverter.convert(f, model, tokenizer, ggml_type) + + elif model.config.architectures[0]=="LlamaForCausalLM": + print("Convert Llama3") + Llama3Converter.convert(f, model, tokenizer, ggml_type) else: print('Warning: Qwen1 is not supported now') QwenConverter.convert(f, model, tokenizer, ggml_type)