Skip to content

Commit

Permalink
feat(BACK-7924): improve error handling (#134)
Browse files Browse the repository at this point in the history
* feat(BACK-7924): improve HTTP client errors

* feat(BACK-7924): define pydantic annotation to customize errors

* feat(BACK-7924): define pydantic annotation to customize errors / context manager

* feat(BACK-7924): use context manager to catch pydantic errors

* feat(BACK-7924): remove unnecessary error wrapping

* feat(BACK-7924): improve/fix linters error messages

* feat(BACK-7924): customize error type labels for paths

* feat(BACK-7924): fix nativeCurrencyAddress bad management of descriptor paths

* feat(BACK-7924): add constraints on intent

* feat(BACK-7924): add more custom error type labels

* feat(BACK-7924): detect constant addresses and print errors

* feat(BACK-7924): use new stuff for linter

* feat(BACK-7924): update/add tests
  • Loading branch information
jnicoulaud-ledger authored Nov 4, 2024
1 parent 6791a59 commit d1b2867
Show file tree
Hide file tree
Showing 23 changed files with 624 additions and 188 deletions.
29 changes: 19 additions & 10 deletions src/erc7730/common/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from httpx._content import IteratorByteStream
from httpx_file import FileTransport
from limiter import Limiter
from pydantic import ConfigDict, TypeAdapter
from pydantic import ConfigDict, TypeAdapter, ValidationError
from pydantic_string_url import FileUrl, HttpUrl
from xdg_base_dirs import xdg_cache_home

Expand Down Expand Up @@ -50,14 +50,19 @@ def get_contract_abis(chain_id: int, contract_address: Address) -> list[ABI]:
:return: deserialized list of ABIs
:raises NotImplementedError: if chain id not supported, API key not setup, or unexpected response
"""
return get(
url=HttpUrl(f"https://{ETHERSCAN}/v2/api"),
chainid=chain_id,
module="contract",
action="getabi",
address=contract_address,
model=list[ABI],
)
try:
return get(
url=HttpUrl(f"https://{ETHERSCAN}/v2/api"),
chainid=chain_id,
module="contract",
action="getabi",
address=contract_address,
model=list[ABI],
)
except Exception as e:
if "Contract source code not verified" in str(e):
raise Exception("contract source is not available on Etherscan") from e
raise e


def get_contract_explorer_url(chain_id: int, contract_address: Address) -> HttpUrl:
Expand Down Expand Up @@ -91,7 +96,11 @@ def get(model: type[_T], url: HttpUrl | FileUrl, **params: Any) -> _T:
:raises Exception: if URL type is not supported, API key not setup, or unexpected response
"""
with _client() as client:
return TypeAdapter(model).validate_json(client.get(url, params=params).raise_for_status().content)
response = client.get(url, params=params).raise_for_status().content
try:
return TypeAdapter(model).validate_json(response)
except ValidationError as e:
raise Exception(f"Received unexpected response from {url}: {response.decode()}") from e


def _client() -> Client:
Expand Down
117 changes: 111 additions & 6 deletions src/erc7730/common/output.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import re
import threading
from builtins import print as builtin_print
from contextlib import AbstractContextManager
from enum import IntEnum, auto
from itertools import groupby
from types import TracebackType
from typing import assert_never, final, override

from pydantic import BaseModel, FilePath
from pydantic import BaseModel, ConfigDict, FilePath, ValidationError
from pydantic_core import ErrorDetails
from rich import print

MUX = threading.Lock()
Expand All @@ -14,6 +17,17 @@
class Output(BaseModel):
"""An output info/debug/warning/error."""

model_config = ConfigDict(
strict=True,
frozen=True,
extra="forbid",
validate_default=True,
validate_return=True,
validate_assignment=True,
arbitrary_types_allowed=False,
allow_inf_nan=False,
)

class Level(IntEnum):
"""ERC7730Linter output level."""

Expand Down Expand Up @@ -88,6 +102,19 @@ def add(self, output: Output) -> None:
self.outputs.append(output)


@final
class SetOutputAdder(OutputAdder):
"""An output adder that stores outputs in a set."""

def __init__(self) -> None:
super().__init__()
self.outputs: set[Output] = set()

def add(self, output: Output) -> None:
super().add(output)
self.outputs.add(output)


class ConsoleOutputAdder(OutputAdder):
"""An output adder that prints to the console."""

Expand Down Expand Up @@ -120,7 +147,10 @@ def add(self, output: Output) -> None:
log += ", ".join(context) + ": "
if output.title is not None:
log += f"{output.title}: "
log += f"[/{style}]{output.message}"
log += f"[/{style}]"
if "\n" in output.message:
log += "\n"
log += output.message

print(log)

Expand Down Expand Up @@ -210,10 +240,10 @@ def add(self, output: Output) -> None:

@final
class BufferAdder(AbstractContextManager[OutputAdder]):
"""A context manager that buffers outputs and outputs them all at once."""
"""A context manager that buffers outputs and outputs them all at once, sorted and deduplicated."""

def __init__(self, delegate: OutputAdder, prolog: str | None = None, epilog: str | None = None) -> None:
self._buffer = ListOutputAdder()
self._buffer = SetOutputAdder()
self._delegate = delegate
self._prolog = prolog
self._epilog = epilog
Expand All @@ -228,10 +258,85 @@ def __exit__(self, etype: type[BaseException] | None, e: BaseException | None, t
try:
if self._prolog is not None:
print(self._prolog)
for output in self._buffer.outputs:
self._delegate.add(output)
if self._buffer.outputs:
for output in sorted(self._buffer.outputs, key=lambda x: (x.file, x.line, x.level, x.title, x.message)):
self._delegate.add(output)
else:
print("no issue found ✔️")
if self._epilog is not None:
print(self._epilog)
finally:
MUX.release()
return None


@final
class ExceptionsToOutput(AbstractContextManager[None]):
"""A context manager that catches exceptions and redirects them to an OutputAdder."""

def __init__(self, delegate: OutputAdder) -> None:
self._delegate = delegate

@override
def __enter__(self) -> None:
return None

@override
def __exit__(self, etype: type[BaseException] | None, e: BaseException | None, tb: TracebackType | None) -> bool:
if isinstance(e, Exception):
exception_to_output(e, self._delegate)
return True
return False


def exception_to_output(e: Exception, out: OutputAdder) -> None:
"""
Sanitize an exception and add it to an OutputAdder.
:param e: exception to handle
:param out: output handler
"""
match e:
case ValidationError() as e:
pydantic_error_to_output(e, out)
case Exception() as e:
out.error(title="Failed processing descriptor", message=str(e))
case _:
assert_never(e)


def pydantic_error_to_output(e: ValidationError, out: OutputAdder) -> None:
"""
Sanitize a pydantic validation exception and add it to an OutputAdder.
This cleans up location, and groups errors by location to avoid outputting multiple errors when not necessary, for
instance for union types.
:param e: exception to handle
:param out: output handler
"""

def filter_location(loc: int | str) -> bool:
if isinstance(loc, int):
return True
return bool(re.match(r"(list|set)\[.*", loc))

def get_location(ex: ErrorDetails) -> str:
if not (loc := ex.get("loc")):
return "unknown location"
return ".".join(map(str, filter(filter_location, loc[:-1])))

def get_value(ex: ErrorDetails) -> str:
return str(ex.get("input", "unknown value"))

def get_details(ex: ErrorDetails) -> str:
return ex.get("msg", "unknown error")

def get_message(ex: ErrorDetails) -> str:
return f"""Value "{get_value(ex)}" is not valid: {get_details(ex)}"""

for location, location_errors in groupby(e.errors(include_url=False), get_location):
if (len(errors := list(location_errors))) > 1:
out.error(title=f"Invalid value at {location}", message="* " + "\n * ".join(map(get_message, errors)))
else:
out.error(title=f"Invalid value at {location}", message=get_message(errors[0]))
30 changes: 28 additions & 2 deletions src/erc7730/common/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import os
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any, TypeVar
from typing import Any, LiteralString, TypeVar

from pydantic import BaseModel
from pydantic import BaseModel, ValidationInfo, WrapValidator
from pydantic_core import PydanticCustomError

from erc7730.common.json import CompactJSONEncoder, read_json_with_includes

Expand Down Expand Up @@ -46,3 +49,26 @@ def model_to_json_file(path: Path, model: _BaseModel) -> None:
with open(path, "w") as f:
f.write(model_to_json_str(model))
f.write("\n")


@dataclass(frozen=True)
class ErrorTypeLabel(WrapValidator):
"""
Wrapper validator that replaces all errors with a simple message "expected a <type label>".
It is useful for annotating union types where pydantic returns multiple errors for each type it tries, or custom
base types such as pattern validated strings to get more user-friendly errors.
"""

def __init__(self, type_label: LiteralString) -> None:
super().__init__(self._validator(type_label))

@staticmethod
def _validator(type_label: LiteralString) -> Callable[[Any, Any, ValidationInfo], Any]:
def validate(v: Any, next_: Any, ctx: ValidationInfo) -> Any:
try:
return next_(v, ctx)
except Exception:
raise PydanticCustomError("custom_error", "expected a " + type_label) from None

return validate
81 changes: 41 additions & 40 deletions src/erc7730/convert/ledger/eip712/convert_eip712_to_erc7730.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from eip712.utils import MissingRootTypeError, MultipleRootTypesError, get_primary_type
from pydantic_string_url import HttpUrl

from erc7730.common.output import OutputAdder
from erc7730.common.output import ExceptionsToOutput, OutputAdder
from erc7730.convert import ERC7730Converter
from erc7730.model.context import EIP712JsonSchema
from erc7730.model.display import (
Expand Down Expand Up @@ -41,50 +41,51 @@ class EIP712toERC7730Converter(ERC7730Converter[ResolvedEIP712DAppDescriptor, In
def convert(
self, descriptor: ResolvedEIP712DAppDescriptor, out: OutputAdder
) -> dict[str, InputERC7730Descriptor] | None:
descriptors: dict[str, InputERC7730Descriptor] = {}
with ExceptionsToOutput(out):
descriptors: dict[str, InputERC7730Descriptor] = {}

for contract in descriptor.contracts:
formats: dict[str, InputFormat] = {}
schemas: list[EIP712JsonSchema | HttpUrl] = []
for contract in descriptor.contracts:
formats: dict[str, InputFormat] = {}
schemas: list[EIP712JsonSchema | HttpUrl] = []

for message in contract.messages:
if (primary_type := self._get_primary_type(message.schema_, out)) is None:
return None
for message in contract.messages:
if (primary_type := self._get_primary_type(message.schema_, out)) is None:
return None

schemas.append(EIP712JsonSchema(primaryType=primary_type, types=message.schema_))
schemas.append(EIP712JsonSchema(primaryType=primary_type, types=message.schema_))

formats[primary_type] = InputFormat(
intent=message.mapper.label,
fields=[self._convert_field(field) for field in message.mapper.fields],
required=None,
screens=None,
)

descriptors[contract.address] = InputERC7730Descriptor(
context=InputEIP712Context(
eip712=InputEIP712(
domain=InputDomain(
name=descriptor.name,
version=None,
chainId=descriptor.chainId,
verifyingContract=contract.address,
),
schemas=schemas,
deployments=[InputDeployment(chainId=descriptor.chainId, address=contract.address)],
formats[primary_type] = InputFormat(
intent=message.mapper.label,
fields=[self._convert_field(field) for field in message.mapper.fields],
required=None,
screens=None,
)
),
metadata=InputMetadata(
owner=contract.contractName,
info=None,
token=None,
constants=None,
enums=None,
),
display=InputDisplay(
definitions=None,
formats=formats,
),
)

descriptors[contract.address] = InputERC7730Descriptor(
context=InputEIP712Context(
eip712=InputEIP712(
domain=InputDomain(
name=descriptor.name,
version=None,
chainId=descriptor.chainId,
verifyingContract=contract.address,
),
schemas=schemas,
deployments=[InputDeployment(chainId=descriptor.chainId, address=contract.address)],
)
),
metadata=InputMetadata(
owner=contract.contractName,
info=None,
token=None,
constants=None,
enums=None,
),
display=InputDisplay(
definitions=None,
formats=formats,
),
)

return descriptors

Expand Down
Loading

0 comments on commit d1b2867

Please sign in to comment.