diff --git a/src/cpp/src/image_generation/flux_pipeline.hpp b/src/cpp/src/image_generation/flux_pipeline.hpp index 547efa4415..bcff7ba141 100644 --- a/src/cpp/src/image_generation/flux_pipeline.hpp +++ b/src/cpp/src/image_generation/flux_pipeline.hpp @@ -351,6 +351,14 @@ class FluxPipeline : public DiffusionPipeline { m_custom_generation_config = m_generation_config; m_custom_generation_config.update_generation_config(properties); + const size_t vae_scale_factor = m_vae->get_vae_scale_factor(); + const auto& transformer_config = m_transformer->get_config(); + + if (m_custom_generation_config.height == -1) + m_custom_generation_config.height = transformer_config.m_default_sample_size * vae_scale_factor; + if (m_custom_generation_config.width == -1) + m_custom_generation_config.width = transformer_config.m_default_sample_size * vae_scale_factor; + // Use callback if defined std::function callback = nullptr; auto callback_iter = properties.find(ov::genai::callback.name()); @@ -358,9 +366,6 @@ class FluxPipeline : public DiffusionPipeline { callback = callback_iter->second.as>(); } - const size_t vae_scale_factor = m_vae->get_vae_scale_factor(); - const auto& transformer_config = m_transformer->get_config(); - if (m_custom_generation_config.height < 0) compute_dim(m_custom_generation_config.height, initial_image, 1 /* assume NHWC */); if (m_custom_generation_config.width < 0) @@ -460,8 +465,6 @@ class FluxPipeline : public DiffusionPipeline { m_generation_config.guidance_scale = 7.0f; m_generation_config.num_inference_steps = 28; m_generation_config.strength = 0.6f; - m_generation_config.height = 1024; - m_generation_config.width = 1024; } m_generation_config.max_sequence_length = 512; } else { diff --git a/src/python/py_image_generation_pipelines.cpp b/src/python/py_image_generation_pipelines.cpp index b011aee878..dcc50234ed 100644 --- a/src/python/py_image_generation_pipelines.cpp +++ b/src/python/py_image_generation_pipelines.cpp @@ -330,6 +330,7 @@ void init_image_generation_pipelines(py::module_& m) { .def_static("stable_diffusion", &ov::genai::Image2ImagePipeline::stable_diffusion, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("unet"), py::arg("vae")) .def_static("latent_consistency_model", &ov::genai::Image2ImagePipeline::latent_consistency_model, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("unet"), py::arg("vae")) .def_static("stable_diffusion_xl", &ov::genai::Image2ImagePipeline::stable_diffusion_xl, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("clip_text_model_with_projection"), py::arg("unet"), py::arg("vae")) + .def_static("flux", &ov::genai::Image2ImagePipeline::flux, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("t5_encoder_model"), py::arg("transformer"), py::arg("vae")) .def( "compile", [](ov::genai::Image2ImagePipeline& pipe, diff --git a/tools/who_what_benchmark/tests/test_cli_image.py b/tools/who_what_benchmark/tests/test_cli_image.py index 1ad8236058..ec36de8efa 100644 --- a/tools/who_what_benchmark/tests/test_cli_image.py +++ b/tools/who_what_benchmark/tests/test_cli_image.py @@ -103,8 +103,8 @@ def test_image_model_types(model_id, model_type, backend): ])), ) def test_image_model_genai(model_id, model_type): - if ("flux" in model_id or "stable-diffusion-3" in model_id) and model_type != "text-to-image": - pytest.skip(reason="FLUX or SD3 are supported as text to image only") + if ("stable-diffusion-3" in model_id) and model_type != "text-to-image": + pytest.skip(reason="SD3 is supported as text to image only") with tempfile.TemporaryDirectory() as temp_dir: GT_FILE = os.path.join(temp_dir, "gt.csv")