Skip to content

Commit

Permalink
Buckify Llama multimodal export (#7604)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackzhxng authored Feb 4, 2025
1 parent ee6f2d9 commit b02c692
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 14 deletions.
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ runtime.python_library(
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
"//caffe2:torch",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/exir/passes:init_mutable_pass",
"//executorch/examples/models:model_base",
"//executorch/examples/models:models",
"//executorch/exir/passes:init_mutable_pass",
Expand Down
14 changes: 14 additions & 0 deletions examples/models/llama3_2_vision/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("executorch")

python_library(
name = "multimodal_lib",
srcs = [
"__init__.py",
],
deps = [
"//executorch/examples/models/llama3_2_vision/text_decoder:model",
"//executorch/examples/models/llama3_2_vision/vision_encoder:model",
],
)
17 changes: 17 additions & 0 deletions examples/models/llama3_2_vision/text_decoder/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("executorch")

python_library(
name = "model",
srcs = [
"model.py",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models:checkpoint",
"//pytorch/torchtune:lib",
"//executorch/extension/llm/modules:module_lib",
],
)

14 changes: 0 additions & 14 deletions examples/models/llama3_2_vision/text_decoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,6 @@ def __init__(self, **kwargs):
print(unexpected)
print("============= /unexpected ================")

# Prune the output layer if output_prune_map is provided.
output_prune_map = None
if self.output_prune_map_path is not None:
from executorch.examples.models.llama2.source_transformation.prune_output import (
prune_output_vocab,
)

with open(self.output_prune_map_path, "r") as f:
output_prune_map = json.load(f)
# Change keys from string to int (json only supports string keys)
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}

self.model_ = prune_output_vocab(self.model_, output_prune_map)

if self.use_kv_cache:
print("Setting up KV cache on the model...")
self.model_.setup_caches(
Expand Down
17 changes: 17 additions & 0 deletions examples/models/llama3_2_vision/vision_encoder/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("executorch")

python_library(
name = "model",
srcs = [
"__init__.py",
"model.py",
],
deps = [
"//caffe2:torch",
"//executorch/extension/llm/modules:module_lib",
"//pytorch/torchtune:lib",
"//executorch/examples/models:model_base",
],
)
2 changes: 2 additions & 0 deletions examples/models/llama3_2_vision/vision_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors

from dataclasses import dataclass, field
from typing import Optional

Expand Down
49 changes: 49 additions & 0 deletions extension/llm/modules/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("executorch")

python_library(
name = "kv_cache",
srcs = [
"kv_cache.py",
],
deps = [
"//caffe2:torch",
"//pytorch/torchtune:lib",
],
)

python_library(
name = "attention",
srcs = [
"attention.py",
],
deps = [
":kv_cache",
"//caffe2:torch",
"//executorch/extension/llm/custom_ops:custom_ops",
"//pytorch/torchtune:lib",
],
)

python_library(
name = "position_embeddings",
srcs = [
"_position_embeddings.py",
],
deps = [
"//caffe2:torch",
],
)

python_library(
name = "module_lib",
srcs = [
"__init__.py",
],
deps= [
":position_embeddings",
":attention",
":kv_cache",
]
)
2 changes: 2 additions & 0 deletions extension/llm/modules/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# Added torch._check() to make sure guards on symints are enforced.
# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py

# pyre-ignore-all-errors

import logging
import math
from typing import Any, Dict, Tuple
Expand Down
2 changes: 2 additions & 0 deletions extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors

import logging
from typing import Optional

Expand Down

0 comments on commit b02c692

Please sign in to comment.