Skip to content

Commit

Permalink
dtype flexibility for flux.
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Jan 3, 2025
1 parent 0a6a362 commit 10cc491
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 21 deletions.
7 changes: 4 additions & 3 deletions sharktank/sharktank/dynamo_exports/flux/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ def get_te_model_and_inputs(


class FluxAEWrapper(torch.nn.Module):
def __init__(self, height=1024, width=1024):
def __init__(self, height=1024, width=1024, precision="fp32"):
super().__init__()
dtype = torch_dtypes[precision]
self.ae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae"
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtypes=dtype
)
self.height = height
self.width = width
Expand All @@ -156,7 +157,7 @@ def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width)
aeparams = fluxconfigs[hf_model_name].ae_params
aeparams.height = height
aeparams.width = width
ae = FluxAEWrapper(height, width)
ae = FluxAEWrapper(height, width, precision).to(dtype)
latents_shape = (
batch_size,
int(height * width / 256),
Expand Down
85 changes: 69 additions & 16 deletions shortfin/python/shortfin_apps/flux/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from .tokenizer import Tokenizer
from .metrics import measure

from einops import rearrange

logger = logging.getLogger("shortfin-flux.service")

prog_isolations = {
Expand Down Expand Up @@ -475,13 +473,26 @@ async def _prepare(self, device, requests):
device, latents_shape, self.service.model_params.sampler_dtype
)

sample_host = request.sample.for_transfer()
sample_host = sfnp.device_array.for_host(
device, latents_shape, sfnp.float32
)
with sample_host.map(discard=True) as m:
m.fill(bytes(1))

sfnp.fill_randn(sample_host, generator=generator)
if self.service.model_params.sampler_dtype != sfnp.float32:
sample_transfer = request.sample.for_transfer()
sfnp.convert(
sample_host,
dtype=self.service.model_params.sampler_dtype,
out=sample_transfer,
)

request.sample.copy_from(sample_transfer)
# sample_debug = torch.frombuffer(sample_transfer.items, dtype=torch.bfloat16)
# print(sample_debug)
else:
request.sample.copy_from(sample_host)

request.sample.copy_from(sample_host)
await device
request.timesteps = get_schedule(
request.steps,
Expand Down Expand Up @@ -634,13 +645,11 @@ async def _denoise(self, device, requests):
sample_host = sfnp.device_array.for_host(
device, img_shape, self.service.model_params.sampler_dtype
)
guidance_float = sfnp.device_array.for_host(device, [req_bs], sfnp.float32)

for i in range(req_bs):
guidance_float.view(i).items = [requests[i].guidance_scale]
cfg_dim = i * cfg_mult
with gs_host.view(i).map(write=True, discard=True) as m:
# TODO: do this without numpy
np_arr = np.asarray(requests[i].guidance_scale, dtype="float32")

m.fill(np_arr)

# Reshape and batch sample latent inputs on device.
# Currently we just generate random latents in the desired shape. Rework for img2img.
Expand All @@ -664,16 +673,24 @@ async def _denoise(self, device, requests):
vec = requests[i].vec
for nc in range(2):
denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec)

sfnp.convert(
guidance_float, dtype=self.service.model_params.sampler_dtype, out=gs_host
)
denoise_inputs["guidance_scale"].copy_from(gs_host)
await device
ts_host = denoise_inputs["timesteps"].for_transfer()
with ts_host.map(write=True) as m:
ts_float = sfnp.device_array.for_host(
device, denoise_inputs["timesteps"].shape, dtype=sfnp.float32
)
with ts_float.map(write=True) as m:
m.fill(float(1))
for tstep in range(len(requests[0].timesteps)):
with ts_host.view(tstep).map(write=True, discard=True) as m:
with ts_float.view(tstep).map(write=True, discard=True) as m:
m.fill(np.asarray(requests[0].timesteps[tstep], dtype="float32"))

sfnp.convert(
ts_float, dtype=self.service.model_params.sampler_dtype, out=ts_host
)
denoise_inputs["timesteps"].copy_from(ts_host)
await device

Expand Down Expand Up @@ -712,7 +729,33 @@ async def _denoise(self, device, requests):
req.denoised_latents = sfnp.device_array.for_device(
device, img_shape, self.service.model_params.vae_dtype
)
req.denoised_latents.copy_from(denoise_inputs["img"].view(idx * cfg_mult))
if (
self.service.model_params.vae_dtype
!= self.service.model_params.sampler_dtype
):
pred_shape = [
1,
(requests[0].height) * (requests[0].width) // 256,
64,
]
denoised_inter = sfnp.device_array.for_host(
device, pred_shape, dtype=self.service.model_params.vae_dtype
)
denoised_host = sfnp.device_array.for_host(
device, pred_shape, dtype=self.service.model_params.sampler_dtype
)
denoised_host.copy_from(denoise_inputs["img"].view(idx * cfg_mult))
await device
sfnp.convert(
denoised_host,
dtype=self.service.model_params.vae_dtype,
out=denoised_inter,
)
req.denoised_latents.copy_from(denoised_inter)
else:
req.denoised_latents.copy_from(
denoise_inputs["img"].view(idx * cfg_mult)
)
return

async def _decode(self, device, requests):
Expand All @@ -726,7 +769,6 @@ async def _decode(self, device, requests):
for bs, fn in entrypoints.items():
if bs == req_bs:
break
await device
latents_shape = [
req_bs,
(requests[0].height * requests[0].width) // 256,
Expand All @@ -735,10 +777,16 @@ async def _decode(self, device, requests):
latents = sfnp.device_array.for_device(
device, latents_shape, self.service.model_params.vae_dtype
)
# latents_host = sfnp.device_array.for_host(
# device, latents_shape, self.service.model_params.vae_dtype
# )
# latents_host.copy_from(latents)
# print(latents_host)
# lat_arr = np.array(latents_host, dtype="float32")
# np.save("vae_in.npy", lat_arr)
for i in range(req_bs):
latents.view(i).copy_from(requests[i].denoised_latents)

await device
# Decode the denoised latents.
logger.debug(
"INVOKE %r: %s",
Expand All @@ -756,7 +804,12 @@ async def _decode(self, device, requests):
images_host = sfnp.device_array.for_host(
device, images_shape, self.service.model_params.vae_dtype
)
await device
images_host.copy_from(image)
# await device
# print(images_host)
# img_arr = np.array(images_host, dtype="float32")
# np.save("vae_out.npy", img_arr)
await device
for idx, req in enumerate(requests):
req.image_array = images_host.view(idx)
Expand Down
4 changes: 2 additions & 2 deletions shortfin/python/shortfin_apps/flux/simple_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

sample_request = {
"prompt": [
" A mountain with a halo cloud over it, Death Mountain, spooky, Zelda",
" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
],
"neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"],
"height": [1024],
"width": [1024],
"steps": [50],
"steps": [2],
"guidance_scale": [3.5],
"seed": [0],
"output_type": ["base64"],
Expand Down

0 comments on commit 10cc491

Please sign in to comment.