Skip to content

Commit

Permalink
Initial CV function commit
Browse files Browse the repository at this point in the history
  • Loading branch information
tzolov committed Jan 22, 2024
1 parent d1c5c10 commit 2bded4f
Show file tree
Hide file tree
Showing 19 changed files with 1,297 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ bin/
.gradle/
.idea/
out/

.vscode/
2 changes: 2 additions & 0 deletions dependencies.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ ext {
springCloudAwsVersion = '3.0.4'

debeziumVersion = '2.5.0.Final'
djlVersion = '0.26.0'
djlSpringVersion = '0.26-SNAPSHOT'

springIntegrationAws = 'org.springframework.integration:spring-integration-aws:3.0.5'
ftpserverCore = 'org.apache.ftpserver:ftpserver-core:1.2.0'
Expand Down
214 changes: 214 additions & 0 deletions function/spring-computer-vision-function/README.adoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
= Computer Vision Functions

This module provides functional interface to perform common Computer Vision tasks such as Image Classification, Object Detection, Instance and Semantic Segmentation, Pose Estimation an more.

It leverages the https://docs.djl.ai/index.html[Deep Java Library] (DJL) library to enable Java developers to harness the power of deep learning.
DJL serves as a bridge between the rich ecosystem of Java programming and the cutting-edge capabilities of deep learning.
DJL provides integration with popular deep learning frameworks like `TensorFlow`, `PyTorch`, and `MXNet`, as well as support for a variety of pre-trained models using `ONNX Runtime`.

== Beans for injection

This module exposes auto-configurations for the following beans:

* `Function<Message<byte[]>, Message<byte[]>> objectDetection` - Offering `Object Detection` for finding all instances of objects from a known set of categories in an image and `Instance Segmentation`` for finding all instances of objects from a known set of categories in an image and drawing a mask on each instance.
* `Function<Message<byte[]>, Message<byte[]>> imageClassifications` - The `Image Classification` task assigns a label to an image from a set of categories.
* `Function<Message<byte[]>, Message<byte[]>> semanticSegmentation` - `Semantic Segmentation` refers to the task of detecting objects of various classes at pixel level. It colors the pixels based on the objects detected in that space.
* `Function<Message<byte[]>, Message<byte[]>> poseEstimation` - `Pose Estimation` refers to the task of detecting human figures in images and videos, and estimating the pose of the bodies.

Each of them are conditional by specific configuration properties.

[%autowidth]
|===
|Bean |Activation Properties

|objectDetection
|djl.output-class=ai.djl.modality.cv.output.DetectedObjects

|imageClassifications
|djl.output-class=ai.djl.modality.Classifications

|semanticSegmentation
|djl.output-class=ai.djl.modality.cv.output.CategoryMask

|poseEstimation
|djl.output-class=ai.djl.modality.cv.output.Joints

|===

Once injected, you can use the `apply` method of the `Function` to invoke it and get the result.

All functions take and return a `Message<byte[]>`.
The input message payload contains the image bytes to be processed.
The output message payload contains the original or the augmented image after the processing.
The `computer.vision.function.augment-enabled` property controls whether the augmented image is returned or not. Defaults to `true`.

== Configuration Options

[%autowidth]
|===
|Property |Description

|djl.application-type
|Defines the CV application task to be performed. Currently supported values are `OBJECT_DETECTION`, `IMAGE_CLASSIFICATION`, `INSTANCE_SEGMENTATION`, `SEMANTIC_SEGMENTATION` and `POSE_ESTIMATION`.

|djl.input-class
|Define input data type, a model may accept multiple input data type. Currently only the `ai.djl.modality.cv.Image` is supported.

|djl.output-class
|Define output data type, a model may generate different outputs. Supported output classes are `ai.djl.modality.cv.output.DetectedObjects`, `ai.djl.modality.cv.output.CategoryMask`, `ai.djl.modality.Classifications`, `ai.djl.modality.cv.output.Joints` .

|djl.urls
|Model repository URLs. Multiple may be supplied to search for models. Specifying a single URL can be used to load a specific model. Can be specified as comma delimited field or as an array in the configuration file.
Current supported archive formats¶ `zip`, `tar`, `tar.gz`, `tgz`, `tar.z`.

Supported URL schemes: `file://` - load a model from local directory or archive file., `http(s)://` - load a model from an archive file from web server, `jar://` - load a model from an archive file in the class path, `djl://` - load a model from the model zoo, `s3://` - load a model from S3 bucket (requires djl aws extension), `hdfs://` - load a model from HDFS file system (requires djl hadoop extension)

|djl.model-filter
| https://github.com/deepjavalibrary/djl/tree/master/model-zoo#how-to-find-a-pre-trained-model-in-the-model-zoo[Model Filters] used to lookup a model from model zoo .

|djl.group-id
|Defines the `groupId` of the model to be loaded from the zoo.

|djl.model-artifact-id
|Defines the `artifactId` of the model to be loaded from the zoo.

|djl.model-name
|(Optional) Defines the modelName of the model to be loaded.
Leave it empty if you want to load the latest version of the model.
Use "saved_model" for TensorFlow saved models.

|djl.engine
| Name of teh https://docs.djl.ai/docs/engine.html[Engine] to use https://docs.djl.ai/docs/engine.html#supported-engines[Supported engine names].

|djl.translator-factory
| https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html[Translator] provides model pre-processing and postprocessing functionality. Multiple https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/cv/translator/package-summary.html[translators] are provided for different models, but you can implement your own translator if needed (see []). The translator-factory property allow to specify the translator to be used with the model.

|computer.vision.function.output-header-name
|Name of the header that contains the JSON payload computed by the functions.

|computer.vision.function.augment-enabled
|Enable image augmentation (false by default).

|===

=== Example Configurations

All computer vision examples use the following Java code snippet to invoke the function:

[source,Java]
----
@SpringBootApplication
public class TfObjectDetectionBootApp implements CommandLineRunner {
@Autowired
private Function<Message<byte[]>, Message<byte[]>> cvFunction;
@Override
public void run(String... args) throws Exception {
byte[] inputImage = new ClassPathResource("Image URI").getInputStream().readAllBytes();
Message<byte[]> outputMessage = cvFunction.apply(
MessageBuilder.withPayload(inputImage).build());
// Augmented output image.
byte[] outputImage = outputMessage.getPayload();
// JSON payload with the detected objects and their bounding boxes.
String jsonBoundingBoxes = outputMessage.getHeader("cvjson", String.class);
}
public static void main(String[] args) {
SpringApplication.run(TfObjectDetectionBootApp.class);
}
}
----

==== Object Detection (TensorFlow)

You can leverage any of the existing [TensorFlow models]. Just compy the url of the model archive as djl.urls property and set the djl.translator-factory to `org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory`.

----
computer.vision.function.augment-enabled=true
djl.application-type=OBJECT_DETECTION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.cv.output.DetectedObjects
djl.engine=TensorFlow
djl.urls=http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz
djl.model-name=saved_model
djl.translator-factory=org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory
djl.arguments.threshold=0.3
----

==== Object Detection (Yolo v8)

You can use the same Java snipped above, just change the configuration to use the Yolo v8 model:

----
computer.vision.function.augment-enabled=true
djl.application-type=OBJECT_DETECTION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.cv.output.DetectedObjects
djl.engine=OnnxRuntime
djl.urls=djl://ai.djl.onnxruntime/yolov8n
djl.translator-factory=ai.djl.modality.cv.translator.YoloV8TranslatorFactory
djl.arguments.threshold=0.3
djl.arguments.width=640
djl.arguments.height=640
djl.arguments.resize=true
djl.arguments.toTensor=true
djl.arguments.applyRatio=true
djl.arguments.maxBox=1000
----

==== Instance Segmentation

Same Java code snipped but with the following configuration:

----
computer.vision.function.augment-enabled=true
djl.application-type=INSTANCE_SEGMENTATION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.cv.output.DetectedObjects
djl.arguments.threshold=0.3
djl.model-filter.backbone=resnet18
djl.model-filter.flavor=v1b
djl.model-filter.dataset=coco
----

Note that here we didn't specify the model to be used, but used the model-filter to find a compatible model from the model zoo.

==== Semantic Segmentation

Same Java code snipped but with the following configuration:

----
computer.vision.function.augment-enabled=true
djl.application-type=SEMANTIC_SEGMENTATION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.cv.output.CategoryMask
djl.arguments.threshold=0.3
djl.urls=https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip
djl.translator-factory=ai.djl.modality.cv.translator.SemanticSegmentationTranslatorFactory
djl.engine=PyTorch
----

==== Image Classification

----
djl.application-type=IMAGE_CLASSIFICATION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.Classifications
djl.arguments.threshold=0.3
djl.engine=MXNet
----

== Tests

See this link:src/test/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfigurationTests.java[test suite] for examples of how this function is used.

The link:src/test/java/org/springframework/cloud/fn/computer/vision/JsonHelperTests.java[JsonHelperTests] validates the JSON serialization and deserialization of the `ComputerVisionFunctionConfiguration` class values object classes.

== Other usage

11 changes: 11 additions & 0 deletions function/spring-computer-vision-function/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
repositories {
mavenLocal()
}

dependencies {
api "ai.djl.spring:djl-spring-boot-starter-autoconfigure:$djlSpringVersion"
api "ai.djl.spring:djl-spring-boot-starter-tensorflow-auto:$djlSpringVersion"
api "ai.djl.spring:djl-spring-boot-starter-pytorch-auto:$djlSpringVersion"
api "ai.djl.spring:djl-spring-boot-starter-mxnet-auto:$djlSpringVersion"
runtimeOnly "ai.djl.onnxruntime:onnxruntime-engine:$djlVersion"
}
Loading

0 comments on commit 2bded4f

Please sign in to comment.