Skip to content

Commit

Permalink
♻️ refactor: Rework currents and add new templates
Browse files Browse the repository at this point in the history
Add Alpaca, Llama3, LLama2, ChatML templates
  • Loading branch information
Palm1r committed Nov 25, 2024
1 parent 36d5242 commit 1261f91
Show file tree
Hide file tree
Showing 17 changed files with 264 additions and 120 deletions.
12 changes: 9 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,22 @@ add_qtc_plugin(QodeAssist
QodeAssistConstants.hpp
QodeAssisttr.h
LLMClientInterface.hpp LLMClientInterface.cpp
templates/Templates.hpp
templates/CodeLlamaFim.hpp
templates/StarCoder2Fim.hpp
templates/DeepSeekCoderFim.hpp
templates/CustomFimTemplate.hpp
templates/DeepSeekCoderChat.hpp
templates/CodeLlamaChat.hpp


templates/Qwen.hpp
templates/StarCoderChat.hpp

templates/Ollama.hpp
templates/BasicChat.hpp
templates/Llama3.hpp
templates/ChatML.hpp
templates/Alpaca.hpp
templates/Llama2.hpp
providers/Providers.hpp
providers/OllamaProvider.hpp providers/OllamaProvider.cpp
providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp
providers/OpenAICompatProvider.hpp providers/OpenAICompatProvider.cpp
Expand Down
3 changes: 2 additions & 1 deletion ChatView/ClientInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ void ClientInterface::sendMessage(const QString &message, bool includeCurrentFil
{
cancelRequest();

m_chatModel->addMessage(message, ChatModel::ChatRole::User, "");

auto &chatAssistantSettings = Settings::chatAssistantSettings();

auto providerName = Settings::generalSettings().caProvider();
Expand Down Expand Up @@ -128,7 +130,6 @@ void ClientInterface::sendMessage(const QString &message, bool includeCurrentFil
QJsonObject request;
request["id"] = QUuid::createUuid().toString();

m_chatModel->addMessage(message, ChatModel::ChatRole::User, "");
m_requestHandler->sendLLMRequest(config, request);
}

Expand Down
1 change: 1 addition & 0 deletions llmcore/PromptTemplate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ class PromptTemplate
virtual QString promptTemplate() const = 0;
virtual QStringList stopWords() const = 0;
virtual void prepareRequest(QJsonObject &request, const ContextData &context) const = 0;
virtual QString description() const = 0;
};
} // namespace QodeAssist::LLMCore
37 changes: 37 additions & 0 deletions providers/Providers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/

#pragma once

#include "llmcore/ProvidersManager.hpp"
#include "providers/LMStudioProvider.hpp"
#include "providers/OllamaProvider.hpp"
#include "providers/OpenAICompatProvider.hpp"

namespace QodeAssist::Providers {

inline void registerProviders()
{
auto &providerManager = LLMCore::ProvidersManager::instance();
providerManager.registerProvider<OllamaProvider>();
providerManager.registerProvider<LMStudioProvider>();
providerManager.registerProvider<OpenAICompatProvider>();
}

} // namespace QodeAssist::Providers
40 changes: 5 additions & 35 deletions qodeassist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,9 @@
#include "QodeAssistClient.hpp"
#include "chat/ChatOutputPane.h"
#include "chat/NavigationPanel.hpp"
#include "llmcore/PromptTemplateManager.hpp"
#include "llmcore/ProvidersManager.hpp"
#include "providers/LMStudioProvider.hpp"
#include "providers/OllamaProvider.hpp"
#include "providers/OpenAICompatProvider.hpp"

#include "templates/BasicChat.hpp"
#include "templates/CodeLlamaChat.hpp"
#include "templates/CodeLlamaFim.hpp"
#include "templates/CustomFimTemplate.hpp"
#include "templates/DeepSeekCoderChat.hpp"
#include "templates/DeepSeekCoderFim.hpp"
#include "templates/Ollama.hpp"
#include "templates/Qwen.hpp"
#include "templates/StarCoder2Fim.hpp"
#include "templates/StarCoderChat.hpp"

#include "providers/Providers.hpp"
#include "templates/Templates.hpp"

using namespace Utils;
using namespace Core;
Expand All @@ -85,25 +72,8 @@ class QodeAssistPlugin final : public ExtensionSystem::IPlugin

void initialize() final
{
auto &providerManager = LLMCore::ProvidersManager::instance();
providerManager.registerProvider<Providers::OllamaProvider>();
providerManager.registerProvider<Providers::LMStudioProvider>();
providerManager.registerProvider<Providers::OpenAICompatProvider>();

auto &templateManager = LLMCore::PromptTemplateManager::instance();
templateManager.registerTemplate<Templates::CodeLlamaFim>();
templateManager.registerTemplate<Templates::StarCoder2Fim>();
templateManager.registerTemplate<Templates::DeepSeekCoderFim>();
templateManager.registerTemplate<Templates::CustomTemplate>();
templateManager.registerTemplate<Templates::DeepSeekCoderChat>();
templateManager.registerTemplate<Templates::CodeLlamaChat>();
templateManager.registerTemplate<Templates::LlamaChat>();
templateManager.registerTemplate<Templates::StarCoderChat>();
templateManager.registerTemplate<Templates::QwenChat>();
templateManager.registerTemplate<Templates::QwenFim>();
templateManager.registerTemplate<Templates::OllamaAutoFim>();
templateManager.registerTemplate<Templates::OllamaAutoChat>();
templateManager.registerTemplate<Templates::BasicChat>();
Providers::registerProviders();
Templates::registerTemplates();

Utils::Icon QCODEASSIST_ICON(
{{":/resources/images/qoderassist-icon.png", Utils::Theme::IconsBaseColor}});
Expand Down
67 changes: 67 additions & 0 deletions templates/Alpaca.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/

#pragma once

#include "llmcore/PromptTemplate.hpp"
#include <QJsonArray>

namespace QodeAssist::Templates {

class Alpaca : public LLMCore::PromptTemplate
{
public:
QString name() const override { return "Alpaca"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override
{
return QStringList() << "### Instruction:" << "### Response:";
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();

for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();

QString formattedContent;
if (role == "system") {
formattedContent = content + "\n\n";
} else if (role == "user") {
formattedContent = "### Instruction:\n" + content + "\n\n";
} else if (role == "assistant") {
formattedContent = "### Response:\n" + content + "\n\n";
}

message["content"] = formattedContent;
messages[i] = message;
}

request["messages"] = messages;
}
QString description() const override
{
return "The message will contain the following tokens: ### Instruction:\n### Response:\n";
}
};

} // namespace QodeAssist::Templates
13 changes: 2 additions & 11 deletions templates/BasicChat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,9 @@ class BasicChat : public LLMCore::PromptTemplate
QString name() const override { return "Basic Chat"; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override { return QStringList(); }

void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();

QJsonObject newMessage;
newMessage["role"] = "user";
newMessage["content"] = context.prefix;
messages.append(newMessage);

request["messages"] = messages;
}
{}
QString description() const override { return "chat without tokens"; }
};

} // namespace QodeAssist::Templates
30 changes: 18 additions & 12 deletions templates/DeepSeekCoderChat.hpp → templates/ChatML.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,41 @@
#pragma once

