From 33225ff7383f0556de3daf42c92d6b47740e037c Mon Sep 17 00:00:00 2001 From: Taha YASSINE Date: Tue, 17 Dec 2024 22:10:30 +0100 Subject: [PATCH] Update caching to work with FA2 --- .devcontainer/devcontainer.json | 46 +++++++++++ requirements.txt | 7 ++ sae_auto_interp/features/cache.py | 8 +- sae_auto_interp/flash_attn.py | 122 ++++++++++++++++++++++++++++++ 4 files changed, 182 insertions(+), 1 deletion(-) create mode 100644 .devcontainer/devcontainer.json create mode 100644 requirements.txt create mode 100644 sae_auto_interp/flash_attn.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..23904d4f --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,46 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile +{ + "name": "SAE dev", + "build": { + // Sets the run context to one level up instead of the .devcontainer folder. + "context": "../..", + // Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename. + "dockerfile": "../../Dockerfile", + "target": "gpu" + }, + + + "runArgs": [ + "--gpus=all", + "--shm-size=8g" + ], + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Uncomment the next line to run commands after the container is created. + // "postCreateCommand": "cat /etc/os-release", + + "mounts": [ + "source=/home/tyassine/.cache/huggingface,target=/root/.cache/huggingface,type=bind" + ], + + // Configure tool-specific properties. + "customizations": { + "vscode": { + "extensions": [ + "ms-python.python", + "ms-python.vscode-pylance", + "ms-toolsai.jupyter", + "mhutchie.git-graph" + ] + } + } + + // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "devcontainer" +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..abe8b27b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +torch +datasets +# flash-attn --no-build-isolation +nnsight +setuptools +ipykernel +git+https://github.com/taha-yassine/transformers.git@patch_fa2 \ No newline at end of file diff --git a/sae_auto_interp/features/cache.py b/sae_auto_interp/features/cache.py index 400f45ef..3dafc278 100644 --- a/sae_auto_interp/features/cache.py +++ b/sae_auto_interp/features/cache.py @@ -154,6 +154,11 @@ def __init__( batch_size (int): Size of batches for processing. filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features. """ + + # Model must use FA2 to allow for efficient packing + if not hasattr(model.config, "_attn_implementation") or model.config._attn_implementation != "flash_attention_2": + raise ValueError("Model must use FlashAttention-2. Please enable it before initializing FeatureCache.") + self.model = model self.submodule_dict = submodule_dict @@ -224,7 +229,8 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): with torch.no_grad(): buffer = {} - with self.model.trace(batch): + # position_ids is required for FA2 + with self.model.trace({"input_ids": batch["input_ids"]}, position_ids=batch["position_ids"]): for module_path, submodule in self.submodule_dict.items(): buffer[module_path] = submodule.ae.output.save() for module_path, latents in buffer.items(): diff --git a/sae_auto_interp/flash_attn.py b/sae_auto_interp/flash_attn.py new file mode 100644 index 00000000..13157f4c --- /dev/null +++ b/sae_auto_interp/flash_attn.py @@ -0,0 +1,122 @@ +# %% Imports +import torch +from datasets import load_dataset +from transformers import AutoTokenizer, AutoModelForCausalLM, default_data_collator +from typing import Dict, List, Tuple +import time + +# %% Functions +def data_collator(features: List[Dict], return_tensors: str = "pt"): + batch = {"input_ids": [], "position_ids": []} + for x in features["input_ids"]: + batch["input_ids"] += x + batch["position_ids"] += list(range(len(x))) + + return default_data_collator([batch], return_tensors=return_tensors) + +def prepare_packed_dataset( + texts: List[str], + tokenizer: AutoTokenizer, +) -> Tuple[torch.Tensor, torch.Tensor, List[Dict]]: + """ + Prepare a packed dataset using continuous batching without padding. + """ + # Tokenize all texts + tokenized = tokenizer( + texts, + add_special_tokens=False, + return_attention_mask=False, + ) + + # Use collator to flatten and pack sequences + packed = data_collator(tokenized, return_tensors="pt") + + return packed + +def prepare_padded_dataset( + texts: List[str], + tokenizer: AutoTokenizer, + max_seq_length: int = 2048 +) -> Dict[str, torch.Tensor]: + """ + Prepare a dataset using traditional padding. + """ + return tokenizer( + texts, + padding=True, + add_special_tokens=False, + max_length=max_seq_length, + return_tensors="pt" + ) + +# %% Load model and tokenizer +model_name = "EleutherAI/pythia-70m" # Small model for testing +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + device_map="auto", +) + +# %% Load and prepare datasets +dataset = load_dataset("Open-Orca/FLAN", split="train", streaming=True).take(100) + +dataset_padded = dataset.map(lambda x: prepare_padded_dataset(x["inputs"], tokenizer), batched=True, batch_size=10) + +dataset_packed = dataset.map(lambda x: prepare_packed_dataset(x["inputs"], tokenizer), batched=True, batch_size=10, remove_columns=dataset.column_names) + +# %% Test traditional padding approach +start_time = time.time() + +with torch.no_grad(): + for batch in dataset_padded: + input_ids = batch["input_ids"].to(model.device) + attention_mask = batch["attention_mask"].to(model.device) + padded_output = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) +padding_time = time.time() - start_time + +# %% Test sequence packing approach +start_time = time.time() + +# Process in chunks that fit the model's context window + +with torch.no_grad(): + for batch in dataset_packed: + input_ids = batch["input_ids"].to(model.device) + position_ids = batch["position_ids"].to(model.device) + packed_output = model( + input_ids=input_ids.unsqueeze(0), + position_ids=position_ids.unsqueeze(0), + ) +packing_time = time.time() - start_time + +# %% +print(f"Traditional padding processing time: {padding_time:.2f} seconds") +print(f"Sequence packing processing time: {packing_time:.2f} seconds") +print(f"Speedup: {padding_time/packing_time:.2f}x") + +# %% +from nnsight import LanguageModel +nnsight_model = LanguageModel( + model_name, + torch_dtype=torch.float16, + device_map='cuda:0', + attn_implementation='flash_attention_2' +) +# print(nnsight_model) + +nnsight_model.tokenizer = tokenizer + +batch = next(iter(dataset_packed)) + +with nnsight_model.trace({"input_ids": batch["input_ids"]}, position_ids=batch["position_ids"].to(model.device).unsqueeze(0)): + logits = nnsight_model.embed_out.output.save() + +print(logits.value) + +# %%