Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate with DeepSpeed not works #3337

Open
2 of 4 tasks
khalil-Hennara opened this issue Jan 12, 2025 · 7 comments
Open
2 of 4 tasks

Accelerate with DeepSpeed not works #3337

khalil-Hennara opened this issue Jan 12, 2025 · 7 comments

Comments

@khalil-Hennara
Copy link

khalil-Hennara commented Jan 12, 2025

System Info

`Accelerate` version: 1.2.1
- Platform: Linux-6.8.0-45-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /opt/conda/bin/accelerate
- Python version: 3.11.10
- Numpy version: 2.1.2
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 503.46 GB
- GPU type: NVIDIA A100-SXM4-80GB
- `Accelerate` default config:
	- compute_environment: LOCAL_MACHINE
	- distributed_type: DEEPSPEED
	- mixed_precision: no
	- use_cpu: False
	- debug: False
	- num_processes: 2
	- machine_rank: 0
	- num_machines: 1
	- rdzv_backend: static
	- same_network: True
	- main_training_function: main
	- enable_cpu_affinity: False
	- deepspeed_config: {'gradient_accumulation_steps': 1, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
	- downcast_bf16: no
	- tpu_use_cluster: False
	- tpu_use_sudo: False
	- tpu_env: []

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

accelerator = Accelerator()
# parse the argument
parser = transformers.HfArgumentParser(
(ModelConfig, DataConfig, PeftConfig, AccelerateConfig, TrainConfig))
model_config, data_config, peft_config, accelerate_config, training_config = parser.parse_args_into_dataclasses()

accelerator_log_kwargs = {}
if accelerate_config.with_tracking:
    accelerator_log_kwargs["log_with"] = training_config.report_to
    accelerator_log_kwargs["project_dir"] = training_config.output_dir

logger = logging.getLogger(__name__)

num_proc = os.cpu_count()

# check if we will use mixed precision for training.
mixed_precision = "no"
if training_config.bf16:
    mixed_precision = 'bf16'
if accelerate_config.use_8_bit:
    mixed_precision = 'fp8'
if training_config.fp16:
    mixed_precision = 'fp16'

# initialize the accelerator
# accelerator = Accelerator(mixed_precision=mixed_precision,
#                           gradient_accumulation_steps=training_config.gradient_accumulation_steps,
#                           **accelerator_log_kwargs)

# if accelerate_config.set_wrapper_manually:
#     if accelerator.state.fsdp_plugin is not None:
#         accelerator.state.fsdp_plugin.auto_wrap_policy = get_llama_wrapper()

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
# print the log from all processes not just the main process
logger.info(accelerator.state)

Expected behavior

to work fine, I've used this script many time, before and I didn't face any problems, but I've been try to solve the problem for the last two days, the problem is new. every time I am running the code I am getting this error

--- Logging error ---
Traceback (most recent call last):
File "/opt/conda/lib/python3.11/logging/init.py", line 1110, in emit
msg = self.format(record)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/logging/init.py", line 953, in format
return fmt.format(record)
^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/logging/init.py", line 687, in format
record.message = record.getMessage()
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/logging/init.py", line 375, in getMessage
msg = str(self.msg)
^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/accelerate/state.py", line 949, in repr
repr = PartialState().repr() + f"\nMixed precision type: {self.mixed_precision}\n"
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/accelerate/state.py", line 1131, in getattr
raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
AttributeError: 'AcceleratorState' object has no attribute 'mixed_precision'

and also the script keep working and then when I am trying to access the accelerator.state.deepspeed_plugin as recommended in official Doc, I get the next error
AttributeError: AcceleratorState object has no attribute deepspeed_plugin. This happens if AcceleratorState._reset_state() was called and an Accelerator or PartialState was not reinitialized.. Did you mean: 'get_deepspeed_plugin'?

even thought there is no such line in my code.

@BenjaminBossan
Copy link
Member

Most of the code you post appears to be unrelated. If you just have the following code, do you still get the error?

import logging
from accelerate import Accelerator

accelerator = Accelerator()
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state)
accelerator.state.deepspeed_plugin

@muellerzr
Copy link
Collaborator

As @BenjaminBossan hints at, the problem is you're using the Accelerate logger, which needs an Accelerator() or PartialState() created first before it can log

@khalil-Hennara
Copy link
Author

khalil-Hennara commented Jan 14, 2025