#include <QJsonArray>

#include "llmcore/PromptTemplate.hpp"

namespace QodeAssist::Templates {

class DeepSeekCoderChat : public LLMCore::PromptTemplate
class ChatML : public LLMCore::PromptTemplate
{
public:
QString name() const override { return "DeepSeekCoder Chat"; }
QString name() const override { return "ChatML"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }

QString promptTemplate() const override { return "### Instruction:\n%1\n### Response:\n"; }

QString promptTemplate() const override { return {}; }
QStringList stopWords() const override
{
return QStringList() << "### Instruction:" << "### Response:" << "\n\n### " << "<|EOT|>";
return QStringList() << "<|im_start|>" << "<|im_end|>";
}

void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.prefix);
QJsonArray messages = request["messages"].toArray();

QJsonObject newMessage;
newMessage["role"] = "user";
newMessage["content"] = formattedPrompt;
messages.append(newMessage);
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();

message["content"] = QString("<|im_start|>%1\n%2\n<|im_end|>").arg(role, content);

messages[i] = message;
}

request["messages"] = messages;
}
QString description() const override
{
return "The message will contain the following tokens: <|im_start|>%1\n%2\n<|im_end|>";
}
};

} // namespace QodeAssist::Templates
5 changes: 4 additions & 1 deletion templates/CodeLlamaFim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ class CodeLlamaFim : public LLMCore::PromptTemplate
{
return QStringList() << "<EOT>" << "<PRE>" << "<SUF" << "<MID>";
}

