Skip to content

Commit

Permalink
Add: IPAdapter. Feature: Optimize Lora training. Fix: incorrect black…
Browse files Browse the repository at this point in the history
… image generation. Add: Official ComfyUI support for CN and Lora
  • Loading branch information
zml-ai committed Dec 17, 2024
1 parent 9b09a9b commit 368280e
Show file tree
Hide file tree
Showing 246 changed files with 18,880 additions and 89,640 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,4 @@ trt/activate.sh
trt/deactivate.sh
*.onnx
ckpts/
app/output/
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "comfyui-hydit"]
path = comfyui-hydit
url = https://github.com/zml-ai/comfyui-hydit
6 changes: 3 additions & 3 deletions IndexKits/index_kits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .bucket import (
MultiIndexV2,
MultiResolutionBucketIndexV2, MultiMultiResolutionBucketIndexV2,
build_multi_resolution_bucket
MultiResolutionBucketIndexV2,
MultiMultiResolutionBucketIndexV2,
build_multi_resolution_bucket,
)
from .bucket import Resolution, ResolutionGroup
from .indexer import IndexV2Builder, ArrowIndexV2
from .common import load_index, show_index_info

__version__ = "0.3.5"

400 changes: 249 additions & 151 deletions IndexKits/index_kits/bucket.py

Large diffs are not rendered by default.

206 changes: 128 additions & 78 deletions IndexKits/index_kits/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,67 @@
from .indexer import ArrowIndexV2
from .bucket import (
ResolutionGroup,
MultiIndexV2, MultiResolutionBucketIndexV2, MultiMultiResolutionBucketIndexV2, IndexV2Builder
MultiIndexV2,
MultiResolutionBucketIndexV2,
MultiMultiResolutionBucketIndexV2,
IndexV2Builder,
)


def load_index(src,
multireso=False,
batch_size=1,
world_size=1,
sample_strategy='uniform',
probability=None,
shadow_file_fn=None,
seed=None,
):
def load_index(
src,
multireso=False,
batch_size=1,
world_size=1,
sample_strategy="uniform",
probability=None,
shadow_file_fn=None,
seed=None,
):
if isinstance(src, str):
src = [src]
if src[0].endswith('.arrow'):
if src[0].endswith(".arrow"):
if multireso:
raise ValueError('Arrow file does not support multiresolution. Please make base index V2 first and then'
'build multiresolution index.')
raise ValueError(
"Arrow file does not support multiresolution. Please make base index V2 first and then"
"build multiresolution index."
)
idx = IndexV2Builder(src).to_index_v2()
elif src[0].endswith('.json'):
elif src[0].endswith(".json"):
if multireso:
if len(src) == 1:
idx = MultiResolutionBucketIndexV2(src[0], batch_size=batch_size,
world_size=world_size,
shadow_file_fn=shadow_file_fn,
)
idx = MultiResolutionBucketIndexV2(
src[0],
batch_size=batch_size,
world_size=world_size,
shadow_file_fn=shadow_file_fn,
)
else:
idx = MultiMultiResolutionBucketIndexV2(src, batch_size=batch_size,
world_size=world_size,
sample_strategy=sample_strategy, probability=probability,
shadow_file_fn=shadow_file_fn, seed=seed,
)
idx = MultiMultiResolutionBucketIndexV2(
src,
batch_size=batch_size,
world_size=world_size,
sample_strategy=sample_strategy,
probability=probability,
shadow_file_fn=shadow_file_fn,
seed=seed,
)
else:
if len(src) == 1:
idx = ArrowIndexV2(src[0],
shadow_file_fn=shadow_file_fn,
)
idx = ArrowIndexV2(
src[0],
shadow_file_fn=shadow_file_fn,
)
else:
idx = MultiIndexV2(src,
sample_strategy=sample_strategy, probability=probability,
shadow_file_fn=shadow_file_fn, seed=seed,
)
idx = MultiIndexV2(
src,
sample_strategy=sample_strategy,
probability=probability,
shadow_file_fn=shadow_file_fn,
seed=seed,
)
else:
raise ValueError(f'Unknown file type: {src[0]}')
raise ValueError(f"Unknown file type: {src[0]}")
return idx


Expand All @@ -59,7 +75,7 @@ def get_attribute(data, attr_list):
for attr in attr_list:
ret_data[attr] = data.get(attr, None)
if ret_data[attr] is None:
raise ValueError(f'Missing {attr} in {data}')
raise ValueError(f"Missing {attr} in {data}")
return ret_data


