-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[onert/python] Introduce OptimizerRegistry
This commit introduces OptimizerRegistry. - create_optimizer : Create an optimizer instance by name - map_optimizer_to_enum : Map an optimizer instance to the appropriate enum value. ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
- Loading branch information
Showing
1 changed file
with
43 additions
and
0 deletions.
There are no files selected for viewing
43 changes: 43 additions & 0 deletions
43
runtime/onert/api/python/package/experimental/train/optimizer/registry.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from .adam import Adam | ||
from .sgd import SGD | ||
|
||
|
||
class OptimizerRegistry: | ||
""" | ||
Registry for creating optimizers by name. | ||
""" | ||
_optimizers = {"adam": Adam, "sgd": SGD} | ||
|
||
@staticmethod | ||
def create_optimizer(name): | ||
""" | ||
Create an optimizer instance by name. | ||
Args: | ||
name (str): Name of the optimizer. | ||
Returns: | ||
BaseOptimizer: Optimizer instance. | ||
""" | ||
if name not in OptimizerRegistry._optimizers: | ||
raise ValueError( | ||
f"Unknown Optimizer: {name}. Custom optimizer is not supported yet") | ||
return OptimizerRegistry._optimizers[name]() | ||
|
||
@staticmethod | ||
def map_optimizer_to_enum(optimizer_instance): | ||
""" | ||
Maps an optimizer instance to the appropriate enum value. | ||
Args: | ||
optimizer_instance (Optimizer): An instance of an optimizer. | ||
Returns: | ||
optimizer_type: Corresponding enum value for the optimizer. | ||
Raises: | ||
TypeError: If the optimizer_instance is not a recognized optimizer type. | ||
""" | ||
# Optimizer to Enum mapping | ||
optimizer_to_enum = {SGD: "SGD", Adam: "ADAM"} | ||
for optimizer_class, enum_value in optimizer_to_enum.items(): | ||
if isinstance(optimizer_instance, optimizer_class): | ||
return enum_value | ||
raise TypeError( | ||
f"Unsupported optimizer type: {type(optimizer_instance).__name__}. " | ||
f"Supported types are: {list(optimizer_to_enum.keys())}.") |