Skip to content

Commit

Permalink
add flatten function
Browse files Browse the repository at this point in the history
  • Loading branch information
narumiruna committed May 20, 2024
1 parent 0d7d2af commit aaddbca
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlconfig/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .conf import flatten
from .conf import getcls
from .conf import instantiate
from .conf import load
Expand Down
17 changes: 17 additions & 0 deletions mlconfig/conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import functools
from typing import Optional

from omegaconf import OmegaConf

Expand Down Expand Up @@ -77,3 +78,19 @@ def instantiate(conf, *args, **kwargs):
func_or_cls = getcls(conf)

return func_or_cls(*args, **kwargs)


def flatten(data: dict, prefix: Optional[str] = None, sep: str = ".") -> dict:
d = {}

for key, value in data.items():
if prefix is not None:
key = prefix + sep + key

if isinstance(value, dict):
d.update(flatten(value, prefix=key))
continue

d[key] = value

return d
14 changes: 14 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from mlconfig import flatten
from mlconfig import getcls
from mlconfig import instantiate
from mlconfig import load
Expand Down Expand Up @@ -55,3 +56,16 @@ def test_instantiate(conf, obj):

def test_getcls(conf):
assert getcls(conf["a"]) == Point


@pytest.mark.parametrize(
"test_input,expected",
[
({}, {}),
({"a": "b"}, {"a": "b"}),
({"a": {"b": {"c": "d"}}}, {"a.b.c": "d"}),
({"a": {"b": "c"}, "d": {"e": "f"}}, {"a.b": "c", "d.e": "f"}),
],
)
def test_flatten(test_input: dict, expected: dict) -> None:
assert flatten(test_input) == expected

0 comments on commit aaddbca

Please sign in to comment.