Skip to content

Commit

Permalink
Add Settings for chat
Browse files Browse the repository at this point in the history
  • Loading branch information
Palm1r committed Sep 18, 2024
1 parent 7b960a9 commit f5de1b9
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 69 deletions.
2 changes: 1 addition & 1 deletion LLMClientInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request)
auto updatedContext = prepareContext(request);

LLMConfig config;
config.provider = LLMProvidersManager::instance().getCurrentProvider();
config.provider = LLMProvidersManager::instance().getCurrentFimProvider();
config.promptTemplate = PromptTemplateManager::instance().getCurrentTemplate();
config.url = QUrl(QString("%1%2").arg(Settings::generalSettings().url(),
Settings::generalSettings().endPoint()));
Expand Down
53 changes: 38 additions & 15 deletions LLMProvidersManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,61 @@ LLMProvidersManager &LLMProvidersManager::instance()
return instance;
}

QStringList LLMProvidersManager::getProviderNames() const
void LLMProvidersManager::setCurrentFimProvider(const QString &name)
{
return m_providers.keys();
if (!m_providers.contains(name)) {
logMessage("Can't find provider with name: " + name);
return;
}

if (m_providers[name] == nullptr) {
logMessage("Provider is null");
return;
}

m_currentFimProvider = m_providers[name];

Settings::generalSettings().url.setVolatileValue(m_currentFimProvider->url());
Settings::generalSettings().endPoint.setVolatileValue(
m_currentFimProvider->completionEndpoint());
}

void LLMProvidersManager::setCurrentProvider(const QString &name)
void LLMProvidersManager::setCurrentChatProvider(const QString &name)
{
if (m_providers.contains(name)) {
m_currentProviderName = name;
if (!m_providers.contains(name)) {
logMessage("Can't find chat provider with name: " + name);
return;
}

if (m_providers[name] == nullptr) {
logMessage(QString("Chat provider with name %1 is null").arg(name));
return;
}

m_currentChatProvider = m_providers[name];

Settings::generalSettings().chatUrl.setVolatileValue(m_currentChatProvider->url());
Settings::generalSettings().chatEndPoint.setVolatileValue(m_currentChatProvider->chatEndpoint());
}

Providers::LLMProvider *LLMProvidersManager::getCurrentProvider()
Providers::LLMProvider *LLMProvidersManager::getCurrentFimProvider()
{
if (m_currentProviderName.isEmpty()) {
if (m_currentFimProvider == nullptr) {
logMessage("Current fim provider is null");
return nullptr;
}

return m_providers[m_currentProviderName];
return m_currentFimProvider;
}

Providers::LLMProvider *LLMProvidersManager::getCurrentChatProvider()
{
auto currentChatProvider
= m_providers[Settings::generalSettings().chatLlmProviders.stringValue()];

if (!currentChatProvider) {
logMessage("Can't find chat provider");
return m_providers.first();
if (m_currentChatProvider == nullptr) {
logMessage("Current chat provider is null");
return nullptr;
}

return currentChatProvider;
return m_currentFimProvider;
}

LLMProvidersManager::~LLMProvidersManager()
Expand Down
13 changes: 7 additions & 6 deletions LLMProvidersManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class LLMProvidersManager
{
public:
static LLMProvidersManager &instance();
~LLMProvidersManager();

template<typename T>
void registerProvider()
Expand All @@ -43,20 +44,20 @@ class LLMProvidersManager
Settings::generalSettings().chatLlmProviders.addOption(name);
}

QStringList getProviderNames() const;
void setCurrentProvider(const QString &name);
Providers::LLMProvider *getCurrentProvider();
void setCurrentFimProvider(const QString &name);
void setCurrentChatProvider(const QString &name);
Providers::LLMProvider *getCurrentFimProvider();
Providers::LLMProvider *getCurrentChatProvider();

~LLMProvidersManager();

private:
LLMProvidersManager() = default;
LLMProvidersManager(const LLMProvidersManager &) = delete;
LLMProvidersManager &operator=(const LLMProvidersManager &) = delete;

QMap<QString, Providers::LLMProvider *> m_providers;
QString m_currentProviderName;
QHash<QString, Providers::LLMProvider *> m_providers;
Providers::LLMProvider *m_currentFimProvider = nullptr;
Providers::LLMProvider *m_currentChatProvider = nullptr;
};

} // namespace QodeAssist
16 changes: 13 additions & 3 deletions PromptTemplateManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#include "PromptTemplateManager.hpp"

#include "QodeAssistUtils.hpp"

namespace QodeAssist {

PromptTemplateManager &PromptTemplateManager::instance()
Expand All @@ -27,11 +29,19 @@ PromptTemplateManager &PromptTemplateManager::instance()
return instance;
}

void PromptTemplateManager::setCurrentTemplate(const QString &name)
void PromptTemplateManager::setCurrentFimTemplate(const QString &name)
{
if (m_templates.contains(name)) {
m_currentTemplateName = name;
if (!m_templates.contains(name)) {
logMessage("Can't find prompt with name: " + name);
return;
}

if (m_templates[name] == nullptr) {
logMessage("Prompt is null");
return;
}

m_currentFimPrompt
}

Templates::PromptTemplate *PromptTemplateManager::getCurrentTemplate()
Expand Down
9 changes: 5 additions & 4 deletions PromptTemplateManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class PromptTemplateManager
{
public:
static PromptTemplateManager &instance();
~PromptTemplateManager();

template<typename T>
void registerTemplate()
Expand All @@ -43,20 +44,20 @@ class PromptTemplateManager
Settings::generalSettings().fimPrompts.addOption(name);
}

void setCurrentTemplate(const QString &name);
Templates::PromptTemplate *getCurrentTemplate();
void setCurrentFimTemplate(const QString &name);
Templates::PromptTemplate *getCurrentFimTemplate();

QStringList getTemplateNames() const;

~PromptTemplateManager();

private:
PromptTemplateManager() = default;
PromptTemplateManager(const PromptTemplateManager &) = delete;
PromptTemplateManager &operator=(const PromptTemplateManager &) = delete;

QMap<QString, Templates::PromptTemplate *> m_templates;
QString m_currentTemplateName;
Templates::PromptTemplate *m_currentFimPrompt;
Templates::PromptTemplate *m_currentChatPrompt;
};

} // namespace QodeAssist
5 changes: 0 additions & 5 deletions chat/ChatClientInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ void ChatClientInterface::sendMessage(const QString &message)
QJsonObject providerRequest;
prepareRequest(providerRequest, message);

