Skip to content

Commit

Permalink
[executorch][serialization] Refactor flatbuffer utils into separate f…
Browse files Browse the repository at this point in the history
…ile (#7488)

Pull Request resolved: #7254

Todo: let xnnpack and vulkan serialization use these utils instead of redefining the same functions.

For usage in extension/flat_tensor/serialize.
ghstack-source-id: 260036856
@exported-using-ghexport

Differential Revision: [D66854756](https://our.internmc.facebook.com/intern/diff/D66854756/)

Co-authored-by: lucylq <[email protected]>
  • Loading branch information
pytorchbot and lucylq authored Jan 3, 2025
1 parent 01d4c31 commit 8dadccf
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 39 deletions.
1 change: 1 addition & 0 deletions exir/_serialize/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ runtime.python_library(
"_dataclass.py",
"_flatbuffer.py",
"_program.py",
"padding.py",
],
resources = {
"//executorch/schema:program.fbs": "program.fbs",
Expand Down
48 changes: 9 additions & 39 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
_program_json_to_flatbuffer,
)

from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required

from executorch.exir.schema import (
BackendDelegateDataReference,
BackendDelegateInlineData,
Expand Down Expand Up @@ -50,19 +52,6 @@ def _json_to_program(program_json: bytes) -> Program:
return _json_to_dataclass(json.loads(program_json), cls=Program)


def _padding_required(offset: int, alignment: int) -> int:
"""Returns the padding required to align `offset` to `alignment`."""
remainder: int = offset % alignment
if remainder != 0:
return alignment - remainder
return 0


def _aligned_size(input_size: int, alignment: int) -> int:
"""Returns input_size padded up to the next whole multiple of alignment."""
return input_size + _padding_required(input_size, alignment)


def _insert_flatbuffer_header(
flatbuffer_data: bytes, magic_regex: str, header_data: bytes
) -> bytes:
Expand Down Expand Up @@ -211,25 +200,6 @@ def to_bytes(self) -> bytes:
return data


def _pad_to(data: bytes, length: int) -> bytes:
"""Returns the input followed by enough zero bytes to become the requested length.
Args:
data: The data to pad.
length: The length of the returned data.
Returns:
The padded data.
Raises:
ValueError: If the requested length is less than the input length.
"""
if length < len(data):
raise ValueError(f"Data length {len(data)} > padded length {length}")
if length > len(data):
data = data + b"\x00" * (length - len(data))
assert len(data) == length
return data


def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
"""Returns the extended header of the program data, if present and valid."""
try:
Expand Down Expand Up @@ -330,7 +300,7 @@ def _extract_constant_segment(
constant_segment_data.append(buffer.storage)
buffer_length = len(buffer.storage)
pad_length = (
_padding_required(buffer_length, tensor_alignment)
padding_required(buffer_length, tensor_alignment)
if tensor_alignment is not None
else 0
)
Expand Down Expand Up @@ -432,11 +402,11 @@ def serialize_pte_binary(
)
program.segments.append(
DataSegment(
offset=_aligned_size(prev_end, segment_alignment), size=len(data)
offset=aligned_size(prev_end, segment_alignment), size=len(data)
)
)
# Add to aggregate segments cord with padding.
padding_length = _padding_required(len(segments_data), segment_alignment)
padding_length = padding_required(len(segments_data), segment_alignment)
if padding_length > 0:
segments_data.append(b"\x00" * padding_length)
segments_data.append(data)
Expand All @@ -454,15 +424,15 @@ def serialize_pte_binary(

# Size of the header to insert. Its size is padded to the largest
# force_align value present in the schema.
padded_header_length: int = _aligned_size(
padded_header_length: int = aligned_size(
input_size=_ExtendedHeader.EXPECTED_LENGTH,
alignment=result.max_alignment,
)
# Size of the program with the header inserted.
program_size: int = padded_header_length + len(result.data)
# Offset to the first segment, or zero if there are no segments.
segment_base_offset: int = (
_aligned_size(input_size=program_size, alignment=segment_alignment)
aligned_size(input_size=program_size, alignment=segment_alignment)
if len(segments_data) > 0
else 0
)
Expand All @@ -471,7 +441,7 @@ def serialize_pte_binary(
header_data: bytes = _ExtendedHeader(
program_size=program_size, segment_base_offset=segment_base_offset
).to_bytes()
header_data = _pad_to(header_data, padded_header_length)
header_data = pad_to(header_data, padded_header_length)

# Insert the header into the flatbuffer data.
program_data: bytes = _insert_flatbuffer_header(
Expand All @@ -496,7 +466,7 @@ def serialize_pte_binary(
# - segments data (optional); aligned to segment_alignment.
pte_data = Cord(program_data)
if len(segments_data) > 0:
padding_length = _padding_required(len(pte_data), segment_alignment)
padding_length = padding_required(len(pte_data), segment_alignment)
pte_data.append(b"\x00" * padding_length)
# The first segment after program data should start at the segment base offset.
assert (
Expand Down
35 changes: 35 additions & 0 deletions exir/_serialize/padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict


def pad_to(data: bytes, length: int) -> bytes:
"""Returns the input followed by enough zero bytes to become the requested length.
Args:
data: The data to pad.
length: The length of the returned data.
Returns:
The padded data.
Raises:
ValueError: If the requested length is less than the input length.
"""
if length < len(data):
raise ValueError(f"Data length {len(data)} > padded length {length}")
if length > len(data):
data = data + b"\x00" * (length - len(data))
assert len(data) == length
return data


def padding_required(offset: int, alignment: int) -> int:
"""Returns the padding required to align `offset` to `alignment`."""
remainder: int = offset % alignment
if remainder != 0:
return alignment - remainder
return 0


def aligned_size(input_size: int, alignment: int) -> int:
"""Returns input_size padded up to the next whole multiple of alignment."""
return input_size + padding_required(input_size, alignment)

0 comments on commit 8dadccf

Please sign in to comment.