Skip to content

Commit

Permalink
Bitsandbytes docs improvements (#18903)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Nov 2, 2023
1 parent 1e68c50 commit ad93f64
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 8 deletions.
11 changes: 11 additions & 0 deletions docs/source-fabric/api/fabric_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ Automatic mixed precision settings are denoted by a ``"-mixed"`` suffix, while "
fabric = Fabric(precision="64-true", devices=1)
Precision settings can also be enabled via the plugins argument (see section below on plugins).
An example is the weights quantization plugin Bitsandbytes for 4-bit and 8-bit:

.. code-block:: python
from lightning.fabric.plugins import BitsandbytesPrecision
precision = BitsandbytesPrecision(mode="nf4-dq", dtype=torch.bfloat16)
fabric = Fabric(plugins=precision)
plugins
=======

Expand Down
10 changes: 6 additions & 4 deletions docs/source-fabric/fundamentals/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ See also: :doc:`../advanced/model_init`
Quantization via Bitsandbytes
*****************************

`bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__ (BNB) is a library that supports quantizing Linear weights.
`bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__ (BNB) is a library that supports quantizing :class:`torch.nn.Linear` weights.

Both 4-bit (`paper reference <https://arxiv.org/abs/2305.14314v1>`__) and 8-bit (`paper reference <https://arxiv.org/abs/2110.02861>`__) quantization is supported.
Specifically, we support the following modes:
Expand All @@ -228,20 +228,22 @@ Specifically, we support the following modes:

While these techniques store weights in 4 or 8 bit, the computation still happens in 16 or 32-bit (float16, bfloat16, float32).
This is configurable via the dtype argument in the plugin.
If your model weights can fit on a single device with 16 bit precision, it's recommended that this plugin is not used as it will slow down training.

Quantizing the model will dramatically reduce the weight's memory requirements but may have a negative impact on the model's performance or runtime.

The :class:`~lightning.fabric.plugins.precision.bitsandbytes.BitsandbytesPrecision` a
The :class:`~lightning.fabric.plugins.precision.bitsandbytes.BitsandbytesPrecision` automatically replaces the :class:`torch.nn.Linear` layers in your model with their BNB alternatives.

.. code-block:: python
from lightning.fabric.plugins import BitsandbytesPrecision
# this will pick out the compute dtype automatically, by default `bfloat16`
precision = BitsandbytesPrecision("nf4-dq")
precision = BitsandbytesPrecision(mode="nf4-dq")
fabric = Fabric(plugins=precision)
# Customize the dtype, or ignore some modules
precision = BitsandbytesPrecision("int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
precision = BitsandbytesPrecision(mode="int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
fabric = Fabric(plugins=precision)
model = MyModel()
Expand Down
17 changes: 16 additions & 1 deletion docs/source-fabric/glossary/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ Glossary
:button_link: ../advanced/distributed_communication.html
:col_css: col-md-4

.. displayitem::
:header: Bfloat16
:button_link: ../fundamentals/precision.html
:col_css: col-md-4

.. displayitem::
:header: Broadcast
:button_link: ../advanced/distributed_communication.html
Expand Down Expand Up @@ -89,7 +94,7 @@ Glossary
:col_css: col-md-4

.. displayitem::
:header: Jypyter
:header: Jupyter
:button_link: ../launch/notebooks.html
:col_css: col-md-4

Expand Down Expand Up @@ -148,6 +153,11 @@ Glossary
:button_link: ../fundamentals/precision.html
:col_css: col-md-4

.. displayitem::
:header: Quantization
:button_link: ../fundamentals/precision.html
:col_css: col-md-4

.. displayitem::
:header: Reduce
:button_link: ../advanced/distributed_communication.html
Expand Down Expand Up @@ -183,6 +193,11 @@ Glossary
:button_link: ../guide/trainer_template.html
:col_css: col-md-4

.. displayitem::
:header: 16-bit, 8-bit, 4-bit
:button_link: ../fundamentals/precision.html
:col_css: col-md-4


.. raw:: html

Expand Down
7 changes: 4 additions & 3 deletions docs/source-pytorch/common/precision_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Under the hood, we use `transformer_engine.pytorch.fp8_autocast <https://docs.nv
Quantization via Bitsandbytes
*****************************

`bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__ (BNB) is a library that supports quantizing Linear weights.
`bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__ (BNB) is a library that supports quantizing :class:`torch.nn.Linear` weights.

Both 4-bit (`paper reference <https://arxiv.org/abs/2305.14314v1>`__) and 8-bit (`paper reference <https://arxiv.org/abs/2110.02861>`__) quantization is supported.
Specifically, we support the following modes:
Expand All @@ -179,6 +179,7 @@ Specifically, we support the following modes:

While these techniques store weights in 4 or 8 bit, the computation still happens in 16 or 32-bit (float16, bfloat16, float32).
This is configurable via the dtype argument in the plugin.
If your model weights can fit on a single device with 16 bit precision, it's recommended that this plugin is not used as it will slow down training.

Quantizing the model will dramatically reduce the weight's memory requirements but may have a negative impact on the model's performance or runtime.

Expand All @@ -189,11 +190,11 @@ The :class:`~lightning.pytorch.plugins.precision.bitsandbytes.BitsandbytesPrecis
from lightning.pytorch.plugins import BitsandbytesPrecision
# this will pick out the compute dtype automatically, by default `bfloat16`
precision = BitsandbytesPrecision("nf4-dq")
precision = BitsandbytesPrecision(mode="nf4-dq")
trainer = Trainer(plugins=precision)
# Customize the dtype, or skip some modules
precision = BitsandbytesPrecision("int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
precision = BitsandbytesPrecision(mode="int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
trainer = Trainer(plugins=precision)
Expand Down

0 comments on commit ad93f64

Please sign in to comment.