Skip to content

Commit

Permalink
[executorch][serialization] Serialize PTD files.
Browse files Browse the repository at this point in the history
Pull Request resolved: #7270

Introduce top-level serialization file that calls:
- serialize_pte_binary for PTE file
- FlatTensor.serialize_tensors for PTD files.


ghstack-source-id: 260061339
@exported-using-ghexport

Differential Revision: [D66523267](https://our.internmc.facebook.com/intern/diff/D66523267/)
  • Loading branch information
lucylq committed Jan 3, 2025
1 parent f2d1eba commit 10cd49a
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 19 deletions.
1 change: 1 addition & 0 deletions exir/_serialize/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ runtime.python_library(
"_dataclass.py",
"_flatbuffer.py",
"_program.py",
"_serialize.py",
"data_serializer.py",
"padding.py",
],
Expand Down
87 changes: 87 additions & 0 deletions exir/_serialize/_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


from typing import Dict, Tuple

from executorch.exir._serialize import _serialize_pte_binary

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize.data_serializer import (
DataPayload,
DataSerializer,
TensorEntry,
TensorLayout,
)

from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.emit import EmitterOutput
from executorch.exir.schema import Tensor, TensorDataLocation


def serialize(
emitter_output: EmitterOutput,
config: ExecutorchBackendConfig,
data_serializer: DataSerializer,
) -> Tuple[Cord, Dict[str, Cord]]:
"""Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files."""

# Serialize PTE file.
pte: Cord = _serialize_pte_binary(
program=emitter_output.program,
mutable_data=emitter_output.mutable_data,
extract_delegate_segments=config.extract_delegate_segments,
segment_alignment=config.segment_alignment,
constant_tensor_alignment=config.constant_tensor_alignment,
delegate_alignment=config.delegate_alignment,
)

# Serialize PTD files.
ptd_files: Dict[str, Cord] = {}

# Find all external tensors and organize into {fqn: TensorLayout}.
fqn_to_tensor_layout: Dict[str, TensorLayout] = {}
for plan in emitter_output.program.execution_plan:
for evalue in plan.values:
if isinstance(evalue.val, Tensor):
tensor = evalue.val
if (
tensor.extra_tensor_info is not None
and tensor.extra_tensor_info.fully_qualified_name is not None
and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL
):
fqn_to_tensor_layout[
tensor.extra_tensor_info.fully_qualified_name
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)

if len(fqn_to_tensor_layout) > 0:
assert emitter_output.external_constant_map is not None
for (
file,
fqn_to_index,
) in (
# pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`.
emitter_output.external_constant_map.items()
):
# Create a TensorEntry for each external tensor.
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
for fqn, index in fqn_to_index.items():
assert fqn in fqn_to_tensor_layout
fqn_to_tensor_entry[fqn] = TensorEntry(
buffer_index=index,
layout=fqn_to_tensor_layout[fqn],
)

ptd_files[file] = data_serializer.serialize(
DataPayload(
buffers=emitter_output.external_constant_buffer,
fqn_to_tensor=fqn_to_tensor_entry,
)
)

return pte, ptd_files
1 change: 1 addition & 0 deletions exir/program/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ python_library(
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/passes:weights_to_outputs_pass",
"//executorch/exir/verification:verifier",
"//executorch/extension/flat_tensor/serialize:serialize",
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])
)

Expand Down
61 changes: 42 additions & 19 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import copy
import io
import logging
import os
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union