yes I am getting error, if I am using huggingface parser @BenjaminBossan

def main():
    accelerator = Accelerator()
    num_proc = os.cpu_count()
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # parse the argument
    parser = transformers.HfArgumentParser(
        (ModelConfig, DataConfig, PeftConfig, AccelerateConfig, TrainConfig))
    model_config, data_config, peft_config, accelerate_config, training_config = parser.parse_args_into_dataclasses()

    logger.info(accelerator.state)

when I am using this code it get an error,

File "/opt/conda/lib/python3.11/site-packages/accelerate/state.py", line 1131, in getattr
raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
AttributeError: 'AcceleratorState' object has no attribute 'mixed_precision'

and even if I am trying to use this line to log the deepspeed config:

def main():
    accelerator = Accelerator()
    num_proc = os.cpu_count()
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # parse the argument
    parser = transformers.HfArgumentParser(
        (ModelConfig, DataConfig, PeftConfig, AccelerateConfig, TrainConfig))
    model_config, data_config, peft_config, accelerate_config, training_config = parser.parse_args_into_dataclasses()

   if accelerator.state.deepspeed_plugin:
        logger.info(accelerator.state.deepspeed_plugin.debugging)

I also get an error, I think the error, is caused after using the HfArgument parser

@BenjaminBossan
Copy link
Member

@khalil-Hennara Could you please provide a complete reproducer that we can run on our machines to reproduce the error?

@khalil-Hennara
Copy link
Author

khalil-Hennara commented Jan 14, 2025

@BenjaminBossan I am gonna provide the code and the error also the accelerate env.

