Skip to content

Commit

Permalink
Fixup model exports and service.
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Jan 3, 2025
1 parent 56c3e4f commit 0a6a362
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 566 deletions.
2 changes: 1 addition & 1 deletion sharktank/sharktank/dynamo_exports/flux/ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def decode(self, z: Tensor) -> Tensor:
pw=2,
)
d_in = d_in / self.scale_factor + self.shift_factor
return self.decoder(d_in)
return self.decoder(d_in).clamp(-1, 1)

def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))
62 changes: 43 additions & 19 deletions sharktank/sharktank/dynamo_exports/flux/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import os
import re
from dataclasses import dataclass
import math

from einops import rearrange

from iree.compiler.ir import Context
from iree.turbine.aot import *
Expand All @@ -16,7 +19,9 @@
import torch

from diffusers.models.transformers import FluxTransformer2DModel
from te import ClipTextEncoderModule
from diffusers.models.autoencoders import AutoencoderKL
from te import HFEmbedder
from transformers import CLIPTextModel
from ae import AutoEncoder, AutoEncoderParams
from scheduler import FluxScheduler
from mmdit import get_flux_transformer_model
Expand Down Expand Up @@ -107,17 +112,14 @@ def get_te_model_and_inputs(
):
match component:
case "clip":
# te = CLIPTextModel.from_pretrained(
# model_repo_map[hf_model_name],
# subfolder="text_encoder"
# )
te = ClipTextEncoderModule(
model_repo_map[hf_model_name], torch_dtypes[precision]
te = HFEmbedder(
"openai/clip-vit-large-patch14",
max_length=77,
torch_dtype=torch.float32,
)
clip_ids_shape = (
batch_size,
77,
2,
)
input_args = [
torch.ones(clip_ids_shape, dtype=torch.int64),
Expand All @@ -127,12 +129,34 @@ def get_te_model_and_inputs(
return None, None


class FluxAEWrapper(torch.nn.Module):
def __init__(self, height=1024, width=1024):
super().__init__()
self.ae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae"
)
self.height = height
self.width = width

def forward(self, z):
d_in = rearrange(
z,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(self.height / 16),
w=math.ceil(self.width / 16),
ph=2,
pw=2,
)
d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor
return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1)


def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width):
dtype = torch_dtypes[precision]
aeparams = fluxconfigs[hf_model_name].ae_params
aeparams.height = height
aeparams.width = width
ae = AutoEncoder(params=aeparams).to(dtype)
ae = FluxAEWrapper(height, width)
latents_shape = (
batch_size,
int(height * width / 256),
Expand Down Expand Up @@ -252,14 +276,14 @@ class CompiledFluxTextEncoder(CompiledModule):

fxb = FxProgramsBuilder(model)

@fxb.export_program(
args=(encode_inputs,),
)
def _encode(
module,
inputs,
):
return module.encode(*inputs)
# @fxb.export_program(
# args=(encode_inputs,),
# )
# def _encode(
# module,
# inputs,
# ):
# return module.encode(*inputs)

@fxb.export_program(
args=(decode_inputs,),
Expand All @@ -268,10 +292,10 @@ def _decode(
module,
inputs,
):
return module.decode(*inputs)
return module.forward(*inputs)

class CompiledFluxAutoEncoder(CompiledModule):
encode = _encode
# encode = _encode
decode = _decode

if external_weights:
Expand Down
Loading

0 comments on commit 0a6a362

Please sign in to comment.