Skip to content

Commit

Permalink
✨ feat: Add request validator
Browse files Browse the repository at this point in the history
  • Loading branch information
Palm1r committed Dec 15, 2024
1 parent 10e8b16 commit 7376a11
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 1 deletion.
7 changes: 7 additions & 0 deletions ChatView/ClientInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ void ClientInterface::sendMessage(const QString &message, bool includeCurrentFil
QJsonObject request;
request["id"] = QUuid::createUuid().toString();

auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type());
if (!errors.isEmpty()) {
LOG_MESSAGE("Validate errors for chat request:");
LOG_MESSAGES(errors);
return;
}

m_requestHandler->sendLLMRequest(config, request);
}

Expand Down
6 changes: 6 additions & 0 deletions LLMClientInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request)
config.promptTemplate->prepareRequest(config.providerRequest, updatedContext);
config.provider->prepareRequest(config.providerRequest, LLMCore::RequestType::Fim);

auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type());
if (!errors.isEmpty()) {
LOG_MESSAGE("Validate errors for fim request:");
LOG_MESSAGES(errors);
return;
}
m_requestHandler.sendLLMRequest(config, request);
}

Expand Down
1 change: 1 addition & 0 deletions llmcore/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_library(LLMCore STATIC
RequestHandler.hpp RequestHandler.cpp
OllamaMessage.hpp OllamaMessage.cpp
OpenAIMessage.hpp OpenAIMessage.cpp
ValidationUtils.hpp ValidationUtils.cpp
)

target_link_libraries(LLMCore
Expand Down
5 changes: 4 additions & 1 deletion llmcore/Provider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
#pragma once

#include <QString>
#include "RequestType.hpp"
#include <utils/environment.h>

#include "PromptTemplate.hpp"
#include "RequestType.hpp"

class QNetworkReply;
class QJsonObject;

Expand All @@ -42,6 +44,7 @@ class Provider
virtual void prepareRequest(QJsonObject &request, RequestType type) = 0;
virtual bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) = 0;
virtual QList<QString> getInstalledModels(const QString &url) = 0;
virtual QList<QString> validateRequest(const QJsonObject &request, TemplateType type) = 0;
};

} // namespace QodeAssist::LLMCore
57 changes: 57 additions & 0 deletions llmcore/ValidationUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/

#include "ValidationUtils.hpp"

#include <QJsonArray>

namespace QodeAssist::LLMCore {

QStringList ValidationUtils::validateRequestFields(
const QJsonObject &request, const QJsonObject &templateObj)
{
QStringList errors;
validateFields(request, templateObj, errors);
validateNestedObjects(request, templateObj, errors);
return errors;
}

void ValidationUtils::validateFields(
const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors)
{
for (auto it = request.begin(); it != request.end(); ++it) {
if (!templateObj.contains(it.key())) {
errors << QString("unknown field '%1'").arg(it.key());
}
}
}

void ValidationUtils::validateNestedObjects(
const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors)
{
for (auto it = request.begin(); it != request.end(); ++it) {
if (templateObj.contains(it.key()) && it.value().isObject()
&& templateObj[it.key()].isObject()) {
validateFields(it.value().toObject(), templateObj[it.key()].toObject(), errors);
validateNestedObjects(it.value().toObject(), templateObj[it.key()].toObject(), errors);
}
}
}

} // namespace QodeAssist::LLMCore
41 changes: 41 additions & 0 deletions llmcore/ValidationUtils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/

#pragma once

#include <QJsonObject>
#include <QStringList>

namespace QodeAssist::LLMCore {

class ValidationUtils
{
public:
static QStringList validateRequestFields(
const QJsonObject &request, const QJsonObject &templateObj);

private:
static void validateFields(
const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors);

static void validateNestedObjects(
const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors);
};

} // namespace QodeAssist::LLMCore
19 changes: 19 additions & 0 deletions providers/LMStudioProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <QNetworkReply>

#include "llmcore/OpenAIMessage.hpp"
#include "llmcore/ValidationUtils.hpp"
#include "logger/Logger.hpp"
#include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp"
Expand Down Expand Up @@ -169,4 +170,22 @@ QList<QString> LMStudioProvider::getInstalledModels(const QString &url)
return models;
}

QList<QString> LMStudioProvider::validateRequest(
const QJsonObject &request, LLMCore::TemplateType type)
{
const auto templateReq = QJsonObject{
{"model", {}},
{"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}},
{"temperature", {}},
{"max_tokens", {}},
{"top_p", {}},
{"top_k", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}},
{"stop", QJsonArray{}},
{"stream", {}}};

return LLMCore::ValidationUtils::validateRequestFields(request, templateReq);
}

} // namespace QodeAssist::Providers
1 change: 1 addition & 0 deletions providers/LMStudioProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LMStudioProvider : public LLMCore::Provider
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
};

} // namespace QodeAssist::Providers
37 changes: 37 additions & 0 deletions providers/OllamaProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <QtCore/qeventloop.h>

#include "llmcore/OllamaMessage.hpp"
#include "llmcore/ValidationUtils.hpp"
#include "logger/Logger.hpp"
#include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp"
Expand Down Expand Up @@ -137,4 +138,40 @@ QList<QString> OllamaProvider::getInstalledModels(const QString &url)
return models;
}

QList<QString> OllamaProvider::validateRequest(const QJsonObject &request, LLMCore::TemplateType type)
{
const auto fimReq = QJsonObject{
{"keep_alive", {}},
{"model", {}},
{"stream", {}},
{"prompt", {}},
{"suffix", {}},
{"system", {}},
{"options",
QJsonObject{
{"temperature", {}},
{"top_p", {}},
{"top_k", {}},
{"num_predict", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}}}}};

const auto messageReq = QJsonObject{
{"keep_alive", {}},
{"model", {}},
{"stream", {}},
{"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}},
{"options",
QJsonObject{
{"temperature", {}},
{"top_p", {}},
{"top_k", {}},
{"num_predict", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}}}}};

return LLMCore::ValidationUtils::validateRequestFields(
request, type == LLMCore::TemplateType::Fim ? fimReq : messageReq);
};

} // namespace QodeAssist::Providers
1 change: 1 addition & 0 deletions providers/OllamaProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class OllamaProvider : public LLMCore::Provider
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
};

} // namespace QodeAssist::Providers
19 changes: 19 additions & 0 deletions providers/OpenAICompatProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <QNetworkReply>

#include "llmcore/OpenAIMessage.hpp"
#include "llmcore/ValidationUtils.hpp"
#include "logger/Logger.hpp"

namespace QodeAssist::Providers {
Expand Down Expand Up @@ -142,4 +143,22 @@ QList<QString> OpenAICompatProvider::getInstalledModels(const QString &url)
return QStringList();
}

QList<QString> OpenAICompatProvider::validateRequest(
const QJsonObject &request, LLMCore::TemplateType type)
{
const auto templateReq = QJsonObject{
{"model", {}},
{"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}},
{"temperature", {}},
{"max_tokens", {}},
{"top_p", {}},
{"top_k", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}},
{"stop", QJsonArray{}},
{"stream", {}}};

return LLMCore::ValidationUtils::validateRequestFields(request, templateReq);
}

} // namespace QodeAssist::Providers
1 change: 1 addition & 0 deletions providers/OpenAICompatProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class OpenAICompatProvider : public LLMCore::Provider
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
};

} // namespace QodeAssist::Providers

0 comments on commit 7376a11

Please sign in to comment.