Skip to content

Commit

Permalink
zcbor.py: Performance improvements in DataTranslator
Browse files Browse the repository at this point in the history
Signed-off-by: Øyvind Rønningstad <[email protected]>
  • Loading branch information
oyvindronningstad committed Jan 21, 2025
1 parent 2b837d1 commit cbcb027
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 28 deletions.
27 changes: 27 additions & 0 deletions tests/scripts/test_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import zcbor
import cbor2
import cProfile, pstats


try:
import zcbor
except ImportError:
print("""
The zcbor package must be installed to run these tests.
During development, install with `pip3 install -e .` to install in a way
that picks up changes in the files without having to reinstall.
""")
exit(1)

cddl_contents = """
perf_int = [0*1000(int/bool)]
"""
raw_message = cbor2.dumps(list(range(1000)))
cmd_spec = zcbor.DataTranslator.from_cddl(cddl_contents, 3).my_types["perf_int"]

profiler = cProfile.Profile()
profiler.enable()
json_obj = cmd_spec.str_to_json(raw_message)
profiler.disable()

profiler.print_stats()
66 changes: 38 additions & 28 deletions zcbor/zcbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,7 @@ def __init__(self, *args, **kwargs):
# Used as a guard against endless recursion in self.dependsOn()
self.dependsOnCall = False
self.skipped = False
self.stored_id = None

def var_name(self, with_prefix=False, observe_skipped=True):
"""Name of variables and enum members for this element."""
Expand Down Expand Up @@ -1456,7 +1457,9 @@ def id(self):
If the name starts with an underscore, prepend an 'f',
since namedtuple() doesn't support identifiers that start with an underscore.
"""
return getrp(r"\A_").sub("f_", self.generate_base_name())
if self.stored_id is None:
self.stored_id = getrp(r"\A_").sub("f_", self.get_base_name())
return self.stored_id

def var_name(self):
"""Override the var_name()"""
Expand All @@ -1465,6 +1468,8 @@ def var_name(self):
def _decode_assert(self, test, msg=""):
"""Check a condition and raise a CddlValidationError if not."""
if not test:
if callable(msg):
msg = msg()
raise CddlValidationError(
f"Data did not decode correctly {'(' + msg + ')' if msg else ''}")

Expand All @@ -1473,6 +1478,9 @@ def _check_tag(self, obj):
Return whether a tag was present.
"""
if not self.tags and not isinstance(obj, CBORTag):
return obj

tags = copy(self.tags) # All expected tags
# Process all tags present in obj
while isinstance(obj, CBORTag):
Expand All @@ -1483,35 +1491,37 @@ def _check_tag(self, obj):
continue
elif self.type in ["OTHER", "GROUP", "UNION"]:
break
self._decode_assert(False, f"Tag ({obj.tag}) not expected for {self}")
self._decode_assert(False, lambda: f"Tag ({obj.tag}) not expected for {self}")
# Check that all expected tags were found in obj.
self._decode_assert(not tags, f"Expected tags ({tags}), but none present.")
self._decode_assert(not tags, lambda: f"Expected tags ({tags}), but none present.")
return obj

_exp_types = {
"UINT": (int,),
"INT": (int,),
"NINT": (int,),
"FLOAT": (float,),
"TSTR": (str,),
"BSTR": (bytes,),
"NIL": (type(None),),
"UNDEF": (type(undefined),),
"ANY": (int, float, str, bytes, type(None), type(undefined), bool, list, dict),
"BOOL": (bool,),
"LIST": (tuple, list),
"MAP": (dict,),
}

def _expected_type(self):
"""Return our expected python type as returned by cbor2."""
return {
"UINT": lambda: (int,),
"INT": lambda: (int,),
"NINT": lambda: (int,),
"FLOAT": lambda: (float,),
"TSTR": lambda: (str,),
"BSTR": lambda: (bytes,),
"NIL": lambda: (type(None),),
"UNDEF": lambda: (type(undefined),),
"ANY": lambda: (int, float, str, bytes, type(None), type(undefined), bool, list, dict),
"BOOL": lambda: (bool,),
"LIST": lambda: (tuple, list),
"MAP": lambda: (dict,),
}[self.type]()
return self._exp_types[self.type]

