Skip to content

Commit

Permalink
review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
likholat committed Jan 23, 2025
1 parent a7a60e8 commit 2de002f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
13 changes: 8 additions & 5 deletions src/cpp/src/image_generation/flux_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,21 @@ class FluxPipeline : public DiffusionPipeline {
m_custom_generation_config = m_generation_config;
m_custom_generation_config.update_generation_config(properties);

const size_t vae_scale_factor = m_vae->get_vae_scale_factor();
const auto& transformer_config = m_transformer->get_config();

if (m_custom_generation_config.height == -1)
m_custom_generation_config.height = transformer_config.m_default_sample_size * vae_scale_factor;
if (m_custom_generation_config.width == -1)
m_custom_generation_config.width = transformer_config.m_default_sample_size * vae_scale_factor;

// Use callback if defined
std::function<bool(size_t, size_t, ov::Tensor&)> callback = nullptr;
auto callback_iter = properties.find(ov::genai::callback.name());
if (callback_iter != properties.end()) {
callback = callback_iter->second.as<std::function<bool(size_t, size_t, ov::Tensor&)>>();
}

const size_t vae_scale_factor = m_vae->get_vae_scale_factor();
const auto& transformer_config = m_transformer->get_config();

if (m_custom_generation_config.height < 0)
compute_dim(m_custom_generation_config.height, initial_image, 1 /* assume NHWC */);
if (m_custom_generation_config.width < 0)
Expand Down Expand Up @@ -460,8 +465,6 @@ class FluxPipeline : public DiffusionPipeline {
m_generation_config.guidance_scale = 7.0f;
m_generation_config.num_inference_steps = 28;
m_generation_config.strength = 0.6f;
m_generation_config.height = 1024;
m_generation_config.width = 1024;
}
m_generation_config.max_sequence_length = 512;
} else {
Expand Down
1 change: 1 addition & 0 deletions src/python/py_image_generation_pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ void init_image_generation_pipelines(py::module_& m) {
.def_static("stable_diffusion", &ov::genai::Image2ImagePipeline::stable_diffusion, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("unet"), py::arg("vae"))
.def_static("latent_consistency_model", &ov::genai::Image2ImagePipeline::latent_consistency_model, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("unet"), py::arg("vae"))
.def_static("stable_diffusion_xl", &ov::genai::Image2ImagePipeline::stable_diffusion_xl, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("clip_text_model_with_projection"), py::arg("unet"), py::arg("vae"))
.def_static("flux", &ov::genai::Image2ImagePipeline::flux, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("t5_encoder_model"), py::arg("transformer"), py::arg("vae"))
.def(
"compile",
[](ov::genai::Image2ImagePipeline& pipe,
Expand Down
4 changes: 2 additions & 2 deletions tools/who_what_benchmark/tests/test_cli_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def test_image_model_types(model_id, model_type, backend):
])),
)
def test_image_model_genai(model_id, model_type):
if ("flux" in model_id or "stable-diffusion-3" in model_id) and model_type != "text-to-image":
pytest.skip(reason="FLUX or SD3 are supported as text to image only")
if ("stable-diffusion-3" in model_id) and model_type != "text-to-image":
pytest.skip(reason="SD3 is supported as text to image only")

with tempfile.TemporaryDirectory() as temp_dir:
GT_FILE = os.path.join(temp_dir, "gt.csv")
Expand Down

0 comments on commit 2de002f

Please sign in to comment.