Expand All @@ -71,27 +87,38 @@ def get_optional_attribute(data, attr_list):


def detect_index_type(data):
if isinstance(data['group_length'], dict):
return 'multireso'
if isinstance(data["group_length"], dict):
return "multireso"
else:
return 'base'
return "base"


def show_index_info(src, only_arrow_files=False, depth=1):
"""
Show base/multireso index information.
"""
if not Path(src).exists():
raise ValueError(f'{src} does not exist.')
raise ValueError(f"{src} does not exist.")
print(f"Loading index file {src} ...")
with open(src, 'r') as f:
with open(src, "r") as f:
src_data = json.load(f)
print(f"Loaded.")
data = get_attribute(src_data, ['data_type', 'indices_file', 'arrow_files', 'cum_length',
'group_length', 'indices', 'example_indices'])
opt_data = get_optional_attribute(src_data, ['config_file'])
data = get_attribute(
src_data,
[
"data_type",
"indices_file",
"arrow_files",
"cum_length",
"group_length",
"indices",
"example_indices",
],
)
opt_data = get_optional_attribute(src_data, ["config_file"])

# Format arrow_files examples
arrow_files = data['arrow_files']
arrow_files = data["arrow_files"]
if only_arrow_files:
existed = set()
arrow_files_output_list = []
Expand All @@ -105,79 +132,96 @@ def show_index_info(src, only_arrow_files=False, depth=1):
if depth >= len(parts):
continue
else:
arrow_file_part = '/'.join(parts[:-depth])
arrow_file_part = "/".join(parts[:-depth])
if arrow_file_part not in existed:
arrow_files_output_list.append(arrow_file_part)
existed.add(arrow_file_part)
else:
raise ValueError(f'Depth {depth} has exceeded the limit of arrow file {arrow_file}.')
arrow_files_repr = '\n'.join(arrow_files_output_list)
raise ValueError(
f"Depth {depth} has exceeded the limit of arrow file {arrow_file}."
)
arrow_files_repr = "\n".join(arrow_files_output_list)
print(arrow_files_repr)
return None

return_space = '\n' + ' ' * 21
return_space = "\n" + " " * 21

if len(arrow_files) <= 4:
arrow_files_repr = return_space.join([arrow_file for arrow_file in arrow_files])
else:
arrow_files_repr = return_space.join([_ for _ in arrow_files[:2]] + ['...']
+ [_ for _ in arrow_files[-2:]])
arrow_files_repr = return_space.join(
[_ for _ in arrow_files[:2]] + ["..."] + [_ for _ in arrow_files[-2:]]
)
arrow_files_length = len(arrow_files)

# Format data_type
data_type = data['data_type']
data_type = data["data_type"]
if isinstance(data_type, str):
data_type = [data_type]
data_type_common = []
src_files = []
found_src_files = False
for data_type_item in data_type:
if not found_src_files and data_type_item.strip() != 'src_files=':
if not found_src_files and data_type_item.strip() != "src_files=":
data_type_common.append(data_type_item.strip())
continue
found_src_files = True
if data_type_item.endswith('.json'):
if data_type_item.endswith(".json"):
src_files.append(data_type_item.strip())
else:
data_type_common.append(data_type_item.strip())
data_type_part2_with_ids = []
max_id_len = len(str(len(src_files)))
for sid, data_type_item in enumerate(src_files, start=1):
data_type_part2_with_ids.append(f'{str(sid).rjust(max_id_len)}. {data_type_item}')
data_type_part2_with_ids.append(
f"{str(sid).rjust(max_id_len)}. {data_type_item}"
)
data_type = data_type_common + data_type_part2_with_ids
data_repr = return_space.join(data_type)

# Format cum_length examples
cum_length = data['cum_length']
cum_length = data["cum_length"]
if len(cum_length) <= 8:
cum_length_repr = ', '.join([str(i) for i in cum_length])
cum_length_repr = ", ".join([str(i) for i in cum_length])
else:
cum_length_repr = ', '.join([str(i) for i in cum_length[:4]] + ['...'] + [str(i) for i in cum_length[-4:]])
cum_length_repr = ", ".join(
[str(i) for i in cum_length[:4]]
+ ["..."]
+ [str(i) for i in cum_length[-4:]]
)
cum_length_length = len(cum_length)