if (LLMProvidersManager::instance().getCurrentProvider() == nullptr)
qDebug() << "Error: No provider selected";
if (PromptTemplateManager::instance().getCurrentTemplate() == nullptr)
qDebug() << "Error: No prompt template selected";

LLMConfig config;
config.provider = LLMProvidersManager::instance().getCurrentChatProvider();
config.promptTemplate = PromptTemplateManager::instance().getCurrentTemplate();
Expand Down
5 changes: 0 additions & 5 deletions chat/ChatClientInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,9 @@ class ChatClientInterface : public QObject
LLMRequestHandler *m_requestHandler;
QString m_accumulatedResponse;

// Захардкоженные значения (временно)
const QString MODEL_NAME = "bartowski/Llama-3.1-SauerkrautLM-8b-Instruct-GGUF";
const QString SERVER_URL = "http://localhost:1234";
const QString ENDPOINT = "/v1/chat/completions";
};

} // namespace QodeAssist::Chat
// Захардкоженные значения
// const QString MODEL_NAME = "codellama:7b-instruct";
// const QString SERVER_URL = "http://localhost:11434";
// const QString ENDPOINT = "/api/chat";
1 change: 0 additions & 1 deletion chat/ChatOutputPane.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class ChatOutputPane : public Core::IOutputPane
explicit ChatOutputPane(QObject *parent = nullptr);
~ChatOutputPane() override;

