Skip to content

Commit

Permalink
refacto generate + use simpler rotary for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
3outeille committed Jun 25, 2024
1 parent ee785d6 commit f33e818
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 139 deletions.
216 changes: 139 additions & 77 deletions run_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
torchrun --nproc_per_node=4 run_generate.py ---ckpt-path checkpoints/test/4
```
"""

import argparse
import os
from pathlib import Path

import torch
Expand All @@ -21,20 +19,21 @@
ParallelismArgs,
get_config_from_file,
)
from nanotron.distributed import get_global_rank
from nanotron.generation.decode import (
GenerationInput,
TokenizerConfig,
decode_text,
decode_tokenized,
GenerationInputs,
GenerationStates,
run_one_inference_step,
)
from nanotron.generation.generate_store import Store
from nanotron.generation.sampler import BasicSampler, GreedySampler, SamplerType, TopKSampler, TopPSampler
from nanotron.logging import log_rank, set_ranks_logging_level
from nanotron.models import build_model
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import sanity_check
from nanotron.parallel.pipeline_parallel.engine import (
OneForwardOneBackwardPipelineEngine,
)
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.random import (
RandomStates,
Expand All @@ -50,16 +49,18 @@
except ImportError:
AutoTokenizer = None


logger = logging.get_logger(__name__)


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path")
parser.add_argument("--dp", type=int, default=1)
parser.add_argument("--pp", type=int, default=0)
parser.add_argument("--tp", type=int, default=0)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate")
parser.add_argument("--use-cache", action="store_true", help="Use cache for generation")
return parser.parse_args()


Expand All @@ -73,9 +74,9 @@ def main():
tokenizer_path = config.tokenizer.tokenizer_name_or_path

parallel_config = ParallelismArgs(
dp=args.dp or config.parallelism.dp,
pp=args.pp or config.parallelism.pp,
tp=args.tp or config.parallelism.tp,
dp=args.dp,
pp=args.pp,
tp=args.tp,
pp_engine=OneForwardOneBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=False,
Expand Down Expand Up @@ -163,86 +164,147 @@ def main():
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left" # TODO @nouamane: do we want this?

dummy_inputs = [
"The future of AI is",
# "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:",
"def fib(n)",
# "def fib(n)",
# 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.',
# "Advancements in technology will lead to",
# "Tomorrow's world is shaped by",
]

outputs = decode_text(
input_iter=(GenerationInput(text=text) for text in dummy_inputs),
tokenizer=tokenizer,
# TODO @thomasw21: From ModelWithLoss extract the model.
model=model.model,
parallel_context=parallel_context,
max_new_tokens=args.max_new_tokens,
max_micro_batch_size=2,
generation_config=GenerationArgs(sampler="greedy", use_cache=True),
tokenizer_config=TokenizerConfig(max_input_length=None),
is_bench=os.environ.get("USE_BENCH", "0") == "1",
log_rank(f"Using cache for generation: {args.use_cache}", logger=logger, level=logging.INFO, rank=0)

# NOTE: This doesn't support micro-batches and batch inference
device = torch.cuda.current_device()
generation_config = GenerationArgs(sampler="greedy", use_cache=args.use_cache)
logits_are_batch_first = True

if generation_config:
if isinstance(generation_config.sampler, str):
sampler_type = SamplerType(generation_config.sampler.upper())
else:
sampler_type = generation_config.sampler
else:
sampler_type = SamplerType.GREEDY

tokenized_prompts = tokenizer(
dummy_inputs,
return_tensors="pt",
return_attention_mask=True,
padding=True,
)
for output in outputs:
input_ids = output.input_ids
generated_ids = output.generation_ids
if isinstance(input_ids, TensorPointer):
assert isinstance(generated_ids, TensorPointer)
continue
assert isinstance(generated_ids, torch.Tensor)

log_rank(
f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}",
logger=logger,
level=logging.INFO,
rank=0,
)
tokenized_prompts["input_ids"] = tokenized_prompts["input_ids"].to(device)
tokenized_prompts["attention_mask"] = tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device)

store = Store()
batch_prompts = None

for i in range(args.max_new_tokens):

if generation_config.use_cache:
# Prepare the batch prompts
batch_prompts = GenerationStates(
new_input_ids=tokenized_prompts["input_ids"]
if i == 0
else tokenized_prompts["input_ids"][:, -1].unsqueeze(0),
new_input_mask=tokenized_prompts["attention_mask"]
if i == 0
else tokenized_prompts["attention_mask"][:, -1].unsqueeze(0),
store=store,
generation_ids=tokenized_prompts["input_ids"],
generation_mask=tokenized_prompts["attention_mask"],
)
else:
batch_prompts = GenerationInputs(
input_ids=tokenized_prompts["input_ids"],
input_masks=tokenized_prompts["attention_mask"],
)

log_rank(
f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}",
logger=logger,
level=logging.INFO,
rank=0,
logits = run_one_inference_step(
model, batch_prompts, parallel_context, device, use_cache=generation_config.use_cache, store=store
)

log_rank(
"--------------------------------------------------",
logger=logger,
level=logging.INFO,
rank=0,
# Sample new token
if parallel_context.is_pipeline_last_stage:
assert logits is not None and isinstance(logits, torch.Tensor)

# Get sampler
if sampler_type == SamplerType.GREEDY:
sampler = GreedySampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.TOP_K:
sampler = TopKSampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.TOP_P:
sampler = TopPSampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.BASIC:
sampler = BasicSampler(pg=parallel_context.tp_pg)
else:
raise NotImplementedError(f"Sampler type {sampler_type} is not implemented")

if logits_are_batch_first:
logits = logits.transpose(0, 1)

# Predict next token
next_token = sampler(sharded_logits=logits[:, -1])

# Extend the tokenized prompts to insert the new token
tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token], dim=-1)
tokenized_prompts["attention_mask"] = torch.cat(
[
tokenized_prompts["attention_mask"],
torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.bool, device=device),
],
dim=-1,
)
else:
# Extend the tokenized prompts to receive the new token
tokenized_prompts["input_ids"] = torch.zeros(
(tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1),
dtype=torch.int64,
device=device,
)
tokenized_prompts["attention_mask"] = torch.zeros(
(
tokenized_prompts["attention_mask"].shape[0],
tokenized_prompts["attention_mask"].shape[1] + 1,
),
dtype=torch.bool,
device=device,
)

# Broadcast the new token to all the pipeline stages
dist.broadcast(
tokenized_prompts["input_ids"],
src=get_global_rank(
group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank
),
group=parallel_context.pp_pg,
)
else:
outputs = decode_tokenized(
input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"),
input_mask=torch.ones(1, 1).to(dtype=torch.bool, device="cuda"),
model=model.model,
parallel_context=parallel_context,
generation_config=GenerationArgs(sampler="greedy", use_cache=True),
max_micro_batch_size=1,
max_new_tokens=12,
returns_logits=False,
)
for output in outputs:
input_ids = output.input_ids
generated_ids = output.generation_ids
if isinstance(input_ids, TensorPointer):
assert isinstance(generated_ids, TensorPointer)
continue
assert isinstance(generated_ids, torch.Tensor)
log_rank(
f"generation: {generated_ids[len(input_ids) :]}",
logger=logger,
level=logging.INFO,
rank=0,
dist.broadcast(
tokenized_prompts["attention_mask"],
src=get_global_rank(
group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank
),
group=parallel_context.pp_pg,
)

log_rank(
"--------------------------------------------------",
logger=logger,
level=logging.INFO,
rank=0,
)
# Decode the generated text
if dist.get_rank() == 0:
for i, prompt in enumerate(dummy_inputs):
if generation_config.use_cache:
tokenized_outputs = torch.cat(
[tokens.view(1, -1) for tokens in batch_prompts.generation_ids], dim=1
)
outputs = tokenizer.decode(tokenized_outputs[0], clean_up_tokenization_spaces=False)
else:
tokenized_outputs = tokenized_prompts["input_ids"][
i, tokenized_prompts["input_ids"].shape[1] - args.max_new_tokens :
]
outputs = tokenizer.decode(tokenized_outputs, clean_up_tokenization_spaces=False)

log_rank(f"Input: {prompt}", logger=logger, level=logging.INFO, rank=0)
log_rank(f"Output: {outputs}", logger=logger, level=logging.INFO, rank=0)

dist.barrier()

Expand Down
3 changes: 3 additions & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class LlamaConfig:
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
True # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
Expand Down
61 changes: 61 additions & 0 deletions src/nanotron/generation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,67 @@ def generator():
)


@torch.inference_mode()
def run_one_inference_step(model, batch, parallel_context, device, use_cache, store):
if dist.get_world_size(group=parallel_context.pp_pg) == 1:
if use_cache:
with attach_store(model=model, store=store):
return model.model(batch.new_input_ids, batch.new_input_mask)
return model.model(batch.input_ids, batch.input_masks)

pipeline_state = PipelineEvalBatchState()
with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state):
batch_size = batch.new_input_ids.shape[0] if use_cache else batch.input_ids.shape[0]
seq_len = batch.new_input_ids.shape[1] if use_cache else batch.input_ids.shape[1]

# Preallocate memory for output logits.
logits = None
if parallel_context.is_pipeline_last_stage:
logits = torch.empty((seq_len, batch_size, model.config.vocab_size), dtype=torch.float32, device=device)

if use_cache:
batch2use = GenerationStates(
new_input_ids=batch.new_input_ids
if parallel_context.is_pipeline_first_stage
else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank),
new_input_mask=batch.new_input_mask
if parallel_context.is_pipeline_first_stage
else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank),
store=store,
generation_ids=batch.generation_ids,
generation_mask=batch.generation_mask,
)
with attach_store(model=model, store=store):
output_tensor = model.model(batch2use.new_input_ids, batch2use.new_input_mask)
else:
batch2use = GenerationInputs(
input_ids=batch.input_ids
if parallel_context.is_pipeline_first_stage
else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank),
input_masks=batch.input_masks
if parallel_context.is_pipeline_first_stage
else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank),
)

output_tensor = model.model(batch2use.input_ids, batch2use.input_masks)

nb_send = len(pipeline_state.microbatches_activations_to_send)
assert nb_send <= 2
for _ in range(nb_send):
# Send activations to the next stage
# Send attention_mask to the next stage
pipeline_state.run_communication()

# Copy logits.
if parallel_context.is_pipeline_last_stage:
logits = output_tensor

# Wait for all the communication to complete.
dist.barrier(group=parallel_context.world_pg)

return logits


# Distributed utilities
def broadcast_tensors(
tensors: List[Union[torch.Tensor, TensorPointer]], group_src: int, group: Optional[ProcessGroup] = None
Expand Down
Loading

0 comments on commit f33e818

Please sign in to comment.