Skip to content

Commit

Permalink
Refactor measure_flops function for TPU and GPU compatibility; update…
Browse files Browse the repository at this point in the history
… max_length in tests
  • Loading branch information
erfanzar committed Jan 5, 2025
1 parent 72eec8f commit 9e95b25
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 23 deletions.
35 changes: 24 additions & 11 deletions easydel/inference/vinference/_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,30 @@
vInferenceConfig,
)


def measure_flops(func, *args, **kwargs):
try:
flops = func.cost_analysis()[0]["flops"]
except: # noqa
flops = 1
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
return result, flops, flops / elapsed_time, elapsed_time
if jax.default_backend() == "tpu":

def measure_flops(func, *args, **kwargs):
try:
flops = func.cost_analysis()[0]["flops"]
except Exception:
flops = 1
start_time = time.perf_counter()
result = jax.block_until_ready(func(*args, **kwargs))
end_time = time.perf_counter()
elapsed_time = end_time - start_time
return result, flops, flops / elapsed_time, elapsed_time
else:
# On GPUs this will be much more efficient
def measure_flops(func, *args, **kwargs):
try:
flops = func.cost_analysis()[0]["flops"]
except Exception:
flops = 1
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
return result, flops, flops / elapsed_time, elapsed_time


@partial(
Expand Down
1 change: 1 addition & 0 deletions easydel/inference/vinference/vinference.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def generate(
interval_func_flops = np.mean(all_interval_func_flops)
state.generate_func_flops = generate_func_flops
state.interval_func_flops = interval_func_flops

state.tokens_pre_second = state.generated_tokens / interval_time
yield state
if state.is_sequence_finished:
Expand Down
30 changes: 19 additions & 11 deletions tests/vinference_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import sys

import jax
import torch

os.environ["JAX_TRACEBACK_FILTERING"] = "off"
os.environ["EASYDEL_AUTO"] = "true"

Expand All @@ -18,31 +21,35 @@

def main():
sharding_axis_dims = (1, 1, 1, -1)
max_length = 6144
pretrained_model_name_or_path = "meta-llama/Llama-3.2-1B-Instruct"
dtype = jnp.float16
max_length = 8192
pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"
partition_axis = ed.PartitionAxis()
dtype = jnp.bfloat16
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
auto_shard_model=True,
sharding_axis_dims=sharding_axis_dims,
config_kwargs=ed.EasyDeLBaseConfigDict(
use_scan_mlp=False,
partition_axis=partition_axis,
attn_dtype=jnp.float16,
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
attn_dtype=dtype,
gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
attn_mechanism=ed.AttentionMechanisms.VANILLA,
),
quantization_method=ed.EasyDeLQuantizationMethods.NONE,
platform=ed.EasyDeLPlatforms.TRITON,
partition_axis=partition_axis,
param_dtype=dtype,
param_dtype=dtype, # jnp.float8_e5m2,
dtype=dtype,
torch_dtype=torch.float16,
partition_axis=partition_axis,
precision=jax.lax.Precision("fastest"),
)

model.eval()
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id

inference = ed.vInference(
model=model,
processor_class=tokenizer,
Expand All @@ -52,9 +59,10 @@ def main():
top_p=model.generation_config.top_p,
top_k=model.generation_config.top_k,
eos_token_id=model.generation_config.eos_token_id,
streaming_chunks=32,
pad_token_id=model.generation_config.pad_token_id,
bos_token_id=model.generation_config.bos_token_id,
streaming_chunks=64,
),
inference_name="llama3-1B",
)

inference.precompile()
Expand Down
2 changes: 1 addition & 1 deletion tests/vinference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def log_mem():

def main():
sharding_axis_dims = (1, 1, 1, -1)
max_length = 4096
max_length = 8192

pretrained_model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"
# pretrained_model_name_or_path = "AntonV/mamba2-370m-hf"
Expand Down

0 comments on commit 9e95b25

Please sign in to comment.