diff --git a/rosys/vision/detector.py b/rosys/vision/detector.py index 6efac13d..e2eae157 100644 --- a/rosys/vision/detector.py +++ b/rosys/vision/detector.py @@ -1,18 +1,17 @@ -from .detections import Category -from dataclasses import field + +from dataclasses import dataclass, field import abc import logging from datetime import datetime from enum import Enum from uuid import uuid4 +from typing import Literal from ..event import Event -from .detections import Detections +from .detections import Detections, Category from .image import Image from .uploads import Uploads -from dataclasses import dataclass - class Autoupload(Enum): """Configures the auto-submitting of images to the Learning Loop""" @@ -121,8 +120,8 @@ async def upload(self, """ @abc.abstractmethod - async def fetch_detector_information(self) -> DetectorInfo: - """Get information about the detector. + async def fetch_detector_info(self) -> DetectorInfo: + """Retrieve information about the detector. Returns: DetectorInfo: information about the detector. @@ -130,3 +129,24 @@ async def fetch_detector_information(self) -> DetectorInfo: Raises: DetectorException: if the about information cannot be retrieved. """ + + async def fetch_model_version_info(self) -> ModelVersioningInfo: + """Retrieve information about the model version and versioning mode. + + Returns: + ModelVersioningInfo: the information about the model versioning as data class. + + Raises: + DetectorException: if the detector is not connected or the information cannot be retrieved. + """ + + async def set_model_version(self, version: Literal['follow_loop', 'pause'] | str) -> None: + """Set the model version or versioning mode. + + Set to 'follow_loop' to automatically update the model version to the latest version in the learning loop. + Set to 'pause' to stop automatic updates and keep the current model version. + Set to a version number (e.g. '1.2') to use a specific version. + + Raises: + DetectorException: if the version control mode is not valid or the version could not be set. + """ diff --git a/rosys/vision/detector_hardware.py b/rosys/vision/detector_hardware.py index 3ac71c35..3321d35a 100644 --- a/rosys/vision/detector_hardware.py +++ b/rosys/vision/detector_hardware.py @@ -141,22 +141,6 @@ async def detect(self, source: str | None = None, creation_date: datetime | str | None = None, ) -> Detections | None: - '''Run inference on the image. - - If the detector is busy, the task is put in a lifo queue with a size of 1. - This means that if the detector is busy and `detect` is called again, - the current image is not processed and None is returned. - - Emits: - NEW_DETECTIONS: when the inference was successful. - - Returns: - Detections: the detections found in the image or None if the detector is too busy. - - Raises: - DetectorException: if the detection fails. - ''' - if not self.is_connected: self.log.error('Detection failed: detector is not connected') raise DetectorException('detector is not connected') @@ -176,7 +160,6 @@ async def _detect(self, source: str | None = None, creation_date: datetime | str | None = None, ) -> Detections: - try: response = await self.sio.call('detect', { 'image': image.data, @@ -186,12 +169,12 @@ async def _detect(self, 'source': source, 'creation_date': _creation_date_to_isoformat(creation_date), }, timeout=3) - except socketio.exceptions.TimeoutError as e: + except socketio.exceptions.TimeoutError: self.timeout_count += 1 if self.timeout_count > 5 and self.auto_disconnect: self.log.error('Detection timed out 5 times in a row. Disconnecting from detector %s', self.name) await self.disconnect() - raise DetectorException('Detection timed out') from e + raise DetectorException('Detection timeout') # pylint: disable=raise-missing-from except Exception as e: raise DetectorException('Detection failed') from e @@ -217,26 +200,17 @@ async def _detect(self, def __str__(self) -> str: return f'{type(self).__name__} ({"connected" if self.is_connected else "disconnected"})' - async def fetch_detector_information(self) -> DetectorInfo: - '''Retrieve information about the detector. - - Returns: - DetectorInfo: the information about the detector as data class. - - Raises: - DetectorException: if the detector is not connected or the information cannot be retrieved. - ''' - + async def fetch_detector_info(self) -> DetectorInfo: if not self.is_connected: raise DetectorException('detector is not connected') try: response = await self.sio.call('about', timeout=5) - except socketio.exceptions.TimeoutError as e: - self.log.error('could not get info for detector %s', self.name) - raise DetectorException('could not get info due to timeout') from e + except socketio.exceptions.TimeoutError: + self.log.error('Communication timeout for detector %s', self.name) + raise DetectorException('Communication timeut') # pylint: disable=raise-missing-from except Exception as e: - raise DetectorException('could not get info') from e + raise DetectorException('Communication failed') from e if not isinstance(response, dict): raise DetectorException('Invalid response from detector') @@ -256,27 +230,20 @@ async def fetch_detector_information(self) -> DetectorInfo: categories=categories, resolution=model_info_dict.get('resolution'), target_version=response.get('target_model'), - version_control=response.get('version_control')) + version_control=response['version_control']) except Exception as e: raise DetectorException('Failed to parse detector info') from e return result - async def fetch_model_versioning_information(self) -> ModelVersioningInfo: - '''Retrieve information about the model versioning. - - Returns: - ModelVersioningInfo: the information about the model versioning as data class. - - Raises: - DetectorException: if the detector is not connected or the information cannot be retrieved. - ''' + async def fetch_model_version_info(self) -> ModelVersioningInfo: try: response = await self.sio.call('get_model_version', timeout=5) - except socketio.exceptions.TimeoutError as e: - raise DetectorException('could not get info due to timeout') from e + except socketio.exceptions.TimeoutError: + self.log.error('Communication timeout for detector %s', self.name) + raise DetectorException('Communication timed out') # pylint: disable=raise-missing-from except Exception as e: - raise DetectorException('could not get info') from e + raise DetectorException('Communication failed') from e if not isinstance(response, dict): raise DetectorException('Invalid response from detector') @@ -293,38 +260,42 @@ async def fetch_model_versioning_information(self) -> ModelVersioningInfo: return result - async def set_model_version_mode(self, version_control_mode: Literal['follow_loop', 'pause'] | str) -> None: - '''Set the model version mode. - - Raises: - DetectorException: if the version control mode is not valid or the version could not be set. - ''' - if version_control_mode not in ['follow_loop', 'pause']: - if not version_control_mode.replace('.', '').isdigit(): + async def set_model_version(self, version: Literal['follow_loop', 'pause'] | str) -> None: + if version not in ['follow_loop', 'pause']: + if not version.replace('.', '').isdigit(): raise DetectorException( - f'invalid version control mode: {version_control_mode} (allowed: follow_loop, pause or a version number like 1.2)') + f'invalid version control mode: {version} (allowed: follow_loop, pause or a version number like 1.2)') try: - response = await self.sio.call('set_model_version_mode', version_control_mode, timeout=5) - except socketio.exceptions.TimeoutError as e: - self.log.error('could not get info for detector %s', self.name) - raise DetectorException('could not get info due to timeout') from e + response = await self.sio.call('set_model_version_mode', version, timeout=5) + except socketio.exceptions.TimeoutError: + self.log.error('Communication timeout for detector %s', self.name) + raise DetectorException('Communication timeout') # pylint: disable=raise-missing-from except Exception as e: - self.log.error('could not get info for detector %s: %s', self.name, e) - raise DetectorException('could not get info') from e + self.log.error('Communication failed for detector %s: %s', self.name, e) + raise DetectorException('Communication failed') from e - if not isinstance(response, dict) or response.get('status') != 'OK': + if not isinstance(response, dict): raise DetectorException('Failed to set model version mode') - async def trigger_soft_reload(self) -> None: + if response.get('status') != 'OK': + error_message = response.get('error', 'unknown error') + raise DetectorException(f'Failed to set model version mode: {error_message}') + + async def soft_reload(self) -> None: + """Trigger a soft reload of the detector. + + Raises: + DetectorException: if the communication fails. + """ try: await self.sio.call('soft_reload', timeout=5) - except socketio.exceptions.TimeoutError as e: - self.log.error('could not get info for detector %s', self.name) - raise DetectorException('could not get info due to timeout') from e + except socketio.exceptions.TimeoutError: + self.log.error('Communication timeout for detector %s', self.name) + raise DetectorException('Communication timeout') # pylint: disable=raise-missing-from except Exception as e: - self.log.error('could not get info for detector %s: %s', self.name, e) - raise DetectorException('could not get info') from e + self.log.error('Communication failed for detector %s: %s', self.name, e) + raise DetectorException('Communication failed') from e def _box_detections_to_int(detections: list[dict]) -> list[dict]: diff --git a/rosys/vision/detector_simulation.py b/rosys/vision/detector_simulation.py index 82adf7c6..d94ad97c 100644 --- a/rosys/vision/detector_simulation.py +++ b/rosys/vision/detector_simulation.py @@ -1,15 +1,15 @@ from dataclasses import dataclass, field from datetime import datetime from uuid import uuid4 - +from typing import Literal import numpy as np from .. import rosys from ..geometry import Point3d from .calibratable_camera_provider import CalibratableCameraProvider from .camera import CalibratableCamera -from .detections import BoxDetection, Detections, PointDetection, Category -from .detector import Autoupload, Detector, DetectorInfo +from .detections import BoxDetection, Detections, PointDetection +from .detector import Autoupload, Detector, DetectorInfo, ModelVersioningInfo from .image import Image @@ -82,12 +82,22 @@ async def upload(self, ) -> None: self.log.info('Uploading %s', image.id) - async def fetch_detector_information(self) -> DetectorInfo: + async def fetch_detector_info(self) -> DetectorInfo: return DetectorInfo(operation_mode='simulation', version_control='pause', state=None, categories=[]) + async def fetch_model_version_info(self) -> ModelVersioningInfo: + return ModelVersioningInfo(current_version='None', + target_version='None', + loop_version='None', + local_versions=[], + version_control='pause') + + async def set_model_version(self, version: Literal['follow_loop', 'pause'] | str) -> None: + self.log.info('Setting model version to %s', version) + def update_simulated_objects(self, image: Image) -> None: pass