Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
daanelson committed Feb 1, 2025
1 parent 00923f9 commit d625793
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 21 deletions.
2 changes: 1 addition & 1 deletion bfl_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def predict(
seed: int | None = None,
width: int = 1024,
height: int = 1024,
**kwargs,
**kwargs, # noqa: ARG002
) -> tuple[List[Image.Image], List[np.ndarray]]:
"""Run a single prediction on the model"""
print("running quantized prediction")
Expand Down
20 changes: 11 additions & 9 deletions diffusers_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
class FluxConfig:
url: str
path: str
download_path: str # this only exists b/c flux-dev needs a different donwload_path from its path on disk. TODO: fix.
download_path: str # this only exists b/c flux-dev needs a different donwload_path from its path on disk. TODO: fix.
num_steps: int
max_sequence_length: int


CONFIGS = {
"flux-schnell": FluxConfig(MODEL_URL_SCHNELL, FLUX_SCHNELL_PATH, FLUX_SCHNELL_PATH, 4, 256),
"flux-schnell": FluxConfig(
MODEL_URL_SCHNELL, FLUX_SCHNELL_PATH, FLUX_SCHNELL_PATH, 4, 256
),
"flux-dev": FluxConfig(MODEL_URL_DEV, FLUX_DEV_PATH, MODEL_CACHE, 28, 512),
}

Expand All @@ -61,18 +63,18 @@ class LoadedLoRAs:
@dataclass
class ModelHolster:
vae: AutoencoderKL
text_encoder: CLIPTextModel
text_encoder_2: T5EncoderModel
tokenizer: CLIPTokenizer
tokenizer_2: T5TokenizerFast
text_encoder: CLIPTextModel
text_encoder_2: T5EncoderModel
tokenizer: CLIPTokenizer
tokenizer_2: T5TokenizerFast


class DiffusersFlux:
def __init__(
self,
model_name: str,
weights_cache: WeightsDownloadCache,
shared_models: ModelHolster | None = None
shared_models: ModelHolster | None = None,
) -> None: # pyright: ignore
"""Load the model into memory to make running multiple predictions efficient"""
start = time.time()
Expand All @@ -89,9 +91,9 @@ def __init__(
# dependency injection hell yeah it's java time baybee
self.weights_cache = weights_cache

if not os.path.exists(model_path):
if not Path.exists(model_path):
print("Model path not found, downloading models")
# TODO: download everything separately; it will suck less.
# TODO: download everything separately; it will suck less.
download_base_weights(config.url, config.download_path)

print("Loading pipeline")
Expand Down
8 changes: 7 additions & 1 deletion lora_loading_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@

# patching inject_adapter_in_model and load_peft_state_dict with low_cpu_mem_usage=True until it's merged into diffusers
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Expand Down
14 changes: 10 additions & 4 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,10 +724,14 @@ def setup(self) -> None:

shared_models_for_fp8 = LoadedModels(
ae=bfl_ae,
clip=PreLoadedHFEmbedder(True, 77, shared_models.tokenizer, shared_models.text_encoder),
t5=PreLoadedHFEmbedder(False, 512, shared_models.tokenizer_2, shared_models.text_encoder_2),
clip=PreLoadedHFEmbedder(
True, 77, shared_models.tokenizer, shared_models.text_encoder
),
t5=PreLoadedHFEmbedder(
False, 512, shared_models.tokenizer_2, shared_models.text_encoder_2
),
flow=None,
config=None
config=None,
)
self.fp8_dev = BflFp8Flux(
FLUX_DEV,
Expand All @@ -739,7 +743,9 @@ def setup(self) -> None:
)

self.bf16_schnell = DiffusersFlux(FLUX_SCHNELL, shared_cache, shared_models)
shared_models_for_fp8.t5=PreLoadedHFEmbedder(False, 256, shared_models.tokenizer_2, shared_models.text_encoder_2)
shared_models_for_fp8.t5 = PreLoadedHFEmbedder(
False, 256, shared_models.tokenizer_2, shared_models.text_encoder_2
)

self.fp8_schnell = BflFp8Flux(
FLUX_SCHNELL_FP8,
Expand Down
15 changes: 9 additions & 6 deletions save_fp8_quantized.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import argparse
import os
from pathlib import Path
from predict import DevPredictor, SchnellPredictor
from safetensors.torch import save_file

"""
Note - for this code to work, you'll need to tweak the config of the fp8 flux models in `predict.py` s.t. they load and quantize models.
Note - for this code to work, you'll need to tweak the config of the fp8 flux models in `predict.py` s.t. they load and quantize models.
in practice, this just means eliminating the '-fp8' suffix on the model names.
"""


def generate_dev_img(p, img_name="cool_dog_1234.png"):
p.predict("a cool dog", "1:1", None, 0, 1, 28, 3, 1234, "png", 100, True, True, "1")
os.system(f"mv out-0.png {img_name}")


def save_dev_fp8():
p = DevPredictor()
p.setup()

fp8_weights_path = "model-cache/dev-fp8"
if not os.path.exists(fp8_weights_path):
os.makedirs(fp8_weights_path)
if not Path.exists(fp8_weights_path):
Path.mkdir(fp8_weights_path, parents=True)

generate_dev_img(p)
print(
Expand Down Expand Up @@ -47,8 +50,8 @@ def save_schnell_fp8():
p.setup()

fp8_weights_path = "model-cache/schnell-fp8"
if not os.path.exists(fp8_weights_path):
os.makedirs(fp8_weights_path)
if not Path.exists(fp8_weights_path):
Path.mkdir(fp8_weights_path, parents=True)

generate_schnell_img(p)
print(
Expand Down Expand Up @@ -80,4 +83,4 @@ def test_schnell_fp8():
else:
print("testing I guess")
# test_dev_fp8()
test_schnell_fp8()
test_schnell_fp8()

0 comments on commit d625793

Please sign in to comment.