Skip to content

Commit

Permalink
some fixes (#30)
Browse files Browse the repository at this point in the history
* some fixes, fill readme

* edit readme, up version
  • Loading branch information
shonenkov authored Nov 5, 2021
1 parent d3e3c69 commit 9589235
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 16 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)

```
pip install rudalle==0.0.1rc4
pip install rudalle==0.0.1rc5
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)
Expand All @@ -18,13 +18,12 @@ pip install rudalle==0.0.1rc4
[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/shonenkov/rudalle-example-generation)
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/anton-l/rudall-e)

**English translation example**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12fbO6YqtzHAHemY2roWQnXvKkdidNQKO?usp=sharing)


**Finetuning example**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Tb7J4PvvegWOybPfUubl5O7m5I24CBg5?usp=sharing)

**English translation example**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12fbO6YqtzHAHemY2roWQnXvKkdidNQKO?usp=sharing)

### generation by ruDALLE:
```python
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
Expand Down Expand Up @@ -95,4 +94,5 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]
- [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference
- [@Igor Pavlov](https://github.com/boomb0om) trained model and prepared code with [super-resolution](https://github.com/boomb0om/Real-ESRGAN-colab)
- [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt
- [@Alex Wortega](https://github.com/AlexWortega) created first FREE version colab notebook with fine-tuning [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) on sneakers domain 💪
- [@Anton Lozhkov](https://github.com/anton-l) Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio), see [here](https://huggingface.co/spaces/anton-l/rudall-e)
2 changes: 1 addition & 1 deletion rudalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
'image_prompts',
]

__version__ = '0.0.1-rc4'
__version__ = '0.0.1-rc5'
7 changes: 5 additions & 2 deletions rudalle/image_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,25 @@ def __init__(self, pil_image, borders, vae, device='cpu', crop_first=False):
self.device = device
img = self._preprocess_img(pil_image)
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first)
self.allow_cache = True

def _preprocess_img(self, pil_img):
img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255.
img = img.unsqueeze(0).to(self.device, dtype=torch.float32)
img = (2 * img) - 1
return img

@staticmethod
def _get_image_prompts(img, borders, vae, crop_first):
def _get_image_prompts(self, img, borders, vae, crop_first):
if crop_first:
assert borders['right'] + borders['left'] + borders['down'] == 0
up_border = borders['up'] * 8
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :])
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)

if borders['right'] + borders['left'] + borders['down'] != 0:
self.allow_cache = False # TODO fix cache in attention

bs, vqg_img_w, vqg_img_h = vqg_img.shape
mask = torch.zeros(vqg_img_w, vqg_img_h)
if borders['up'] != 0:
Expand Down
36 changes: 28 additions & 8 deletions rudalle/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# -*- coding: utf-8 -*-
import os
from glob import glob
from os.path import join

import torch
import torchvision
import transformers
Expand Down Expand Up @@ -34,10 +38,10 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image
sample_scores = []
if image_prompts is not None:
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
prompts = prompts.repeat(images_num, 1)
if use_cache:
use_cache = False
prompts = prompts.repeat(chunk_bs, 1)
if use_cache and image_prompts.allow_cache is False:
print('Warning: use_cache changed to False')
use_cache = False
for idx in tqdm(range(out.shape[1], total_seq_length)):
idx -= text_seq_length
if image_prompts is not None and idx in prompts_idx:
Expand Down Expand Up @@ -84,15 +88,31 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu'
return top_pil_images, top_scores


def show(pil_images, nrow=4):
def show(pil_images, nrow=4, save_dir=None, show=True):
"""
:param pil_images: list of images in PIL
:param nrow: number of rows
:param save_dir: dir for separately saving of images, example: save_dir='./pics'
"""
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
count = len(glob(join(save_dir, 'img_*.png')))
for i, pil_image in enumerate(pil_images):
pil_image.save(join(save_dir, f'img_{count+i}.png'))

imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):
imgs = [imgs.cpu()]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14))
for i, img in enumerate(imgs):
img = img.detach()
img = torchvision.transforms.functional.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
fix.show()
plt.show()
if save_dir is not None:
count = len(glob(join(save_dir, 'group_*.png')))
img.save(join(save_dir, f'group_{count+i}.png'))
if show:
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if show:
fix.show()
plt.show()
9 changes: 9 additions & 0 deletions tests/test_show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from rudalle.pipelines import show


def test_show(sample_image):
img = sample_image.copy()
img = img.resize((256, 256))
pil_images = [img]*5
show(pil_images, nrow=2, save_dir='/tmp/pics', show=False)

0 comments on commit 9589235

Please sign in to comment.