Skip to content

Commit

Permalink
lint again
Browse files Browse the repository at this point in the history
  • Loading branch information
daanelson committed Feb 4, 2025
1 parent 29270f6 commit 1514583
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 23 deletions.
2 changes: 1 addition & 1 deletion bfl_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def predict(
seed: int | None = None,
width: int = 1024,
height: int = 1024,
**kwargs,
**kwargs, # noqa: ARG002
) -> tuple[List[Image.Image], List[np.ndarray]]:
"""Run a single prediction on the model"""
print("running quantized prediction")
Expand Down
20 changes: 11 additions & 9 deletions diffusers_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
class FluxConfig:
url: str
path: str
download_path: str # this only exists b/c flux-dev needs a different donwload_path from its path on disk. TODO: fix.
download_path: str # this only exists b/c flux-dev needs a different donwload_path from its path on disk. TODO: fix.
num_steps: int
max_sequence_length: int


CONFIGS = {
"flux-schnell": FluxConfig(MODEL_URL_SCHNELL, FLUX_SCHNELL_PATH, FLUX_SCHNELL_PATH, 4, 256),
"flux-schnell": FluxConfig(
MODEL_URL_SCHNELL, FLUX_SCHNELL_PATH, FLUX_SCHNELL_PATH, 4, 256
),
"flux-dev": FluxConfig(MODEL_URL_DEV, FLUX_DEV_PATH, MODEL_CACHE, 28, 512),
}

Expand All @@ -61,18 +63,18 @@ class LoadedLoRAs:
@dataclass
class ModelHolster:
vae: AutoencoderKL
text_encoder: CLIPTextModel
text_encoder_2: T5EncoderModel
tokenizer: CLIPTokenizer
tokenizer_2: T5TokenizerFast
text_encoder: CLIPTextModel
text_encoder_2: T5EncoderModel
tokenizer: CLIPTokenizer
tokenizer_2: T5TokenizerFast


class DiffusersFlux:
def __init__(
self,
model_name: str,
weights_cache: WeightsDownloadCache,
shared_models: ModelHolster | None = None
shared_models: ModelHolster | None = None,
) -> None: # pyright: ignore
"""Load the model into memory to make running multiple predictions efficient"""
start = time.time()
Expand All @@ -89,9 +91,9 @@ def __init__(
# dependency injection hell yeah it's java time baybee
self.weights_cache = weights_cache

if not os.path.exists(model_path):
if not os.path.exists(model_path): # noqa: PTH110
print("Model path not found, downloading models")
# TODO: download everything separately; it will suck less.
# TODO: download everything separately; it will suck less.
download_base_weights(config.url, config.download_path)

print("Loading pipeline")
Expand Down
8 changes: 7 additions & 1 deletion lora_loading_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@

# patching inject_adapter_in_model and load_peft_state_dict with low_cpu_mem_usage=True until it's merged into diffusers
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
cls,
state_dict,
network_alphas,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Expand Down
8 changes: 2 additions & 6 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
BflFp8Flux,
BflReduxPredictor,
)
from diffusers_predictor import DiffusersFlux
from flux.modules.conditioner import PreLoadedHFEmbedder
from fp8.util import LoadedModels

torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
Expand All @@ -31,7 +28,6 @@
from cog import BasePredictor, Input, Path # type: ignore
from flux.util import (
download_weights,
load_ae,
)

from diffusers.pipelines.stable_diffusion.safety_checker import (
Expand Down Expand Up @@ -715,7 +711,7 @@ class HotswapPredictor(Predictor):
def setup(self) -> None:
self.base_setup()
shared_cache = WeightsDownloadCache()
self.bf16_dev= BflBf16Predictor(
self.bf16_dev = BflBf16Predictor(
FLUX_DEV,
offload=self.should_offload(),
weights_download_cache=shared_cache,
Expand All @@ -735,7 +731,7 @@ def setup(self) -> None:
loaded_models=self.bf16_dev.get_shared_models(),
offload=self.should_offload(),
weights_download_cache=shared_cache,
restore_lora_from_cloned_weights=True
restore_lora_from_cloned_weights=True,
)
self.fp8_schnell = BflFp8Flux(
FLUX_SCHNELL_FP8,
Expand Down
14 changes: 8 additions & 6 deletions save_fp8_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@
from safetensors.torch import save_file

"""
Note - for this code to work, you'll need to tweak the config of the fp8 flux models in `predict.py` s.t. they load and quantize models.
Note - for this code to work, you'll need to tweak the config of the fp8 flux models in `predict.py` s.t. they load and quantize models.
in practice, this just means eliminating the '-fp8' suffix on the model names.
"""


def generate_dev_img(p, img_name="cool_dog_1234.png"):
p.predict("a cool dog", "1:1", None, 0, 1, 28, 3, 1234, "png", 100, True, True, "1")
os.system(f"mv out-0.png {img_name}")


def save_dev_fp8():
p = DevPredictor()
p.setup()

fp8_weights_path = "model-cache/dev-fp8"
if not os.path.exists(fp8_weights_path):
os.makedirs(fp8_weights_path)
if not os.path.exists(fp8_weights_path): # noqa: PTH110
os.makedirs(fp8_weights_path) # noqa: PTH103

generate_dev_img(p)
print(
Expand Down Expand Up @@ -47,8 +49,8 @@ def save_schnell_fp8():
p.setup()

fp8_weights_path = "model-cache/schnell-fp8"
if not os.path.exists(fp8_weights_path):
os.makedirs(fp8_weights_path)
if not os.path.exists(fp8_weights_path): # noqa: PTH110
os.makedirs(fp8_weights_path) # noqa: PTH103

generate_schnell_img(p)
print(
Expand Down Expand Up @@ -80,4 +82,4 @@ def test_schnell_fp8():
else:
print("testing I guess")
# test_dev_fp8()
test_schnell_fp8()
test_schnell_fp8()

0 comments on commit 1514583

Please sign in to comment.