Skip to content

Commit

Permalink
feat: add image data import
Browse files Browse the repository at this point in the history
  • Loading branch information
Angular2Guy committed May 12, 2024
1 parent ec392a6 commit 7a3c580
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,10 @@ public ImageDto postImageQuery(@RequestParam("query") String query,@RequestParam
var result = this.imageService.queryImage(this.imageMapper.map(imageQuery, query));
return result;
}

@PostMapping("/import")
public ImageDto postImportImage(@RequestParam("query") String query,@RequestParam("type") String type, @RequestParam("file") MultipartFile imageQuery) {
var result = this.imageService.importImage(this.imageMapper.map(imageQuery, query), this.imageMapper.map(imageQuery));
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package ch.xxx.aidoclibchat.usecase.mapping;

import java.io.IOException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.MediaType;
Expand All @@ -20,11 +22,24 @@

import ch.xxx.aidoclibchat.domain.common.MetaData.ImageType;
import ch.xxx.aidoclibchat.domain.model.dto.ImageQueryDto;
import ch.xxx.aidoclibchat.domain.model.entity.Image;

@Component
public class ImageMapper {
private static final Logger LOG = LoggerFactory.getLogger(ImageMapper.class);

public Image map(MultipartFile multipartFile) {
var image = new Image();
try {
image.setImageContent(multipartFile.getBytes());
image.setImageName(multipartFile.getName());
image.setImageType(this.toImageType(multipartFile.getContentType()));
} catch (IOException e) {
LOG.info("Mapping failed.", e);
}
return image;
}

public ImageQueryDto map(MultipartFile multipartFile, String query) {
var imageDto = new ImageQueryDto();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ch.xxx.aidoclibchat.usecase.service;

import java.awt.Image;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -28,31 +27,65 @@
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
import org.springframework.util.MimeType;

import ch.xxx.aidoclibchat.domain.common.MetaData;
import ch.xxx.aidoclibchat.domain.common.MetaData.ImageType;
import ch.xxx.aidoclibchat.domain.model.dto.ImageDto;
import ch.xxx.aidoclibchat.domain.model.dto.ImageQueryDto;
import ch.xxx.aidoclibchat.domain.model.entity.DocumentVsRepository;
import ch.xxx.aidoclibchat.domain.model.entity.Image;
import ch.xxx.aidoclibchat.domain.model.entity.ImageRepository;
import jakarta.transaction.Transactional;

@Service
@Transactional
public class ImageService {
private static final Logger LOG = LoggerFactory.getLogger(ImageService.class);
private ChatClient chatClient;
private ImageRepository imageRepository;
private DocumentVsRepository documentVsRepository;

public ImageService(ChatClient chatClient) {
private record ResultData(String answer, ImageQueryDto imageQueryDto) {
}

public ImageService(ChatClient chatClient, ImageRepository imageRepository, DocumentVsRepository documentVsRepository) {
this.chatClient = chatClient;
this.imageRepository = imageRepository;
this.documentVsRepository = documentVsRepository;
}

public ImageDto importImage(ImageQueryDto imageDto, Image image) {
var resultData = createAIResult(imageDto);
image.setImageContent(resultData.imageQueryDto().getImageContent());
var myImage = this.imageRepository.save(image);
var aiDocument = new Document(resultData.answer());
aiDocument.getMetadata().put(MetaData.ID, myImage.getId().toString());
aiDocument.getMetadata().put(MetaData.DATATYPE, MetaData.DataType.IMAGE.toString());
this.documentVsRepository.add(List.of(aiDocument));
return new ImageDto(resultData.answer(),
Base64.getEncoder().encodeToString(resultData.imageQueryDto().getImageContent()),
resultData.imageQueryDto().getImageType());
}

public ImageDto queryImage(ImageQueryDto imageDto) {
if(ImageType.JPEG.equals(imageDto.getImageType()) || ImageType.PNG.equals(imageDto.getImageType())) {
var resultData = createAIResult(imageDto);
return new ImageDto(resultData.answer(),
Base64.getEncoder().encodeToString(resultData.imageQueryDto().getImageContent()),
resultData.imageQueryDto().getImageType());
}

private ResultData createAIResult(ImageQueryDto imageDto) {
if (ImageType.JPEG.equals(imageDto.getImageType()) || ImageType.PNG.equals(imageDto.getImageType())) {
imageDto = this.resizeImage(imageDto);
}
var prompt = new Prompt(new UserMessage(imageDto.getQuery(), List
.of(new Media(MimeType.valueOf(imageDto.getImageType().getMediaType()), imageDto.getImageContent()))));
var response = this.chatClient.call(prompt);
var answer = response.getResult().getOutput().getContent();
return new ImageDto(answer, Base64.getEncoder().encodeToString(imageDto.getImageContent()), imageDto.getImageType());
var resultData = new ResultData(response.getResult().getOutput().getContent(), imageDto);
return resultData;
}

private ImageQueryDto resizeImage(ImageQueryDto imageDto) {
Expand All @@ -62,16 +95,16 @@ private ImageQueryDto resizeImage(ImageQueryDto imageDto) {
int targetWidth = image.getWidth();
if (image.getHeight() > 672 && image.getWidth() > 672) {
if (image.getHeight() < image.getWidth()) {
targetHeight = Math.round( image.getHeight() / (image.getHeight() / 672.0f));
targetWidth = Math.round(image.getWidth() / (image.getHeight() / 672.0f));
targetHeight = Math.round(image.getHeight() / (image.getHeight() / 672.0f));
targetWidth = Math.round(image.getWidth() / (image.getHeight() / 672.0f));
} else {
targetHeight = Math.round(image.getHeight() / (image.getWidth() / 672.0f));
targetWidth = Math.round(image.getWidth() / (image.getWidth() / 672.0f));
}
}
var outputImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
outputImage.getGraphics().drawImage(image.getScaledInstance(targetWidth, targetHeight, Image.SCALE_SMOOTH),
0, 0, null);
outputImage.getGraphics().drawImage(
image.getScaledInstance(targetWidth, targetHeight, java.awt.Image.SCALE_SMOOTH), 0, 0, null);
var ios = new ByteArrayOutputStream();
ImageIO.write(outputImage, imageDto.getImageType().toString(), ios);
imageDto.setImageContent(ios.toByteArray());
Expand Down

0 comments on commit 7a3c580

Please sign in to comment.