First of all accelerate env

  • Accelerate version: 1.2.1
  • Platform: Linux-6.8.0-45-generic-x86_64-with-glibc2.35
  • accelerate bash location: /opt/conda/bin/accelerate
  • Python version: 3.11.10
  • Numpy version: 2.1.2
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • PyTorch XPU available: False
  • PyTorch NPU available: False
  • PyTorch MLU available: False
  • PyTorch MUSA available: False
  • System RAM: 377.45 GB
  • GPU type: NVIDIA A100-SXM4-40GB
  • Accelerate default config:
    • compute_environment: LOCAL_MACHINE
    • distributed_type: DEEPSPEED
    • mixed_precision: bf16
    • use_cpu: False
    • debug: False
    • num_processes: 2
    • machine_rank: 0
    • num_machines: 1
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • enable_cpu_affinity: False
    • deepspeed_config: {'gradient_accumulation_steps': 1, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
    • tpu_env: []

Second the code, I have remove any dependencies of my code to make it easy to reproduce

import logging
import math
import os
import sys
import fire
import numpy as np
import torch

# third party dependencies
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed, DummyScheduler, DummyOptim
from transformers import default_data_collator, TrainingArguments
from transformers.utils import check_min_version, is_flash_attn_2_available
from peft import LoraConfig, get_peft_model
import datasets
from torch.utils.data import DataLoader

def main():
    # parse the argument
    accelerator = Accelerator()
    
    parser = transformers.HfArgumentParser((TrainingArguments))
    training_config = parser.parse_args_into_dataclasses()[0]

    logger = get_logger(__name__)
    
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    
    logger.info(accelerator.state)
    
if __name__ == "__main__":
    main()

The error message

"
[rank0]: Traceback (most recent call last):
[rank0]: File "/root/projects/kawn/scripts/training/test.py", line 38, in
[rank0]: main()
[rank0]: File "/root/projects/kawn/scripts/training/test.py", line 35, in main
[rank0]: print(accelerator.state)
[rank0]: File "/opt/conda/lib/python3.11/site-packages/accelerate/state.py", line 949, in repr
[rank0]: repr = PartialState().repr() + f"\nMixed precision type: {self.mixed_precision}\n"
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/accelerate/state.py", line 1131, in getattr
[rank0]: raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
[rank0]: AttributeError: 'AcceleratorState' object has no attribute 'mixed_precision'
"

The launch script

accelerate launch kawn/scripts/training/test.py --output_dir test
I am obligate to provide the output_dir as it's required for TrainingArguments.

My training script has been build using run_clm_no_trainer.py as an official example. And I have used it many time before training many models, and I didn't face any such problems, also within the https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py you can notice that the logger has been initilized before the accelerator, and the accelerator also have been modified according to the custom argument from the user.

@BenjaminBossan
Copy link
Member

Thanks @khalil-Hennara I could simplify the script as follows:

import transformers
from accelerate import Accelerator
from transformers import TrainingArguments


def main1():
    accelerator = Accelerator()
    parser = transformers.HfArgumentParser(TrainingArguments)
    # parser.parse_args_into_dataclasses()
    repr(accelerator.state)
    print("main1 passes")

def main2():
    parser = transformers.HfArgumentParser(TrainingArguments)
    parser.parse_args_into_dataclasses()
    accelerator = Accelerator()
    repr(accelerator.state)
    print("main2 passes")

def main3():
    accelerator = Accelerator()
    parser = transformers.HfArgumentParser(TrainingArguments)
    parser.parse_args_into_dataclasses()
    repr(accelerator.state)
    print("main3 should pass but fails")

if __name__ == "__main__":
    main1()
    main2()
    main3()

Note that I have 3 different conditions:

  1. main1: Comment out parser.parse_args_into_dataclasses() and it works.
  2. main2: Initialize the Accelerator after parser.parse_args_into_dataclasses() and it works.
  3. main3: Initialize the Accelerator before parser.parse_args_into_dataclasses() and it fails.

In main3, the error is caused by this line:

if self.distributed_type == DistributedType.DEEPSPEED:

because the self.distributed_type attribute does not exist, even though it existed earlier on during the __init__ phase of the accelerator. It's not clear to me why the attribute disappears and how that relates to parse_args_into_dataclasses.

@khalil-Hennara
Copy link
Author

Thank you @BenjaminBossan for your explaination but I am still get error messages.
I've try the same code you provide
and I get this error message from main2

error message

"
[rank0]: Traceback (most recent call last):
[rank0]: File "/root/projects/kawn/scripts/training/test.py", line 29, in
[rank0]: main2()
[rank0]: File "/root/projects/kawn/scripts/training/test.py", line 16, in main2
[rank0]: accelerator = Accelerator()
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py", line 302, in init
[rank0]: deepspeed_plugins = AcceleratorState().deepspeed_plugins
[rank0]: ^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/accelerate/state.py", line 887, in init
[rank0]: raise ValueError(
[rank0]: ValueError: Please make sure to properly initialize your accelerator via accelerator = Accelerator() before using any functionality from the accelerate library.
"
using HfArguement parser is what make the problem, I've try using ArgumentParser from STD and it's work fine.

import transformers
from accelerate import Accelerator
from transformers import TrainingArguments
import argparse

def get_wandb_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
    """Add Weights & Biases arguments to parser"""
    wandb_group = parser.add_argument_group('Weights & Biases Arguments', 'W&B configuration parameters')

    # Project name
    wandb_group.add_argument('--output_dir', type=str, default='kawn_vision',
                             help='W&B project name (default: kawn_vision)')
    return parser
    

def main1():
    accelerator = Accelerator()
    parser = transformers.HfArgumentParser(TrainingArguments)
    # parser.parse_args_into_dataclasses()
    repr(accelerator.state)
    print("main1 passes")

def main2():
    parser = argparse.ArgumentParser()
    
    args = get_wandb_args(parser).parse_args()
    
    accelerator = Accelerator()
    repr(accelerator.state)
    print("main2 passes")

def main3():
    accelerator = Accelerator()
    parser = transformers.HfArgumentParser(TrainingArguments)
    parser.parse_args_into_dataclasses()
    repr(accelerator.state)
    print("main3 should pass but fails")

if __name__ == "__main__":
    main1()
    main2()
    main3()

the previous code failed on main3 with following error

main3 error

"
[rank0]: Traceback (most recent call last):
[rank0]: File "/root/projects/kawn/scripts/training/test.py", line 42, in
[rank0]: main3()
[rank0]: File "/root/projects/kawn/scripts/training/test.py", line 36, in main3
[rank0]: repr(accelerator.state)
[rank0]: File "/opt/conda/lib/python3.11/site-packages/accelerate/state.py", line 949, in repr
[rank0]: repr = PartialState().repr() + f"\nMixed precision type: {self.mixed_precision}\n"
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/accelerate/state.py", line 1131, in getattr
[rank0]: raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
[rank0]: AttributeError: 'AcceleratorState' object has no attribute 'mixed_precision'
"
The last comment on this problem it's only occur if we are using Distributed training nothing happen with normal training

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants