Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Model loading refactor #10604

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open

[FEAT] Model loading refactor #10604

wants to merge 21 commits into from

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Jan 17, 2025

What does this PR do?

Fixes #10013 . This PR refactors model loading in diffusers. Here's a list of major changes in this PR.

  • only two loading paths (low_cpu_mem_usage=True and low_cpu_mem_usage = False). We don't rely on load_checkpoint_and_dispatch anymore and we don't merge sharded checkpoint also.
  • support for sharded checkpoints for both loading paths
  • keep_module_in_fp32 support for sharded checkpoints
  • better support for displaying warning due to error/unexpected/missing/mismatched keys

For low_cpu_mem_usage = False:

  • Faster initialization (thanks to skipping the init + assign_to_params_buffers). I didn't benchmarked it but it should be as fast as low_cpu_mem_usage=True or maybe even faster. We did a similar PR in transformers thanks to @muellerzr.
  • Better torch_dtype support We don't initialize anymore the model in fp32 then cast the model to a specific dtype after finishing to load the weights.

For low_cpu_mem_usage = True or device_map!=None:

  • one path, we don't rely anymore on load_checkpoint_and_dispatch
  • device_map support for quantization
  • non persistance buffer support through dispatch_model ( the test you added is passing cc @hlky )

Single format file:

  • Simplified the single file format loading through from_pretrained. This way we have the same features as this function (device_map, quantization ...). Feel free to share your opinion @DN6, I didn't expect to touch this but I felt that we could simplify a bit

TODO (some items can be done in follow-up PRs):

  • Check if we have any regression / tests issues
  • Add more tests
  • Deal with missing keys in the model for both paths (before, it only worked when low_cpu_mem_usage=False since we are initializing the whole model)
  • Fix typing
  • Better support for offload with safetensors (like in transformers)

Please let me know your thoughts on the PR !

cc @sayakpaul, @DN6 , @yiyixuxu , @hlky , @a-r-r-o-w

@SunMarc SunMarc changed the title [FEAT ] Model loading refactor [FEAT] Model loading refactor Jan 17, 2025
@SunMarc
Copy link
Member Author

SunMarc commented Jan 18, 2025

FLAX CPU failing test is unrelated, failing in other PRs too

from huggingface_hub.utils import validate_hf_hub_args

from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils import deprecate, logging
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will let @DN6 comment on the single-file related changes.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this! Left some comments from a first pass.

I think we will need to also add tests for seeing if device_map works as expected for quantization. Okay to not test that a bit later once there is consensus about the design changes. Maybe we could add that as a TODO.

Other tests could include checking if we can do low_cpu_mem_usage=True along with some changed config values. This will ensure we're well tested for cases like #9343.

src/diffusers/models/model_loading_utils.py Show resolved Hide resolved
@@ -134,15 +135,14 @@ def _fetch_remapped_cls_from_config(config, old_class):

def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
variant: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variant isn't used anyway in this method, so good for me.

But let's make sure the method is invoked properly with proper arguments.

src/diffusers/models/model_loading_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/model_loading_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/model_loading_utils.py Show resolved Hide resolved
Comment on lines +1382 to +1384
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we fully considered the consequences of this especially under things like "layerwise upcasting"? (see #10347)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
torch.set_default_dtype(torch.float8_e4m3fn)
# TypeError: couldn't find storage object Float8_e4m3fnStorage

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback ! For torch.float8_e4m3fn dtype, we can just make an exception for this dtype and skip that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't look like the layerwise upcasting PR passes torch_dtype = torch.float8_e4m3fn in from_pretrained. cc @a-r-r-o-w . LMK if I should still take care of this case or we can deal with that in a follow-up PR when we need that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you perhaps run the test_layerwise_casting_inference tests to confirm, that would be great.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all 37 tests are passing !

for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_default_torch_dtype() already calls set_default_dtype(), is that still needed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to set back the default dtype to the original dtype dtype_orig. This way, if the user continue to create tensors, it will be back to the default dtype they are expected e.g. FP32

src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Show resolved Hide resolved
src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member

@SunMarc,

Additionally, I ran some tests on audace (two RTX 4090s). Some tests that are failing (they fail on main too):

Failures
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_0_hf_internal_testing_unet2d_sharded_dummy - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_1_hf_internal_testing_tiny_sd_unet_sharded_latest_format - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_local - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_local_subfolder - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_subfolder_0_hf_internal_testing_unet2d_sharded_dummy_subfolder - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_load_sharded_checkpoint_from_hub_subfolder_1_hf_internal_testing_tiny_sd_unet_sharded_latest_format_subfolder - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints_device_map - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argumen...
FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints_with_variant - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument...

^^ passes when using with CUDA_VISIBLE_DEVICES=0 (same with main). Expected?

Same for following:

FAILED tests/models/unets/test_models_unet_2d_condition.py::UNet2DConditionModelTests::test_sharded_checkpoints_device_map - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

And then I also ran:

RUN_SLOW=1 pytest tests/pipelines/stable_diffusion/test_stable_diffusion.py::StableDiffusionPipelineDeviceMapTests

Everything passes.


for param_name, param in named_buffers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep this or equivalent elsewhere, context: #10523

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes I did should also cover this use case. The test you added should pass with my PR. The is mainly due to adding the dispatch_model function at the end.

Comment on lines +1382 to +1384
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
torch.set_default_dtype(torch.float8_e4m3fn)
# TypeError: couldn't find storage object Float8_e4m3fnStorage

src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
logger = logging.get_logger(__name__)

_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")

TORCH_INIT_FUNCTIONS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a merge blocker, but is it possible to dynamically create this mapping? Then we could avoid having to make manual updates in case new inits are added to torch.

Although I suppose that doesn't happen too often.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this could work:

import torch.nn.init as init

init_functions = {
    name: getattr(init, name) for name in dir(init) if callable(getattr(init, name)) 
    and name.endswith("_")
    and not name.startswith("_")
}

print("Available initialization functions:")
for name in init_functions:
    print(name)

Prints:

Available initialization functions:
constant_
dirac_
eye_
kaiming_normal_
kaiming_uniform_
normal_
ones_
orthogonal_
sparse_
trunc_normal_
uniform_
xavier_normal_
xavier_uniform_
zeros_

Copy link
Member Author

@SunMarc SunMarc Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT @DN6 ? I'm fine with either. Also there are some missing function since you only choose the one finishing with "_", though I don't think these are deprecated now.

    "uniform": nn.init.uniform,
    "normal": nn.init.normal,
    "xavier_uniform": nn.init.xavier_uniform,
    "xavier_normal": nn.init.xavier_normal,
    "kaiming_uniform": nn.init.kaiming_uniform,
    "kaiming_normal": nn.init.kaiming_normal,

# in the case it is sharded, we have already the index
if is_sharded:
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
resolved_archive_file = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps resolved_model_file can work here since most of the time this variable is used with _get_model_file?

dduf_entries=dduf_entries,
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
dtype_orig = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a use case where we might want to support loading checkpoints that are in mixed precision. e.g The Mochi Video model needs to preserve norms in FP32 (we can't load in FP16/BF16 and then cast back to FP32 with _keep_in_fp32_modules)
https://huggingface.co/Kijai/Mochi_preview_comfy/blob/main/mochi_preview_dit_fp8_e4m3fn.safetensors

We were thinking of introducing an auto dtype for such cases.

Additionally, torch FP8 is a valid and popular storage type in the Diffusion community that is dynamically upcast during inference time (a feature we will add soon).
https://github.com/huggingface/diffusers/pull/10347/files

I think this might break if a user tries something like `.from_pretrained(.., torch_dtype=torch.float8_e4m3fn), which would be a breaking change for us.

Think we need to update the casting here to account for these cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, torch FP8 is a valid and popular storage type in the Diffusion community that is dynamically upcast during inference time (a feature we will add soon).
https://github.com/huggingface/diffusers/pull/10347/files

I think this might break if a user tries something like `.from_pretrained(.., torch_dtype=torch.float8_e4m3fn), which would be a breaking change for us.

Think we need to update the casting here to account for these cases.

Yes, I will update the code to reflect this.

Copy link
Member Author

@SunMarc SunMarc Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a use case where we might want to support loading checkpoints that are in mixed precision. e.g The Mochi Video model needs to preserve norms in FP32 (we can't load in FP16/BF16 and then cast back to FP32 with _keep_in_fp32_modules)
https://huggingface.co/Kijai/Mochi_preview_comfy/blob/main/mochi_preview_dit_fp8_e4m3fn.safetensors

Under low_cpu_mem_usage = True, it won't load the the model in FP16/BF16 then cast it back to FP32. With
_keep_in_fp32_modules, we should be able to make sure that the param stays in FP32.

Of course, if we have more complicated use case where the params are a mix of many stype then it would make sense to introduce the dtype auto, so that we use the dtype of the state dict.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can introduce dtype="auto" after this PR is merged. Just wanted to flag

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As said here, I don't think we need to change anything yet cc @a-r-r-o-w

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more comments.

I am running the 4bit quantization tests currently. And so far things are looking nice! Some tests that might be worth including/consdering:

  • Device map with quantization
  • Effectiveness of keep_modules_in_fp32 when not using quantization.

WDYT?

Edit: 4bit and 8bit tests (bitsandbytes) are passing.

@@ -362,17 +362,18 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =

if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
unexpected_keys = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the single-file related changes to uniformize the use of load_model_dict_into_meta() (with the new signature)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's right !

src/diffusers/models/model_loading_utils.py Show resolved Hide resolved
Comment on lines -258 to -259
if named_buffers is None:
return unexpected_keys
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines -258 to -259
if named_buffers is None:
return unexpected_keys
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, found it:

dispatch_model(model, **device_map_kwargs)

It's a tad bit easier for reviewers if we could just provide these links going forward.

logger = logging.get_logger(__name__)

_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")

TORCH_INIT_FUNCTIONS = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this could work:

import torch.nn.init as init

init_functions = {
    name: getattr(init, name) for name in dir(init) if callable(getattr(init, name)) 
    and name.endswith("_")
    and not name.startswith("_")
}

print("Available initialization functions:")
for name in init_functions:
    print(name)

Prints:

Available initialization functions:
constant_
dirac_
eye_
kaiming_normal_
kaiming_uniform_
normal_
ones_
orthogonal_
sparse_
trunc_normal_
uniform_
xavier_normal_
xavier_uniform_
zeros_



@contextmanager
def no_init_weights():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you briefly then elaborate what happens in this codepath?

Comment on lines +1382 to +1384
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you perhaps run the test_layerwise_casting_inference tests to confirm, that would be great.

Comment on lines -202 to +207
def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context:
def test_missing_key_loading_warning_message(self):
with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")

# make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception)
assert "conv_out.bias" in " ".join(logs.output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain the changes?

Copy link
Member Author

@SunMarc SunMarc Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched from raising an error to just a warning for missing keys.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Core] refactor model loading
6 participants