From f5de1b91ec251ded2a46c1fbd1ef8b18dd511d8e Mon Sep 17 00:00:00 2001 From: Petr Mironychev <9195189+Palm1r@users.noreply.github.com> Date: Wed, 18 Sep 2024 22:05:59 +0200 Subject: [PATCH] Add Settings for chat --- LLMClientInterface.cpp | 2 +- LLMProvidersManager.cpp | 53 +++++++++++++++++++++++++---------- LLMProvidersManager.hpp | 13 +++++---- PromptTemplateManager.cpp | 16 +++++++++-- PromptTemplateManager.hpp | 9 +++--- chat/ChatClientInterface.cpp | 5 ---- chat/ChatClientInterface.hpp | 5 ---- chat/ChatOutputPane.h | 1 - core/LLMRequestHandler.cpp | 4 +-- qodeassist.cpp | 4 +-- settings/GeneralSettings.cpp | 54 +++++++++++++++++++++--------------- settings/GeneralSettings.hpp | 3 +- 12 files changed, 100 insertions(+), 69 deletions(-) diff --git a/LLMClientInterface.cpp b/LLMClientInterface.cpp index 01e17fa..aaba9b1 100644 --- a/LLMClientInterface.cpp +++ b/LLMClientInterface.cpp @@ -148,7 +148,7 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request) auto updatedContext = prepareContext(request); LLMConfig config; - config.provider = LLMProvidersManager::instance().getCurrentProvider(); + config.provider = LLMProvidersManager::instance().getCurrentFimProvider(); config.promptTemplate = PromptTemplateManager::instance().getCurrentTemplate(); config.url = QUrl(QString("%1%2").arg(Settings::generalSettings().url(), Settings::generalSettings().endPoint())); diff --git a/LLMProvidersManager.cpp b/LLMProvidersManager.cpp index 49b669c..5283d3e 100644 --- a/LLMProvidersManager.cpp +++ b/LLMProvidersManager.cpp @@ -30,38 +30,61 @@ LLMProvidersManager &LLMProvidersManager::instance() return instance; } -QStringList LLMProvidersManager::getProviderNames() const +void LLMProvidersManager::setCurrentFimProvider(const QString &name) { - return m_providers.keys(); + if (!m_providers.contains(name)) { + logMessage("Can't find provider with name: " + name); + return; + } + + if (m_providers[name] == nullptr) { + logMessage("Provider is null"); + return; + } + + m_currentFimProvider = m_providers[name]; + + Settings::generalSettings().url.setVolatileValue(m_currentFimProvider->url()); + Settings::generalSettings().endPoint.setVolatileValue( + m_currentFimProvider->completionEndpoint()); } -void LLMProvidersManager::setCurrentProvider(const QString &name) +void LLMProvidersManager::setCurrentChatProvider(const QString &name) { - if (m_providers.contains(name)) { - m_currentProviderName = name; + if (!m_providers.contains(name)) { + logMessage("Can't find chat provider with name: " + name); + return; } + + if (m_providers[name] == nullptr) { + logMessage(QString("Chat provider with name %1 is null").arg(name)); + return; + } + + m_currentChatProvider = m_providers[name]; + + Settings::generalSettings().chatUrl.setVolatileValue(m_currentChatProvider->url()); + Settings::generalSettings().chatEndPoint.setVolatileValue(m_currentChatProvider->chatEndpoint()); } -Providers::LLMProvider *LLMProvidersManager::getCurrentProvider() +Providers::LLMProvider *LLMProvidersManager::getCurrentFimProvider() { - if (m_currentProviderName.isEmpty()) { + if (m_currentFimProvider == nullptr) { + logMessage("Current fim provider is null"); return nullptr; } - return m_providers[m_currentProviderName]; + return m_currentFimProvider; } Providers::LLMProvider *LLMProvidersManager::getCurrentChatProvider() { - auto currentChatProvider - = m_providers[Settings::generalSettings().chatLlmProviders.stringValue()]; - - if (!currentChatProvider) { - logMessage("Can't find chat provider"); - return m_providers.first(); + if (m_currentChatProvider == nullptr) { + logMessage("Current chat provider is null"); + return nullptr; } - return currentChatProvider; + return m_currentFimProvider; } LLMProvidersManager::~LLMProvidersManager() diff --git a/LLMProvidersManager.hpp b/LLMProvidersManager.hpp index bf81c9b..9dcc7a5 100644 --- a/LLMProvidersManager.hpp +++ b/LLMProvidersManager.hpp @@ -30,6 +30,7 @@ class LLMProvidersManager { public: static LLMProvidersManager &instance(); + ~LLMProvidersManager(); template void registerProvider() @@ -43,20 +44,20 @@ class LLMProvidersManager Settings::generalSettings().chatLlmProviders.addOption(name); } - QStringList getProviderNames() const; - void setCurrentProvider(const QString &name); - Providers::LLMProvider *getCurrentProvider(); + void setCurrentFimProvider(const QString &name); + void setCurrentChatProvider(const QString &name); + Providers::LLMProvider *getCurrentFimProvider(); Providers::LLMProvider *getCurrentChatProvider(); - ~LLMProvidersManager(); private: LLMProvidersManager() = default; LLMProvidersManager(const LLMProvidersManager &) = delete; LLMProvidersManager &operator=(const LLMProvidersManager &) = delete; - QMap m_providers; - QString m_currentProviderName; + QHash m_providers; + Providers::LLMProvider *m_currentFimProvider = nullptr; + Providers::LLMProvider *m_currentChatProvider = nullptr; }; } // namespace QodeAssist diff --git a/PromptTemplateManager.cpp b/PromptTemplateManager.cpp index 80de111..1eea5c9 100644 --- a/PromptTemplateManager.cpp +++ b/PromptTemplateManager.cpp @@ -19,6 +19,8 @@ #include "PromptTemplateManager.hpp" +#include "QodeAssistUtils.hpp" + namespace QodeAssist { PromptTemplateManager &PromptTemplateManager::instance() @@ -27,11 +29,19 @@ PromptTemplateManager &PromptTemplateManager::instance() return instance; } -void PromptTemplateManager::setCurrentTemplate(const QString &name) +void PromptTemplateManager::setCurrentFimTemplate(const QString &name) { - if (m_templates.contains(name)) { - m_currentTemplateName = name; + if (!m_templates.contains(name)) { + logMessage("Can't find prompt with name: " + name); + return; + } + + if (m_templates[name] == nullptr) { + logMessage("Prompt is null"); + return; } + + m_currentFimPrompt } Templates::PromptTemplate *PromptTemplateManager::getCurrentTemplate() diff --git a/PromptTemplateManager.hpp b/PromptTemplateManager.hpp index e935750..0abadbe 100644 --- a/PromptTemplateManager.hpp +++ b/PromptTemplateManager.hpp @@ -31,6 +31,7 @@ class PromptTemplateManager { public: static PromptTemplateManager &instance(); + ~PromptTemplateManager(); template void registerTemplate() @@ -43,12 +44,11 @@ class PromptTemplateManager Settings::generalSettings().fimPrompts.addOption(name); } - void setCurrentTemplate(const QString &name); - Templates::PromptTemplate *getCurrentTemplate(); + void setCurrentFimTemplate(const QString &name); + Templates::PromptTemplate *getCurrentFimTemplate(); QStringList getTemplateNames() const; - ~PromptTemplateManager(); private: PromptTemplateManager() = default; @@ -56,7 +56,8 @@ class PromptTemplateManager PromptTemplateManager &operator=(const PromptTemplateManager &) = delete; QMap m_templates; - QString m_currentTemplateName; + Templates::PromptTemplate *m_currentFimPrompt; + Templates::PromptTemplate *m_currentChatPrompt; }; } // namespace QodeAssist diff --git a/chat/ChatClientInterface.cpp b/chat/ChatClientInterface.cpp index 8a7ecae..5ec8b9a 100644 --- a/chat/ChatClientInterface.cpp +++ b/chat/ChatClientInterface.cpp @@ -64,11 +64,6 @@ void ChatClientInterface::sendMessage(const QString &message) QJsonObject providerRequest; prepareRequest(providerRequest, message); - if (LLMProvidersManager::instance().getCurrentProvider() == nullptr) - qDebug() << "Error: No provider selected"; - if (PromptTemplateManager::instance().getCurrentTemplate() == nullptr) - qDebug() << "Error: No prompt template selected"; - LLMConfig config; config.provider = LLMProvidersManager::instance().getCurrentChatProvider(); config.promptTemplate = PromptTemplateManager::instance().getCurrentTemplate(); diff --git a/chat/ChatClientInterface.hpp b/chat/ChatClientInterface.hpp index c44600b..8341fd1 100644 --- a/chat/ChatClientInterface.hpp +++ b/chat/ChatClientInterface.hpp @@ -46,14 +46,9 @@ class ChatClientInterface : public QObject LLMRequestHandler *m_requestHandler; QString m_accumulatedResponse; - // Захардкоженные значения (временно) const QString MODEL_NAME = "bartowski/Llama-3.1-SauerkrautLM-8b-Instruct-GGUF"; const QString SERVER_URL = "http://localhost:1234"; const QString ENDPOINT = "/v1/chat/completions"; }; } // namespace QodeAssist::Chat -// Захардкоженные значения -// const QString MODEL_NAME = "codellama:7b-instruct"; -// const QString SERVER_URL = "http://localhost:11434"; -// const QString ENDPOINT = "/api/chat"; diff --git a/chat/ChatOutputPane.h b/chat/ChatOutputPane.h index 7c340cb..12e8267 100644 --- a/chat/ChatOutputPane.h +++ b/chat/ChatOutputPane.h @@ -32,7 +32,6 @@ class ChatOutputPane : public Core::IOutputPane explicit ChatOutputPane(QObject *parent = nullptr); ~ChatOutputPane() override; - // Переопределенные методы из IOutputPane QWidget *outputWidget(QWidget *parent) override; QList toolBarWidgets() const override; void clearContents() override; diff --git a/core/LLMRequestHandler.cpp b/core/LLMRequestHandler.cpp index eb5d459..e48ce4b 100644 --- a/core/LLMRequestHandler.cpp +++ b/core/LLMRequestHandler.cpp @@ -77,12 +77,12 @@ void LLMRequestHandler::handleLLMResponse(QNetworkReply *reply, QString &accumulatedResponse = m_accumulatedResponses[reply]; - auto provider = LLMProvidersManager::instance().getCurrentProvider(); + auto provider = LLMProvidersManager::instance().getCurrentFimProvider(); if (provider == nullptr) qDebug() << "No provider selected"; bool isComplete = LLMProvidersManager::instance() - .getCurrentProvider() + .getCurrentFimProvider() ->handleResponse(reply, accumulatedResponse); if (!Settings::generalSettings().multiLineCompletion() diff --git a/qodeassist.cpp b/qodeassist.cpp index 1853984..472dc72 100644 --- a/qodeassist.cpp +++ b/qodeassist.cpp @@ -79,11 +79,11 @@ class QodeAssistPlugin final : public ExtensionSystem::IPlugin providerManager.registerProvider(); providerManager.registerProvider(); providerManager.registerProvider(); - providerManager.setCurrentProvider("Ollama"); + providerManager.setCurrentFimProvider("Ollama"); auto &templateManager = PromptTemplateManager::instance(); - templateManager.registerTemplate(); templateManager.registerTemplate(); + templateManager.registerTemplate(); templateManager.registerTemplate(); templateManager.registerTemplate(); templateManager.setCurrentTemplate("StarCoder2"); diff --git a/settings/GeneralSettings.cpp b/settings/GeneralSettings.cpp index 84925f8..b549421 100644 --- a/settings/GeneralSettings.cpp +++ b/settings/GeneralSettings.cpp @@ -84,17 +84,19 @@ GeneralSettings::GeneralSettings() autoCompletionTypingInterval.setDefaultValue(2000); llmProviders.setSettingsKey(Constants::LLM_PROVIDERS); - llmProviders.setDisplayName(Tr::tr("FIM Provider:")); + llmProviders.setDisplayName(Tr::tr("AI Suggest Provider:")); llmProviders.setDisplayStyle(Utils::SelectionAspect::DisplayStyle::ComboBox); llmProviders.setDefaultValue(0); url.setSettingsKey(Constants::URL); url.setLabelText(Tr::tr("URL:")); url.setDisplayStyle(Utils::StringAspect::LineEditDisplay); + url.setDefaultValue("http://localhost:11434"); endPoint.setSettingsKey(Constants::END_POINT); endPoint.setLabelText(Tr::tr("FIM Endpoint:")); endPoint.setDisplayStyle(Utils::StringAspect::LineEditDisplay); + endPoint.setDefaultValue("/api/generate"); modelName.setSettingsKey(Constants::MODEL_NAME); modelName.setLabelText(Tr::tr("Model name:")); @@ -108,17 +110,19 @@ GeneralSettings::GeneralSettings() resetToDefaults.m_buttonText = Tr::tr("Reset Page to Defaults"); chatLlmProviders.setSettingsKey(Constants::CHAT_LLM_PROVIDERS); - chatLlmProviders.setDisplayName(Tr::tr("Chat Provider:")); + chatLlmProviders.setDisplayName(Tr::tr("AI Chat Provider:")); chatLlmProviders.setDisplayStyle(Utils::SelectionAspect::DisplayStyle::ComboBox); chatLlmProviders.setDefaultValue(0); chatUrl.setSettingsKey(Constants::CHAT_URL); chatUrl.setLabelText(Tr::tr("URL:")); chatUrl.setDisplayStyle(Utils::StringAspect::LineEditDisplay); + chatUrl.setDefaultValue("http://localhost:11434"); chatEndPoint.setSettingsKey(Constants::CHAT_END_POINT); chatEndPoint.setLabelText(Tr::tr("Chat Endpoint:")); chatEndPoint.setDisplayStyle(Utils::StringAspect::LineEditDisplay); + chatEndPoint.setDefaultValue("/api/chat"); chatModelName.setSettingsKey(Constants::CHAT_MODEL_NAME); chatModelName.setLabelText(Tr::tr("Model name:")); @@ -172,16 +176,17 @@ void GeneralSettings::setupConnections() { connect(&llmProviders, &Utils::SelectionAspect::volatileValueChanged, this, [this]() { int index = llmProviders.volatileValue(); - logMessage(QString("currentProvider %1").arg(llmProviders.displayForIndex(index))); - LLMProvidersManager::instance().setCurrentProvider(llmProviders.displayForIndex(index)); - updateProviderSettings(); + LLMProvidersManager::instance().setCurrentFimProvider(llmProviders.displayForIndex(index)); + modelName.setVolatileValue(""); }); - connect(&fimPrompts, &Utils::SelectionAspect::volatileValueChanged, this, [this]() { - int index = fimPrompts.volatileValue(); - logMessage(QString("currentPrompt %1").arg(fimPrompts.displayForIndex(index))); - PromptTemplateManager::instance().setCurrentTemplate(fimPrompts.displayForIndex(index)); + // connect(&fimPrompts, &Utils::SelectionAspect::volatileValueChanged, this, [this]() { + // int index = fimPrompts.volatileValue(); + // logMessage(QString("currentPrompt %1").arg(fimPrompts.displayForIndex(index))); + // PromptTemplateManager::instance().setCurrentTemplate(fimPrompts.displayForIndex(index)); + // }); + connect(&selectModels, &ButtonAspect::clicked, this, [this]() { + showModelSelectionDialog(&modelName); }); - connect(&selectModels, &ButtonAspect::clicked, this, [this]() { showModelSelectionDialog(); }); connect(&enableLogging, &Utils::BoolAspect::volatileValueChanged, this, [this]() { setLoggingEnabled(enableLogging.volatileValue()); }); @@ -195,22 +200,21 @@ void GeneralSettings::setupConnections() &Utils::StringAspect::volatileValueChanged, this, &GeneralSettings::updateStatusIndicators); -} - -void GeneralSettings::updateProviderSettings() -{ - const auto provider = LLMProvidersManager::instance().getCurrentProvider(); - if (provider) { - url.setVolatileValue(provider->url()); - endPoint.setVolatileValue(provider->completionEndpoint()); - modelName.setVolatileValue(""); - } + connect(&chatLlmProviders, &Utils::SelectionAspect::volatileValueChanged, this, [this]() { + int index = chatLlmProviders.volatileValue(); + LLMProvidersManager::instance().setCurrentChatProvider( + chatLlmProviders.displayForIndex(index)); + chatModelName.setVolatileValue(""); + }); + connect(&chatSelectModels, &ButtonAspect::clicked, this, [this]() { + showModelSelectionDialog(&chatModelName); + }); } -void GeneralSettings::showModelSelectionDialog() +void GeneralSettings::showModelSelectionDialog(Utils::StringAspect *modelNameObj) { - auto *provider = LLMProvidersManager::instance().getCurrentProvider(); + auto *provider = LLMProvidersManager::instance().getCurrentFimProvider(); Utils::Environment env = Utils::Environment::systemEnvironment(); if (provider) { @@ -225,7 +229,7 @@ void GeneralSettings::showModelSelectionDialog() &ok); if (ok && !selectedModel.isEmpty()) { - modelName.setVolatileValue(selectedModel); + modelNameObj->setVolatileValue(selectedModel); writeSettings(); } } @@ -255,6 +259,10 @@ void GeneralSettings::resetPageToDefaults() resetAspect(llmProviders); resetAspect(fimPrompts); resetAspect(chatLlmProviders); + resetAspect(url); + resetAspect(endPoint); + resetAspect(chatUrl); + resetAspect(chatEndPoint); } updateStatusIndicators(); diff --git a/settings/GeneralSettings.hpp b/settings/GeneralSettings.hpp index c4b2a3b..9db53be 100644 --- a/settings/GeneralSettings.hpp +++ b/settings/GeneralSettings.hpp @@ -60,8 +60,7 @@ class GeneralSettings : public Utils::AspectContainer private: void setupConnections(); - void updateProviderSettings(); - void showModelSelectionDialog(); + void showModelSelectionDialog(Utils::StringAspect *modelNameObj); void resetPageToDefaults(); void updateStatusIndicators();