diff --git a/sharktank/sharktank/dynamo_exports/flux/export.py b/sharktank/sharktank/dynamo_exports/flux/export.py index 7ecd7c089..9002aae18 100644 --- a/sharktank/sharktank/dynamo_exports/flux/export.py +++ b/sharktank/sharktank/dynamo_exports/flux/export.py @@ -130,10 +130,11 @@ def get_te_model_and_inputs( class FluxAEWrapper(torch.nn.Module): - def __init__(self, height=1024, width=1024): + def __init__(self, height=1024, width=1024, precision="fp32"): super().__init__() + dtype = torch_dtypes[precision] self.ae = AutoencoderKL.from_pretrained( - "black-forest-labs/FLUX.1-dev", subfolder="vae" + "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtypes=dtype ) self.height = height self.width = width @@ -156,7 +157,7 @@ def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width) aeparams = fluxconfigs[hf_model_name].ae_params aeparams.height = height aeparams.width = width - ae = FluxAEWrapper(height, width) + ae = FluxAEWrapper(height, width, precision).to(dtype) latents_shape = ( batch_size, int(height * width / 256), diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py index 5d1f58f03..b758536d3 100644 --- a/shortfin/python/shortfin_apps/flux/components/service.py +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -24,8 +24,6 @@ from .tokenizer import Tokenizer from .metrics import measure -from einops import rearrange - logger = logging.getLogger("shortfin-flux.service") prog_isolations = { @@ -475,13 +473,26 @@ async def _prepare(self, device, requests): device, latents_shape, self.service.model_params.sampler_dtype ) - sample_host = request.sample.for_transfer() + sample_host = sfnp.device_array.for_host( + device, latents_shape, sfnp.float32 + ) with sample_host.map(discard=True) as m: m.fill(bytes(1)) - sfnp.fill_randn(sample_host, generator=generator) + if self.service.model_params.sampler_dtype != sfnp.float32: + sample_transfer = request.sample.for_transfer() + sfnp.convert( + sample_host, + dtype=self.service.model_params.sampler_dtype, + out=sample_transfer, + ) + + request.sample.copy_from(sample_transfer) + # sample_debug = torch.frombuffer(sample_transfer.items, dtype=torch.bfloat16) + # print(sample_debug) + else: + request.sample.copy_from(sample_host) - request.sample.copy_from(sample_host) await device request.timesteps = get_schedule( request.steps, @@ -634,13 +645,11 @@ async def _denoise(self, device, requests): sample_host = sfnp.device_array.for_host( device, img_shape, self.service.model_params.sampler_dtype ) + guidance_float = sfnp.device_array.for_host(device, [req_bs], sfnp.float32) + for i in range(req_bs): + guidance_float.view(i).items = [requests[i].guidance_scale] cfg_dim = i * cfg_mult - with gs_host.view(i).map(write=True, discard=True) as m: - # TODO: do this without numpy - np_arr = np.asarray(requests[i].guidance_scale, dtype="float32") - - m.fill(np_arr) # Reshape and batch sample latent inputs on device. # Currently we just generate random latents in the desired shape. Rework for img2img. @@ -664,16 +673,24 @@ async def _denoise(self, device, requests): vec = requests[i].vec for nc in range(2): denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec) - + sfnp.convert( + guidance_float, dtype=self.service.model_params.sampler_dtype, out=gs_host + ) denoise_inputs["guidance_scale"].copy_from(gs_host) await device ts_host = denoise_inputs["timesteps"].for_transfer() - with ts_host.map(write=True) as m: + ts_float = sfnp.device_array.for_host( + device, denoise_inputs["timesteps"].shape, dtype=sfnp.float32 + ) + with ts_float.map(write=True) as m: m.fill(float(1)) for tstep in range(len(requests[0].timesteps)): - with ts_host.view(tstep).map(write=True, discard=True) as m: + with ts_float.view(tstep).map(write=True, discard=True) as m: m.fill(np.asarray(requests[0].timesteps[tstep], dtype="float32")) + sfnp.convert( + ts_float, dtype=self.service.model_params.sampler_dtype, out=ts_host + ) denoise_inputs["timesteps"].copy_from(ts_host) await device @@ -712,7 +729,33 @@ async def _denoise(self, device, requests): req.denoised_latents = sfnp.device_array.for_device( device, img_shape, self.service.model_params.vae_dtype ) - req.denoised_latents.copy_from(denoise_inputs["img"].view(idx * cfg_mult)) + if ( + self.service.model_params.vae_dtype + != self.service.model_params.sampler_dtype + ): + pred_shape = [ + 1, + (requests[0].height) * (requests[0].width) // 256, + 64, + ] + denoised_inter = sfnp.device_array.for_host( + device, pred_shape, dtype=self.service.model_params.vae_dtype + ) + denoised_host = sfnp.device_array.for_host( + device, pred_shape, dtype=self.service.model_params.sampler_dtype + ) + denoised_host.copy_from(denoise_inputs["img"].view(idx * cfg_mult)) + await device + sfnp.convert( + denoised_host, + dtype=self.service.model_params.vae_dtype, + out=denoised_inter, + ) + req.denoised_latents.copy_from(denoised_inter) + else: + req.denoised_latents.copy_from( + denoise_inputs["img"].view(idx * cfg_mult) + ) return async def _decode(self, device, requests): @@ -726,7 +769,6 @@ async def _decode(self, device, requests): for bs, fn in entrypoints.items(): if bs == req_bs: break - await device latents_shape = [ req_bs, (requests[0].height * requests[0].width) // 256, @@ -735,10 +777,16 @@ async def _decode(self, device, requests): latents = sfnp.device_array.for_device( device, latents_shape, self.service.model_params.vae_dtype ) + # latents_host = sfnp.device_array.for_host( + # device, latents_shape, self.service.model_params.vae_dtype + # ) + # latents_host.copy_from(latents) + # print(latents_host) + # lat_arr = np.array(latents_host, dtype="float32") + # np.save("vae_in.npy", lat_arr) for i in range(req_bs): latents.view(i).copy_from(requests[i].denoised_latents) - await device # Decode the denoised latents. logger.debug( "INVOKE %r: %s", @@ -756,7 +804,12 @@ async def _decode(self, device, requests): images_host = sfnp.device_array.for_host( device, images_shape, self.service.model_params.vae_dtype ) + await device images_host.copy_from(image) + # await device + # print(images_host) + # img_arr = np.array(images_host, dtype="float32") + # np.save("vae_out.npy", img_arr) await device for idx, req in enumerate(requests): req.image_array = images_host.view(idx) diff --git a/shortfin/python/shortfin_apps/flux/simple_client.py b/shortfin/python/shortfin_apps/flux/simple_client.py index 9382ddd9d..cfe284e21 100644 --- a/shortfin/python/shortfin_apps/flux/simple_client.py +++ b/shortfin/python/shortfin_apps/flux/simple_client.py @@ -19,12 +19,12 @@ sample_request = { "prompt": [ - " A mountain with a halo cloud over it, Death Mountain, spooky, Zelda", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", ], "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], "height": [1024], "width": [1024], - "steps": [50], + "steps": [2], "guidance_scale": [3.5], "seed": [0], "output_type": ["base64"],