-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathmzkolors.py
94 lines (74 loc) · 2.5 KB
/
mzkolors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import uuid
import torch
from bizyair import BizyAirBaseNode, BizyAirNodeIO, create_node_data
from bizyair.common.env_var import BIZYAIR_SERVER_ADDRESS
from bizyair.data_types import CONDITIONING
from bizyair.image_utils import encode_data
from .utils import (
decode_and_deserialize,
get_api_key,
send_post_request,
serialize_and_encode,
)
CATEGORY_NAME = "☁️BizyAir/Kolors"
class BizyAirMZChatGLM3TextEncode:
API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/mzkolorschatglm3"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = CATEGORY_NAME
def encode(self, text):
API_KEY = get_api_key()
assert len(text) <= 4096, f"the prompt is too long, length: {len(text)}"
payload = {
"text": text,
}
auth = f"Bearer {API_KEY}"
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": auth,
}
response: str = send_post_request(
self.API_URL, payload=payload, headers=headers
)
tensors_np = decode_and_deserialize(response)
ret_conditioning = []
for item in tensors_np:
t, d = item
t_tensor = torch.from_numpy(t)
d_dict = {}
for k, v in d.items():
d_dict[k] = torch.from_numpy(v)
ret_conditioning.append([t_tensor, d_dict])
return (ret_conditioning,)
class BizyAir_MinusZoneChatGLM3TextEncode(BizyAirMZChatGLM3TextEncode, BizyAirBaseNode):
RETURN_TYPES = (CONDITIONING,)
FUNCTION = "mz_encode"
def mz_encode(self, text):
out = self.encode(text)[0]
node_data = create_node_data(
class_type="ComfyAirLoadData",
inputs={"conditioning": {"relay": out}},
outputs={"slot_index": 3},
)
node_data["is_changed"] = uuid.uuid4().hex
return (
BizyAirNodeIO(
self.assigned_id,
nodes={self.assigned_id: encode_data(node_data, old_version=True)},
),
)
NODE_CLASS_MAPPINGS = {
"BizyAir_MinusZoneChatGLM3TextEncode": BizyAir_MinusZoneChatGLM3TextEncode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BizyAir_MinusZoneChatGLM3TextEncode": "☁️BizyAir MinusZone ChatGLM3 Text Encode",
}