Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG : memory_stats() is not supported in TPU pod causing the inference in TPU pod to throw an error #181

Open
salrowili opened this issue Jan 5, 2025 · 8 comments

Comments

@salrowili
Copy link

Hi,

There is a bug in the inference code. memory_stats() is only supported on TPU single hardware. memory_stats() is part of the inference metric and to work around this issue, add # before this line and apply the change for all devices in the pod.

self._start_memory_monitoring()

@erfanzar
Copy link
Owner

erfanzar commented Jan 5, 2025

thanks for reporting the bug ill make a commit and check the inference ASAP

@erfanzar
Copy link
Owner

erfanzar commented Jan 6, 2025

Hey @salrowili can you confirm that the code now works just fine?

@salrowili
Copy link
Author

I have tested the recent update on TPUv4-8 . following this commit 1caafc3 , there is a new bug :

LOADING MODEL ... 
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.32it/s]
Converting Model:   0%|                                                                                                                                                                         | 0/291 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/home/user/EasyDeL/tests/vinference_test.py", line 140, in <module>
    main()
  File "/home/user/EasyDeL/tests/vinference_test.py", line 39, in main
    model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
  File "/home/user/.local/lib/python3.10/site-packages/easydel/modules/auto/auto_modeling.py", line 101, in from_pretrained
    return cls._from_torch_pretrained(
  File "/home/user/.local/lib/python3.10/site-packages/easydel/modules/auto/auto_modeling.py", line 210, in _from_torch_pretrained
    return Base._from_torch_pretrained(
  File "/home/user/.local/lib/python3.10/site-packages/easydel/infra/mixins/bridge.py", line 693, in _from_torch_pretrained
    params = model.pure_transform_fn(
  File "/home/user/.local/lib/python3.10/site-packages/easydel/utils/parameters_transformation.py", line 236, in torch_dict_to_easydel_params
    raise e
  File "/home/user/.local/lib/python3.10/site-packages/easydel/utils/parameters_transformation.py", line 233, in torch_dict_to_easydel_params
    jax_array = shard_fns[key_tuple](jax_array)
TypeError: 'NoneType' object is not callable

if I roll back to the previous commit 1045c25, the inference code will work ok. I have also noticed that in the inference example

sharding_axis_dims = (1, 1, 1, -1)
, the FSDP sharding was not set up right for TPUv4-8. This will affect not only the speed but also the performance. See this example below :

# fmt:off
import os
import sys
import threading
import time

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

import easydel as ed
# fmt:on
import jax
import torch
import transformers
from huggingface_hub import HfApi
from jax import numpy as jnp
from jax import sharding
from datasets import load_dataset
from tqdm import tqdm
PartitionSpec, api = sharding.PartitionSpec, HfApi()


def calc_accuracy(actuals,preds):
    total_correct=0
    total_examples=len(actuals)
    for actual,pred in zip(actuals,preds):
        pred_letter="A"
        if "A" in pred:
            pred_letter="A"
        if "B" in pred:
            pred_letter="B"
        if "C" in pred:
            pred_letter="C"
        if "D" in pred:
            pred_letter="D"
        if actual==pred_letter:
           total_correct+=1
    acc_score=total_correct/total_examples
    return acc_score


def log_mem():
	while True:
		ed.utils.analyze_memory.SMPMemoryMonitor(5).print_current_status()
		time.sleep(5)


threading.Thread(target=log_mem)  # .start()


def main():
	sharding_axis_dims = (1, 1, 1, -1)
	max_length = 2048

	pretrained_model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"
	partition_axis = ed.PartitionAxis()
	dtype = jnp.bfloat16
	print("LOADING MODEL ... ")
	model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
		pretrained_model_name_or_path,
		auto_shard_model=True,
		sharding_axis_dims=sharding_axis_dims,
		config_kwargs=ed.EasyDeLBaseConfigDict(
			freq_max_position_embeddings=max_length,
			mask_max_position_embeddings=max_length,
			attn_dtype=dtype,
			gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
			kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
			attn_mechanism=ed.AttentionMechanisms.VANILLA,
			# use_scan_mlp=True,
			# scan_mlp_chunk_size=128,
		),
		quantization_method=ed.EasyDeLQuantizationMethods.NONE,
		param_dtype=dtype,  # jnp.float8_e5m2,
		dtype=dtype,
		torch_dtype=torch.float16,
		partition_axis=partition_axis,
		precision=jax.lax.Precision("fastest"),
	)
	print("MODEL LOADED")
	tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
	tokenizer.padding_side = "left"
	tokenizer.pad_token_id = tokenizer.eos_token_id
	print("TOKENIZER LOADED")
	# model = model.quantize(
	# 	method=ed.EasyDeLQuantizationMethods.A8BIT,
	# 	block_size=128,
	# 	quantization_pattern=".*(gate_proj|up_proj).*",
	# )
	print("CREATING vInference")
	inference = ed.vInference(
		model=model,
		processor_class=tokenizer,
		generation_config=ed.vInferenceConfig(
			max_new_tokens=1024,
			temperature=0.6,
			do_sample=False,
			top_p=0.9,
			top_k=10,
			eos_token_id=model.generation_config.eos_token_id,
			streaming_chunks=32,
		),
	)

	print(model.model_task)
	print(model.model_type)
	print("Compiling")
	inference.precompile(1, inference.model_prefill_length)

	print("Done Compiling")
	print("Evaluating on MMLU Lite")
	prompts=[]
	pred_list=[]
	actual_list=[]
	data=load_dataset("CohereForAI/Global-MMLU-Lite","en",split="test")
	for item in tqdm(data,total=len(data)):
		question=item["question"]
		option_a=item["option_a"]
		option_b=item["option_b"]
		option_c=item["option_c"]
		option_d=item["option_d"]
		actual_list.append(item["answer"])
		prompt=f"Answer the following question by writing the right answer letter which can be A,B,C or D. Write only the correct answer letter in your response. \nQuestion : {question}\nA. {option_a}. \nB. {option_b}. \nC. {option_c}. \nD. {option_d}"
		prompts.append(prompt)
		messages=[{"role": "system", "content": "You are a helpful AI assistant."},
			{"role": "user","content": prompt}]
		ids = tokenizer.apply_chat_template(
			messages,
			return_tensors="jax",
			return_dict=True,
			max_length=inference.model_prefill_length,
			padding="max_length",
			add_generation_prompt=True,
			)

		pad_seq = inference.model_prefill_length
		for response in inference.generate(**ids):
			next_slice = slice(
				pad_seq,
				pad_seq + inference.generation_config.streaming_chunks,
			)
			pad_seq += inference.generation_config.streaming_chunks
			output=tokenizer.decode(response.sequences[0][next_slice], skip_special_tokens=True)
			pred_list.append(output)
	for prompt,pred in zip(prompts,pred_list):
		print("--------------------------------------")
		print(f"Prompt: {prompt}\nPrediction : {pred}")
	print("---------- Evaluation Score -----------------")
	acc_score=calc_accuracy(actual_list,pred_list)
	print(f"accuracy score : {acc_score}")

if __name__ == "__main__":
	main()

With FSDP sharding (1,1,1,-1) accuracy will drop to 31 with more than one minute and a half for the running time.
However, with (1,1,-1,1), the accuracy will jump to 34 and take only 40 seconds.
I also had to match the generation config of llama3.1 with https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/main/generation_config.json which increases the result slightly.

Feel free to post this example in the docs folder, because I noticed that many have asked to show more tutorials for inference and evaluation scripts (:

@erfanzar
Copy link
Owner

erfanzar commented Jan 7, 2025

Thank you! I appreciate the detailed insights.

At the moment, the HEAD commit (b9c4bd2) is working fine, so the bug seems to be resolved. Regarding sharding methods, I typically prefer Sequence Sharding over Tensor Parallelism, as it tends to perform better on GPUs.

Your example script and the explanation of the impact of sharding configurations are incredibly helpful, especially for optimizing inference. It might include this in the documentation as a reference for others.

@creatorrr
Copy link

@salrowili Did you ever run inference server on a multinode TPU pod like v4-64 ?

@salrowili
Copy link
Author

Hi @creatorrr,

I have managed to run inference on TPUv4-32 using sharding method of (1,1,4,4) with 0.0.80 version but the speed was not significantly higher than TPUv4-8. For SFT training, the best setting for me on TPUv-32 was (4,1,1,4). However, it was convenient for me to run inference on TPUv4-32 even though the speed was not better, because at least i can finetune and evaluate the model on the same TPU pod machine without create a separate TPUv4-8 machine for inference.

I suggest also that you better to compare the result from TPUv4-32 inference with TPUv4-8 (1,1,-1,1) sharding method, because a different sharding setting (e.g, (1,1,4,4) can yield to poor performance. TPUv4-8 (1,1,-1,1) gives me almost an identical performance to GPU inference with TRL repo.

If (1,1,4,8), (1,1,8,4) or (1,4,4,4) did not work with you on TPUv-64, your best bet is to change the topology of TPU POD itself when you create it through -- topology=.. flag. See https://cloud.google.com/tpu/docs/v4 . If your intention from running inference on TPUv-64 is to run large models, just to let you know, i have managed to run 70b model on TPUv4-8 with A8BIT flag on 0.0.80 version using this script :

......
sharding_axis_dims = (1, 1, -1, 1)   ## (1,1,1,-1) will cause generation to be very bad
max_length = 1024
num_devices = len(jax.devices())
input_shape = (num_devices, max_length)
pretrained_model_name_or_path = "..."
partition_axis = ed.PartitionAxis()
dtype = jnp.bfloat16
model, params = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
input_shape=input_shape,
auto_shard_params=True,
sharding_axis_dims=sharding_axis_dims,
config_kwargs=ed.EasyDeLBaseConfigDict(
	freq_max_position_embeddings=max_length,
	mask_max_position_embeddings=max_length,
	attn_dtype=dtype,
	gradient_checkpointing=ed.EasyDeLGradientCheckPointers.EVERYTHING_SAVEABLE,
	kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.A8BIT,
	attn_mechanism=ed.AttentionMechanisms.VANILLA,
),
quantization_method=ed.EasyDeLQuantizationMethods.A8BIT,
param_dtype=dtype,
trust_remote_code=True,
dtype=dtype,
from_torch=True,
torch_dtype=torch.float16,
partition_axis=partition_axis,
precision=jax.lax.Precision("fastest"),
)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id
inference = ed.vInference(
model=model,
params=params,
tokenizer=tokenizer,
generation_config=ed.vInferenceConfig(
	max_new_tokens=16, 
	temperature=0.6,
	top_p=0.9,
	top_k=1,
	do_sample=False,
	eos_token_id=model.generation_config.eos_token_id,
	streaming_chunks=32,
),
)
print(model.model_task)
print(model.model_type)
print("Compiling")
inference.precompile(1, inference.model_prefill_length)
print("Done Compiling")
......

Do not forget to remove memory state from every machine in the pod :

self._start_memory_monitoring()

You can execute this command on all workers in the pod

sed -i '117d' .../easydel/inference/vinference/metrics.py

I also suggest for @erfanzar , that we create a new page in docs folder that discuss and list the best practice for each TPU setting and compare those setting in term of speed and performance for finetuning and inference tasks.

@erfanzar
Copy link
Owner

@salrowili Absolutely, that's a great suggestion! I'd be happy to collaborate on documenting TPU best practices. Would you be open to discussing this further over Discord? We could combine our experiences with different sharding approaches, including custom configurations, to create a comprehensive guide.

@salrowili
Copy link
Author

salrowili commented Jan 12, 2025

Great idea @erfanzar . I am in. However, i suggest that we can open a topic for discussion here. The topic will only address best practice for TPU configurations.

My idea is that other will learn more from our discussions, from the point when we identify the issue till finding the solution. If we just presented the best practice for developer and researcher, they may not be convinced of the motivation behind it. So a topic to discuss the best practice here and a reference documentation in the doc file would serve the best.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants