Skip to content

Commit

Permalink
Merge pull request #12 from minhnh/refactor/separate_ros_img_classifi…
Browse files Browse the repository at this point in the history
…cation

[Refactor] separate image message & OpenCV image classification, use dictionary for classes
  • Loading branch information
minhnh authored Oct 16, 2019
2 parents 69aed47 + f302aa9 commit eb0c545
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 33 deletions.
7 changes: 4 additions & 3 deletions docs/python_package.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,16 @@ An extension of `SceneDetectionActionServer` which uses `SingleImageDetectionHan
from an image extracted from a `sensor_msgs/PointCloud2` message, while also fitting planes in the clouds.

## [`utils.py`](../ros/src/mas_perception_libs/utils.py)
* `get_classes_in_data_dir`: Returns a list of strings as class names for a directory. This directory structure
* `get_classes_in_data_dir`: Returns a dictionary mapping from indices to classes as names of top level directories.
This directory structure
```
data
├── class_1
└── class_2
```
should returns
should return
```
['class_1', 'class_2']
{0: 'class_1', 1: 'class_2'}
```
when called on `data`.
* `process_image_message`: Converts `sensor_msgs/Image` to CV image, then resizes and/or runs a preprocessing function
Expand Down
54 changes: 29 additions & 25 deletions ros/src/mas_perception_libs/image_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import yaml
from abc import ABCMeta, abstractmethod
import numpy as np
from cv_bridge import CvBridge
Expand All @@ -11,7 +12,7 @@ class ImageClassifier(object):
"""
__metaclass__ = ABCMeta

_classes = None # type: list
_classes = None # type: dict

def __init__(self, **kwargs):
# read information on classes, either directly, via a file, or from a data directory
Expand All @@ -20,7 +21,11 @@ def __init__(self, **kwargs):
if self._classes is None:
class_file = kwargs.get('class_file', None)
if class_file is not None and os.path.exists(class_file):
self._classes = ImageClassifier.read_classes_from_file(class_file)
with open(class_file) as infile:
if yaml.__version__ < '5.1':
self._classes = yaml.load(infile)
else:
self._classes = yaml.load(infile, Loader=yaml.FullLoader)

if self._classes is None:
data_dir = kwargs.get('data_dir', None)
Expand All @@ -34,7 +39,7 @@ def __init__(self, **kwargs):
@property
def classes(self):
"""
list of strings containing class names TODO(minhnh): make this dictionary from predicted class to class name
dictionary mapping from predicted numeric class value to class name
"""
return self._classes

Expand All @@ -52,17 +57,6 @@ def classify(self, image_messages):
"""
pass

@staticmethod
def write_classes_to_file(classes, outfile_path):
with open(outfile_path, 'w') as outfile:
outfile.write('\n'.join(classes))

@staticmethod
def read_classes_from_file(infile):
with open(infile) as f:
content = f.readlines()
return [x.strip() for x in content]


class ImageClassifierTest(ImageClassifier):
"""
Expand Down Expand Up @@ -113,26 +107,36 @@ def __init__(self, **kwargs):
# CvBridge for ROS image conversion
self._cv_bridge = CvBridge()

def classify(self, image_messages):
if len(image_messages) == 0:
return [], [], []

np_images = [process_image_message(msg, self._cv_bridge, self._target_size, self._img_preprocess_func)
for msg in image_messages]

image_array = []
def classify_np_images(self, np_images):
"""
Classify NumPy images
"""
image_tensor = []
indices = []
for i in range(len(np_images)):
if np_images[i] is None:
# skip broken images
continue

image_array.append(np_images[i])
image_tensor.append(np_images[i])
indices.append(i)

image_array = np.array(image_array)
preds = self._model.predict(image_array)
image_tensor = np.array(image_tensor)
preds = self._model.predict(image_tensor)
class_indices = np.argmax(preds, axis=1)
confidences = np.max(preds, axis=1)
predicted_classes = [self._classes[i] for i in class_indices]

return indices, predicted_classes, confidences

def classify(self, image_messages):
"""
Classify ROS `sensor_msgs/Image` messages
"""
if len(image_messages) == 0:
return [], [], []

np_images = [process_image_message(msg, self._cv_bridge, self._target_size, self._img_preprocess_func)
for msg in image_messages]

return self.classify_np_images(np_images)
2 changes: 1 addition & 1 deletion ros/src/mas_perception_libs/image_recognition_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def handle_recognize_image(self, req):
rospy.loginfo('number of images to recognize: ' + str(len(req.images)))
if req.model_name not in self._classifiers:
model_path = os.path.join(self._model_dir, req.model_name + '.h5')
class_file = os.path.join(self._model_dir, req.model_name + '.txt')
class_file = os.path.join(self._model_dir, req.model_name + '.yml')
rospy.loginfo('recognition model path: ' + model_path)
rospy.loginfo('recognition class file path: ' + class_file)
self._classifiers[req.model_name] = self._classifier_class(model_path=model_path, class_file=class_file)
Expand Down
10 changes: 6 additions & 4 deletions ros/src/mas_perception_libs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,16 @@ def get_bag_file_msg_by_type(bag_file_path, msg_type):
def get_classes_in_data_dir(data_dir):
"""
:type data_dir: str
:return: list of classes as names of top level directories
:return: dictionary mapping from indices to classes as names of top level directories
"""
classes = []
class_dict = {}
index = 0
for subdir in sorted(os.listdir(data_dir)):
if os.path.isdir(os.path.join(data_dir, subdir)):
classes.append(subdir)
class_dict[index] = subdir
index += 1

return classes
return class_dict


def process_image_message(image_msg, cv_bridge, target_size=None, func_preprocess_img=None):
Expand Down

0 comments on commit eb0c545

Please sign in to comment.