Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
- improve method names and exceptions
- implement dummy methods in detector_simulation
  • Loading branch information
denniswittich committed Dec 12, 2024
1 parent 1e24778 commit 5cbd68b
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 79 deletions.
34 changes: 27 additions & 7 deletions rosys/vision/detector.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from .detections import Category
from dataclasses import field

from dataclasses import dataclass, field
import abc
import logging
from datetime import datetime
from enum import Enum
from uuid import uuid4
from typing import Literal

from ..event import Event
from .detections import Detections
from .detections import Detections, Category
from .image import Image
from .uploads import Uploads

from dataclasses import dataclass


class Autoupload(Enum):
"""Configures the auto-submitting of images to the Learning Loop"""
Expand Down Expand Up @@ -121,12 +120,33 @@ async def upload(self,
"""

@abc.abstractmethod
async def fetch_detector_information(self) -> DetectorInfo:
"""Get information about the detector.
async def fetch_detector_info(self) -> DetectorInfo:
"""Retrieve information about the detector.
Returns:
DetectorInfo: information about the detector.
Raises:
DetectorException: if the about information cannot be retrieved.
"""

async def fetch_model_version_info(self) -> ModelVersioningInfo:
"""Retrieve information about the model version and versioning mode.
Returns:
ModelVersioningInfo: the information about the model versioning as data class.
Raises:
DetectorException: if the detector is not connected or the information cannot be retrieved.
"""

async def set_model_version(self, version: Literal['follow_loop', 'pause'] | str) -> None:
"""Set the model version or versioning mode.
Set to 'follow_loop' to automatically update the model version to the latest version in the learning loop.
Set to 'pause' to stop automatic updates and keep the current model version.
Set to a version number (e.g. '1.2') to use a specific version.
Raises:
DetectorException: if the version control mode is not valid or the version could not be set.
"""
107 changes: 39 additions & 68 deletions rosys/vision/detector_hardware.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,6 @@ async def detect(self,
source: str | None = None,
creation_date: datetime | str | None = None,
) -> Detections | None:
'''Run inference on the image.
If the detector is busy, the task is put in a lifo queue with a size of 1.
This means that if the detector is busy and `detect` is called again,
the current image is not processed and None is returned.
Emits:
NEW_DETECTIONS: when the inference was successful.
Returns:
Detections: the detections found in the image or None if the detector is too busy.
Raises:
DetectorException: if the detection fails.
'''

if not self.is_connected:
self.log.error('Detection failed: detector is not connected')
raise DetectorException('detector is not connected')
Expand All @@ -176,7 +160,6 @@ async def _detect(self,
source: str | None = None,
creation_date: datetime | str | None = None,
) -> Detections:

try:
response = await self.sio.call('detect', {
'image': image.data,
Expand All @@ -186,12 +169,12 @@ async def _detect(self,
'source': source,
'creation_date': _creation_date_to_isoformat(creation_date),
}, timeout=3)
except socketio.exceptions.TimeoutError as e:
except socketio.exceptions.TimeoutError:
self.timeout_count += 1
if self.timeout_count > 5 and self.auto_disconnect:
self.log.error('Detection timed out 5 times in a row. Disconnecting from detector %s', self.name)
await self.disconnect()
raise DetectorException('Detection timed out') from e
raise DetectorException('Detection timeout') # pylint: disable=raise-missing-from
except Exception as e:
raise DetectorException('Detection failed') from e

Expand All @@ -217,26 +200,17 @@ async def _detect(self,
def __str__(self) -> str:
return f'{type(self).__name__} ({"connected" if self.is_connected else "disconnected"})'

async def fetch_detector_information(self) -> DetectorInfo:
'''Retrieve information about the detector.
Returns:
DetectorInfo: the information about the detector as data class.
Raises:
DetectorException: if the detector is not connected or the information cannot be retrieved.
'''

async def fetch_detector_info(self) -> DetectorInfo:
if not self.is_connected:
raise DetectorException('detector is not connected')

try:
response = await self.sio.call('about', timeout=5)
except socketio.exceptions.TimeoutError as e:
self.log.error('could not get info for detector %s', self.name)
raise DetectorException('could not get info due to timeout') from e
except socketio.exceptions.TimeoutError:
self.log.error('Communication timeout for detector %s', self.name)
raise DetectorException('Communication timeut') # pylint: disable=raise-missing-from
except Exception as e:
raise DetectorException('could not get info') from e
raise DetectorException('Communication failed') from e

if not isinstance(response, dict):
raise DetectorException('Invalid response from detector')
Expand All @@ -256,27 +230,20 @@ async def fetch_detector_information(self) -> DetectorInfo:
categories=categories,
resolution=model_info_dict.get('resolution'),
target_version=response.get('target_model'),
version_control=response.get('version_control'))
version_control=response['version_control'])
except Exception as e:
raise DetectorException('Failed to parse detector info') from e

return result

async def fetch_model_versioning_information(self) -> ModelVersioningInfo:
'''Retrieve information about the model versioning.
Returns:
ModelVersioningInfo: the information about the model versioning as data class.
Raises:
DetectorException: if the detector is not connected or the information cannot be retrieved.
'''
async def fetch_model_version_info(self) -> ModelVersioningInfo:
try:
response = await self.sio.call('get_model_version', timeout=5)
except socketio.exceptions.TimeoutError as e:
raise DetectorException('could not get info due to timeout') from e
except socketio.exceptions.TimeoutError:
self.log.error('Communication timeout for detector %s', self.name)
raise DetectorException('Communication timed out') # pylint: disable=raise-missing-from
except Exception as e:
raise DetectorException('could not get info') from e
raise DetectorException('Communication failed') from e

if not isinstance(response, dict):
raise DetectorException('Invalid response from detector')
Expand All @@ -293,38 +260,42 @@ async def fetch_model_versioning_information(self) -> ModelVersioningInfo:

return result

async def set_model_version_mode(self, version_control_mode: Literal['follow_loop', 'pause'] | str) -> None:
'''Set the model version mode.
Raises:
DetectorException: if the version control mode is not valid or the version could not be set.
'''
if version_control_mode not in ['follow_loop', 'pause']:
if not version_control_mode.replace('.', '').isdigit():
async def set_model_version(self, version: Literal['follow_loop', 'pause'] | str) -> None:
if version not in ['follow_loop', 'pause']:
if not version.replace('.', '').isdigit():
raise DetectorException(
f'invalid version control mode: {version_control_mode} (allowed: follow_loop, pause or a version number like 1.2)')
f'invalid version control mode: {version} (allowed: follow_loop, pause or a version number like 1.2)')

