Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] enable multi-backend bitsandbytes #10574

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

hlky
Copy link
Collaborator

@hlky hlky commented Jan 14, 2025

What does this PR do?

Mainly copied from transformers PR.

May need to look at

_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
)
if is_loaded_in_8bit_bnb and device is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)
# This can happen for `transformer` models. CPU placement was added in
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
module.to(device, dtype)

Test results are same as nightly https://github.com/huggingface/diffusers/actions/runs/12758480172/job/35560601164

RUN_SLOW=1 pytest -v -s tests/quantization/bnb/
===================================================================================== short test summary info =====================================================================================
FAILED tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_generate_quality_dequantize - NotImplementedError: Only row-major format inputs are supported, but got format `col32`
FAILED tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_quality - AssertionError: False is not true
===================================================================== 2 failed, 42 passed, 26 warnings in 1601.02s (0:26:41) ======================================================================

Fixes #10395

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +94 to +95
if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because bnb is supported on intel CPUs?

if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled:
pass
elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The common piece of code between the two utilities could be clubbed into a small function and reused?

Previously we didn't do because it was relatively small and was better off in-line.

if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled:
pass
elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The common piece of code between the two utilities could be clubbed into a small function and reused?

Previously we didn't do because it was relatively small and was better off in-line.

@@ -183,7 +194,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t().to(dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: #10401

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will rebase after that PR has merged.

@@ -304,3 +318,80 @@ def _check_bnb_status(module) -> Union[bool, bool]:
and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
)
return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb


def _validate_bnb_multi_backend_availability(raise_exception):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas I wonder if it makes sense have these as utility functions in bitsandbytes so that they can be reused in transformers and diffusers (and any other libraries)?

return True


@lru_cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually don't do lru_cache in import_utils.py. Any specific reasons?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied from transformers, not sure on the context.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments but this is already very good!

May need to look at

Anything specific? Not seeing anything CUDA-specific.

@sayakpaul
Copy link
Member

Test results are same as nightly

Were the tests run on the aws-g6e-xlarge-plus runner? If so, tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests::test_quality should have passed. Will take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Quantization] enable multi-backend bitsandbytes
3 participants