void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
request["prompt"] = formattedPrompt;
}
QString description() const override
{
return "The message will contain the following tokens: <PRE> %1 <SUF>%2 <MID>";
}
};

} // namespace QodeAssist::Templates
2 changes: 1 addition & 1 deletion templates/CustomFimTemplate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class CustomTemplate : public LLMCore::PromptTemplate
return Settings::customPromptSettings().customJsonTemplate();
}
QStringList stopWords() const override { return QStringList(); }

void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonDocument doc = QJsonDocument::fromJson(promptTemplate().toUtf8());
Expand All @@ -56,6 +55,7 @@ class CustomTemplate : public LLMCore::PromptTemplate
request[it.key()] = it.value();
}
}
QString description() const override { return promptTemplate(); }

private:
QJsonValue processJsonValue(const QJsonValue &value, const LLMCore::ContextData &context) const
Expand Down
5 changes: 5 additions & 0 deletions templates/DeepSeekCoderFim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class DeepSeekCoderFim : public LLMCore::PromptTemplate
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
request["prompt"] = formattedPrompt;
}
QString description() const override
{
return "The message will contain the following tokens: "
"<|fim▁begin|>%1<|fim▁hole|>%2<|fim▁end|>";
}
};

} // namespace QodeAssist::Templates
44 changes: 26 additions & 18 deletions templates/CodeLlamaChat.hpp → templates/Llama2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,46 @@

#pragma once

#include <QJsonArray>

#include "llmcore/PromptTemplate.hpp"
#include <QJsonArray>

namespace QodeAssist::Templates {

class CodeLlamaChat : public LLMCore::PromptTemplate
class Llama2 : public LLMCore::PromptTemplate
{
public:
QString name() const override { return "Llama 2"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "CodeLlama Chat"; }
QString promptTemplate() const override { return "[INST] %1 [/INST]"; }
QStringList stopWords() const override { return QStringList() << "[INST]" << "[/INST]"; }

QString promptTemplate() const override { return {}; }
QStringList stopWords() const override { return QStringList() << "[INST]"; }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.prefix);
QJsonArray messages = request["messages"].toArray();

QJsonObject newMessage;
newMessage["role"] = "user";
newMessage["content"] = formattedPrompt;
messages.append(newMessage);
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();

QString formattedContent;
if (role == "system") {
formattedContent = QString("[INST]<<SYS>>\n%1\n<</SYS>>[/INST]\n").arg(content);
} else if (role == "user") {
formattedContent = QString("[INST]%1[/INST]\n").arg(content);
} else if (role == "assistant") {
formattedContent = content + "\n";
}

message["content"] = formattedContent;
messages[i] = message;
}

request["messages"] = messages;
}
};

class LlamaChat : public CodeLlamaChat
{
public:
QString name() const override { return "Llama Chat"; }
QString description() const override
{
return "The message will contain the following tokens: [INST]%1[/INST]\n";
}
};

} // namespace QodeAssist::Templates
Loading

0 comments on commit 1261f91

Please sign in to comment.