diff --git a/README.md b/README.md index 253394b..3996ca3 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master) ``` -pip install rudalle==0.0.1rc4 +pip install rudalle==0.0.1rc5 ``` ### 🤗 HF Models: [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) @@ -18,13 +18,12 @@ pip install rudalle==0.0.1rc4 [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/shonenkov/rudalle-example-generation) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/anton-l/rudall-e) -**English translation example** -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12fbO6YqtzHAHemY2roWQnXvKkdidNQKO?usp=sharing) - - **Finetuning example** [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Tb7J4PvvegWOybPfUubl5O7m5I24CBg5?usp=sharing) +**English translation example** +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12fbO6YqtzHAHemY2roWQnXvKkdidNQKO?usp=sharing) + ### generation by ruDALLE: ```python from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip @@ -95,4 +94,5 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky] - [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference - [@Igor Pavlov](https://github.com/boomb0om) trained model and prepared code with [super-resolution](https://github.com/boomb0om/Real-ESRGAN-colab) - [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt +- [@Alex Wortega](https://github.com/AlexWortega) created first FREE version colab notebook with fine-tuning [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) on sneakers domain 💪 - [@Anton Lozhkov](https://github.com/anton-l) Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio), see [here](https://huggingface.co/spaces/anton-l/rudall-e) diff --git a/rudalle/__init__.py b/rudalle/__init__.py index d86b14e..6b4e0a6 100644 --- a/rudalle/__init__.py +++ b/rudalle/__init__.py @@ -22,4 +22,4 @@ 'image_prompts', ] -__version__ = '0.0.1-rc4' +__version__ = '0.0.1-rc5' diff --git a/rudalle/image_prompts.py b/rudalle/image_prompts.py index 2134435..3b70be5 100644 --- a/rudalle/image_prompts.py +++ b/rudalle/image_prompts.py @@ -18,6 +18,7 @@ def __init__(self, pil_image, borders, vae, device='cpu', crop_first=False): self.device = device img = self._preprocess_img(pil_image) self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first) + self.allow_cache = True def _preprocess_img(self, pil_img): img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255. @@ -25,8 +26,7 @@ def _preprocess_img(self, pil_img): img = (2 * img) - 1 return img - @staticmethod - def _get_image_prompts(img, borders, vae, crop_first): + def _get_image_prompts(self, img, borders, vae, crop_first): if crop_first: assert borders['right'] + borders['left'] + borders['down'] == 0 up_border = borders['up'] * 8 @@ -34,6 +34,9 @@ def _get_image_prompts(img, borders, vae, crop_first): else: _, _, [_, _, vqg_img] = vae.model.encode(img) + if borders['right'] + borders['left'] + borders['down'] != 0: + self.allow_cache = False # TODO fix cache in attention + bs, vqg_img_w, vqg_img_h = vqg_img.shape mask = torch.zeros(vqg_img_w, vqg_img_h) if borders['up'] != 0: diff --git a/rudalle/pipelines.py b/rudalle/pipelines.py index 4bde77e..a023e95 100644 --- a/rudalle/pipelines.py +++ b/rudalle/pipelines.py @@ -1,4 +1,8 @@ # -*- coding: utf-8 -*- +import os +from glob import glob +from os.path import join + import torch import torchvision import transformers @@ -34,10 +38,10 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image sample_scores = [] if image_prompts is not None: prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts - prompts = prompts.repeat(images_num, 1) - if use_cache: - use_cache = False + prompts = prompts.repeat(chunk_bs, 1) + if use_cache and image_prompts.allow_cache is False: print('Warning: use_cache changed to False') + use_cache = False for idx in tqdm(range(out.shape[1], total_seq_length)): idx -= text_seq_length if image_prompts is not None and idx in prompts_idx: @@ -84,7 +88,18 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu' return top_pil_images, top_scores -def show(pil_images, nrow=4): +def show(pil_images, nrow=4, save_dir=None, show=True): + """ + :param pil_images: list of images in PIL + :param nrow: number of rows + :param save_dir: dir for separately saving of images, example: save_dir='./pics' + """ + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + count = len(glob(join(save_dir, 'img_*.png'))) + for i, pil_image in enumerate(pil_images): + pil_image.save(join(save_dir, f'img_{count+i}.png')) + imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) if not isinstance(imgs, list): imgs = [imgs.cpu()] @@ -92,7 +107,12 @@ def show(pil_images, nrow=4): for i, img in enumerate(imgs): img = img.detach() img = torchvision.transforms.functional.to_pil_image(img) - axs[0, i].imshow(np.asarray(img)) - axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) - fix.show() - plt.show() + if save_dir is not None: + count = len(glob(join(save_dir, 'group_*.png'))) + img.save(join(save_dir, f'group_{count+i}.png')) + if show: + axs[0, i].imshow(np.asarray(img)) + axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + if show: + fix.show() + plt.show() diff --git a/tests/test_show.py b/tests/test_show.py new file mode 100644 index 0000000..72a7376 --- /dev/null +++ b/tests/test_show.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +from rudalle.pipelines import show + + +def test_show(sample_image): + img = sample_image.copy() + img = img.resize((256, 256)) + pil_images = [img]*5 + show(pil_images, nrow=2, save_dir='/tmp/pics', show=False)