diff --git a/.gitignore b/.gitignore index 5e2c6ab5..6204d5d0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,4 @@ bin/ .gradle/ .idea/ out/ - +.vscode/ diff --git a/build.gradle b/build.gradle index b4d87717..af5b0c2c 100644 --- a/build.gradle +++ b/build.gradle @@ -59,6 +59,7 @@ allprojects { mavenBom "io.awspring.cloud:spring-cloud-aws-dependencies:$springCloudAwsVersion" mavenBom "org.springframework.boot:spring-boot-dependencies:$springBootVersion" mavenBom "org.springframework.cloud:spring-cloud-dependencies:$springCloudVersion" + mavenBom "ai.djl:bom:$djlVersion" } } } diff --git a/dependencies.gradle b/dependencies.gradle index e12b00e8..2ec92bae 100644 --- a/dependencies.gradle +++ b/dependencies.gradle @@ -4,6 +4,8 @@ ext { springCloudAwsVersion = '3.0.4' debeziumVersion = '2.5.0.Final' + djlVersion = '0.26.0' + djlSpringVersion = '0.26' springIntegrationAws = 'org.springframework.integration:spring-integration-aws:3.0.5' ftpserverCore = 'org.apache.ftpserver:ftpserver-core:1.2.0' diff --git a/function/spring-computer-vision-function/README.adoc b/function/spring-computer-vision-function/README.adoc new file mode 100644 index 00000000..c778a454 --- /dev/null +++ b/function/spring-computer-vision-function/README.adoc @@ -0,0 +1,216 @@ += Computer Vision Functions + +This module provides functional interface to perform common Computer Vision tasks such as Image Classification, Object Detection, Instance and Semantic Segmentation, Pose Estimation an more. + +It leverages the https://docs.djl.ai/index.html[Deep Java Library] (DJL) to enable Java developers to harness the power of deep learning. +DJL serves as a bridge between the rich ecosystem of Java programming and the cutting-edge capabilities of deep learning. +DJL provides integration with popular deep learning frameworks like `TensorFlow`, `PyTorch`, and `MXNet`, as well as support for a variety of pre-trained models using `ONNX Runtime`. + +== Beans for injection + +This module exposes auto-configurations for the following beans: + +* `Function, Message> objectDetection` - Offering `Object Detection` for finding all instances of objects from a known set of categories in an image and `Instance Segmentation` for finding all instances of objects from a known set of categories in an image and drawing a mask on each instance. +* `Function, Message> imageClassifications` - The `Image Classification` task assigns a label to an image from a set of categories. +* `Function, Message> semanticSegmentation` - `Semantic Segmentation` refers to the task of detecting objects of various classes at pixel level. +It colors the pixels based on the objects detected in that space. +* `Function, Message> poseEstimation` - `Pose Estimation` refers to the task of detecting human figures in images and videos, and estimating the pose of the bodies. + +Each of them are conditional by specific configuration properties. + +[%autowidth] +|=== +|Bean |Activation Properties + +|objectDetection +|djl.output-class=ai.djl.modality.cv.output.DetectedObjects + +|imageClassifications +|djl.output-class=ai.djl.modality.Classifications + +|semanticSegmentation +|djl.output-class=ai.djl.modality.cv.output.CategoryMask + +|poseEstimation +|djl.output-class=ai.djl.modality.cv.output.Joints + +|=== + +Once injected, you can use the `apply` method of the `Function` to invoke it and get the result. + +All functions take and return a `Message`. +The input message payload contains the image bytes to be processed. +The output message payload contains the original or the augmented image after the processing. +The `computer.vision.function.augment-enabled` property controls whether the augmented image is returned or not. +Defaults to `true`. + +== Configuration Options + +[%autowidth] +|=== +|Property |Description + +|djl.application-type +|Defines the CV application task to be performed. Currently supported values are `OBJECT_DETECTION`, `IMAGE_CLASSIFICATION`, `INSTANCE_SEGMENTATION`, `SEMANTIC_SEGMENTATION` and `POSE_ESTIMATION`. + +|djl.input-class +|Define input data type, a model may accept multiple input data type. Currently only the `ai.djl.modality.cv.Image` is supported. + +|djl.output-class +|Define output data type, a model may generate different outputs. Supported output classes are `ai.djl.modality.cv.output.DetectedObjects`, `ai.djl.modality.cv.output.CategoryMask`, `ai.djl.modality.Classifications`, `ai.djl.modality.cv.output.Joints` . + +|djl.urls +|Model repository URLs. Multiple may be supplied to search for models. Specifying a single URL can be used to load a specific model. Can be specified as comma delimited field or as an array in the configuration file. +Current supported archive formats: `zip`, `tar`, `tar.gz`, `tgz`, `tar.z`. + +Supported URL schemes: `file://` - load a model from local directory or archive file., `http(s)://` - load a model from an archive file from web server, `jar://` - load a model from an archive file in the class path, `djl://` - load a model from the model zoo, `s3://` - load a model from S3 bucket (requires djl aws extension), `hdfs://` - load a model from HDFS file system (requires djl hadoop extension) + +|djl.model-filter +| https://github.com/deepjavalibrary/djl/tree/master/model-zoo#how-to-find-a-pre-trained-model-in-the-model-zoo[Model Filters] used to lookup a model from model zoo . + +|djl.group-id +|Defines the `groupId` of the model to be loaded from the zoo. + +|djl.model-artifact-id +|Defines the `artifactId` of the model to be loaded from the zoo. + +|djl.model-name +|(Optional) Defines the modelName of the model to be loaded. +Leave it empty if you want to load the latest version of the model. +Use "saved_model" for TensorFlow saved models. + +|djl.engine +| Name of teh https://docs.djl.ai/docs/engine.html[Engine] to use https://docs.djl.ai/docs/engine.html#supported-engines[Supported engine names]. + +|djl.translator-factory +| https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html[Translator] provides model pre-processing and postprocessing functionality. Multiple https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/cv/translator/package-summary.html[translators] are provided for different models, but you can implement your own translator if needed (see []). The translator-factory property allow to specify the translator to be used with the model. + +|computer.vision.function.output-header-name +|Name of the header that contains the JSON payload computed by the functions. + +|computer.vision.function.augment-enabled +|Enable image augmentation (false by default). + +|=== + +=== Example Configurations + +All computer vision examples use the following Java code snippet to invoke the function: + +[source,Java] +---- +@SpringBootApplication +public class TfObjectDetectionBootApp implements CommandLineRunner { + + @Autowired + private Function, Message> cvFunction; + + @Override + public void run(String... args) throws Exception { + byte[] inputImage = new ClassPathResource("Image URI").getInputStream().readAllBytes(); + + Message outputMessage = cvFunction.apply( + MessageBuilder.withPayload(inputImage).build()); + + // Augmented output image. + byte[] outputImage = outputMessage.getPayload(); + + // JSON payload with the detected objects and their bounding boxes. + String jsonBoundingBoxes = outputMessage.getHeader("cvjson", String.class); + } + + public static void main(String[] args) { + SpringApplication.run(TfObjectDetectionBootApp.class); + } +} +---- + +==== Object Detection (TensorFlow) + +You can leverage any of the existing [TensorFlow models]. Just compy the url of the model archive as djl.urls property and set the djl.translator-factory to `org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory`. + +---- +computer.vision.function.augment-enabled=true +djl.application-type=OBJECT_DETECTION +djl.input-class=ai.djl.modality.cv.Image +djl.output-class=ai.djl.modality.cv.output.DetectedObjects +djl.engine=TensorFlow +djl.urls=http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz +djl.model-name=saved_model +djl.translator-factory=org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory +djl.arguments.threshold=0.3 +---- + +==== Object Detection (Yolo v8) + +You can use the same Java snipped above, just change the configuration to use the Yolo v8 model: + +---- +computer.vision.function.augment-enabled=true +djl.application-type=OBJECT_DETECTION +djl.input-class=ai.djl.modality.cv.Image +djl.output-class=ai.djl.modality.cv.output.DetectedObjects +djl.engine=OnnxRuntime +djl.urls=djl://ai.djl.onnxruntime/yolov8n +djl.translator-factory=ai.djl.modality.cv.translator.YoloV8TranslatorFactory +djl.arguments.threshold=0.3 +djl.arguments.width=640 +djl.arguments.height=640 +djl.arguments.resize=true +djl.arguments.toTensor=true +djl.arguments.applyRatio=true +djl.arguments.maxBox=1000 +---- + +==== Instance Segmentation + +Same Java code snipped but with the following configuration: + +---- +computer.vision.function.augment-enabled=true +djl.application-type=INSTANCE_SEGMENTATION +djl.input-class=ai.djl.modality.cv.Image +djl.output-class=ai.djl.modality.cv.output.DetectedObjects +djl.arguments.threshold=0.3 + +djl.model-filter.backbone=resnet18 +djl.model-filter.flavor=v1b +djl.model-filter.dataset=coco +---- + +Note that here we didn't specify the model to be used, but used the model-filter to find a compatible model from the model zoo. + +==== Semantic Segmentation + +Same Java code snipped but with the following configuration: + +---- +computer.vision.function.augment-enabled=true +djl.application-type=SEMANTIC_SEGMENTATION +djl.input-class=ai.djl.modality.cv.Image +djl.output-class=ai.djl.modality.cv.output.CategoryMask +djl.arguments.threshold=0.3 + +djl.urls=https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip +djl.translator-factory=ai.djl.modality.cv.translator.SemanticSegmentationTranslatorFactory +djl.engine=PyTorch +---- + +==== Image Classification + +---- +djl.application-type=IMAGE_CLASSIFICATION +djl.input-class=ai.djl.modality.cv.Image +djl.output-class=ai.djl.modality.Classifications +djl.arguments.threshold=0.3 +djl.engine=MXNet +---- + +== Tests + +See this link:src/test/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfigurationTests.java[test suite] for examples of how this function is used. + +The link:src/test/java/org/springframework/cloud/fn/computer/vision/JsonHelperTests.java[JsonHelperTests] validates the JSON serialization and deserialization of the `ComputerVisionFunctionConfiguration` class values object classes. + +== Other usage + diff --git a/function/spring-computer-vision-function/build.gradle b/function/spring-computer-vision-function/build.gradle new file mode 100644 index 00000000..19271a03 --- /dev/null +++ b/function/spring-computer-vision-function/build.gradle @@ -0,0 +1,7 @@ +dependencies { + api "ai.djl.spring:djl-spring-boot-starter-autoconfigure:$djlSpringVersion" + api "ai.djl.spring:djl-spring-boot-starter-tensorflow-auto:$djlSpringVersion" + api "ai.djl.spring:djl-spring-boot-starter-pytorch-auto:$djlSpringVersion" + api "ai.djl.spring:djl-spring-boot-starter-mxnet-auto:$djlSpringVersion" + runtimeOnly "ai.djl.onnxruntime:onnxruntime-engine" +} diff --git a/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfiguration.java b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfiguration.java new file mode 100644 index 00000000..0823fa45 --- /dev/null +++ b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfiguration.java @@ -0,0 +1,178 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.fn.computer.vision; + +import java.awt.image.RenderedImage; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; + +import javax.imageio.ImageIO; + +import ai.djl.inference.Predictor; +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.CategoryMask; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Joints; +import ai.djl.spring.configuration.DjlAutoConfiguration; +import ai.djl.spring.configuration.DjlConfigurationProperties; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.integration.support.MessageBuilder; +import org.springframework.messaging.Message; + +/** + * A configuration class that provides the necessary beans for the Computer Vision + * Function. + * + * @author Christian Tzolov + */ +@AutoConfiguration(after = DjlAutoConfiguration.class) +@EnableConfigurationProperties({ DjlConfigurationProperties.class, ComputerVisionFunctionProperties.class }) +public class ComputerVisionFunctionConfiguration { + + private static final Logger log = LoggerFactory.getLogger(ComputerVisionFunctionConfiguration.class); + + private ImageFactory imageFactory = ImageFactory.getInstance(); + + @Bean + @ConditionalOnProperty(prefix = "djl", name = "output-class", + havingValue = "ai.djl.modality.cv.output.DetectedObjects") + public Function, Message> objectDetection(Supplier> predictorProvider, + ComputerVisionFunctionProperties cvProperties, DjlConfigurationProperties djlProperties) { + + Function toJson = (detectedObjects) -> JsonHelper.toJson(detectedObjects); + + BiFunction augmentImage = (detectedObjects, image) -> { + Image newImage = image.duplicate(); + newImage.drawBoundingBoxes(detectedObjects); + return getByteArray((RenderedImage) newImage.getWrappedImage(), cvProperties.getOutputImageFormatName()); + }; + + return predictor(predictorProvider, cvProperties, djlProperties, DetectedObjects.class, toJson, augmentImage); + } + + @Bean + @ConditionalOnProperty(prefix = "djl", name = "output-class", + havingValue = "ai.djl.modality.cv.output.CategoryMask") + public Function, Message> semanticSegmentation(Supplier> predictorProvider, + ComputerVisionFunctionProperties cvProperties, DjlConfigurationProperties djlProperties) { + + Function toJson = (mask) -> JsonHelper.toJson(mask); + + BiFunction augmentImage = (mask, image) -> { + Image newImage = image.duplicate(); + mask.drawMask(newImage, 200, 0); + return getByteArray((RenderedImage) newImage.getWrappedImage(), cvProperties.getOutputImageFormatName()); + }; + + return predictor(predictorProvider, cvProperties, djlProperties, CategoryMask.class, toJson, augmentImage); + } + + @Bean + @ConditionalOnProperty(prefix = "djl", name = "output-class", havingValue = "ai.djl.modality.Classifications") + public Function, Message> imageClassifications(Supplier> predictorProvider, + ComputerVisionFunctionProperties cvProperties, DjlConfigurationProperties djlProperties) { + + Function toJson = (classifications) -> JsonHelper.toJson(classifications); + + BiFunction augmentImage = (classifications, image) -> { + Image newImage = image.duplicate(); + return getByteArray((RenderedImage) newImage.getWrappedImage(), cvProperties.getOutputImageFormatName()); + }; + + return predictor(predictorProvider, cvProperties, djlProperties, Classifications.class, toJson, augmentImage); + } + + @Bean + @ConditionalOnProperty(prefix = "djl", name = "output-class", havingValue = "ai.djl.modality.cv.output.Joints") + public Function, Message> poseEstimation(Supplier> predictorProvider, + ComputerVisionFunctionProperties cvProperties, DjlConfigurationProperties djlProperties) { + + Function toJson = (joins) -> JsonHelper.toJson(joins); + + BiFunction augmentImage = (joints, image) -> { + Image newImage = image.duplicate(); + newImage.drawJoints(joints); + return getByteArray((RenderedImage) newImage.getWrappedImage(), cvProperties.getOutputImageFormatName()); + }; + + return predictor(predictorProvider, cvProperties, djlProperties, Joints.class, toJson, augmentImage); + } + + private Function, Message> predictor(Supplier> predictorProvider, + ComputerVisionFunctionProperties cvProperties, DjlConfigurationProperties djlProperties, + Class outputClass, Function toJsonFunction, + BiFunction augmentImageFunction) { + + return (input) -> { + + Predictor predictor = (Predictor) predictorProvider.get(); + + try { + + Image image = this.imageFactory.fromInputStream(new ByteArrayInputStream(input.getPayload())); + + T output = predictor.predict(image); + + String outputJson = toJsonFunction.apply(output); + + byte[] outputImageBytes = input.getPayload(); + + if (cvProperties.isAugmentEnabled()) { + outputImageBytes = augmentImageFunction.apply(output, image); + } + + String headerName = cvProperties.getOutputHeaderName(); + Message outMessage = MessageBuilder.withPayload(outputImageBytes) + .setHeader(headerName, outputJson) + .build(); + + return outMessage; + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + finally { + predictor.close(); + } + }; + } + + private static byte[] getByteArray(RenderedImage image, String formatName) { + try { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + ImageIO.write(image, formatName, byteArrayOutputStream); + return byteArrayOutputStream.toByteArray(); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } + } + +} diff --git a/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionProperties.java b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionProperties.java new file mode 100644 index 00000000..7688d4bc --- /dev/null +++ b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionProperties.java @@ -0,0 +1,68 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.fn.computer.vision; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Configuration properties for the Computer Vision Function. + * + * @author Christian Tzolov + */ +@ConfigurationProperties("computer.vision.function") +public class ComputerVisionFunctionProperties { + + /** + * Enable image augmentation. + */ + private boolean augmentEnabled = false; + + /** + * Output augmented image format name. + */ + private String outputImageFormatName = "png"; + + /** + * Name of the header that contains the JSON payload computed by the functions. + */ + private String outputHeaderName = "cvjson"; + + public boolean isAugmentEnabled() { + return this.augmentEnabled; + } + + public void setAugmentEnabled(boolean augmentImage) { + this.augmentEnabled = augmentImage; + } + + public String getOutputImageFormatName() { + return this.outputImageFormatName; + } + + public void setOutputImageFormatName(String outputImageFormatName) { + this.outputImageFormatName = outputImageFormatName; + } + + public String getOutputHeaderName() { + return this.outputHeaderName; + } + + public void setOutputHeaderName(String jsonHeaderName) { + this.outputHeaderName = jsonHeaderName; + } + +} diff --git a/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/JsonHelper.java b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/JsonHelper.java new file mode 100644 index 00000000..378d8813 --- /dev/null +++ b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/JsonHelper.java @@ -0,0 +1,112 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.fn.computer.vision; + +import java.lang.reflect.Type; +import java.util.List; + +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.CategoryMask; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Joints; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.util.JsonUtils; +import com.google.gson.Gson; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonParseException; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; + +/** + * Helper class to serialize and deserialize {@link DetectedObjects}, + * {@link Classifications}, {@link CategoryMask} and {@link Joints} to/from JSON. + * + * @author Christian Tzolov + */ +public final class JsonHelper { + + private static final Gson GSON = JsonUtils.builder().create(); + + private JsonHelper() { + } + + public static String toJson(Joints joints) { + return GSON.toJson(joints); + } + + public static Joints toJoints(String json) { + return GSON.fromJson(json, Joints.class); + } + + public static String toJson(CategoryMask categoryMask) { + return GSON.toJson(Mask.fromCategoryMask(categoryMask)); + } + + public static CategoryMask toCategoryMask(String json) { + return GSON.fromJson(json, Mask.class).toCategoryMask(); + } + + public static String toJson(Classifications classifications) { + return GSON.toJson(classifications); + } + + public static Classifications toClassifications(String json) { + return GSON.fromJson(json, Classifications.class); + } + + private static final Gson GSON2 = JsonUtils.builder() + .registerTypeAdapter(BoundingBox.class, new BoundingBoxAdapter()) + .create(); + + public static String toJson(DetectedObjects detectedObject) { + return GSON2.toJson(detectedObject); + } + + public static DetectedObjects toDetectedObjects(String json) { + return GSON2.fromJson(json, DetectedObjects.class); + } + + public record Mask(List classes, int[][] mask) { + + public static Mask fromCategoryMask(CategoryMask categoryMask) { + return new Mask(categoryMask.getClasses(), categoryMask.getMask()); + } + + public CategoryMask toCategoryMask() { + return new CategoryMask(this.classes, this.mask); + } + } + + public static class BoundingBoxAdapter implements JsonSerializer, JsonDeserializer { + + @Override + public JsonElement serialize(BoundingBox boundingBox, Type typeOfSrc, JsonSerializationContext context) { + return context.serialize(boundingBox); + } + + @Override + public BoundingBox deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) + throws JsonParseException { + return context.deserialize(json, Rectangle.class); + } + + } + +} diff --git a/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/package-info.java b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/package-info.java new file mode 100644 index 00000000..d389c3e8 --- /dev/null +++ b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/package-info.java @@ -0,0 +1,4 @@ +/** + * Provides classes for the Computer Vision Function. + */ +package org.springframework.cloud.fn.computer.vision; diff --git a/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/translator/TensorflowSavedModelObjectDetectionTranslator.java b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/translator/TensorflowSavedModelObjectDetectionTranslator.java new file mode 100644 index 00000000..71228c1e --- /dev/null +++ b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/translator/TensorflowSavedModelObjectDetectionTranslator.java @@ -0,0 +1,184 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.fn.computer.vision.translator; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Scanner; +import java.util.concurrent.ConcurrentHashMap; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.JsonUtils; +import com.google.gson.annotations.SerializedName; + +/** + * A {@link NoBatchifyTranslator} that post-processes the output of a TensorFlow + * SavedModel Object Detection model. + * + * @author Christian Tzolov + */ +public final class TensorflowSavedModelObjectDetectionTranslator + implements NoBatchifyTranslator { + + private static final String ITEM_DELIMITER = "item "; + + private static final String DEFAULT_MSCOCO_LABELS_URL = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt"; + + private static final String DETECTION_BOXES = "detection_boxes"; + + private static final String DETECTION_SCORES = "detection_scores"; + + private static final String DETECTION_CLASSES = "detection_classes"; + + private String classLabelsUrl; + + private Map classLabels; + + private int maxBoxes; + + private float threshold; + + public TensorflowSavedModelObjectDetectionTranslator() { + this(DEFAULT_MSCOCO_LABELS_URL, 10, 0.3f); + } + + public TensorflowSavedModelObjectDetectionTranslator(String categoryLabelsUrl, int maxBoxes, float threshold) { + this.classLabelsUrl = categoryLabelsUrl; + this.maxBoxes = maxBoxes; + this.threshold = threshold; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + // input to tf object-detection models is a list of tensors, hence NDList + NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR); + // optionally resize the image for faster processing + array = NDImageUtils.resize(array, 224); + // tf object-detection models expect 8 bit unsigned integer tensor + array = array.toType(DataType.UINT8, true); + // tf object-detection models expect a 4 dimensional input + array = array.expandDims(0); + + return new NDList(array); + } + + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) throws IOException { + if (this.classLabels == null) { + this.classLabels = loadSynset(); + } + } + + private Map loadSynset() throws IOException { + Map map = new ConcurrentHashMap<>(); + int maxId = 0; + try (InputStream is = new BufferedInputStream(new URL(this.classLabelsUrl).openStream()); + Scanner scanner = new Scanner(is, StandardCharsets.UTF_8.name())) { + + scanner.useDelimiter(ITEM_DELIMITER); + while (scanner.hasNext()) { + String content = scanner.next(); + content = content.replaceAll("(\"|\\d)\\n\\s", "$1,"); + Item item = JsonUtils.GSON.fromJson(content, Item.class); + map.put(item.id, item.displayName); + if (item.id > maxId) { + maxId = item.id; + } + } + } + return map; + } + + /** {@inheritDoc} */ + @Override + public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { + // output of tf object-detection models is a list of tensors, hence NDList in djl + // output NDArray order in the list are not guaranteed + + int[] classIds = null; + float[] probabilities = null; + NDArray boundingBoxes = null; + for (NDArray array : list) { + if (DETECTION_BOXES.equals(array.getName())) { + boundingBoxes = array.get(0); + } + else if (DETECTION_SCORES.equals(array.getName())) { + probabilities = array.get(0).toFloatArray(); + } + else if (DETECTION_CLASSES.equals(array.getName())) { + // class id is between 1 - number of classes + classIds = array.get(0).toType(DataType.INT32, true).toIntArray(); + } + } + Objects.requireNonNull(classIds); + Objects.requireNonNull(probabilities); + Objects.requireNonNull(boundingBoxes); + + List retNames = new ArrayList<>(); + List retProbs = new ArrayList<>(); + List retBB = new ArrayList<>(); + + // result are already sorted + for (int i = 0; i < Math.min(classIds.length, this.maxBoxes); ++i) { + int classId = classIds[i]; + double probability = probabilities[i]; + // classId starts from 1, -1 means background + if (classId > 0 && probability > this.threshold) { + String className = this.classLabels.getOrDefault(classId, "#" + classId); + float[] box = boundingBoxes.get(i).toFloatArray(); + float yMin = box[0]; + float xMin = box[1]; + float yMax = box[2]; + float xMax = box[3]; + Rectangle rect = new Rectangle(xMin, yMin, xMax - xMin, yMax - yMin); + retNames.add(className); + retProbs.add(probability); + retBB.add(rect); + } + } + + return new DetectedObjects(retNames, retProbs, retBB); + } + + private static final class Item { + + int id; + + @SerializedName("display_name") + String displayName; + + } + +} diff --git a/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/translator/TensorflowSavedModelObjectDetectionTranslatorFactory.java b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/translator/TensorflowSavedModelObjectDetectionTranslatorFactory.java new file mode 100644 index 00000000..4bf61f78 --- /dev/null +++ b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/translator/TensorflowSavedModelObjectDetectionTranslatorFactory.java @@ -0,0 +1,39 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.fn.computer.vision.translator; + +import java.util.Map; + +import ai.djl.Model; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.translator.ObjectDetectionTranslatorFactory; +import ai.djl.translate.Translator; + +/** + * Translator for TensorFlow Object Detection SavedModel. + * + * @author Christian Tzolov + */ +public class TensorflowSavedModelObjectDetectionTranslatorFactory extends ObjectDetectionTranslatorFactory { + + @Override + protected Translator buildBaseTranslator(Model model, Map arguments) { + return new TensorflowSavedModelObjectDetectionTranslator(); + } + +} diff --git a/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/translator/package-info.java b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/translator/package-info.java new file mode 100644 index 00000000..efe1aa03 --- /dev/null +++ b/function/spring-computer-vision-function/src/main/java/org/springframework/cloud/fn/computer/vision/translator/package-info.java @@ -0,0 +1,4 @@ +/** + * Provides classes for translating the output of the computer vision function. + */ +package org.springframework.cloud.fn.computer.vision.translator; diff --git a/function/spring-computer-vision-function/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/function/spring-computer-vision-function/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000..e0eed01e --- /dev/null +++ b/function/spring-computer-vision-function/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1 @@ +org.springframework.cloud.fn.computer.vision.ComputerVisionFunctionConfiguration diff --git a/function/spring-computer-vision-function/src/test/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfigurationTests.java b/function/spring-computer-vision-function/src/test/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfigurationTests.java new file mode 100644 index 00000000..a2382384 --- /dev/null +++ b/function/spring-computer-vision-function/src/test/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfigurationTests.java @@ -0,0 +1,388 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.fn.computer.vision; + +import java.awt.image.BufferedImage; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.util.function.Function; + +import javax.imageio.ImageIO; + +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.CategoryMask; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Joints; +import ai.djl.modality.cv.translator.SemanticSegmentationTranslatorFactory; +import ai.djl.modality.cv.translator.YoloV8TranslatorFactory; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.spring.configuration.ApplicationType; +import ai.djl.spring.configuration.DjlAutoConfiguration; +import ai.djl.spring.configuration.DjlConfigurationProperties; +import ai.djl.util.JsonUtils; +import com.google.gson.Gson; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory; +import org.springframework.core.io.ClassPathResource; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +import static org.assertj.core.api.Assertions.assertThat; + +public class ComputerVisionFunctionConfigurationTests { + + private static final Logger log = LoggerFactory.getLogger(ComputerVisionFunctionConfigurationTests.class); + + private Gson gson = JsonUtils.builder().create(); + + private ApplicationContextRunner applicationContextRunner; + + @BeforeEach + public void setUp() { + applicationContextRunner = new ApplicationContextRunner().withConfiguration( + AutoConfigurations.of(DjlAutoConfiguration.class, ComputerVisionFunctionConfiguration.class)); + } + + /** + * This configuration can be used to load any of the Tensorflow2 models for object + * detection from here: + * https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md + */ + @Test + public void tf2SavedModel() { + + applicationContextRunner.withPropertyValues( + // @formatter:off + "computer.vision.function.augment-enabled=true", + "djl.application-type=" + ApplicationType.OBJECT_DETECTION, + "djl.input-class=" + Image.class.getName(), + "djl.output-class=" + DetectedObjects.class.getName(), + "djl.engine=TensorFlow", + "djl.urls=http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz", + "djl.model-name=saved_model", + "djl.translator-factory=" + TensorflowSavedModelObjectDetectionTranslatorFactory.class.getName(), + "djl.arguments.threshold=0.3") + // @formatter:on + .run((context) -> { + assertThat(context).hasSingleBean(ZooModel.class); + assertThat(context).hasBean("predictorProvider"); + assertThat(context).hasBean("objectDetection"); + assertThat(context).doesNotHaveBean("semanticSegmentation"); + + Function, Message> predictor = (Function, Message>) context + .getBean("objectDetection"); + + var djlProperties = context.getBean(DjlConfigurationProperties.class); + + assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.OBJECT_DETECTION); + assertThat(djlProperties.getInputClass()).isEqualTo(Image.class); + assertThat(djlProperties.getOutputClass()).isEqualTo(DetectedObjects.class); + assertThat(djlProperties.getEngine()).isEqualTo("TensorFlow"); + assertThat(djlProperties.getUrls()).contains( + "http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz"); + assertThat(djlProperties.getModelName()).isEqualTo("saved_model"); + assertThat(djlProperties.getTranslatorFactory()) + .isEqualTo(TensorflowSavedModelObjectDetectionTranslatorFactory.class.getName()); + + byte[] inputImage = new ClassPathResource("/object-detection.jpg").getInputStream().readAllBytes(); + + Message outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build()); + + assertThat(outputMessage).isNotNull(); + assertThat(outputMessage.getPayload()).isNotNull(); + assertThat(outputMessage.getPayload().length).isGreaterThan(0); + assertThat(outputMessage.getHeaders()).containsKey("cvjson"); + + var json = outputMessage.getHeaders().get("cvjson", String.class); + log.info(json); + + assertThat(JsonHelper.toDetectedObjects(json)).isNotNull(); + + // save(outputMessage.getPayload(), + // "tf2-sm-object-detection-augmented.jpg"); + }); + } + + @Test + public void yolov8Detection() { + applicationContextRunner.withPropertyValues( + // @formatter:off + "computer.vision.function.augment-enabled=true", + "djl.application-type=" + ApplicationType.OBJECT_DETECTION, + "djl.input-class=" + Image.class.getName(), + "djl.output-class=" + DetectedObjects.class.getName(), + "djl.engine=OnnxRuntime", + "djl.urls=djl://ai.djl.onnxruntime/yolov8n", + "djl.translator-factory=" + YoloV8TranslatorFactory.class.getName(), + "djl.arguments.threshold=0.3", + "djl.arguments.width=640", + "djl.arguments.height=640", + "djl.arguments.resize=true", + "djl.arguments.toTensor=true", + "djl.arguments.applyRatio=true", + "djl.arguments.maxBox=1000") + // @formatter:on + .run((context) -> { + assertThat(context).hasSingleBean(ZooModel.class); + assertThat(context).hasBean("predictorProvider"); + assertThat(context).hasBean("objectDetection"); + + Function, Message> predictor = (Function, Message>) context + .getBean("objectDetection", Function.class); + + var djlProperties = context.getBean(DjlConfigurationProperties.class); + + assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.OBJECT_DETECTION); + assertThat(djlProperties.getInputClass()).isEqualTo(Image.class); + assertThat(djlProperties.getOutputClass()).isEqualTo(DetectedObjects.class); + assertThat(djlProperties.getEngine()).isEqualTo("OnnxRuntime"); + assertThat(djlProperties.getUrls()).contains("djl://ai.djl.onnxruntime/yolov8n"); + assertThat(djlProperties.getTranslatorFactory()).isEqualTo(YoloV8TranslatorFactory.class.getName()); + + byte[] inputImage = new ClassPathResource("/object-detection.jpg").getInputStream().readAllBytes(); + + Message outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build()); + + assertThat(outputMessage).isNotNull(); + assertThat(outputMessage.getPayload()).isNotNull(); + assertThat(outputMessage.getPayload().length).isGreaterThan(0); + assertThat(outputMessage.getHeaders()).containsKey("cvjson"); + + var json = "" + outputMessage.getHeaders().get("cvjson", String.class); + log.info(json); + + var detectionObjects = JsonHelper.toDetectedObjects(json); + + assertThat(detectionObjects).isNotNull(); + + // save(outputMessage.getPayload(), + // "yolo-v8-object-detection-augmented.jpg"); + }); + + } + + @Test + public void instanceSegmentation() { + applicationContextRunner.withPropertyValues( + // @formatter:off + "computer.vision.function.augment-enabled=true", + "djl.application-type=" + ApplicationType.INSTANCE_SEGMENTATION, + "djl.input-class=" + Image.class.getName(), + "djl.output-class=" + DetectedObjects.class.getName(), + "djl.arguments.threshold=0.3", + + "djl.model-filter.backbone=resnet18", + "djl.model-filter.flavor=v1b", + "djl.model-filter.dataset=coco") + // @formatter:on + .run((context) -> { + assertThat(context).hasSingleBean(ZooModel.class); + assertThat(context).hasBean("predictorProvider"); + assertThat(context).hasBean("objectDetection"); + + Function, Message> predictor = (Function, Message>) context + .getBean("objectDetection", Function.class); + + var djlProperties = context.getBean(DjlConfigurationProperties.class); + + assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.INSTANCE_SEGMENTATION); + assertThat(djlProperties.getInputClass()).isEqualTo(Image.class); + assertThat(djlProperties.getOutputClass()).isEqualTo(DetectedObjects.class); + + // byte[] inputImage = new + // ClassPathResource("/object-detection.jpg").getInputStream().readAllBytes(); + byte[] inputImage = new ClassPathResource("/amsterdam-cityscape.jpg").getInputStream().readAllBytes(); + + Message outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build()); + + assertThat(outputMessage).isNotNull(); + assertThat(outputMessage.getPayload()).isNotNull(); + assertThat(outputMessage.getPayload().length).isGreaterThan(0); + assertThat(outputMessage.getHeaders()).containsKey("cvjson"); + String json = outputMessage.getHeaders().get("cvjson", String.class); + log.info(json); + // save(outputMessage.getPayload(), + // "instance-segmentation-augmented.jpg"); + + assertThat(JsonHelper.toDetectedObjects(json)).isNotNull(); + }); + } + + @Test + public void semanticSegmentation() { + applicationContextRunner.withPropertyValues( + // @formatter:off + "computer.vision.function.augment-enabled=true", + "djl.application-type=" + ApplicationType.SEMANTIC_SEGMENTATION, + "djl.input-class=" + Image.class.getName(), + "djl.output-class=" + CategoryMask.class.getName(), + "djl.arguments.threshold=0.3", + + "djl.urls=https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip", + "djl.translator-factory=" + SemanticSegmentationTranslatorFactory.class.getName(), + "djl.engine=PyTorch") + // @formatter:on + .run((context) -> { + assertThat(context).hasSingleBean(ZooModel.class); + assertThat(context).hasBean("predictorProvider"); + assertThat(context).hasBean("semanticSegmentation"); + + Function, Message> predictor = (Function, Message>) context + .getBean("semanticSegmentation", Function.class); + + var djlProperties = context.getBean(DjlConfigurationProperties.class); + + assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.SEMANTIC_SEGMENTATION); + assertThat(djlProperties.getInputClass()).isEqualTo(Image.class); + assertThat(djlProperties.getOutputClass()).isEqualTo(CategoryMask.class); + + byte[] inputImage = new ClassPathResource("/amsterdam-cityscape.jpg").getInputStream().readAllBytes(); + + Message outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build()); + + assertThat(outputMessage).isNotNull(); + assertThat(outputMessage.getPayload()).isNotNull(); + assertThat(outputMessage.getPayload().length).isGreaterThan(0); + assertThat(outputMessage.getHeaders()).containsKey("cvjson"); + + String ssJson = outputMessage.getHeaders().get("cvjson", String.class); + + // log.info(ssJson); + + // save(outputMessage.getPayload(), + // "semantic-segmentation-augmented.jpg"); + + assertThat(JsonHelper.toCategoryMask(ssJson)).isNotNull(); + }); + } + + @Test + public void imageClassifications() { + applicationContextRunner.withPropertyValues( + // @formatter:off + "computer.vision.function.augment-enabled=false", + "djl.application-type=" + ApplicationType.IMAGE_CLASSIFICATION, + "djl.input-class=" + Image.class.getName(), + "djl.output-class=" + Classifications.class.getName(), + "djl.arguments.threshold=0.3", + "djl.engine=MXNet") + // @formatter:on + .run((context) -> { + assertThat(context).hasSingleBean(ZooModel.class); + assertThat(context).hasBean("predictorProvider"); + assertThat(context).hasBean("imageClassifications"); + + Function, Message> predictor = (Function, Message>) context + .getBean("imageClassifications", Function.class); + + var djlProperties = context.getBean(DjlConfigurationProperties.class); + + assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.IMAGE_CLASSIFICATION); + assertThat(djlProperties.getInputClass()).isEqualTo(Image.class); + assertThat(djlProperties.getOutputClass()).isEqualTo(Classifications.class); + + byte[] inputImage = new ClassPathResource("/karakatschan.jpg").getInputStream().readAllBytes(); + + Message outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build()); + + assertThat(outputMessage).isNotNull(); + assertThat(outputMessage.getPayload()).isNotNull(); + assertThat(outputMessage.getPayload().length).isGreaterThan(0); + assertThat(outputMessage.getHeaders()).containsKey("cvjson"); + + String json = outputMessage.getHeaders().get("cvjson", String.class); + + log.info(json); + + assertThat(JsonHelper.toClassifications(json)).isNotNull(); + }); + } + + @Test + public void poseEstimation() { + applicationContextRunner.withPropertyValues( + // @formatter:off + "computer.vision.function.augment-enabled=true", + "djl.application-type=" + ApplicationType.POSE_ESTIMATION, + "djl.input-class=" + Image.class.getName(), + "djl.output-class=" + Joints.class.getName(), + "djl.arguments.threshold=0.3", + "djl.model-filter.backbone=resnet18", + "djl.model-filter.flavor=v1b", + "djl.model-filter.dataset=imagenet") + // @formatter:on + .run((context) -> { + assertThat(context).hasSingleBean(ZooModel.class); + assertThat(context).hasBean("predictorProvider"); + assertThat(context).hasBean("poseEstimation"); + + Function, Message> predictor = (Function, Message>) context + .getBean("poseEstimation", Function.class); + + var djlProperties = context.getBean(DjlConfigurationProperties.class); + + assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.POSE_ESTIMATION); + assertThat(djlProperties.getInputClass()).isEqualTo(Image.class); + assertThat(djlProperties.getOutputClass()).isEqualTo(Joints.class); + + byte[] inputImage = new ClassPathResource("/pose.png").getInputStream().readAllBytes(); + + Message outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build()); + + assertThat(outputMessage).isNotNull(); + assertThat(outputMessage.getPayload()).isNotNull(); + assertThat(outputMessage.getPayload().length).isGreaterThan(0); + assertThat(outputMessage.getHeaders()).containsKey("cvjson"); + + String ssJson = outputMessage.getHeaders().get("cvjson", String.class); + + log.info(ssJson); + + assertThat(JsonHelper.toJoints(ssJson)).isNotNull(); + }); + } + + private static void save(byte[] imageBytes, String outputFileName) { + BufferedImage image = createImageFromBytes(imageBytes); + try { + // Use ImageIO.write() to save the RenderedImage to the specified file + ImageIO.write(image, "png", new File("build" + File.separator + outputFileName)); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + } + + private static BufferedImage createImageFromBytes(byte[] imageData) { + ByteArrayInputStream bais = new ByteArrayInputStream(imageData); + try { + return ImageIO.read(bais); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + } + +} diff --git a/function/spring-computer-vision-function/src/test/java/org/springframework/cloud/fn/computer/vision/JsonHelperTests.java b/function/spring-computer-vision-function/src/test/java/org/springframework/cloud/fn/computer/vision/JsonHelperTests.java new file mode 100644 index 00000000..073489d5 --- /dev/null +++ b/function/spring-computer-vision-function/src/test/java/org/springframework/cloud/fn/computer/vision/JsonHelperTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.fn.computer.vision; + +import java.util.List; + +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.output.CategoryMask; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +public class JsonHelperTests { + + @Test + public void categoryMask() { + + var categoryMask = new CategoryMask(List.of("a", "b", "c"), new int[][] { { 1, 2, 3 }, { 4, 5, 6 } }); + + var json = JsonHelper.toJson(categoryMask); + + assertThat(json).isNotEmpty(); + + var categoryMask2 = JsonHelper.toCategoryMask(json); + + assertThat(categoryMask.getClasses()).isEqualTo(categoryMask2.getClasses()); + assertThat(categoryMask.getMask()).isEqualTo(categoryMask2.getMask()); + } + + @Test + public void classifications() { + + var classifications = new Classifications(List.of("a", "b", "c"), List.of(0.1, 0.2, 0.3)); + classifications.setTopK(3); + + var json = JsonHelper.toJson(classifications); + + assertThat(json).isNotEmpty(); + + var classifications2 = JsonHelper.toClassifications(json); + + assertThat(classifications2.getClassNames()).isEqualTo(classifications.getClassNames()); + assertThat(classifications2.getProbabilities()).isEqualTo(classifications.getProbabilities()); + assertThat(classifications2.topK()).hasSize(3); + } + + @Test + public void detectedObjects() { + DetectedObjects detectedObjects = new DetectedObjects(List.of("a", "b", "c"), List.of(0.1, 0.2, 0.3), + List.of(new Rectangle(1, 2, 3, 4), new Rectangle(5, 6, 7, 8), new Rectangle(9, 10, 11, 12))); + detectedObjects.setTopK(3); + + var json = JsonHelper.toJson(detectedObjects); + + assertThat(json).isNotEmpty(); + + var detectedObjects2 = JsonHelper.toDetectedObjects(json); + + assertThat(detectedObjects2.getClassNames()).isEqualTo(detectedObjects.getClassNames()); + assertThat(detectedObjects2.getProbabilities()).isEqualTo(detectedObjects.getProbabilities()); + assertThat(detectedObjects2.topK()).hasSize(3); + + assertThat(detectedObjects2.getNumberOfObjects()).isEqualTo(3); + } + +} diff --git a/function/spring-computer-vision-function/src/test/resources/amsterdam-cityscape.jpg b/function/spring-computer-vision-function/src/test/resources/amsterdam-cityscape.jpg new file mode 100644 index 00000000..d77cae40 Binary files /dev/null and b/function/spring-computer-vision-function/src/test/resources/amsterdam-cityscape.jpg differ diff --git a/function/spring-computer-vision-function/src/test/resources/karakatschan.jpg b/function/spring-computer-vision-function/src/test/resources/karakatschan.jpg new file mode 100644 index 00000000..0a8b2808 Binary files /dev/null and b/function/spring-computer-vision-function/src/test/resources/karakatschan.jpg differ diff --git a/function/spring-computer-vision-function/src/test/resources/object-detection.jpg b/function/spring-computer-vision-function/src/test/resources/object-detection.jpg new file mode 100644 index 00000000..9eb325ac Binary files /dev/null and b/function/spring-computer-vision-function/src/test/resources/object-detection.jpg differ diff --git a/function/spring-computer-vision-function/src/test/resources/pose.png b/function/spring-computer-vision-function/src/test/resources/pose.png new file mode 100644 index 00000000..7ff6034e Binary files /dev/null and b/function/spring-computer-vision-function/src/test/resources/pose.png differ diff --git a/function/spring-computer-vision-function/src/test/resources/test1.png b/function/spring-computer-vision-function/src/test/resources/test1.png new file mode 100644 index 00000000..c1fed0b2 Binary files /dev/null and b/function/spring-computer-vision-function/src/test/resources/test1.png differ