try:
response = await self.sio.call('set_model_version_mode', version_control_mode, timeout=5)
except socketio.exceptions.TimeoutError as e:
self.log.error('could not get info for detector %s', self.name)
raise DetectorException('could not get info due to timeout') from e
response = await self.sio.call('set_model_version_mode', version, timeout=5)
except socketio.exceptions.TimeoutError:
self.log.error('Communication timeout for detector %s', self.name)
raise DetectorException('Communication timeout') # pylint: disable=raise-missing-from
except Exception as e:
self.log.error('could not get info for detector %s: %s', self.name, e)
raise DetectorException('could not get info') from e
self.log.error('Communication failed for detector %s: %s', self.name, e)
raise DetectorException('Communication failed') from e

if not isinstance(response, dict) or response.get('status') != 'OK':
if not isinstance(response, dict):
raise DetectorException('Failed to set model version mode')

async def trigger_soft_reload(self) -> None:
if response.get('status') != 'OK':
error_message = response.get('error', 'unknown error')
raise DetectorException(f'Failed to set model version mode: {error_message}')

async def soft_reload(self) -> None:
"""Trigger a soft reload of the detector.
Raises:
DetectorException: if the communication fails.
"""
try:
await self.sio.call('soft_reload', timeout=5)
except socketio.exceptions.TimeoutError as e:
self.log.error('could not get info for detector %s', self.name)
raise DetectorException('could not get info due to timeout') from e
except socketio.exceptions.TimeoutError:
self.log.error('Communication timeout for detector %s', self.name)
raise DetectorException('Communication timeout') # pylint: disable=raise-missing-from
except Exception as e:
self.log.error('could not get info for detector %s: %s', self.name, e)
raise DetectorException('could not get info') from e
self.log.error('Communication failed for detector %s: %s', self.name, e)
raise DetectorException('Communication failed') from e


def _box_detections_to_int(detections: list[dict]) -> list[dict]:
Expand Down
18 changes: 14 additions & 4 deletions rosys/vision/detector_simulation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from dataclasses import dataclass, field
from datetime import datetime
from uuid import uuid4

from typing import Literal
import numpy as np

from .. import rosys
from ..geometry import Point3d
from .calibratable_camera_provider import CalibratableCameraProvider
from .camera import CalibratableCamera
from .detections import BoxDetection, Detections, PointDetection, Category
from .detector import Autoupload, Detector, DetectorInfo
from .detections import BoxDetection, Detections, PointDetection
from .detector import Autoupload, Detector, DetectorInfo, ModelVersioningInfo
from .image import Image


Expand Down Expand Up @@ -82,12 +82,22 @@ async def upload(self,
) -> None:
self.log.info('Uploading %s', image.id)

async def fetch_detector_information(self) -> DetectorInfo:
async def fetch_detector_info(self) -> DetectorInfo:
return DetectorInfo(operation_mode='simulation',
version_control='pause',
state=None,
categories=[])

async def fetch_model_version_info(self) -> ModelVersioningInfo:
return ModelVersioningInfo(current_version='None',
target_version='None',
loop_version='None',
local_versions=[],
version_control='pause')

async def set_model_version(self, version: Literal['follow_loop', 'pause'] | str) -> None:
self.log.info('Setting model version to %s', version)

def update_simulated_objects(self, image: Image) -> None:
pass

Expand Down

0 comments on commit 5cbd68b

Please sign in to comment.