From 5703f6fa804d3abbedecae4eb9196f52e9fb50b1 Mon Sep 17 00:00:00 2001 From: akhoroshev Date: Fri, 22 Nov 2024 09:34:59 +0300 Subject: [PATCH 1/2] fix out dims --- cpp/tensorrt_llm/runtime/loraModule.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/tensorrt_llm/runtime/loraModule.cpp b/cpp/tensorrt_llm/runtime/loraModule.cpp index 8a8e2e559..e0540779d 100644 --- a/cpp/tensorrt_llm/runtime/loraModule.cpp +++ b/cpp/tensorrt_llm/runtime/loraModule.cpp @@ -42,12 +42,12 @@ std::vector LoraModule::createLoraModules(std::vector c modules.emplace_back( t, hidden, (numHeads * attnHeadSize + 2 * numKvHeads * attnHeadSize), false, true, -1, 0); break; - case ModuleType::kATTN_Q: - case ModuleType::kATTN_K: - case ModuleType::kATTN_V: - case ModuleType::kCROSS_ATTN_Q: - case ModuleType::kCROSS_ATTN_K: - case ModuleType::kCROSS_ATTN_V: modules.emplace_back(t, hidden, hidden, false, true, -1, 0); break; + case ModuleType::kATTN_Q: modules.emplace_back(t, hidden, numHeads * attnHeadSize, false, true, -1, 0); break; + case ModuleType::kATTN_K: modules.emplace_back(t, hidden, numKvHeads * attnHeadSize, false, true, -1, 0); break; + case ModuleType::kATTN_V: modules.emplace_back(t, hidden, numKvHeads * attnHeadSize, false, true, -1, 0); break; + case ModuleType::kCROSS_ATTN_Q: modules.emplace_back(t, hidden, numHeads * attnHeadSize, false, true, -1, 0); break; + case ModuleType::kCROSS_ATTN_K: modules.emplace_back(t, hidden, numKvHeads * attnHeadSize, false, true, -1, 0); break; + case ModuleType::kCROSS_ATTN_V: modules.emplace_back(t, hidden, numKvHeads * attnHeadSize, false, true, -1, 0); break; case ModuleType::kATTN_DENSE: case ModuleType::kCROSS_ATTN_DENSE: modules.emplace_back(t, hidden, hidden, false, true, 1, -1); break; case ModuleType::kMLP_H_TO_4H: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break; From bb34a856b52900502ca8043474db949ff6906df2 Mon Sep 17 00:00:00 2001 From: akhoroshev Date: Fri, 22 Nov 2024 09:56:47 +0300 Subject: [PATCH 2/2] CROSS_ATTN rollback --- cpp/tensorrt_llm/runtime/loraModule.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/tensorrt_llm/runtime/loraModule.cpp b/cpp/tensorrt_llm/runtime/loraModule.cpp index e0540779d..6bdc4c0ab 100644 --- a/cpp/tensorrt_llm/runtime/loraModule.cpp +++ b/cpp/tensorrt_llm/runtime/loraModule.cpp @@ -45,9 +45,9 @@ std::vector LoraModule::createLoraModules(std::vector c case ModuleType::kATTN_Q: modules.emplace_back(t, hidden, numHeads * attnHeadSize, false, true, -1, 0); break; case ModuleType::kATTN_K: modules.emplace_back(t, hidden, numKvHeads * attnHeadSize, false, true, -1, 0); break; case ModuleType::kATTN_V: modules.emplace_back(t, hidden, numKvHeads * attnHeadSize, false, true, -1, 0); break; - case ModuleType::kCROSS_ATTN_Q: modules.emplace_back(t, hidden, numHeads * attnHeadSize, false, true, -1, 0); break; - case ModuleType::kCROSS_ATTN_K: modules.emplace_back(t, hidden, numKvHeads * attnHeadSize, false, true, -1, 0); break; - case ModuleType::kCROSS_ATTN_V: modules.emplace_back(t, hidden, numKvHeads * attnHeadSize, false, true, -1, 0); break; + case ModuleType::kCROSS_ATTN_Q: + case ModuleType::kCROSS_ATTN_K: + case ModuleType::kCROSS_ATTN_V: modules.emplace_back(t, hidden, hidden, false, true, -1, 0); break; case ModuleType::kATTN_DENSE: case ModuleType::kCROSS_ATTN_DENSE: modules.emplace_back(t, hidden, hidden, false, true, 1, -1); break; case ModuleType::kMLP_H_TO_4H: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;