import torch
import torch._export
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._serialize import serialize
from executorch.exir._serialize.data_serializer import DataSerializer
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.partitioner import Partitioner
Expand Down Expand Up @@ -59,6 +61,7 @@
EXIREdgeDialectVerifier,
get_aten_verifier,
)
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch.export import ExportedProgram
from torch.export._remove_auto_functionalized_pass import (
Expand Down Expand Up @@ -497,23 +500,31 @@ def __init__(
)
self.exported_program = exir_exported_program.exported_program
self._pte_data: Optional[Cord] = None
self._data_files: Optional[Dict[str, Cord]] = None
self._buffer: Optional[bytes] = None
self._emitter_output: Optional[EmitterOutput] = None
self._emit_stacktrace: bool = emit_stacktrace
self._extract_delegate_segments: bool = extract_delegate_segments
self._segment_alignment: int = segment_alignment
self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment
self._delegate_alignment: Optional[int] = delegate_alignment
self._data_serializer: DataSerializer = FlatTensorSerializer()

def _get_emitter_output(self) -> EmitterOutput:
if self._emitter_output is None:
self._emitter_output = emit_program(
self.exported_program, self._emit_stacktrace
)
return self._emitter_output

def _get_pte_data(self) -> Cord:
if self._pte_data is None:
self._pte_data = _serialize_pte_binary(
program=self.program,
extract_delegate_segments=self._extract_delegate_segments,
segment_alignment=self._segment_alignment,
constant_tensor_alignment=self._constant_tensor_alignment,
delegate_alignment=self._delegate_alignment,
self._pte_data, self._data_files = serialize(
self._get_emitter_output(),
ExecutorchBackendConfig(),
self._data_serializer,
)
assert self._pte_data is not None
return self._pte_data

@property
Expand All @@ -532,11 +543,7 @@ def buffer(self) -> bytes:

@property
def program(self) -> Program:
if self._emitter_output is None:
self._emitter_output = emit_program(
self.exported_program, self._emit_stacktrace
)
return self._emitter_output.program
return self._get_emitter_output().program

@property
def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]:
Expand Down Expand Up @@ -571,6 +578,16 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
"""
self._get_pte_data().write_to_file(open_file)

def write_data_to_file(self, outdir) -> None:
"""
Writes the serialized ExecuTorch data files to the directory at `outdir`.
"""
assert self._data_files is not None
for filename, cord in self._data_files.items():
with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f:
logging.info(f"Writing data file to {filename}.ptd")
cord.write_to_file(f)


def _get_aten_to_edge_passes(config: EdgeCompileConfig):
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
Expand Down Expand Up @@ -1453,13 +1470,9 @@ def __init__(
)

# Serialize emitter output, ready to be written to a file.
self._pte_data: Cord = _serialize_pte_binary(
program=self._emitter_output.program,
mutable_data=self._emitter_output.mutable_data,
extract_delegate_segments=backend_config.extract_delegate_segments,
segment_alignment=backend_config.segment_alignment,
constant_tensor_alignment=backend_config.constant_tensor_alignment,
delegate_alignment=backend_config.delegate_alignment,
self._data_serializer = FlatTensorSerializer()
self._pte_data, self._data_files = serialize(
self._emitter_output, ExecutorchBackendConfig(), self._data_serializer
)
self._buffer: Optional[bytes] = None

Expand Down Expand Up @@ -1542,6 +1555,16 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
"""
self._pte_data.write_to_file(open_file)

def write_data_to_file(self, outdir) -> None:
"""
Writes the serialized ExecuTorch data files to the directory at `outdir`.
"""
assert self._data_files is not None
for filename, cord in self._data_files.items():
with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f:
logging.info(f"Writing data file to {filename}")
cord.write_to_file(f)

def save(self, path: str) -> None:
"""
Saves the serialized ExecuTorch binary to the file at `path`.
Expand Down
3 changes: 3 additions & 0 deletions extension/export_util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,12 @@ def save_pte_program(
filename = os.path.join(output_dir, f"{model_name}.pte")

try:
# Write program to file.
with open(filename, "wb") as file:
prog.write_to_file(file)
logging.info(f"Saved exported program to {filename}")
# Write data to file/s.
prog.write_data_to_file(outdir=output_dir)
except Exception as e:
logging.error(f"Error while saving to {filename}: {e}")

Expand Down

0 comments on commit 10cd49a

Please sign in to comment.