-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/add instance segmentation #67
Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
3d3c88e
Add FastSAM
HonzaCuhel f2dbf33
Update
HonzaCuhel 535d09a
Update Colab notebook
HonzaCuhel 454d749
Add vizualization
HonzaCuhel 7bb93e9
Update README.md and tests
HonzaCuhel 3fcb736
Update COCO converter
HonzaCuhel 5a0795d
Refactor YOLO converter
HonzaCuhel c0cf6ab
Refactor visualize function
HonzaCuhel a1c6b6a
[Automated] Updated coverage badge
actions-user 7879220
fix: different color for different classes in the segmenetation visua…
sokovninn 4fae718
Switch to SlimSAM
HonzaCuhel f40e5a0
Switch to SlimSAM
HonzaCuhel 853d5ad
Update instance segmentation example
HonzaCuhel 04e91fd
Update tests
HonzaCuhel ff771ad
Fix: annotator tests
HonzaCuhel 335cc05
[Automated] Updated coverage badge
actions-user f887910
Update docs & luxonis dataset creation
HonzaCuhel b8151cb
fix: return SliamSAM processor
sokovninn af08e4b
fix: handle empty polygon list
sokovninn c566bea
Fix: remove long outputs from Jupyter Notebook
HonzaCuhel 07a58f0
Fix: README.md
HonzaCuhel 057a9b4
Add OWLv2 non-square pixel fix
HonzaCuhel 437d067
Rename vars
HonzaCuhel cd819c4
Fix: correct all SlimSAM mentions
HonzaCuhel 5e45347
fix: different image sizes for owlv2 postprocessing
sokovninn 3b915ba
Update OWLv2 bbox correction
HonzaCuhel 68487e4
fix: pass segmentation annotator size
sokovninn 5401431
fix: shifted annotations when tta is used
sokovninn d47253a
Fix OWLv2 device
HonzaCuhel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import List | ||
|
||
import numpy as np | ||
import PIL | ||
import torch | ||
from transformers import SamModel, SamProcessor | ||
|
||
from datadreamer.dataset_annotation.image_annotator import BaseAnnotator | ||
from datadreamer.dataset_annotation.utils import mask_to_polygon | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SlimSAMAnnotator(BaseAnnotator): | ||
"""A class for image annotation using the SlimSAM model, specializing in instance | ||
segmentation. | ||
|
||
Attributes: | ||
model (SAM): The SAM model for instance segmentation. | ||
processor (SamProcessor): The processor for the SAM model. | ||
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU). | ||
size (str): The size of the SAM model to use ('base' or 'large'). | ||
|
||
Methods: | ||
_init_model(): Initializes the SAM model. | ||
_init_processor(): Initializes the processor for the SAM model. | ||
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels. | ||
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
seed: float = 42, | ||
device: str = "cuda", | ||
size: str = "base", | ||
) -> None: | ||
"""Initializes the SAMAnnotator with a specific seed and device. | ||
|
||
Args: | ||
seed (float): Seed for reproducibility. Defaults to 42. | ||
device (str): The device to run the model on. Defaults to 'cuda'. | ||
""" | ||
super().__init__(seed) | ||
self.size = size | ||
self.model = self._init_model() | ||
self.processor = self._init_processor() | ||
self.device = device | ||
self.model.to(self.device) | ||
|
||
def _init_model(self) -> SamModel: | ||
"""Initializes the SAM model for object detection. | ||
|
||
Returns: | ||
SamModel: The initialized SAM model. | ||
""" | ||
logger.info(f"Initializing `SlimSAM {self.size} model...") | ||
if self.size == "large": | ||
return SamModel.from_pretrained("Zigeng/SlimSAM-uniform-50") | ||
return SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") | ||
|
||
def _init_processor(self) -> SamProcessor: | ||
"""Initializes the processor for the SAM model. | ||
|
||
Returns: | ||
SamProcessor: The initialized processor. | ||
""" | ||
if self.size == "large": | ||
return SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50") | ||
return SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") | ||
|
||
def annotate_batch( | ||
self, | ||
images: List[PIL.Image.Image], | ||
boxes_batch: List[np.ndarray], | ||
iou_threshold: float = 0.2, | ||
) -> List[List[List[float]]]: | ||
"""Annotates images for the task of instance segmentation using the SlimSAM | ||
model. | ||
|
||
Args: | ||
images: The images to be annotated. | ||
boxes_batch: The bounding boxes of found objects. | ||
iou_threshold (float, optional): Intersection over union threshold for non-maximum suppression. Defaults to 0.2. | ||
|
||
Returns: | ||
List: A list containing the final segment masks represented as a polygon. | ||
""" | ||
final_segments = [] | ||
|
||
n = len(images) | ||
|
||
for i in range(n): | ||
boxes = boxes_batch[i].tolist() | ||
if len(boxes) == 0: | ||
final_segments.append([]) | ||
continue | ||
|
||
inputs = self.processor( | ||
images[i], input_boxes=[boxes], return_tensors="pt" | ||
).to(self.device) | ||
|
||
with torch.no_grad(): | ||
outputs = self.model(**inputs, return_dict=True) | ||
|
||
masks = self.processor.image_processor.post_process_masks( | ||
outputs.pred_masks.cpu(), | ||
inputs["original_sizes"].cpu(), | ||
inputs["reshaped_input_sizes"].cpu(), | ||
)[0] | ||
|
||
iou_scores = outputs.iou_scores.cpu() | ||
|
||
image_masks = [] | ||
for j in range(len(boxes)): | ||
keep_idx = iou_scores[0, j] >= iou_threshold | ||
filtered_masks = masks[j, keep_idx].cpu().float() | ||
final_masks = filtered_masks.permute(1, 2, 0) | ||
final_masks = final_masks.mean(axis=-1) | ||
final_masks = (final_masks > 0).int() | ||
final_masks = final_masks.numpy().astype(np.uint8) | ||
polygon = mask_to_polygon(final_masks) | ||
if len(polygon) != 0: | ||
image_masks.append(polygon) | ||
|
||
final_segments.append(image_masks) | ||
|
||
return final_segments | ||
|
||
def release(self, empty_cuda_cache: bool = False) -> None: | ||
"""Releases the model and optionally empties the CUDA cache. | ||
|
||
Args: | ||
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False. | ||
""" | ||
self.model = self.model.to("cpu") | ||
if empty_cuda_cache: | ||
with torch.no_grad(): | ||
torch.cuda.empty_cache() | ||
|
||
|
||
if __name__ == "__main__": | ||
import requests | ||
from PIL import Image | ||
|
||
url = "https://ultralytics.com/images/bus.jpg" | ||
im = Image.open(requests.get(url, stream=True).raw) | ||
annotator = SlimSAMAnnotator(device="cpu", size="large") | ||
final_segments = annotator.annotate_batch([im], [np.array([[3, 229, 559, 650]])]) | ||
print(len(final_segments), len(final_segments[0])) | ||
print(final_segments[0][0][:5]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SlimSAM doesn't support batched inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is that each image can have different number of detected objects, and in that case the batched inference isn't possible straight away , so that's why I implemented it per image. But now that you've mentioned it, I thought about it again and realized that we could "padd" the bboxes with dummy bboxes, so that we can have batch inference, I'm currrently testing it. Let me know @sokovninn, if you'd find this small hack better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. Dummy bboxes is a good solution. However, I am not sure if it will bring any boost in inference speed, but it is worth a try I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly, I'll test it and let you know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turned out not to be faster, so not gonna use it.