Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training hangs indefinitely on first forward pass when using TPU v3-8 in Kaggle #3370

Open
2 of 4 tasks
WpythonW opened this issue Jan 27, 2025 · 0 comments
Open
2 of 4 tasks

Comments

@WpythonW
Copy link

System Info

accelerate version: 1.2.1
OS: Linux (Kaggle environment)
python version: 3.10.16 (main, Dec 25 2024, 01:31:21) [GCC 12.2.0]
torch version: 2.5.0+cu124
torch_xla version: 2.5.0+libtpu
Hardware: TPU v3-8

Accelerate configuration:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: XLA
downcast_bf16: 'yes'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Environment: Kaggle Notebook with TPU v3-8

Required setup (run each cell separately):

Cell 1:

!pip install datasets evaluate
!pip uninstall -y tensorflow
!pip install tensorflow-cpu

Cell 2:

!apt update
!apt install lsof

Cell 3:

import os
os.environ.pop('TPU_PROCESS_ADDRESSES')

Cell 4:

!mkdir /root/.cache/huggingface
!mkdir /root/.cache/huggingface/accelerate

Cell 5:

%%writefile /root/.cache/huggingface/accelerate/default_config.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: XLA
downcast_bf16: 'yes'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Cell 6:

%%writefile train.py
from accelerate import Accelerator
from evaluate import load
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm.auto import tqdm
import datasets
from datasets import Dataset

def main():    
    accelerator = Accelerator()
    
    model_name = "answerdotai/ModernBERT-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, 
        num_labels=3
    )
    
    ds = load_dataset("TimKoornstra/financial-tweets-sentiment")
    tweets = ds['train']['tweet']
    labels = ds['train']['sentiment']
    tokenized_input = tokenizer(tweets, padding="max_length", max_length=512, truncation=True)
    
    input_ids = tokenized_input['input_ids']
    attention_mask = tokenized_input['attention_mask']

    tokenized_dataset = Dataset.from_dict({
        "input_ids": input_ids, 
        "attention_mask": attention_mask, 
        "labels": labels
    }).with_format("torch")
     
    train_dataloader = DataLoader(
        tokenized_dataset,
        shuffle=True,
        batch_size=2,
        num_workers=0
    )
    
    optimizer = AdamW(params=model.parameters(), lr=1e-4)

    model, optimizer, train_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader
    )
    
    model.train()
    print('Starting training')
    for epoch in range(3):
        for step, batch in enumerate(train_dataloader):
            print(step)
            outputs = model(**batch)
            print("outputs got")  # Never reaches this line
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
            print("step made")

if __name__=="__main__":
    main()

Cell 7:

!accelerate launch --tpu train.py

Output:

WARNING:root:Unsupported nprocs (8), ignoring...
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Starting training
Starting training
Starting training
Starting training
Starting training
Starting training
Starting training
Starting training
0
0
0
0
0
0
0
0

Expected behavior

The training process should proceed normally through these steps:

Initialize model and data across TPU cores
For each batch:
Print step number (as implemented)
Complete forward pass
Print "outputs got" message
Calculate loss
Complete backward pass
Update model parameters
Print "step made" message
Move to next batch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant