Skip to content

Commit

Permalink
pg.typing to support __future__.annotations.
Browse files Browse the repository at this point in the history
With this CL, PyGlove is compatible with `__future__.annotations`.
For example:

```python
from __future__ import annotations

class A(pg.Object):
  child: A | None
```

PiperOrigin-RevId: 723114026
  • Loading branch information
daiyip authored and pyglove authors committed Feb 4, 2025
1 parent 50d8a38 commit 97d6446
Show file tree
Hide file tree
Showing 7 changed files with 538 additions and 44 deletions.
5 changes: 4 additions & 1 deletion pyglove/core/symbolic/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import abc
import functools
import inspect
import sys
import typing
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union

Expand Down Expand Up @@ -154,7 +155,9 @@ def _infer_fields_from_annotations(cls) -> List[pg_typing.Field]:
if typing.get_origin(attr_annotation) is typing.ClassVar:
continue

field = pg_typing.Field.from_annotation(key, attr_annotation)
field = pg_typing.Field.from_annotation(
key, attr_annotation, parent_module=sys.modules[cls.__module__]
)
if isinstance(key, pg_typing.ConstStrKey):
attr_value = cls.__dict__.get(attr_name, pg_typing.MISSING_VALUE)
if attr_value != pg_typing.MISSING_VALUE:
Expand Down
3 changes: 3 additions & 0 deletions pyglove/core/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ class Foo(pg.Object):
# Interface for custom typing.
from pyglove.core.typing.custom_typing import CustomTyping

# Annotation conversion
from pyglove.core.typing.annotation_conversion import annotation_from_str

# Callable signature.
from pyglove.core.typing.callable_signature import Argument
from pyglove.core.typing.callable_signature import CallableType
Expand Down
202 changes: 196 additions & 6 deletions pyglove/core/typing/annotation_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
"""Conversion from annotations to PyGlove value specs."""

import builtins
import collections
import inspect
import types
import typing

from pyglove.core import coding
from pyglove.core import utils
from pyglove.core.typing import annotated
from pyglove.core.typing import class_schema
Expand All @@ -35,6 +37,191 @@
_UnionType = getattr(types, 'UnionType', None) # pylint: disable=invalid-name


def annotation_from_str(
annotation_str: str,
parent_module: typing.Optional[types.ModuleType] = None,
) -> typing.Any:
"""Parses annotations from str.
BNF for PyType annotations:
```
<maybe_union> ::= <type> | <type> "|" <maybe_union>
<type> ::= <literal_type> | <non_literal_type>
<literal_type> ::= "Literal"<literal_params>
<literal_params> ::= "["<python_values>"]" (parsed by `pg.coding.evaluate`)
<non_literal_type> ::= <type_id> | <type_id>"["<type_arg>"]"
<type_arg> ::= <maybe_type_list> | <maybe_type_list>","<maybe_type_list>
<maybe_type_list> ::= "["<type_arg>"]" | <maybe_union>
<type_id> ::= 'aAz_.1-9'
```
Args:
annotation_str: String form of type annotations. E.g. "list[str]"
parent_module: The module where the annotation was defined.
Returns:
Object form of the annotation.
Raises:
SyntaxError: If the annotation string is invalid.
"""
s = annotation_str
context = dict(pos=0)

def _eof() -> bool:
return context['pos'] == len(s)

def _pos() -> int:
return context['pos']

def _next(n: int = 1, offset: int = 0) -> str:
if _eof():
return '<EOF>'
return s[_pos() + offset:_pos() + offset + n]

def _advance(n: int) -> None:
context['pos'] += n

def _error_illustration() -> str:
return f'{s}\n{" " * _pos()}' + '^'

def _match(ch) -> bool:
if _next(len(ch)) == ch:
_advance(len(ch))
return True
return False

def _skip_whitespaces() -> None:
while _next() in ' \t':
_advance(1)

def _maybe_union():
t = _type()
while not _eof():
_skip_whitespaces()
if _match('|'):
t = t | _type()
else:
break
return t

def _type():
type_id = _type_id()
t = _resolve(type_id)
if t is typing.Literal:
return t[_literal_params()]
elif _match('['):
arg = _type_arg()
if not _match(']'):
raise SyntaxError(
f'Expected "]" at position {_pos()}.\n\n' + _error_illustration()
)
return t[arg]
return t

