Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
Add input image field to UI (not used yet) and download and convert dreamlike_anime 1.0 model with some logic added for using two different model (text2image, and image2image) setting up for adding image2image pipeline later on.
  • Loading branch information
whitneyfoster committed Jan 16, 2025
1 parent ebf9101 commit dccb159
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions demos/paint_your_dreams_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
ov_pipelines = {}

stop_generating: bool = True
hf_model_name: Optional[str] = None
hf_model_name_t2i: Optional[str] = None


def get_available_devices() -> list[str]:
Expand All @@ -39,17 +39,21 @@ def get_available_devices() -> list[str]:
return list({device.split(".")[0] for device in core.available_devices if device != "NPU"})


def download_models(model_name: str, safety_checker_model: str) -> None:
def download_models(model_name_t2i: str, model_name_i2i: str, safety_checker_model: str) -> None:
global safety_checker

is_openvino_model = model_name.split("/")[0] == "OpenVINO"
is_openvino_model = model_name_t2i.split("/")[0] == "OpenVINO"

output_dir = MODEL_DIR / model_name
if not output_dir.exists():
output_dir_t2i = MODEL_DIR / model_name_t2i
if not output_dir_t2i.exists():
if is_openvino_model:
snapshot_download(model_name, local_dir=output_dir)
snapshot_download(model_name_t2i, local_dir=output_dir_t2i, resume_download=True)
else:
raise ValueError(f"Model {model_name} is not from OpenVINO Hub and not supported")
raise ValueError(f"Model {model_name_t2i} is not from OpenVINO Hub and not supported")

output_dir_t2i = MODEL_DIR / "dreamlike_anime_1_0_ov" / "FP16"
if not output_dir_t2i.exists():
os.system("optimum-cli export openvino --model dreamlike-art/dreamlike-anime-1.0 --task stable-diffusion --weight-format fp16 " + str(output_dir_t2i))

safety_checker_dir = MODEL_DIR / safety_checker_model
if not safety_checker_dir.exists():
Expand Down Expand Up @@ -82,7 +86,7 @@ async def generate_images(prompt: str, seed: int, size: int, guidance_scale: flo
global stop_generating
stop_generating = not endless_generation

ov_pipeline = await load_pipeline(hf_model_name, device)
ov_pipeline = await load_pipeline(hf_model_name_t2i, device)

while True:
if randomize_seed:
Expand Down Expand Up @@ -125,6 +129,7 @@ def build_ui():
with gr.Blocks() as demo:
with gr.Group():
with gr.Row():
inputs=gr.Image(label="Input Image, leave blank for Text2Text Generation")
prompt_text = gr.Text(
label="Prompt",
placeholder="Enter your prompt here",
Expand Down Expand Up @@ -195,12 +200,14 @@ def build_ui():
return demo


def run_endless_lcm(model_name: str, safety_checker_model: str, local_network: bool = False, public_interface: bool = False):
global hf_model_name
hf_model_name = model_name
def run_endless_lcm(model_name_t2i: str, model_name_i2i: str, safety_checker_model: str, local_network: bool = False, public_interface: bool = False):
global hf_model_name_t2i
global hf_model_name_i2i
hf_model_name_t2i = model_name_t2i
hf_model_name_i2i = model_name_i2i
server_name = "0.0.0.0" if local_network else None

download_models(model_name, safety_checker_model)
download_models(model_name_t2i, model_name_i2i, safety_checker_model)

demo = build_ui()
log.info("Demo is ready!")
Expand All @@ -212,13 +219,16 @@ def run_endless_lcm(model_name: str, safety_checker_model: str, local_network: b
log.getLogger().setLevel(log.INFO)

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="OpenVINO/LCM_Dreamshaper_v7-fp16-ov",
parser.add_argument("--model_name_t2i", type=str, default="OpenVINO/LCM_Dreamshaper_v7-fp16-ov",
choices=["OpenVINO/LCM_Dreamshaper_v7-int8-ov", "OpenVINO/LCM_Dreamshaper_v7-fp16-ov"],
help="Visual GenAI model to be used")
help="Text2Image GenAI model to be used")
parser.add_argument("--model_name_i2i", type=str, default="dreamlike-art/dreamlike-anime-1.0",
choices=["dreamlike-art/dreamlike-anime-1.0"],
help="Image2Image GenAI model to be used")
parser.add_argument("--safety_checker_model", type=str, default="Falconsai/nsfw_image_detection",
choices=["Falconsai/nsfw_image_detection"], help="The model to verify if the generated image is NSFW")
parser.add_argument("--local_network", action="store_true", help="Whether demo should be available in local network")
parser.add_argument("--public", default=False, action="store_true", help="Whether interface should be available publicly")

args = parser.parse_args()
run_endless_lcm(args.model_name, args.safety_checker_model, args.local_network, args.public)
run_endless_lcm(args.model_name_t2i, args.model_name_i2i, args.safety_checker_model, args.local_network, args.public)

0 comments on commit dccb159

Please sign in to comment.