if detect_index_type(data) == 'base':
if detect_index_type(data) == "base":
# Format group_length examples
group_length = data['group_length']
group_length = data["group_length"]
if len(group_length) <= 8:
group_length_repr = ', '.join([str(i) for i in group_length])
group_length_repr = ", ".join([str(i) for i in group_length])
else:
group_length_repr = ', '.join([str(i) for i in group_length[:4]] + ['...'] + [str(i) for i in group_length[-4:]])
group_length_repr = ", ".join(
[str(i) for i in group_length[:4]]
+ ["..."]
+ [str(i) for i in group_length[-4:]]
)
group_length_length = len(group_length)

# Format indices examples
indices = data['indices']
if len(indices) == 0 and data['indices_file'] != '':
indices_file = Path(src).parent / data['indices_file']
indices = data["indices"]
if len(indices) == 0 and data["indices_file"] != "":
indices_file = Path(src).parent / data["indices_file"]
if Path(indices_file).exists():
print(f"Loading indices from {indices_file} ...")
indices = np.load(indices_file)['x']
indices = np.load(indices_file)["x"]
print(f"Loaded.")
else:
raise ValueError(f'This Index file contains an extra file {indices_file} which is missed.')
raise ValueError(
f"This Index file contains an extra file {indices_file} which is missed."
)
if len(indices) <= 8:
indices_repr = ', '.join([str(i) for i in indices])
indices_repr = ", ".join([str(i) for i in indices])
else:
indices_repr = ', '.join([str(i) for i in indices[:4]] + ['...'] + [str(i) for i in indices[-4:]])
indices_repr = ", ".join(
[str(i) for i in indices[:4]] + ["..."] + [str(i) for i in indices[-4:]]
)

# Calculate indices total length
indices_length = len(indices)
Expand All @@ -188,7 +232,7 @@ def show_index_info(src, only_arrow_files=False, depth=1):
\033[4mdata_type:\033[0m {data_repr}"""

# Process optional data
if opt_data['config_file'] is not None:
if opt_data["config_file"] is not None:
print_str += f"""
\033[4mconfig_file:\033[0m {opt_data['config_file']}"""

Expand All @@ -205,18 +249,20 @@ def show_index_info(src, only_arrow_files=False, depth=1):
Examples: {indices_repr}"""

else:
group_length = data['group_length']
group_length = data["group_length"]

indices_file = Path(src).parent / data['indices_file']
assert Path(indices_file).exists(), f'indices_file {indices_file} not found'
indices_file = Path(src).parent / data["indices_file"]
assert Path(indices_file).exists(), f"indices_file {indices_file} not found"
print(f"Loading indices from {indices_file} ...")
indices_data = np.load(indices_file)
print(f"Loaded.")
indices_length = sum([len(indices) for key, indices in indices_data.items()])
keys = [k for k in group_length.keys() if len(indices_data[k]) > 0]

resolutions = ResolutionGroup.from_list_of_hxw(keys)
resolutions.attr = [f'{len(indices):>,d}' for k, indices in indices_data.items()]
resolutions.attr = [
f"{len(indices):>,d}" for k, indices in indices_data.items()
]
resolutions.prefix_space = 25

print_str = f"""File: {Path(src).absolute()}
Expand All @@ -225,7 +271,7 @@ def show_index_info(src, only_arrow_files=False, depth=1):
\033[4mdata_type:\033[0m {data_repr}"""

# Process optional data
if opt_data['config_file'] is not None:
if opt_data["config_file"] is not None:
print_str += f"""
\033[4mconfig_file:\033[0m {opt_data['config_file']}"""

Expand All @@ -236,15 +282,19 @@ def show_index_info(src, only_arrow_files=False, depth=1):
if src_file.exists():
with src_file.open() as f:
base_data = json.load(f)
if 'config_file' in base_data:
config_files.append(base_data['config_file'])
if "config_file" in base_data:
config_files.append(base_data["config_file"])
else:
config_files.append('Unknown')
config_files.append("Unknown")
else:
config_files.append('Missing the src file')
config_files.append("Missing the src file")
if config_files:
config_file_str = return_space.join([f'{str(sid).rjust(max_id_len)}. {config_file}'
for sid, config_file in enumerate(config_files, start=1)])
config_file_str = return_space.join(
[
f"{str(sid).rjust(max_id_len)}. {config_file}"
for sid, config_file in enumerate(config_files, start=1)
]
)
print_str += f"""
\033[4mbase config files:\033[0m {config_file_str}"""

Expand All @@ -259,4 +309,4 @@ def show_index_info(src, only_arrow_files=False, depth=1):
\033[4mbuckets: Count = {len(keys)}\033[0m
{resolutions}"""

print(print_str + '\n)\n')
print(print_str + "\n)\n")
Loading

0 comments on commit 368280e

Please sign in to comment.