From dccb1595e25c1bb5802eaaa4a69bddc48b021c34 Mon Sep 17 00:00:00 2001 From: whitneyfoster Date: Thu, 16 Jan 2025 15:25:21 -0800 Subject: [PATCH] Update main.py Add input image field to UI (not used yet) and download and convert dreamlike_anime 1.0 model with some logic added for using two different model (text2image, and image2image) setting up for adding image2image pipeline later on. --- demos/paint_your_dreams_demo/main.py | 40 +++++++++++++++++----------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/demos/paint_your_dreams_demo/main.py b/demos/paint_your_dreams_demo/main.py index 50f6aeab..ea96e202 100644 --- a/demos/paint_your_dreams_demo/main.py +++ b/demos/paint_your_dreams_demo/main.py @@ -30,7 +30,7 @@ ov_pipelines = {} stop_generating: bool = True -hf_model_name: Optional[str] = None +hf_model_name_t2i: Optional[str] = None def get_available_devices() -> list[str]: @@ -39,17 +39,21 @@ def get_available_devices() -> list[str]: return list({device.split(".")[0] for device in core.available_devices if device != "NPU"}) -def download_models(model_name: str, safety_checker_model: str) -> None: +def download_models(model_name_t2i: str, model_name_i2i: str, safety_checker_model: str) -> None: global safety_checker - is_openvino_model = model_name.split("/")[0] == "OpenVINO" + is_openvino_model = model_name_t2i.split("/")[0] == "OpenVINO" - output_dir = MODEL_DIR / model_name - if not output_dir.exists(): + output_dir_t2i = MODEL_DIR / model_name_t2i + if not output_dir_t2i.exists(): if is_openvino_model: - snapshot_download(model_name, local_dir=output_dir) + snapshot_download(model_name_t2i, local_dir=output_dir_t2i, resume_download=True) else: - raise ValueError(f"Model {model_name} is not from OpenVINO Hub and not supported") + raise ValueError(f"Model {model_name_t2i} is not from OpenVINO Hub and not supported") + + output_dir_t2i = MODEL_DIR / "dreamlike_anime_1_0_ov" / "FP16" + if not output_dir_t2i.exists(): + os.system("optimum-cli export openvino --model dreamlike-art/dreamlike-anime-1.0 --task stable-diffusion --weight-format fp16 " + str(output_dir_t2i)) safety_checker_dir = MODEL_DIR / safety_checker_model if not safety_checker_dir.exists(): @@ -82,7 +86,7 @@ async def generate_images(prompt: str, seed: int, size: int, guidance_scale: flo global stop_generating stop_generating = not endless_generation - ov_pipeline = await load_pipeline(hf_model_name, device) + ov_pipeline = await load_pipeline(hf_model_name_t2i, device) while True: if randomize_seed: @@ -125,6 +129,7 @@ def build_ui(): with gr.Blocks() as demo: with gr.Group(): with gr.Row(): + inputs=gr.Image(label="Input Image, leave blank for Text2Text Generation") prompt_text = gr.Text( label="Prompt", placeholder="Enter your prompt here", @@ -195,12 +200,14 @@ def build_ui(): return demo -def run_endless_lcm(model_name: str, safety_checker_model: str, local_network: bool = False, public_interface: bool = False): - global hf_model_name - hf_model_name = model_name +def run_endless_lcm(model_name_t2i: str, model_name_i2i: str, safety_checker_model: str, local_network: bool = False, public_interface: bool = False): + global hf_model_name_t2i + global hf_model_name_i2i + hf_model_name_t2i = model_name_t2i + hf_model_name_i2i = model_name_i2i server_name = "0.0.0.0" if local_network else None - download_models(model_name, safety_checker_model) + download_models(model_name_t2i, model_name_i2i, safety_checker_model) demo = build_ui() log.info("Demo is ready!") @@ -212,13 +219,16 @@ def run_endless_lcm(model_name: str, safety_checker_model: str, local_network: b log.getLogger().setLevel(log.INFO) parser = argparse.ArgumentParser() - parser.add_argument("--model_name", type=str, default="OpenVINO/LCM_Dreamshaper_v7-fp16-ov", + parser.add_argument("--model_name_t2i", type=str, default="OpenVINO/LCM_Dreamshaper_v7-fp16-ov", choices=["OpenVINO/LCM_Dreamshaper_v7-int8-ov", "OpenVINO/LCM_Dreamshaper_v7-fp16-ov"], - help="Visual GenAI model to be used") + help="Text2Image GenAI model to be used") + parser.add_argument("--model_name_i2i", type=str, default="dreamlike-art/dreamlike-anime-1.0", + choices=["dreamlike-art/dreamlike-anime-1.0"], + help="Image2Image GenAI model to be used") parser.add_argument("--safety_checker_model", type=str, default="Falconsai/nsfw_image_detection", choices=["Falconsai/nsfw_image_detection"], help="The model to verify if the generated image is NSFW") parser.add_argument("--local_network", action="store_true", help="Whether demo should be available in local network") parser.add_argument("--public", default=False, action="store_true", help="Whether interface should be available publicly") args = parser.parse_args() - run_endless_lcm(args.model_name, args.safety_checker_model, args.local_network, args.public) + run_endless_lcm(args.model_name_t2i, args.model_name_i2i, args.safety_checker_model, args.local_network, args.public)