From d2591ce631df63e9f411aa7630ed48d2658bbf93 Mon Sep 17 00:00:00 2001 From: narumi Date: Mon, 20 May 2024 19:07:59 +0800 Subject: [PATCH] add flatten function --- mlconfig/__init__.py | 1 + mlconfig/conf.py | 16 ++++++++++++++++ tests/test_conf.py | 6 ++++++ 3 files changed, 23 insertions(+) diff --git a/mlconfig/__init__.py b/mlconfig/__init__.py index 354db64..324e22b 100644 --- a/mlconfig/__init__.py +++ b/mlconfig/__init__.py @@ -1,3 +1,4 @@ +from .conf import flatten from .conf import getcls from .conf import instantiate from .conf import load diff --git a/mlconfig/conf.py b/mlconfig/conf.py index f28b27b..c3c6df3 100644 --- a/mlconfig/conf.py +++ b/mlconfig/conf.py @@ -77,3 +77,19 @@ def instantiate(conf, *args, **kwargs): func_or_cls = getcls(conf) return func_or_cls(*args, **kwargs) + + +def flatten(data: dict, prefix: str | None = None, sep: str = ".") -> dict: + d = {} + + for key, value in data.items(): + if prefix is not None: + key = prefix + sep + key + + if isinstance(value, dict): + d.update(flatten(value, prefix=key)) + continue + + d[key] = value + + return d diff --git a/tests/test_conf.py b/tests/test_conf.py index ff4f55d..e00dcf1 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -1,5 +1,6 @@ import pytest +from mlconfig import flatten from mlconfig import getcls from mlconfig import instantiate from mlconfig import load @@ -55,3 +56,8 @@ def test_instantiate(conf, obj): def test_getcls(conf): assert getcls(conf["a"]) == Point + + +def test_flatten() -> None: + d = {"a": {"b": "c"}, "d": {"e": "f"}} + assert flatten(d) == {"a.b": "c", "d.e": "f"}