def _check_type(self, obj):
"""Check that the decoded object has the correct type."""
if self.type not in ["OTHER", "GROUP", "UNION"]:
exp_type = self._expected_type()
self._decode_assert(
type(obj) in exp_type,
f"{str(self)}: Wrong type ({type(obj)}) of {str(obj)}, expected {str(exp_type)}")
lambda: f"{str(self)}: Wrong type ({type(obj)}) of {str(obj)}, expected {str(exp_type)}")

def _check_value(self, obj):
"""Check that the decode value conforms to the restrictions in the CDDL."""
Expand All @@ -1522,28 +1532,28 @@ def _check_value(self, obj):
value = self.value.encode("utf-8")
self._decode_assert(
self.value == obj,
f"{obj} should have value {self.value} according to {self.var_name()}")
lambda: f"{obj} should have value {self.value} according to {self.var_name()}")
if self.type in ["UINT", "INT", "NINT", "FLOAT"]:
if self.min_value is not None:
self._decode_assert(obj >= self.min_value, "Minimum value: " + str(self.min_value))
self._decode_assert(obj >= self.min_value, lambda: "Minimum value: " + str(self.min_value))
if self.max_value is not None:
self._decode_assert(obj <= self.max_value, "Maximum value: " + str(self.max_value))
self._decode_assert(obj <= self.max_value, lambda: "Maximum value: " + str(self.max_value))
if self.type == "UINT":
if self.bits:
mask = sum(((1 << b.value) for b in self.my_control_groups[self.bits].value))
self._decode_assert(not (obj & ~mask), "Allowed bitmask: " + bin(mask))
self._decode_assert(not (obj & ~mask), lambda: "Allowed bitmask: " + bin(mask))
if self.type in ["TSTR", "BSTR"]:
if self.min_size is not None:
self._decode_assert(
len(obj) >= self.min_size, "Minimum length: " + str(self.min_size))
len(obj) >= self.min_size, lambda: "Minimum length: " + str(self.min_size))
if self.max_size is not None:
self._decode_assert(
len(obj) <= self.max_size, "Maximum length: " + str(self.max_size))
len(obj) <= self.max_size, lambda: "Maximum length: " + str(self.max_size))

def _check_key(self, obj):
"""Check that the object is not a KeyTuple, which would mean it's not properly processed."""
self._decode_assert(
not isinstance(obj, KeyTuple), "Unexpected key found: (key,value)=" + str(obj))
not isinstance(obj, KeyTuple), lambda: "Unexpected key found: (key,value)=" + str(obj))

def _flatten_obj(self, obj):
"""Recursively remove intermediate objects that have single members. Keep lists as is."""
Expand Down Expand Up @@ -1672,13 +1682,13 @@ def _decode_single_obj(self, obj):
return self._construct_obj(retval)
except CddlValidationError as c:
self.errors.append(str(c))
self._decode_assert(False, "No matches for union: " + str(self))
self._decode_assert(False, lambda: "No matches for union: " + str(self))
assert False, "Unexpected type: " + self.type

def _handle_key(self, next_obj):
"""Decode key and value in the form of a KeyTuple"""
self._decode_assert(
isinstance(next_obj, KeyTuple), f"Expected key: {self.key} value=" + pformat(next_obj))
isinstance(next_obj, KeyTuple), lambda: f"Expected key: {self.key} value=" + pformat(next_obj))
key, obj = next_obj
key_res = self.key._decode_single_obj(key)
obj_res = self._decode_single_obj(obj)
Expand Down Expand Up @@ -1729,7 +1739,7 @@ def _decode_obj(self, it):
except CddlValidationError as c:
self.errors.append(str(c))
child_it = it_copy
self._decode_assert(found, "No matches for union: " + str(self))
self._decode_assert(found, lambda: "No matches for union: " + str(self))
else:
ret = (it, self._decode_single_obj(self._iter_next(it)))
return ret
Expand Down

0 comments on commit cbcb027

Please sign in to comment.