From 0a6a3622f6593bb5c47d269d4e606e6b9c5b033e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 6 Dec 2024 15:11:39 -0600 Subject: [PATCH] Fixup model exports and service. --- sharktank/sharktank/dynamo_exports/flux/ae.py | 2 +- .../sharktank/dynamo_exports/flux/export.py | 62 +- sharktank/sharktank/dynamo_exports/flux/te.py | 535 +----------------- .../shortfin_apps/flux/components/service.py | 111 +++- .../flux/examples/flux_dev_config_mixed.json | 2 +- .../shortfin_apps/flux/simple_client.py | 6 +- 6 files changed, 152 insertions(+), 566 deletions(-) diff --git a/sharktank/sharktank/dynamo_exports/flux/ae.py b/sharktank/sharktank/dynamo_exports/flux/ae.py index a0618d63c..b9363090e 100644 --- a/sharktank/sharktank/dynamo_exports/flux/ae.py +++ b/sharktank/sharktank/dynamo_exports/flux/ae.py @@ -346,7 +346,7 @@ def decode(self, z: Tensor) -> Tensor: pw=2, ) d_in = d_in / self.scale_factor + self.shift_factor - return self.decoder(d_in) + return self.decoder(d_in).clamp(-1, 1) def forward(self, x: Tensor) -> Tensor: return self.decode(self.encode(x)) diff --git a/sharktank/sharktank/dynamo_exports/flux/export.py b/sharktank/sharktank/dynamo_exports/flux/export.py index 7d82cab25..7ecd7c089 100644 --- a/sharktank/sharktank/dynamo_exports/flux/export.py +++ b/sharktank/sharktank/dynamo_exports/flux/export.py @@ -7,6 +7,9 @@ import os import re from dataclasses import dataclass +import math + +from einops import rearrange from iree.compiler.ir import Context from iree.turbine.aot import * @@ -16,7 +19,9 @@ import torch from diffusers.models.transformers import FluxTransformer2DModel -from te import ClipTextEncoderModule +from diffusers.models.autoencoders import AutoencoderKL +from te import HFEmbedder +from transformers import CLIPTextModel from ae import AutoEncoder, AutoEncoderParams from scheduler import FluxScheduler from mmdit import get_flux_transformer_model @@ -107,17 +112,14 @@ def get_te_model_and_inputs( ): match component: case "clip": - # te = CLIPTextModel.from_pretrained( - # model_repo_map[hf_model_name], - # subfolder="text_encoder" - # ) - te = ClipTextEncoderModule( - model_repo_map[hf_model_name], torch_dtypes[precision] + te = HFEmbedder( + "openai/clip-vit-large-patch14", + max_length=77, + torch_dtype=torch.float32, ) clip_ids_shape = ( batch_size, 77, - 2, ) input_args = [ torch.ones(clip_ids_shape, dtype=torch.int64), @@ -127,12 +129,34 @@ def get_te_model_and_inputs( return None, None +class FluxAEWrapper(torch.nn.Module): + def __init__(self, height=1024, width=1024): + super().__init__() + self.ae = AutoencoderKL.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="vae" + ) + self.height = height + self.width = width + + def forward(self, z): + d_in = rearrange( + z, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(self.height / 16), + w=math.ceil(self.width / 16), + ph=2, + pw=2, + ) + d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor + return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1) + + def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width): dtype = torch_dtypes[precision] aeparams = fluxconfigs[hf_model_name].ae_params aeparams.height = height aeparams.width = width - ae = AutoEncoder(params=aeparams).to(dtype) + ae = FluxAEWrapper(height, width) latents_shape = ( batch_size, int(height * width / 256), @@ -252,14 +276,14 @@ class CompiledFluxTextEncoder(CompiledModule): fxb = FxProgramsBuilder(model) - @fxb.export_program( - args=(encode_inputs,), - ) - def _encode( - module, - inputs, - ): - return module.encode(*inputs) + # @fxb.export_program( + # args=(encode_inputs,), + # ) + # def _encode( + # module, + # inputs, + # ): + # return module.encode(*inputs) @fxb.export_program( args=(decode_inputs,), @@ -268,10 +292,10 @@ def _decode( module, inputs, ): - return module.decode(*inputs) + return module.forward(*inputs) class CompiledFluxAutoEncoder(CompiledModule): - encode = _encode + # encode = _encode decode = _decode if external_weights: diff --git a/sharktank/sharktank/dynamo_exports/flux/te.py b/sharktank/sharktank/dynamo_exports/flux/te.py index 59d728d2c..74d17a665 100644 --- a/sharktank/sharktank/dynamo_exports/flux/te.py +++ b/sharktank/sharktank/dynamo_exports/flux/te.py @@ -1,524 +1,29 @@ -### This file contains impls for underlying related models (CLIP, T5, etc) +from torch import Tensor, nn +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer -import torch, math -from torch import nn -from transformers import CLIPTokenizer, T5TokenizerFast -from transformers import T5EncoderModel -from iree.turbine import ops -from huggingface_hub import hf_hub_download -from safetensors import safe_open -from sharktank.layers import T5Config -from sharktank.models import t5 - -CLIP_CONFIG = { - "hidden_act": "quick_gelu", - "hidden_size": 768, - "intermediate_size": 3172, - "num_attention_heads": 12, - "num_hidden_layers": 12, -} - - -class ClipTextEncoderModule(torch.nn.Module): - @torch.no_grad() - def __init__( - self, - repo, - precision, - ): - super().__init__() - self.dtype = torch.float16 if precision == "fp16" else torch.float32 - self.clip = SDClipModel( - layer="hidden", - layer_idx=-2, - device="cpu", - dtype=self.dtype, - layer_norm_hidden_state=False, - return_projected_pooled=True, - textmodel_json_config=CLIP_CONFIG, - ) - if precision == "fp16": - self.clip = self.clip.half() - clip_weights = hf_hub_download( - repo_id=repo, - filename="text_encoder/model.safetensors", - ) - with safe_open(clip_weights, framework="pt", device="cpu") as f: - load_into(f, self.clip.transformer, "", "cpu", self.dtype) - - def forward(self, clip_ids): - vec = self.clip(clip_ids) - - return vec - - -################################################################################################# -### Core/Utility -################################################################################################# - - -def attention(q, k, v, heads, mask=None): - """Convenience wrapper around a basic attention operation""" - b, _, dim_head = q.shape - # ops.iree.trace_tensor("attention_q", q[0,0,:5]) - # ops.iree.trace_tensor("attention_k", k[0,0,:5]) - # ops.iree.trace_tensor("attention_v", v[0,0,:5]) - dim_head //= heads - q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) - out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False - ) - # ops.iree.trace_tensor("attention_out", out[0,0,:5]) - return out.transpose(1, 2).reshape(b, -1, heads * dim_head) - - -class Mlp(nn.Module): - """MLP as used in Vision Transformer, MLP-Mixer and related networks""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - bias=True, - dtype=None, - device=None, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - - self.fc1 = nn.Linear( - in_features, hidden_features, bias=bias, dtype=dtype, device=device - ) - self.act = act_layer - self.fc2 = nn.Linear( - hidden_features, out_features, bias=bias, dtype=dtype, device=device - ) - - def forward(self, x): - x = self.fc1(x) - # ops.iree.trace_tensor("mlpfx", x[0,0,:5]) - x = self.act(x) - # ops.iree.trace_tensor("mlpact", x[0,0,:5]) - x = self.fc2(x) - # ops.iree.trace_tensor("mlpanotherfc", x[0,0,:5]) - return x - - -def load_into(f, model, prefix, device, dtype=None): - """Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module.""" - for key in f.keys(): - if key.startswith(prefix) and not key.startswith("loss."): - path = key[len(prefix) :].split(".") - obj = model - for p in path: - if obj is list: - obj = obj[int(p)] - else: - obj = getattr(obj, p, None) - if obj is None: - print( - f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model" - ) - break - if obj is None: - continue - try: - tensor = f.get_tensor(key).to(device=device) - if dtype is not None: - tensor = tensor.to(dtype=dtype) - obj.requires_grad_(False) - obj.set_(tensor) - except Exception as e: - print(f"Failed to load key '{key}' in safetensors file: {e}") - raise e - - -################################################################################################# -### CLIP -################################################################################################# - - -class CLIPAttention(torch.nn.Module): - def __init__(self, embed_dim, heads, dtype, device): - super().__init__() - self.heads = heads - self.q_proj = nn.Linear( - embed_dim, embed_dim, bias=True, dtype=dtype, device=device - ) - self.k_proj = nn.Linear( - embed_dim, embed_dim, bias=True, dtype=dtype, device=device - ) - self.v_proj = nn.Linear( - embed_dim, embed_dim, bias=True, dtype=dtype, device=device - ) - self.out_proj = nn.Linear( - embed_dim, embed_dim, bias=True, dtype=dtype, device=device - ) - - def forward(self, x, mask=None): - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) - out = attention(q, k, v, self.heads, mask) - return self.out_proj(out) - - -ACTIVATIONS = { - "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), - "gelu": torch.nn.functional.gelu, -} - - -class CLIPLayer(torch.nn.Module): - def __init__( - self, - embed_dim, - heads, - intermediate_size, - intermediate_activation, - dtype, - device, - ): - super().__init__() - self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) - self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) - self.mlp = Mlp( - embed_dim, - intermediate_size, - embed_dim, - act_layer=ACTIVATIONS[intermediate_activation], - dtype=dtype, - device=device, - ) - - def forward(self, x, mask=None): - x += self.self_attn(self.layer_norm1(x), mask) - x += self.mlp(self.layer_norm2(x)) - return x - - -class CLIPEncoder(torch.nn.Module): - def __init__( - self, - num_layers, - embed_dim, - heads, - intermediate_size, - intermediate_activation, - dtype, - device, - ): - super().__init__() - self.layers = torch.nn.ModuleList( - [ - CLIPLayer( - embed_dim, - heads, - intermediate_size, - intermediate_activation, - dtype, - device, - ) - for i in range(num_layers) - ] - ) - - def forward(self, x, mask=None, intermediate_output=None): - if intermediate_output is not None: - if intermediate_output < 0: - intermediate_output = len(self.layers) + intermediate_output - intermediate = None - for i, l in enumerate(self.layers): - x = l(x, mask) - if i == intermediate_output: - intermediate = x.clone() - return x, intermediate - - -class CLIPEmbeddings(torch.nn.Module): - def __init__( - self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None - ): +# Copied from https://github.com/black-forest-labs/flux +class HFEmbedder(nn.Module): + def __init__(self, version: str, max_length: int, **hf_kwargs): super().__init__() - self.token_embedding = torch.nn.Embedding( - vocab_size, embed_dim, dtype=dtype, device=device - ) - self.position_embedding = torch.nn.Embedding( - num_positions, embed_dim, dtype=dtype, device=device - ) - - def forward(self, input_tokens): - return self.token_embedding(input_tokens) + self.position_embedding.weight - - -class CLIPTextModel_(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - num_layers = config_dict["num_hidden_layers"] - embed_dim = config_dict["hidden_size"] - heads = config_dict["num_attention_heads"] - intermediate_size = config_dict["intermediate_size"] - intermediate_activation = config_dict["hidden_act"] - super().__init__() - self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) - self.encoder = CLIPEncoder( - num_layers, - embed_dim, - heads, - intermediate_size, - intermediate_activation, - dtype, - device, - ) - self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - - def forward( - self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True - ): - x = self.embeddings(input_tokens) - causal_mask = ( - torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) - .fill_(float("-inf")) - .triu_(1) - ) - x, i = self.encoder( - x, mask=causal_mask, intermediate_output=intermediate_output - ) - x = self.final_layer_norm(x) - if i is not None and final_layer_norm_intermediate: - i = self.final_layer_norm(i) - pooled_output = x[ - torch.arange(x.shape[0], device=x.device), - input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), - ] - return x, i, pooled_output - - -class CLIPTextModel(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - super().__init__() - self.num_layers = config_dict["num_hidden_layers"] - self.text_model = CLIPTextModel_(config_dict, dtype, device) - embed_dim = config_dict["hidden_size"] - self.text_projection = nn.Linear( - embed_dim, embed_dim, bias=False, dtype=dtype, device=device - ) - self.text_projection.weight.copy_(torch.eye(embed_dim)) - self.dtype = dtype - - def get_input_embeddings(self): - return self.text_model.embeddings.token_embedding - - def set_input_embeddings(self, embeddings): - self.text_model.embeddings.token_embedding = embeddings - - def forward(self, *args, **kwargs): - x = self.text_model(*args, **kwargs) - out = self.text_projection(x[2]) - return (x[0], x[1], out, x[2]) - - -class SDTokenizer: - def __init__( - self, - max_length=77, - pad_with_end=True, - tokenizer=None, - has_start_token=True, - pad_to_max_length=True, - min_length=None, - ): - self.tokenizer = tokenizer + self.is_clip = version.startswith("openai") self.max_length = max_length - self.min_length = min_length - empty = self.tokenizer("")["input_ids"] - if has_start_token: - self.tokens_start = 1 - self.start_token = empty[0] - self.end_token = empty[1] - else: - self.tokens_start = 0 - self.start_token = None - self.end_token = empty[0] - self.pad_with_end = pad_with_end - self.pad_to_max_length = pad_to_max_length - vocab = self.tokenizer.get_vocab() - self.inv_vocab = {v: k for k, v in vocab.items()} - self.max_word_length = 8 + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" - def tokenize_with_weights(self, text: str): - """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" - if self.pad_with_end: - pad_token = self.end_token - else: - pad_token = 0 - batch = [] - if self.start_token is not None: - batch.append((self.start_token, 1.0)) - to_tokenize = text.replace("\n", " ").split(" ") - to_tokenize = [x for x in to_tokenize if x != ""] - for word in to_tokenize: - batch.extend( - [ - (t, 1) - for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1] - ] + if self.is_clip: + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( + version, **hf_kwargs ) - batch.append((self.end_token, 1.0)) - if self.pad_to_max_length: - batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) - if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) - return [batch] - - -class SDXLClipGTokenizer(SDTokenizer): - def __init__(self, tokenizer): - super().__init__(pad_with_end=False, tokenizer=tokenizer) - - -class SD3Tokenizer: - def __init__(self): - clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) - self.clip_g = SDXLClipGTokenizer(clip_tokenizer) - self.t5xxl = T5XXLTokenizer() - - def tokenize_with_weights(self, text: str | list[str]): - out = {} - if isinstance(text, list): - text = text[0] - out["g"] = self.clip_g.tokenize_with_weights(text) - out["l"] = self.clip_l.tokenize_with_weights(text) - out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) - for k, v in out.items(): - out[k] = torch.tensor(v, dtype=torch.int64, device="cpu") - return out - - -class ClipTokenWeightEncoder: - def encode_token_weights(self, token_weight_pairs): - # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - tokens = token_weight_pairs[:, :, 0] - out, pooled = self(tokens) - if pooled is not None: - first_pooled = pooled[0:1].cpu() - else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2).cpu(), first_pooled - - -class SDClipModel(torch.nn.Module): - """Uses the CLIP transformer encoder for text (from huggingface)""" - - LAYERS = ["last", "pooled", "hidden"] - - def __init__( - self, - device="cpu", - max_length=77, - layer="last", - layer_idx=None, - textmodel_json_config=None, - dtype=None, - model_class=CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, - layer_norm_hidden_state=True, - return_projected_pooled=True, - ): - super().__init__() - assert layer in self.LAYERS - self.transformer = model_class(textmodel_json_config, dtype, device) - self.num_layers = self.transformer.num_layers - self.max_length = max_length - self.transformer = self.transformer.eval() - for param in self.parameters(): - param.requires_grad = False - self.layer = layer - self.layer_idx = None - self.special_tokens = special_tokens - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.layer_norm_hidden_state = layer_norm_hidden_state - self.return_projected_pooled = return_projected_pooled - if layer == "hidden": - assert layer_idx is not None - assert abs(layer_idx) < self.num_layers - self.set_clip_options({"layer": layer_idx}) - self.options_default = ( - self.layer, - self.layer_idx, - self.return_projected_pooled, - ) - - def encode_token_weights(self, token_weight_pairs): - pass - - def set_clip_options(self, options): - layer_idx = options.get("layer", self.layer_idx) - self.return_projected_pooled = options.get( - "projected_pooled", self.return_projected_pooled - ) - if layer_idx is None or abs(layer_idx) > self.num_layers: - self.layer = "last" else: - self.layer = "hidden" - self.layer_idx = layer_idx - - def forward(self, token_weight_pairs): - # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - tokens = token_weight_pairs[:, :, 0] - # backup_embeds = self.transformer.get_input_embeddings() - # device = backup_embeds.weight.device - # tokens = torch.LongTensor(tokens).to(device) - outputs = self.transformer( - tokens, - intermediate_output=self.layer_idx, - final_layer_norm_intermediate=self.layer_norm_hidden_state, - ) - # self.transformer.set_input_embeddings(backup_embeds) - if self.layer == "last": - z = outputs[0] - else: - z = outputs[1] - pooled_output = None - if len(outputs) >= 3: - if ( - not self.return_projected_pooled - and len(outputs) >= 4 - and outputs[3] is not None - ): - pooled_output = outputs[3].float() - elif outputs[2] is not None: - pooled_output = outputs[2].float() - out, pooled = z.float(), pooled_output - if pooled is not None: - first_pooled = pooled[0:1].cpu() - else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2).cpu(), first_pooled - + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( + version, **hf_kwargs + ) -class SDXLClipG(SDClipModel): - """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + self.hf_module = self.hf_module.eval().requires_grad_(False) - def __init__( - self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None - ): - if layer == "penultimate": - layer = "hidden" - layer_idx = -2 - super().__init__( - device=device, - layer=layer, - layer_idx=layer_idx, - textmodel_json_config=config, - dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 0}, - layer_norm_hidden_state=False, + def forward(self, input_ids) -> Tensor: + outputs = self.hf_module( + input_ids=input_ids, + attention_mask=None, + output_hidden_states=False, ) + return outputs[self.output_key] diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py index 6d723bbe6..5d1f58f03 100644 --- a/shortfin/python/shortfin_apps/flux/components/service.py +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -6,9 +6,12 @@ import asyncio import logging +import math +import torch import numpy as np from tqdm.auto import tqdm from pathlib import Path +from typing import Callable from PIL import Image import base64 @@ -21,6 +24,8 @@ from .tokenizer import Tokenizer from .metrics import measure +from einops import rearrange + logger = logging.getLogger("shortfin-flux.service") prog_isolations = { @@ -30,6 +35,37 @@ } +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + class GenerateService: """Top level service interface for image generation.""" @@ -420,9 +456,10 @@ async def _prepare(self, device, requests): # Generate random sample latents. seed = request.seed channels = self.service.model_params.num_latents_channels + image_seq_len = (request.height) * (request.width) // 256 latents_shape = [ 1, - (requests[0].height) * (requests[0].width) // 256, + image_seq_len, 64, ] # latents_shape = ( @@ -446,6 +483,11 @@ async def _prepare(self, device, requests): request.sample.copy_from(sample_host) await device + request.timesteps = get_schedule( + request.steps, + image_seq_len, + shift=not self.service.model_params.is_schnell, + ) return async def _clip(self, device, requests): @@ -464,7 +506,7 @@ async def _clip(self, device, requests): clip_inputs = [ sfnp.device_array.for_device( device, - [req_bs, self.service.model_params.clip_max_seq_len, 2], + [req_bs, self.service.model_params.clip_max_seq_len], sfnp.sint64, ), ] @@ -486,12 +528,12 @@ async def _clip(self, device, requests): fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) - (vec, _) = await fn(*clip_inputs, fiber=self.fiber) + (vec,) = await fn(*clip_inputs, fiber=self.fiber) await device for i in range(req_bs): cfg_mult = 2 - requests[i].vec = vec.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) + requests[i].vec = vec.view(slice(i, (i + 1))) return @@ -580,7 +622,9 @@ async def _denoise(self, device, requests): device, vec_shape, self.service.model_params.sampler_dtype ), "step": sfnp.device_array.for_device(device, [1], sfnp.int64), - "num_steps": sfnp.device_array.for_device(device, [1], sfnp.int64), + "timesteps": sfnp.device_array.for_device( + device, [100], self.service.model_params.sampler_dtype + ), "guidance_scale": sfnp.device_array.for_device( device, [req_bs], self.service.model_params.sampler_dtype ), @@ -618,17 +662,20 @@ async def _denoise(self, device, requests): # Batch CLIP projections. vec = requests[i].vec - denoise_inputs["vec"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( - vec - ) + for nc in range(2): + denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec) denoise_inputs["guidance_scale"].copy_from(gs_host) - - ns_host = denoise_inputs["num_steps"].for_transfer() - with ns_host.map(write=True) as m: - ns_host.items = [step_count] - - denoise_inputs["num_steps"].copy_from(ns_host) + await device + ts_host = denoise_inputs["timesteps"].for_transfer() + with ts_host.map(write=True) as m: + m.fill(float(1)) + for tstep in range(len(requests[0].timesteps)): + with ts_host.view(tstep).map(write=True, discard=True) as m: + m.fill(np.asarray(requests[0].timesteps[tstep], dtype="float32")) + + denoise_inputs["timesteps"].copy_from(ts_host) + await device for i, t in tqdm( enumerate(range(step_count)), @@ -640,14 +687,25 @@ async def _denoise(self, device, requests): s_host.items = [i] denoise_inputs["step"].copy_from(s_host) - logger.debug( + logger.info( "INVOKE %r", fns["sampler"], ) + await device + # np_arrs = {} + # host_arrs = {} + # for key, value in denoise_inputs.items(): + # host_arrs[key] = denoise_inputs[key].for_transfer() + # host_arrs[key].copy_from(denoise_inputs[key]) + # await device + # np_arrs[key] = np.array(host_arrs[key]) + # for key, value in np_arrs.items(): + # np.save(f"{key}.npy", value) + (noise_pred,) = await fns["sampler"]( *denoise_inputs.values(), fiber=self.fiber ) - + await device denoise_inputs["img"].copy_from(noise_pred) for idx, req in enumerate(requests): @@ -671,7 +729,7 @@ async def _decode(self, device, requests): await device latents_shape = [ req_bs, - (requests[0].height) * (requests[0].width) // 256, + (requests[0].height * requests[0].width) // 256, 64, ] latents = sfnp.device_array.for_device( @@ -688,7 +746,6 @@ async def _decode(self, device, requests): "".join([f"\n 0: {latents.shape}"]), ) (image,) = await fn(latents, fiber=self.fiber) - await device images_shape = [ req_bs, @@ -709,23 +766,23 @@ async def _postprocess(self, device, requests): # Process output images for req in requests: image_shape = [ - 1, 3, req.height, req.width, ] + out_shape = [req.height, req.width, 3] images_planar = sfnp.device_array.for_host( device, image_shape, self.service.model_params.vae_dtype ) images_planar.copy_from(req.image_array) - for j in range(3): - data = [0.3 + j * 0.1 for _ in range(req.height * req.width)] - images_planar.view(0, j).items = data - permuted = sfnp.transpose(images_planar, (0, 2, 3, 1)) - cast_image = sfnp.multiply(127.5, (sfnp.add(permuted, 1.0))) - image = sfnp.round(cast_image, dtype=sfnp.uint8) - - image_bytes = bytes(image.map(read=True)) + permuted = sfnp.device_array.for_host( + device, out_shape, self.service.model_params.vae_dtype + ) + out = sfnp.device_array.for_host(device, out_shape, sfnp.uint8) + sfnp.transpose(images_planar, (1, 2, 0), out=permuted) + permuted = sfnp.multiply(127.5, (sfnp.add(permuted, 1.0))) + out = sfnp.round(permuted, dtype=sfnp.uint8) + image_bytes = bytes(out.map(read=True)) image = base64.b64encode(image_bytes).decode("utf-8") req.result_image = image diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json index cbe44680d..46a20af78 100644 --- a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json +++ b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json @@ -7,7 +7,7 @@ "clip_batch_sizes": [ 1 ], - "clip_dtype": "bfloat16", + "clip_dtype": "float32", "clip_module_name": "compiled_flux_text_encoder", "t5xxl_batch_sizes": [ 1 diff --git a/shortfin/python/shortfin_apps/flux/simple_client.py b/shortfin/python/shortfin_apps/flux/simple_client.py index 42a3e02f3..9382ddd9d 100644 --- a/shortfin/python/shortfin_apps/flux/simple_client.py +++ b/shortfin/python/shortfin_apps/flux/simple_client.py @@ -19,13 +19,13 @@ sample_request = { "prompt": [ - " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " A mountain with a halo cloud over it, Death Mountain, spooky, Zelda", ], "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], "height": [1024], "width": [1024], - "steps": [20], - "guidance_scale": [3], + "steps": [50], + "guidance_scale": [3.5], "seed": [0], "output_type": ["base64"], "rid": ["string"],