forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstable_diffusion_handler.py
95 lines (81 loc) · 3.31 KB
/
stable_diffusion_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import logging
import zipfile
from abc import ABC
import diffusers
import numpy as np
import torch
from diffusers import DiffusionPipeline
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
logger.info("Diffusers version %s", diffusers.__version__)
class DiffusersHandler(BaseHandler, ABC):
"""
Diffusers handler class for text to image generation.
"""
def __init__(self):
self.initialized = False
def initialize(self, ctx):
"""In this initialize function, the Stable Diffusion model is loaded and
initialized here.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
# Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
# further setup config can be added.
with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref:
zip_ref.extractall(model_dir + "/model")
self.pipe = DiffusionPipeline.from_pretrained(model_dir + "/model")
self.pipe.to(self.device)
logger.info("Diffusion model from path %s loaded successfully", model_dir)
self.initialized = True
def preprocess(self, requests):
"""Basic text preprocessing, of the user's prompt.
Args:
requests (str): The Input data in the form of text is passed on to the preprocess
function.
Returns:
list : The preprocess function returns a list of prompts.
"""
inputs = []
for _, data in enumerate(requests):
input_text = data.get("data")
if input_text is None:
input_text = data.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
logger.info("Received text: '%s'", input_text)
inputs.append(input_text)
return inputs
def inference(self, inputs):
"""Generates the image relevant to the received text.
Args:
input_batch (list): List of Text from the pre-process function is passed here
Returns:
list : It returns a list of the generate images for the input text
"""
# Handling inference for sequence_classification.
inferences = self.pipe(
inputs, guidance_scale=7.5, num_inference_steps=50, height=768, width=768
).images
logger.info("Generated image: '%s'", inferences)
return inferences
def postprocess(self, inference_output):
"""Post Process Function converts the generated image into Torchserve readable format.
Args:
inference_output (list): It contains the generated image of the input text.
Returns:
(list): Returns a list of the images.
"""
images = []
for image in inference_output:
images.append(np.array(image).tolist())
return images