From 97d6446a2dea91c67f5f39335e0509573ccb8cdd Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Tue, 4 Feb 2025 09:19:35 -0800 Subject: [PATCH] pg.typing to support `__future__.annotations`. With this CL, PyGlove is compatible with `__future__.annotations`. For example: ```python from __future__ import annotations class A(pg.Object): child: A | None ``` PiperOrigin-RevId: 723114026 --- pyglove/core/symbolic/object.py | 5 +- pyglove/core/typing/__init__.py | 3 + pyglove/core/typing/annotation_conversion.py | 202 +++++++++++++++++- .../core/typing/annotation_conversion_test.py | 149 ++++++++++++- pyglove/core/typing/annotation_future_test.py | 135 ++++++++++++ pyglove/core/typing/class_schema.py | 66 ++++-- pyglove/core/typing/class_schema_test.py | 22 +- 7 files changed, 538 insertions(+), 44 deletions(-) create mode 100644 pyglove/core/typing/annotation_future_test.py diff --git a/pyglove/core/symbolic/object.py b/pyglove/core/symbolic/object.py index 52dbabe..10c6073 100644 --- a/pyglove/core/symbolic/object.py +++ b/pyglove/core/symbolic/object.py @@ -16,6 +16,7 @@ import abc import functools import inspect +import sys import typing from typing import Any, Dict, Iterator, List, Optional, Sequence, Union @@ -154,7 +155,9 @@ def _infer_fields_from_annotations(cls) -> List[pg_typing.Field]: if typing.get_origin(attr_annotation) is typing.ClassVar: continue - field = pg_typing.Field.from_annotation(key, attr_annotation) + field = pg_typing.Field.from_annotation( + key, attr_annotation, parent_module=sys.modules[cls.__module__] + ) if isinstance(key, pg_typing.ConstStrKey): attr_value = cls.__dict__.get(attr_name, pg_typing.MISSING_VALUE) if attr_value != pg_typing.MISSING_VALUE: diff --git a/pyglove/core/typing/__init__.py b/pyglove/core/typing/__init__.py index 2c332af..48a0b18 100644 --- a/pyglove/core/typing/__init__.py +++ b/pyglove/core/typing/__init__.py @@ -375,6 +375,9 @@ class Foo(pg.Object): # Interface for custom typing. from pyglove.core.typing.custom_typing import CustomTyping +# Annotation conversion +from pyglove.core.typing.annotation_conversion import annotation_from_str + # Callable signature. from pyglove.core.typing.callable_signature import Argument from pyglove.core.typing.callable_signature import CallableType diff --git a/pyglove/core/typing/annotation_conversion.py b/pyglove/core/typing/annotation_conversion.py index abde706..c7f1c99 100644 --- a/pyglove/core/typing/annotation_conversion.py +++ b/pyglove/core/typing/annotation_conversion.py @@ -13,11 +13,13 @@ # limitations under the License. """Conversion from annotations to PyGlove value specs.""" +import builtins import collections import inspect import types import typing +from pyglove.core import coding from pyglove.core import utils from pyglove.core.typing import annotated from pyglove.core.typing import class_schema @@ -35,6 +37,191 @@ _UnionType = getattr(types, 'UnionType', None) # pylint: disable=invalid-name +def annotation_from_str( + annotation_str: str, + parent_module: typing.Optional[types.ModuleType] = None, + ) -> typing.Any: + """Parses annotations from str. + + BNF for PyType annotations: + + ``` + ::= | "|" + ::= | + + ::= "Literal" + ::= "[""]" (parsed by `pg.coding.evaluate`) + + ::= | "[""]" + ::= | "," + ::= "[""]" | + ::= 'aAz_.1-9' + ``` + + Args: + annotation_str: String form of type annotations. E.g. "list[str]" + parent_module: The module where the annotation was defined. + + Returns: + Object form of the annotation. + + Raises: + SyntaxError: If the annotation string is invalid. + """ + s = annotation_str + context = dict(pos=0) + + def _eof() -> bool: + return context['pos'] == len(s) + + def _pos() -> int: + return context['pos'] + + def _next(n: int = 1, offset: int = 0) -> str: + if _eof(): + return '' + return s[_pos() + offset:_pos() + offset + n] + + def _advance(n: int) -> None: + context['pos'] += n + + def _error_illustration() -> str: + return f'{s}\n{" " * _pos()}' + '^' + + def _match(ch) -> bool: + if _next(len(ch)) == ch: + _advance(len(ch)) + return True + return False + + def _skip_whitespaces() -> None: + while _next() in ' \t': + _advance(1) + + def _maybe_union(): + t = _type() + while not _eof(): + _skip_whitespaces() + if _match('|'): + t = t | _type() + else: + break + return t + + def _type(): + type_id = _type_id() + t = _resolve(type_id) + if t is typing.Literal: + return t[_literal_params()] + elif _match('['): + arg = _type_arg() + if not _match(']'): + raise SyntaxError( + f'Expected "]" at position {_pos()}.\n\n' + _error_illustration() + ) + return t[arg] + return t + + def _literal_params(): + if not _match('['): + raise SyntaxError( + f'Expected "[" at position {_pos()}.\n\n' + _error_illustration() + ) + arg_start = _pos() + in_str = False + escape_mode = False + num_open_bracket = 1 + + while num_open_bracket > 0: + ch = _next() + if _eof(): + raise SyntaxError( + f'Unexpected end of annotation at position {_pos()}.\n\n' + + _error_illustration() + ) + if ch == '\\': + escape_mode = not escape_mode + else: + escape_mode = False + + if ch == "'" and not escape_mode: + in_str = not in_str + elif not in_str: + if ch == '[': + num_open_bracket += 1 + elif ch == ']': + num_open_bracket -= 1 + _advance(1) + + arg_str = s[arg_start:_pos() - 1] + return coding.evaluate( + '(' + arg_str + ')', permission=coding.CodePermission.BASIC + ) + + def _type_arg(): + t_args = [] + t_args.append(_maybe_type_list()) + while _match(','): + t_args.append(_maybe_type_list()) + return tuple(t_args) if len(t_args) > 1 else t_args[0] + + def _maybe_type_list(): + if _match('['): + ret = _type_arg() + if not _match(']'): + raise SyntaxError( + f'Expected "]" at position {_pos()}.\n\n' + _error_illustration() + ) + return list(ret) if isinstance(ret, tuple) else [ret] + return _maybe_union() + + def _type_id() -> str: + _skip_whitespaces() + if _match('...'): + return '...' + start = _pos() + while not _eof(): + c = _next() + if c.isalnum() or c in '_.': + _advance(1) + else: + break + t_id = s[start:_pos()] + if not all(x.isidentifier() for x in t_id.split('.')): + raise SyntaxError( + f'Expected type identifier, got {t_id!r} at position {start}.\n\n' + + _error_illustration() + ) + return t_id + + def _resolve(type_id: str): + def _resolve_name(name: str, parent_obj: typing.Any): + if name == 'None': + return None + if parent_obj is not None and hasattr(parent_obj, name): + return getattr(parent_obj, name) + if hasattr(builtins, name): + return getattr(builtins, name) + if type_id == '...': + return ... + return utils.MISSING_VALUE + parent_obj = parent_module + for name in type_id.split('.'): + parent_obj = _resolve_name(name, parent_obj) + if parent_obj == utils.MISSING_VALUE: + return typing.ForwardRef( # pytype: disable=not-callable + type_id, False, parent_module + ) + return parent_obj + + root = _maybe_union() + if _pos() != len(s): + raise SyntaxError( + 'Unexpected end of annotation.\n\n' + _error_illustration() + ) + return root + + def _field_from_annotation( key: typing.Union[str, class_schema.KeySpec], annotation: typing.Any, @@ -91,7 +278,12 @@ def _value_spec_from_type_annotation( parent_module: typing.Optional[types.ModuleType] = None ) -> class_schema.ValueSpec: """Creates a value spec from type annotation.""" - if annotation is bool: + if isinstance(annotation, str) and not accept_value_as_annotation: + annotation = annotation_from_str(annotation, parent_module) + + if annotation is None: + return vs.Object(type(None)) + elif annotation is bool: return vs.Bool() elif annotation is int: return vs.Int() @@ -193,10 +385,7 @@ def _sub_value_spec_from_annotation( elif ( inspect.isclass(annotation) or pg_inspect.is_generic(annotation) - or (isinstance(annotation, str) and not accept_value_as_annotation) ): - if isinstance(annotation, str) and parent_module is not None: - annotation = class_schema.ForwardRef(parent_module, annotation) return vs.Object(annotation) if accept_value_as_annotation: @@ -227,11 +416,12 @@ def _value_spec_from_annotation( return annotation elif annotation == inspect.Parameter.empty: return vs.Any() - elif annotation is None: + + if annotation is None: if accept_value_as_annotation: return vs.Any().noneable() else: - return vs.Any().freeze(None) + return vs.Object(type(None)) if auto_typing: return _value_spec_from_type_annotation( diff --git a/pyglove/core/typing/annotation_conversion_test.py b/pyglove/core/typing/annotation_conversion_test.py index 35d5dac..14338a9 100644 --- a/pyglove/core/typing/annotation_conversion_test.py +++ b/pyglove/core/typing/annotation_conversion_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The PyGlove Authors +# Copyright 2025 The PyGlove Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.core.typing.annotation_conversion.""" - import inspect import sys import typing import unittest +from pyglove.core import coding from pyglove.core.typing import annotated from pyglove.core.typing import annotation_conversion from pyglove.core.typing import key_specs as ks @@ -31,6 +30,129 @@ class Foo: pass +_MODULE = sys.modules[__name__] + + +class AnnotationFromStrTest(unittest.TestCase): + """Tests for annotation_from_str.""" + + def test_basic_types(self): + self.assertIsNone(annotation_conversion.annotation_from_str('None')) + self.assertEqual(annotation_conversion.annotation_from_str('str'), str) + self.assertEqual(annotation_conversion.annotation_from_str('int'), int) + self.assertEqual(annotation_conversion.annotation_from_str('float'), float) + self.assertEqual(annotation_conversion.annotation_from_str('bool'), bool) + self.assertEqual(annotation_conversion.annotation_from_str('list'), list) + self.assertEqual( + annotation_conversion.annotation_from_str('list[int]'), list[int] + ) + self.assertEqual(annotation_conversion.annotation_from_str('tuple'), tuple) + self.assertEqual( + annotation_conversion.annotation_from_str('tuple[int]'), tuple[int] + ) + self.assertEqual( + annotation_conversion.annotation_from_str('tuple[int, ...]'), + tuple[int, ...] + ) + self.assertEqual( + annotation_conversion.annotation_from_str('tuple[int, str]'), + tuple[int, str] + ) + + def test_generic_types(self): + self.assertEqual( + annotation_conversion.annotation_from_str('typing.List[str]', _MODULE), + typing.List[str] + ) + + def test_union(self): + self.assertEqual( + annotation_conversion.annotation_from_str( + 'typing.Union[str, typing.Union[int, float]]', _MODULE), + typing.Union[str, int, float] + ) + if sys.version_info >= (3, 10): + self.assertEqual( + annotation_conversion.annotation_from_str( + 'str | int | float', _MODULE), + typing.Union[str, int, float] + ) + + def test_literal(self): + self.assertEqual( + annotation_conversion.annotation_from_str( + 'typing.Literal[1, True, "a", \'"b"\', "\\"c\\"", "\\\\"]', + _MODULE + ), + typing.Literal[1, True, 'a', '"b"', '"c"', '\\'] + ) + self.assertEqual( + annotation_conversion.annotation_from_str( + 'typing.Literal[(1, 1), f"A {[1]}"]', _MODULE), + typing.Literal[(1, 1), 'A [1]'] + ) + with self.assertRaisesRegex(SyntaxError, 'Expected "\\["'): + annotation_conversion.annotation_from_str('typing.Literal', _MODULE) + + with self.assertRaisesRegex(SyntaxError, 'Unexpected end of annotation'): + annotation_conversion.annotation_from_str('typing.Literal[1', _MODULE) + + with self.assertRaisesRegex( + coding.CodeError, 'Function definition is not allowed' + ): + annotation_conversion.annotation_from_str( + 'typing.Literal[lambda x: x]', _MODULE + ) + + def test_callable(self): + self.assertEqual( + annotation_conversion.annotation_from_str( + 'typing.Callable[int, int]', _MODULE), + typing.Callable[[int], int] + ) + self.assertEqual( + annotation_conversion.annotation_from_str( + 'typing.Callable[[int], int]', _MODULE), + typing.Callable[[int], int] + ) + self.assertEqual( + annotation_conversion.annotation_from_str( + 'typing.Callable[..., None]', _MODULE), + typing.Callable[..., None] + ) + + def test_forward_ref(self): + self.assertEqual( + annotation_conversion.annotation_from_str( + 'AAA', _MODULE), + typing.ForwardRef( + 'AAA', False, _MODULE + ) + ) + self.assertEqual( + annotation_conversion.annotation_from_str( + 'typing.List[AAA]', _MODULE), + typing.List[ + typing.ForwardRef( + 'AAA', False, _MODULE + ) + ] + ) + + def test_bad_annotation(self): + with self.assertRaisesRegex(SyntaxError, 'Expected type identifier'): + annotation_conversion.annotation_from_str('typing.List[]') + + with self.assertRaisesRegex(SyntaxError, 'Expected "]"'): + annotation_conversion.annotation_from_str('typing.List[int') + + with self.assertRaisesRegex(SyntaxError, 'Unexpected end of annotation'): + annotation_conversion.annotation_from_str('typing.List[int]1', _MODULE) + + with self.assertRaisesRegex(SyntaxError, 'Expected "]"'): + annotation_conversion.annotation_from_str('typing.Callable[[x') + + class FieldFromAnnotationTest(unittest.TestCase): """Tests for Field.fromAnnotation.""" @@ -132,17 +254,24 @@ def test_no_annotation(self): def test_none(self): self.assertEqual( - ValueSpec.from_annotation(None, False), vs.Any().freeze(None)) + ValueSpec.from_annotation(None, False), vs.Object(type(None))) self.assertEqual( - ValueSpec.from_annotation(None, True), vs.Any().freeze(None)) + ValueSpec.from_annotation('None', True), vs.Object(type(None))) self.assertEqual( - ValueSpec.from_annotation( - None, accept_value_as_annotation=True), vs.Any().noneable()) + ValueSpec.from_annotation(None, True), vs.Object(type(None))) + self.assertEqual( + ValueSpec.from_annotation(None, accept_value_as_annotation=True), + vs.Any().noneable() + ) def test_any(self): self.assertEqual( ValueSpec.from_annotation(typing.Any, False), vs.Any(annotation=typing.Any)) + self.assertEqual( + ValueSpec.from_annotation('typing.Any', True, parent_module=_MODULE), + vs.Any(annotation=typing.Any) + ) self.assertEqual( ValueSpec.from_annotation(typing.Any, True), vs.Any(annotation=typing.Any)) @@ -152,6 +281,7 @@ def test_any(self): def test_bool(self): self.assertEqual(ValueSpec.from_annotation(bool, True), vs.Bool()) + self.assertEqual(ValueSpec.from_annotation('bool', True), vs.Bool()) self.assertEqual( ValueSpec.from_annotation(bool, False), vs.Any(annotation=bool)) self.assertEqual( @@ -159,6 +289,7 @@ def test_bool(self): def test_int(self): self.assertEqual(ValueSpec.from_annotation(int, True), vs.Int()) + self.assertEqual(ValueSpec.from_annotation('int', True), vs.Int()) self.assertEqual(ValueSpec.from_annotation(int, True, True), vs.Int()) self.assertEqual( ValueSpec.from_annotation(int, False), vs.Any(annotation=int)) @@ -182,7 +313,9 @@ def test_str(self): ValueSpec.from_annotation(str, False, True), vs.Any(annotation=str)) self.assertEqual( - ValueSpec.from_annotation('A', False, False), vs.Any(annotation='A')) + ValueSpec.from_annotation('A', False, False), + vs.Any(annotation='A') + ) self.assertEqual( ValueSpec.from_annotation('A', False, True), vs.Str('A')) self.assertEqual( diff --git a/pyglove/core/typing/annotation_future_test.py b/pyglove/core/typing/annotation_future_test.py new file mode 100644 index 0000000..f47ebb1 --- /dev/null +++ b/pyglove/core/typing/annotation_future_test.py @@ -0,0 +1,135 @@ +# Copyright 2025 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import sys +import typing +from typing import List, Literal, Union +import unittest + +from pyglove.core import symbolic as pg +from pyglove.core.typing import key_specs as ks +from pyglove.core.typing import value_specs as vs + + +class AnnotationFutureConversionTest(unittest.TestCase): + + # Class with forward declaration must not be defined in functions. + class A(pg.Object): + a: typing.Optional[AnnotationFutureConversionTest.A] + b: List[AnnotationFutureConversionTest.A] + + def assert_value_spec(self, cls, field_name, expected_value_spec): + self.assertEqual(cls.__schema__[field_name].value, expected_value_spec) + + def test_basics(self): + + class Foo(pg.Object): + a: int + b: float + c: bool + d: str + e: typing.Any + f: typing.Dict[str, typing.Any] + g: typing.List[str] + h: typing.Tuple[int, int] + i: typing.Callable[[int, int], None] + + self.assert_value_spec(Foo, 'a', vs.Int()) + self.assert_value_spec(Foo, 'b', vs.Float()) + self.assert_value_spec(Foo, 'c', vs.Bool()) + self.assert_value_spec(Foo, 'd', vs.Str()) + self.assert_value_spec(Foo, 'e', vs.Any(annotation=typing.Any)) + self.assert_value_spec( + Foo, 'f', vs.Dict([(ks.StrKey(), vs.Any(annotation=typing.Any))]) + ) + self.assert_value_spec(Foo, 'g', vs.List(vs.Str())) + self.assert_value_spec(Foo, 'h', vs.Tuple([vs.Int(), vs.Int()])) + self.assert_value_spec( + Foo, 'i', + vs.Callable([vs.Int(), vs.Int()], returns=vs.Object(type(None))) + ) + + def test_list(self): + if sys.version_info >= (3, 10): + + class Bar(pg.Object): + x: list[int | None] + + self.assert_value_spec(Bar, 'x', vs.List(vs.Int().noneable())) + + def test_var_length_tuple(self): + + class Foo(pg.Object): + x: typing.Tuple[int, ...] + + self.assert_value_spec(Foo, 'x', vs.Tuple(vs.Int())) + + if sys.version_info >= (3, 10): + + class Bar(pg.Object): + x: tuple[int, ...] + + self.assert_value_spec(Bar, 'x', vs.Tuple(vs.Int())) + + def test_optional(self): + + class Foo(pg.Object): + x: typing.Optional[int] + + self.assert_value_spec(Foo, 'x', vs.Int().noneable()) + + if sys.version_info >= (3, 10): + class Bar(pg.Object): + x: int | None + + self.assert_value_spec(Bar, 'x', vs.Int().noneable()) + + def test_union(self): + + class Foo(pg.Object): + x: Union[int, typing.Union[str, bool], None] + + self.assert_value_spec( + Foo, 'x', vs.Union([vs.Int(), vs.Str(), vs.Bool()]).noneable() + ) + + if sys.version_info >= (3, 10): + + class Bar(pg.Object): + x: int | str | bool + + self.assert_value_spec( + Bar, 'x', vs.Union([vs.Int(), vs.Str(), vs.Bool()]) + ) + + def test_literal(self): + + class Foo(pg.Object): + x: Literal[1, True, 'abc'] + + self.assert_value_spec( + Foo, 'x', vs.Enum(vs.MISSING_VALUE, [1, True, 'abc']) + ) + + def test_self_referencial(self): + self.assert_value_spec( + self.A, 'a', vs.Object(self.A).noneable() + ) + self.assert_value_spec( + self.A, 'b', vs.List(vs.Object(self.A)) + ) + +if __name__ == '__main__': + unittest.main() diff --git a/pyglove/core/typing/class_schema.py b/pyglove/core/typing/class_schema.py index 06f951f..cdf921d 100644 --- a/pyglove/core/typing/class_schema.py +++ b/pyglove/core/typing/class_schema.py @@ -97,9 +97,10 @@ def from_str(cls, key: str) -> 'KeySpec': class ForwardRef(utils.Formattable): """Forward type reference.""" - def __init__(self, module: types.ModuleType, name: str): + def __init__(self, module: types.ModuleType, qualname: str): self._module = module - self._name = name + self._qualname = qualname + self._resolved_value = None @property def module(self) -> types.ModuleType: @@ -109,35 +110,54 @@ def module(self) -> types.ModuleType: @property def name(self) -> str: """Returns the name of the type reference.""" - return self._name + return self._qualname.split('.')[-1] @property def qualname(self) -> str: """Returns the qualified name of the reference.""" - return f'{self.module.__name__}.{self.name}' + return self._qualname + + @property + def type_id(self) -> str: + """Returns the type id of the reference.""" + return f'{self.module.__name__}.{self.qualname}' def as_annotation(self) -> Union[Type[Any], str]: """Returns the forward reference as an annotation.""" - return self.cls if self.resolved else self.name + return self.cls if self.resolved else self.qualname @property def resolved(self) -> bool: """Returns True if the symbol for the name is resolved..""" - return hasattr(self.module, self.name) + if self._resolved_value is None: + self._resolved_value = self._resolve() + return self._resolved_value is not None + + def _resolve(self) -> Optional[Any]: + names = self._qualname.split('.') + parent_obj = self.module + for name in names: + parent_obj = getattr(parent_obj, name, utils.MISSING_VALUE) + if parent_obj == utils.MISSING_VALUE: + return None + if not inspect.isclass(parent_obj): + raise TypeError( + f'{self.qualname!r} from module {self.module.__name__!r} ' + 'is not a class.' + ) + return parent_obj @property def cls(self) -> Type[Any]: """Returns the resolved reference class..""" - reference = getattr(self.module, self.name, None) - if reference is None: - raise TypeError( - f'{self.name!r} does not exist in module {self.module.__name__!r}' - ) - elif not inspect.isclass(reference): - raise TypeError( - f'{self.name!r} from module {self.module.__name__!r} is not a class.' - ) - return reference + if self._resolved_value is None: + self._resolved_value = self._resolve() + if self._resolved_value is None: + raise TypeError( + f'{self.qualname!r} does not exist in ' + f'module {self.module.__name__!r}' + ) + return self._resolved_value def format( self, @@ -150,7 +170,7 @@ def format( return utils.kvlist_str( [ ('module', self.module.__name__, None), - ('name', self.name, None), + ('name', self.qualname, None), ], label=self.__class__.__name__, compact=compact, @@ -164,7 +184,7 @@ def __eq__(self, other: Any) -> bool: if self is other: return True elif isinstance(other, ForwardRef): - return self.module is other.module and self.name == other.name + return self.module is other.module and self.qualname == other.qualname elif inspect.isclass(other): return self.resolved and self.cls is other # pytype: disable=bad-return-type @@ -173,11 +193,11 @@ def __ne__(self, other: Any) -> bool: return not self.__eq__(other) def __hash__(self) -> int: - return hash((self.module, self.name)) + return hash((self.module, self.qualname)) def __deepcopy__(self, memo) -> 'ForwardRef': """Override deep copy to avoid copying module.""" - return ForwardRef(self.module, self.name) + return ForwardRef(self.module, self.qualname) class ValueSpec(utils.Formattable, utils.JSONConvertible): @@ -628,9 +648,11 @@ def from_annotation( annotation: Any, description: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, - auto_typing=True) -> 'Field': + auto_typing=True, + parent_module: Optional[types.ModuleType] = None + ) -> 'Field': """Gets a Field from annotation.""" - del key, annotation, description, metadata, auto_typing + del key, annotation, description, metadata, auto_typing, parent_module assert False, 'Overridden in `annotation_conversion.py`.' @property diff --git a/pyglove/core/typing/class_schema_test.py b/pyglove/core/typing/class_schema_test.py index 8dc3555..4a21e07 100644 --- a/pyglove/core/typing/class_schema_test.py +++ b/pyglove/core/typing/class_schema_test.py @@ -31,24 +31,31 @@ class ForwardRefTest(unittest.TestCase): """Test for `ForwardRef` class.""" + class A: + pass + def setUp(self): super().setUp() self._module = sys.modules[__name__] def test_basics(self): - r = class_schema.ForwardRef(self._module, 'FieldTest') + r = class_schema.ForwardRef(self._module, 'ForwardRefTest.A') self.assertIs(r.module, self._module) - self.assertEqual(r.name, 'FieldTest') - self.assertEqual(r.qualname, f'{self._module.__name__}.FieldTest') + self.assertEqual(r.name, 'A') + self.assertEqual(r.qualname, 'ForwardRefTest.A') + self.assertEqual(r.type_id, f'{self._module.__name__}.ForwardRefTest.A') def test_resolved(self): - self.assertTrue(class_schema.ForwardRef(self._module, 'FieldTest').resolved) + self.assertTrue( + class_schema.ForwardRef(self._module, 'ForwardRefTest.A').resolved + ) self.assertFalse(class_schema.ForwardRef(self._module, 'Foo').resolved) def test_as_annotation(self): self.assertEqual( - class_schema.ForwardRef(self._module, 'FieldTest').as_annotation(), - FieldTest, + class_schema.ForwardRef( + self._module, 'ForwardRefTest.A').as_annotation(), + ForwardRefTest.A, ) self.assertEqual( class_schema.ForwardRef(self._module, 'Foo').as_annotation(), 'Foo' @@ -56,7 +63,8 @@ def test_as_annotation(self): def test_cls(self): self.assertIs( - class_schema.ForwardRef(self._module, 'FieldTest').cls, FieldTest + class_schema.ForwardRef(self._module, 'ForwardRefTest.A').cls, + ForwardRefTest.A ) with self.assertRaisesRegex(TypeError, '.* does not exist in module'):