def _literal_params():
if not _match('['):
raise SyntaxError(
f'Expected "[" at position {_pos()}.\n\n' + _error_illustration()
)
arg_start = _pos()
in_str = False
escape_mode = False
num_open_bracket = 1

while num_open_bracket > 0:
ch = _next()
if _eof():
raise SyntaxError(
f'Unexpected end of annotation at position {_pos()}.\n\n'
+ _error_illustration()
)
if ch == '\\':
escape_mode = not escape_mode
else:
escape_mode = False

if ch == "'" and not escape_mode:
in_str = not in_str
elif not in_str:
if ch == '[':
num_open_bracket += 1
elif ch == ']':
num_open_bracket -= 1
_advance(1)

arg_str = s[arg_start:_pos() - 1]
return coding.evaluate(
'(' + arg_str + ')', permission=coding.CodePermission.BASIC
)

def _type_arg():
t_args = []
t_args.append(_maybe_type_list())
while _match(','):
t_args.append(_maybe_type_list())
return tuple(t_args) if len(t_args) > 1 else t_args[0]

def _maybe_type_list():
if _match('['):
ret = _type_arg()
if not _match(']'):
raise SyntaxError(
f'Expected "]" at position {_pos()}.\n\n' + _error_illustration()
)
return list(ret) if isinstance(ret, tuple) else [ret]
return _maybe_union()

def _type_id() -> str:
_skip_whitespaces()
if _match('...'):
return '...'
start = _pos()
while not _eof():
c = _next()
if c.isalnum() or c in '_.':
_advance(1)
else:
break
t_id = s[start:_pos()]
if not all(x.isidentifier() for x in t_id.split('.')):
raise SyntaxError(
f'Expected type identifier, got {t_id!r} at position {start}.\n\n'
+ _error_illustration()
)
return t_id

def _resolve(type_id: str):
def _resolve_name(name: str, parent_obj: typing.Any):
if name == 'None':
return None
if parent_obj is not None and hasattr(parent_obj, name):
return getattr(parent_obj, name)
if hasattr(builtins, name):
return getattr(builtins, name)
if type_id == '...':
return ...
return utils.MISSING_VALUE
parent_obj = parent_module
for name in type_id.split('.'):
parent_obj = _resolve_name(name, parent_obj)
if parent_obj == utils.MISSING_VALUE:
return typing.ForwardRef( # pytype: disable=not-callable
type_id, False, parent_module
)
return parent_obj

root = _maybe_union()
if _pos() != len(s):
raise SyntaxError(
'Unexpected end of annotation.\n\n' + _error_illustration()
)
return root


def _field_from_annotation(
key: typing.Union[str, class_schema.KeySpec],
annotation: typing.Any,
Expand Down Expand Up @@ -91,7 +278,12 @@ def _value_spec_from_type_annotation(
parent_module: typing.Optional[types.ModuleType] = None
) -> class_schema.ValueSpec:
"""Creates a value spec from type annotation."""
if annotation is bool:
if isinstance(annotation, str) and not accept_value_as_annotation:
annotation = annotation_from_str(annotation, parent_module)

if annotation is None:
return vs.Object(type(None))
elif annotation is bool:
return vs.Bool()
elif annotation is int:
return vs.Int()
Expand Down Expand Up @@ -193,10 +385,7 @@ def _sub_value_spec_from_annotation(
elif (
inspect.isclass(annotation)
or pg_inspect.is_generic(annotation)
or (isinstance(annotation, str) and not accept_value_as_annotation)
):
if isinstance(annotation, str) and parent_module is not None:
annotation = class_schema.ForwardRef(parent_module, annotation)
return vs.Object(annotation)

if accept_value_as_annotation:
Expand Down Expand Up @@ -227,11 +416,12 @@ def _value_spec_from_annotation(
return annotation
elif annotation == inspect.Parameter.empty:
return vs.Any()
elif annotation is None:

if annotation is None:
if accept_value_as_annotation:
return vs.Any().noneable()
else:
return vs.Any().freeze(None)
return vs.Object(type(None))

if auto_typing:
return _value_spec_from_type_annotation(
Expand Down
Loading

0 comments on commit 97d6446

Please sign in to comment.