// Переопределенные методы из IOutputPane
QWidget *outputWidget(QWidget *parent) override;
QList<QWidget *> toolBarWidgets() const override;
void clearContents() override;
Expand Down
4 changes: 2 additions & 2 deletions core/LLMRequestHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ void LLMRequestHandler::handleLLMResponse(QNetworkReply *reply,

QString &accumulatedResponse = m_accumulatedResponses[reply];

auto provider = LLMProvidersManager::instance().getCurrentProvider();
auto provider = LLMProvidersManager::instance().getCurrentFimProvider();
if (provider == nullptr)
qDebug() << "No provider selected";

bool isComplete = LLMProvidersManager::instance()
.getCurrentProvider()
.getCurrentFimProvider()
->handleResponse(reply, accumulatedResponse);

if (!Settings::generalSettings().multiLineCompletion()
Expand Down
4 changes: 2 additions & 2 deletions qodeassist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ class QodeAssistPlugin final : public ExtensionSystem::IPlugin
providerManager.registerProvider<Providers::OllamaProvider>();
providerManager.registerProvider<Providers::LMStudioProvider>();
providerManager.registerProvider<Providers::OpenAICompatProvider>();
providerManager.setCurrentProvider("Ollama");
providerManager.setCurrentFimProvider("Ollama");

auto &templateManager = PromptTemplateManager::instance();
templateManager.registerTemplate<Templates::StarCoder2Template>();
templateManager.registerTemplate<Templates::CodeLLamaTemplate>();
templateManager.registerTemplate<Templates::StarCoder2Template>();
templateManager.registerTemplate<Templates::DeepSeekCoderV2Template>();
templateManager.registerTemplate<Templates::CustomTemplate>();
templateManager.setCurrentTemplate("StarCoder2");
Expand Down
54 changes: 31 additions & 23 deletions settings/GeneralSettings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,19 @@ GeneralSettings::GeneralSettings()
autoCompletionTypingInterval.setDefaultValue(2000);

llmProviders.setSettingsKey(Constants::LLM_PROVIDERS);
llmProviders.setDisplayName(Tr::tr("FIM Provider:"));
llmProviders.setDisplayName(Tr::tr("AI Suggest Provider:"));
llmProviders.setDisplayStyle(Utils::SelectionAspect::DisplayStyle::ComboBox);
llmProviders.setDefaultValue(0);

url.setSettingsKey(Constants::URL);
url.setLabelText(Tr::tr("URL:"));
url.setDisplayStyle(Utils::StringAspect::LineEditDisplay);
url.setDefaultValue("http://localhost:11434");

endPoint.setSettingsKey(Constants::END_POINT);
endPoint.setLabelText(Tr::tr("FIM Endpoint:"));
endPoint.setDisplayStyle(Utils::StringAspect::LineEditDisplay);
endPoint.setDefaultValue("/api/generate");

modelName.setSettingsKey(Constants::MODEL_NAME);
modelName.setLabelText(Tr::tr("Model name:"));
Expand All @@ -108,17 +110,19 @@ GeneralSettings::GeneralSettings()
resetToDefaults.m_buttonText = Tr::tr("Reset Page to Defaults");

chatLlmProviders.setSettingsKey(Constants::CHAT_LLM_PROVIDERS);
chatLlmProviders.setDisplayName(Tr::tr("Chat Provider:"));
chatLlmProviders.setDisplayName(Tr::tr("AI Chat Provider:"));
chatLlmProviders.setDisplayStyle(Utils::SelectionAspect::DisplayStyle::ComboBox);
chatLlmProviders.setDefaultValue(0);

chatUrl.setSettingsKey(Constants::CHAT_URL);
chatUrl.setLabelText(Tr::tr("URL:"));
chatUrl.setDisplayStyle(Utils::StringAspect::LineEditDisplay);
chatUrl.setDefaultValue("http://localhost:11434");

chatEndPoint.setSettingsKey(Constants::CHAT_END_POINT);
chatEndPoint.setLabelText(Tr::tr("Chat Endpoint:"));
chatEndPoint.setDisplayStyle(Utils::StringAspect::LineEditDisplay);
chatEndPoint.setDefaultValue("/api/chat");

chatModelName.setSettingsKey(Constants::CHAT_MODEL_NAME);
chatModelName.setLabelText(Tr::tr("Model name:"));
Expand Down Expand Up @@ -172,16 +176,17 @@ void GeneralSettings::setupConnections()
{
connect(&llmProviders, &Utils::SelectionAspect::volatileValueChanged, this, [this]() {
int index = llmProviders.volatileValue();
logMessage(QString("currentProvider %1").arg(llmProviders.displayForIndex(index)));
LLMProvidersManager::instance().setCurrentProvider(llmProviders.displayForIndex(index));
updateProviderSettings();
LLMProvidersManager::instance().setCurrentFimProvider(llmProviders.displayForIndex(index));
modelName.setVolatileValue("");
});
connect(&fimPrompts, &Utils::SelectionAspect::volatileValueChanged, this, [this]() {
int index = fimPrompts.volatileValue();
logMessage(QString("currentPrompt %1").arg(fimPrompts.displayForIndex(index)));
PromptTemplateManager::instance().setCurrentTemplate(fimPrompts.displayForIndex(index));
// connect(&fimPrompts, &Utils::SelectionAspect::volatileValueChanged, this, [this]() {
// int index = fimPrompts.volatileValue();
// logMessage(QString("currentPrompt %1").arg(fimPrompts.displayForIndex(index)));
// PromptTemplateManager::instance().setCurrentTemplate(fimPrompts.displayForIndex(index));
// });
connect(&selectModels, &ButtonAspect::clicked, this, [this]() {
showModelSelectionDialog(&modelName);
});
connect(&selectModels, &ButtonAspect::clicked, this, [this]() { showModelSelectionDialog(); });
connect(&enableLogging, &Utils::BoolAspect::volatileValueChanged, this, [this]() {
setLoggingEnabled(enableLogging.volatileValue());
});
Expand All @@ -195,22 +200,21 @@ void GeneralSettings::setupConnections()
&Utils::StringAspect::volatileValueChanged,
this,
&GeneralSettings::updateStatusIndicators);
}

void GeneralSettings::updateProviderSettings()
{
const auto provider = LLMProvidersManager::instance().getCurrentProvider();

if (provider) {
url.setVolatileValue(provider->url());
endPoint.setVolatileValue(provider->completionEndpoint());
modelName.setVolatileValue("");
}
connect(&chatLlmProviders, &Utils::SelectionAspect::volatileValueChanged, this, [this]() {
int index = chatLlmProviders.volatileValue();
LLMProvidersManager::instance().setCurrentChatProvider(
chatLlmProviders.displayForIndex(index));
chatModelName.setVolatileValue("");
});
connect(&chatSelectModels, &ButtonAspect::clicked, this, [this]() {
showModelSelectionDialog(&chatModelName);
});
}

void GeneralSettings::showModelSelectionDialog()
void GeneralSettings::showModelSelectionDialog(Utils::StringAspect *modelNameObj)
{
auto *provider = LLMProvidersManager::instance().getCurrentProvider();
auto *provider = LLMProvidersManager::instance().getCurrentFimProvider();
Utils::Environment env = Utils::Environment::systemEnvironment();

if (provider) {
Expand All @@ -225,7 +229,7 @@ void GeneralSettings::showModelSelectionDialog()
&ok);

if (ok && !selectedModel.isEmpty()) {
modelName.setVolatileValue(selectedModel);
modelNameObj->setVolatileValue(selectedModel);
writeSettings();
}
}
Expand Down Expand Up @@ -255,6 +259,10 @@ void GeneralSettings::resetPageToDefaults()
resetAspect(llmProviders);
resetAspect(fimPrompts);
resetAspect(chatLlmProviders);
resetAspect(url);
resetAspect(endPoint);
resetAspect(chatUrl);
resetAspect(chatEndPoint);
}

updateStatusIndicators();
Expand Down
3 changes: 1 addition & 2 deletions settings/GeneralSettings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ class GeneralSettings : public Utils::AspectContainer

private:
void setupConnections();
void updateProviderSettings();
void showModelSelectionDialog();
void showModelSelectionDialog(Utils::StringAspect *modelNameObj);
void resetPageToDefaults();

void updateStatusIndicators();
Expand Down

0 comments on commit f5de1b9

Please sign in to comment.