Skip to content

Commit

Permalink
[onert/python] Introduce OptimizerRegistry
Browse files Browse the repository at this point in the history
This commit introduces OptimizerRegistry.
  - create_optimizer : Create an optimizer instance by name
  - map_optimizer_to_enum : Map an optimizer instance to the appropriate enum value.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani committed Jan 22, 2025
1 parent e6091bc commit 2c8aa65
Showing 1 changed file with 43 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from .adam import Adam
from .sgd import SGD


class OptimizerRegistry:
"""
Registry for creating optimizers by name.
"""
_optimizers = {"adam": Adam, "sgd": SGD}

@staticmethod
def create_optimizer(name):
"""
Create an optimizer instance by name.
Args:
name (str): Name of the optimizer.
Returns:
BaseOptimizer: Optimizer instance.
"""
if name not in OptimizerRegistry._optimizers:
raise ValueError(
f"Unknown Optimizer: {name}. Custom optimizer is not supported yet")
return OptimizerRegistry._optimizers[name]()

@staticmethod
def map_optimizer_to_enum(optimizer_instance):
"""
Maps an optimizer instance to the appropriate enum value.
Args:
optimizer_instance (Optimizer): An instance of an optimizer.
Returns:
optimizer_type: Corresponding enum value for the optimizer.
Raises:
TypeError: If the optimizer_instance is not a recognized optimizer type.
"""
# Optimizer to Enum mapping
optimizer_to_enum = {SGD: "SGD", Adam: "ADAM"}
for optimizer_class, enum_value in optimizer_to_enum.items():
if isinstance(optimizer_instance, optimizer_class):
return enum_value
raise TypeError(
f"Unsupported optimizer type: {type(optimizer_instance).__name__}. "
f"Supported types are: {list(optimizer_to_enum.keys())}.")

0 comments on commit 2c8aa65

Please sign in to comment.