Skip to content

Commit

Permalink
♻️ refactor: Improve Ollama response handler
Browse files Browse the repository at this point in the history
  • Loading branch information
Palm1r committed Dec 10, 2024
1 parent 8102ba9 commit b692402
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 40 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ add_qtc_plugin(QodeAssist
providers/OllamaProvider.hpp providers/OllamaProvider.cpp
providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp
providers/OpenAICompatProvider.hpp providers/OpenAICompatProvider.cpp
providers/OllamaMessage.hpp providers/OllamaMessage.cpp
QodeAssist.qrc
LSPCompletion.hpp
LLMSuggestion.hpp LLMSuggestion.cpp
Expand Down
76 changes: 76 additions & 0 deletions providers/OllamaMessage.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/

#include "OllamaMessage.hpp"

namespace QodeAssist::Providers {

OllamaMessage OllamaMessage::fromJson(const QJsonObject &obj, Type type)
{
OllamaMessage msg;
msg.model = obj["model"].toString();
msg.createdAt = QDateTime::fromString(obj["created_at"].toString(), Qt::ISODate);
msg.done = obj["done"].toBool();
msg.doneReason = obj["done_reason"].toString();
msg.error = obj["error"].toString();

if (type == Type::Generate) {
auto &genResponse = msg.response.emplace<GenerateResponse>();
genResponse.response = obj["response"].toString();
if (msg.done && obj.contains("context")) {
const auto array = obj["context"].toArray();
genResponse.context.reserve(array.size());
for (const auto &val : array) {
genResponse.context.append(val.toInt());
}
}
} else {
auto &chatResponse = msg.response.emplace<ChatResponse>();
const auto msgObj = obj["message"].toObject();
chatResponse.role = msgObj["role"].toString();
chatResponse.content = msgObj["content"].toString();
}

if (msg.done) {
msg.metrics
= {obj["total_duration"].toVariant().toLongLong(),
obj["load_duration"].toVariant().toLongLong(),
obj["prompt_eval_count"].toVariant().toLongLong(),
obj["prompt_eval_duration"].toVariant().toLongLong(),
obj["eval_count"].toVariant().toLongLong(),
obj["eval_duration"].toVariant().toLongLong()};
}

return msg;
}

QString OllamaMessage::getContent() const
{
if (std::holds_alternative<GenerateResponse>(response)) {
return std::get<GenerateResponse>(response).response;
}
return std::get<ChatResponse>(response).content;
}

bool OllamaMessage::hasError() const
{
return !error.isEmpty();
}

} // namespace QodeAssist::Providers
71 changes: 71 additions & 0 deletions providers/OllamaMessage.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/

#pragma once

#include <QDateTime>
#include <QJsonArray>
#include <QJsonObject>
#include <QObject>

namespace QodeAssist::Providers {

class OllamaMessage
{
public:
enum class Type { Generate, Chat };

struct Metrics
{
qint64 totalDuration{0};
qint64 loadDuration{0};
qint64 promptEvalCount{0};
qint64 promptEvalDuration{0};
qint64 evalCount{0};
qint64 evalDuration{0};
};

struct GenerateResponse
{
QString response;
QVector<int> context;
};

struct ChatResponse
{
QString role;
QString content;
};

QString model;
QDateTime createdAt;
std::variant<GenerateResponse, ChatResponse> response;
bool done{false};
QString doneReason;
Metrics metrics;
QString error;

static OllamaMessage fromJson(const QJsonObject &obj, Type type);

QString getContent() const;

bool hasError() const;
};

} // namespace QodeAssist::Providers
69 changes: 29 additions & 40 deletions providers/OllamaProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <QNetworkReply>
#include <QtCore/qeventloop.h>

#include "OllamaMessage.hpp"
#include "logger/Logger.hpp"
#include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp"
Expand Down Expand Up @@ -87,53 +88,41 @@ void OllamaProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType t

bool OllamaProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse)
{
QString endpoint = reply->url().path();

bool isComplete = false;
while (reply->canReadLine()) {
QByteArray line = reply->readLine().trimmed();
if (line.isEmpty()) {
continue;
}
const QString endpoint = reply->url().path();
auto messageType = endpoint == completionEndpoint() ? OllamaMessage::Type::Generate
: OllamaMessage::Type::Chat;

auto processMessage =
[&accumulatedResponse](const QJsonDocument &doc, OllamaMessage::Type messageType) {
if (doc.isNull()) {
LOG_MESSAGE("Invalid JSON response from Ollama");
return false;
}

QJsonDocument doc = QJsonDocument::fromJson(line);
if (doc.isNull()) {
LOG_MESSAGE("Invalid JSON response from Ollama: " + QString::fromUtf8(line));
continue;
}
auto message = OllamaMessage::fromJson(doc.object(), messageType);
if (message.hasError()) {
LOG_MESSAGE("Error in Ollama response: " + message.error);
return false;
}

QJsonObject responseObj = doc.object();
accumulatedResponse += message.getContent();
return message.done;
};

if (responseObj.contains("error")) {
QString errorMessage = responseObj["error"].toString();
LOG_MESSAGE("Error in Ollama response: " + errorMessage);
return false;
}
if (reply->canReadLine()) {
while (reply->canReadLine()) {
QByteArray line = reply->readLine().trimmed();
if (line.isEmpty())
continue;

if (endpoint == completionEndpoint()) {
if (responseObj.contains("response")) {
QString completion = responseObj["response"].toString();
accumulatedResponse += completion;
if (processMessage(QJsonDocument::fromJson(line), messageType)) {
return true;
}
} else if (endpoint == chatEndpoint()) {
if (responseObj.contains("message")) {
QJsonObject message = responseObj["message"].toObject();
if (message.contains("content")) {
QString content = message["content"].toString();
accumulatedResponse += content;
}
}
} else {
LOG_MESSAGE("Unknown endpoint: " + endpoint);
}

if (responseObj.contains("done") && responseObj["done"].toBool()) {
isComplete = true;
break;
}
return false;
} else {
return processMessage(QJsonDocument::fromJson(reply->readAll()), messageType);
}

return isComplete;
}

QList<QString> OllamaProvider::getInstalledModels(const QString &url)
Expand Down

0 comments on commit b692402

Please sign in to comment.