Skip to content

Commit

Permalink
feat: imagequery
Browse files Browse the repository at this point in the history
  • Loading branch information
Angular2Guy committed May 8, 2024
1 parent 3d4146d commit 692b8ba
Show file tree
Hide file tree
Showing 7 changed files with 1,581 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ private ImageQueryDto resizeImage(ImageQueryDto imageDto) {
int targetWidth = image.getWidth();
if (image.getHeight() > 672 && image.getWidth() > 672) {
if (image.getHeight() < image.getWidth()) {
targetHeight = image.getHeight() / (image.getHeight() / 672);
targetWidth = image.getWidth() / (image.getHeight() / 672);
targetHeight = Math.round( image.getHeight() / (image.getHeight() / 672.0f));
targetWidth = Math.round(image.getWidth() / (image.getHeight() / 672.0f));
} else {
targetHeight = image.getHeight() / (image.getWidth() / 672);
targetWidth = image.getWidth() / (image.getWidth() / 672);
targetHeight = Math.round(image.getHeight() / (image.getWidth() / 672.0f));
targetWidth = Math.round(image.getWidth() / (image.getWidth() / 672.0f));
}
}
var outputImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
Expand All @@ -76,6 +76,7 @@ private ImageQueryDto resizeImage(ImageQueryDto imageDto) {
ImageIO.write(outputImage, imageDto.getImageType().toString(), ios);
imageDto.setImageContent(ios.toByteArray());
imageDto.setContentSize(ios.toByteArray().length);
LOG.info("Resized image to x: {}, y: {}", targetWidth, targetHeight);
} catch (IOException e) {
LOG.info("Image resize failed.", e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package org.springframework.ai.ollama;

import java.util.Base64;
import java.util.List;

import org.springframework.ai.ollama.metadata.OllamaChatResponseMetadata;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

public class OllamaChatClient implements ChatClient, StreamingChatClient {

/**
* Low-level Ollama API library.
*/
private final OllamaApi chatApi;

/**
* Default options to be used for all chat requests.
*/
private OllamaOptions defaultOptions;

public OllamaChatClient(OllamaApi chatApi) {
this(chatApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
}

public OllamaChatClient(OllamaApi chatApi, OllamaOptions defaultOptions) {
Assert.notNull(chatApi, "OllamaApi must not be null");
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
this.chatApi = chatApi;
this.defaultOptions = defaultOptions;
}

/**
* @deprecated Use {@link OllamaOptions#setModel} instead.
*/
@Deprecated
public OllamaChatClient withModel(String model) {
this.defaultOptions.setModel(model);
return this;
}

/**
* @deprecated Use {@link OllamaOptions} constructor instead.
*/
public OllamaChatClient withDefaultOptions(OllamaOptions options) {
this.defaultOptions = options;
return this;
}

@Override
public ChatResponse call(Prompt prompt) {

OllamaApi.ChatResponse response = this.chatApi.chat(ollamaChatRequest(prompt, false));

var generator = new Generation(response.message().content());
if (response.promptEvalCount() != null && response.evalCount() != null) {
generator = generator.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
}
return new ChatResponse(List.of(generator), OllamaChatResponseMetadata.from(response));
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

Flux<OllamaApi.ChatResponse> response = this.chatApi.streamingChat(ollamaChatRequest(prompt, true));

return response.map(chunk -> {
Generation generation = (chunk.message() != null) ? new Generation(chunk.message().content())
: new Generation("");
if (Boolean.TRUE.equals(chunk.done())) {
generation = generation.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
}
return new ChatResponse(List.of(generation), OllamaChatResponseMetadata.from(chunk));
});
}

/**
* Package access for testing.
*/
OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {

List<OllamaApi.Message> ollamaMessages = prompt.getInstructions()
.stream()
.filter(message -> message.getMessageType() == MessageType.USER
|| message.getMessageType() == MessageType.ASSISTANT
|| message.getMessageType() == MessageType.SYSTEM)
.map(m -> {
var messageBuilder = OllamaApi.Message.builder(toRole(m)).withContent(m.getContent());

if (!CollectionUtils.isEmpty(m.getMedia())) {
messageBuilder
.withImages(m.getMedia().stream().map(media -> this.fromMediaData(media.getData())).toList());
}
return messageBuilder.build();
})
.toList();

// runtime options
OllamaOptions runtimeOptions = null;
if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeChatOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(runtimeChatOptions, ChatOptions.class,
OllamaOptions.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
}

OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

// Override the model.
if (!StringUtils.hasText(mergedOptions.getModel())) {
throw new IllegalArgumentException("Model is not set!");
}

String model = mergedOptions.getModel();
OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(model)
.withStream(stream)
.withMessages(ollamaMessages)
.withOptions(mergedOptions);

if (mergedOptions.getFormat() != null) {
requestBuilder.withFormat(mergedOptions.getFormat());
}

if (mergedOptions.getKeepAlive() != null) {
requestBuilder.withKeepAlive(mergedOptions.getKeepAlive());
}

return requestBuilder.build();
}

private String fromMediaData(Object mediaData) {
if (mediaData instanceof byte[] bytes) {
return Base64.getEncoder().encodeToString(bytes);
}
else if (mediaData instanceof String text) {
return text;
}
else {
throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName());
}

}

private OllamaApi.Message.Role toRole(Message message) {

switch (message.getMessageType()) {
case USER:
return Role.USER;
case ASSISTANT:
return Role.ASSISTANT;
case SYSTEM:
return Role.SYSTEM;
default:
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
}

}
Loading

0 comments on commit 692b8ba

Please sign in to comment.