From 2c8aa656504ed1410099cd245f05a203fb5b3f78 Mon Sep 17 00:00:00 2001 From: ragmani Date: Wed, 22 Jan 2025 04:05:56 +0000 Subject: [PATCH] [onert/python] Introduce OptimizerRegistry This commit introduces OptimizerRegistry. - create_optimizer : Create an optimizer instance by name - map_optimizer_to_enum : Map an optimizer instance to the appropriate enum value. ONE-DCO-1.0-Signed-off-by: ragmani --- .../experimental/train/optimizer/registry.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 runtime/onert/api/python/package/experimental/train/optimizer/registry.py diff --git a/runtime/onert/api/python/package/experimental/train/optimizer/registry.py b/runtime/onert/api/python/package/experimental/train/optimizer/registry.py new file mode 100644 index 00000000000..3e0d2ce7f55 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/optimizer/registry.py @@ -0,0 +1,43 @@ +from .adam import Adam +from .sgd import SGD + + +class OptimizerRegistry: + """ + Registry for creating optimizers by name. + """ + _optimizers = {"adam": Adam, "sgd": SGD} + + @staticmethod + def create_optimizer(name): + """ + Create an optimizer instance by name. + Args: + name (str): Name of the optimizer. + Returns: + BaseOptimizer: Optimizer instance. + """ + if name not in OptimizerRegistry._optimizers: + raise ValueError( + f"Unknown Optimizer: {name}. Custom optimizer is not supported yet") + return OptimizerRegistry._optimizers[name]() + + @staticmethod + def map_optimizer_to_enum(optimizer_instance): + """ + Maps an optimizer instance to the appropriate enum value. + Args: + optimizer_instance (Optimizer): An instance of an optimizer. + Returns: + optimizer_type: Corresponding enum value for the optimizer. + Raises: + TypeError: If the optimizer_instance is not a recognized optimizer type. + """ + # Optimizer to Enum mapping + optimizer_to_enum = {SGD: "SGD", Adam: "ADAM"} + for optimizer_class, enum_value in optimizer_to_enum.items(): + if isinstance(optimizer_instance, optimizer_class): + return enum_value + raise TypeError( + f"Unsupported optimizer type: {type(optimizer_instance).__name__}. " + f"Supported types are: {list(optimizer_to_enum.keys())}.")