Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization to devtools #7554

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -627,6 +628,15 @@ def check_node_count(self, input: Dict[Any, int]):

return self

def visualize(
self, reuse_server: bool = True, stage: Optional[str] = None, **kwargs
):
# import here to avoid importing model_explorer when it is not needed which is most of the time.
from executorch.devtools.visualization import visualize

visualize(self.get_artifact(stage), reuse_server=reuse_server, **kwargs)
return self

def run_method_and_compare_outputs(
self,
stage: Optional[str] = None,
Expand Down
11 changes: 11 additions & 0 deletions devtools/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from executorch.devtools.visualization.visualization_utils import ( # noqa: F401
ModelExplorerServer,
SingletonModelExplorerServer,
visualize,
)
119 changes: 119 additions & 0 deletions devtools/visualization/visualization_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import subprocess
import time

from executorch.exir import EdgeProgramManager, ExecutorchProgramManager
from model_explorer import config, consts, visualize_from_config # type: ignore
from torch.export.exported_program import ExportedProgram


class SingletonModelExplorerServer:
"""Singleton context manager for starting a model-explorer server.
If multiple ModelExplorerServer contexts are nested, a single
server is still used.
"""

server: None | subprocess.Popen = None
num_open: int = 0
wait_after_start = 2.0

def __init__(self, open_in_browser: bool = True, port: int | None = None):
if SingletonModelExplorerServer.server is None:
command = ["model-explorer"]
if not open_in_browser:
command.append("--no_open_in_browser")
if port is not None:
command.append("--port")
command.append(str(port))
SingletonModelExplorerServer.server = subprocess.Popen(command)

def __enter__(self):
SingletonModelExplorerServer.num_open = (
SingletonModelExplorerServer.num_open + 1
)
time.sleep(SingletonModelExplorerServer.wait_after_start)
return self

def __exit__(self, type, value, traceback):
SingletonModelExplorerServer.num_open = (
SingletonModelExplorerServer.num_open - 1
)
if SingletonModelExplorerServer.num_open == 0:
if SingletonModelExplorerServer.server is not None:
SingletonModelExplorerServer.server.kill()
try:
SingletonModelExplorerServer.server.wait(
SingletonModelExplorerServer.wait_after_start
)
except subprocess.TimeoutExpired:
SingletonModelExplorerServer.server.terminate()
SingletonModelExplorerServer.server = None


class ModelExplorerServer:
"""Context manager for starting a model-explorer server."""

wait_after_start = 2.0

def __init__(self, open_in_browser: bool = True, port: int | None = None):
command = ["model-explorer"]
if not open_in_browser:
command.append("--no_open_in_browser")
if port is not None:
command.append("--port")
command.append(str(port))
self.server = subprocess.Popen(command)

def __enter__(self):
time.sleep(self.wait_after_start)

def __exit__(self, type, value, traceback):
self.server.kill()
try:
self.server.wait(self.wait_after_start)
except subprocess.TimeoutExpired:
self.server.terminate()


def _get_exported_program(
visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
) -> ExportedProgram:
if isinstance(visualizable, ExportedProgram):
return visualizable
if isinstance(visualizable, (EdgeProgramManager, ExecutorchProgramManager)):
return visualizable.exported_program()
raise RuntimeError(f"Cannot get ExportedProgram from {visualizable}")


def visualize(
visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
reuse_server: bool = True,
no_open_in_browser: bool = False,
**kwargs,
):
"""Wraps the visualize_from_config call from model_explorer.
For convenicence, figures out how to find the exported_program
from EdgeProgramManager and ExecutorchProgramManager for you.

See https://github.com/google-ai-edge/model-explorer/wiki/4.-API-Guide#visualize-pytorch-models
for full documentation.
"""
cur_config = config()
settings = consts.DEFAULT_SETTINGS
cur_config.add_model_from_pytorch(
"Executorch",
exported_program=_get_exported_program(visualizable),
settings=settings,
)
if reuse_server:
cur_config.set_reuse_server()
visualize_from_config(
cur_config,
no_open_in_browser=no_open_in_browser,
**kwargs,
)
153 changes: 153 additions & 0 deletions devtools/visualization/visualization_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import time

import pytest
import torch
from executorch.backends.xnnpack.test.tester import Tester

from executorch.devtools.visualization import (
ModelExplorerServer,
SingletonModelExplorerServer,
visualization_utils,
visualize,
)
from executorch.exir import ExportedProgram
from model_explorer.config import ModelExplorerConfig # type: ignore


@pytest.fixture
def server():
"""Mock relevant calls in visualization.visualize and check that parameters have their expected value."""
monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context():
_called_reuse_server = False

def mock_set_reuse_server(self):
nonlocal _called_reuse_server
_called_reuse_server = True

def mock_add_model_from_pytorch(self, name, exported_program, settings):
assert isinstance(exported_program, ExportedProgram)

def mock_visualize_from_config(cur_config, no_open_in_browser):
pass

monkeypatch.setattr(
ModelExplorerConfig, "set_reuse_server", mock_set_reuse_server
)
monkeypatch.setattr(
ModelExplorerConfig, "add_model_from_pytorch", mock_add_model_from_pytorch
)
monkeypatch.setattr(
visualization_utils, "visualize_from_config", mock_visualize_from_config
)
yield monkeypatch.context
assert _called_reuse_server, "Did not call reuse_server"


class Linear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int = 3,
bias: bool = True,
):
super().__init__()
self.inputs = (torch.randn(5, 10, 25, in_features),)
self.fc = torch.nn.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
)

def get_inputs(self) -> tuple[torch.Tensor]:
return self.inputs

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc(x)


def test_visualize_manual_export(server):
with server():
model = Linear(20, 30)
exported_program = torch.export.export(model, model.get_inputs())
visualize(exported_program)
time.sleep(3.0)


def test_visualize_exported_program(server):
with server():
model = Linear(20, 30)
(
Tester(
model,
example_inputs=model.get_inputs(),
)
.export()
.visualize()
)


def test_visualize_to_edge(server):
with server():
model = Linear(20, 30)
(
Tester(
model,
example_inputs=model.get_inputs(),
)
.export()
.to_edge()
.visualize()
)


def test_visualize_partition(server):
with server():
model = Linear(20, 30)
(
Tester(
model,
example_inputs=model.get_inputs(),
)
.export()
.to_edge()
.partition()
.visualize()
)


def test_visualize_to_executorch(server):
with server():
model = Linear(20, 30)
(
Tester(
model,
example_inputs=model.get_inputs(),
)
.export()
.to_edge()
.partition()
.to_executorch()
.visualize()
)


if __name__ == "__main__":
"""A test to run locally to make sure that the web browser opens up
automatically as intended.
"""

test_visualize_manual_export(ModelExplorerServer)

with SingletonModelExplorerServer():
test_visualize_manual_export(SingletonModelExplorerServer)
test_visualize_exported_program(SingletonModelExplorerServer)
test_visualize_to_edge(SingletonModelExplorerServer)
test_visualize_partition(SingletonModelExplorerServer)
test_visualize_to_executorch(SingletonModelExplorerServer)
2 changes: 2 additions & 0 deletions install_requirements.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-25 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -169,6 +170,7 @@ def python_is_compatible():
"tomli", # Imported by extract_sources.py when using python < 3.11.
"wheel", # For building the pip package archive.
"zstd", # Imported by resolve_buck.py.
"ai-edge-model-explorer>=0.1.16", # For visualizing ExportedPrograms
]

# Assemble the list of requirements to actually install.
Expand Down