Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix caching #57

Draft
wants to merge 40 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f6b64b8
prefix trie
IanMagnusson Jul 8, 2022
f1b6534
prefix trie
IanMagnusson Jul 8, 2022
ab2314c
generalized caching
IanMagnusson Jul 15, 2022
d2c5af8
Merge branch 'main' into prefix-caching
IanMagnusson Jul 15, 2022
235e63c
update changelog
IanMagnusson Jul 15, 2022
2c86554
fix type hints
IanMagnusson Jul 15, 2022
fef986d
fix type hints
IanMagnusson Jul 15, 2022
bb150b5
fix type hints
IanMagnusson Jul 15, 2022
2b993e8
relax min prefix requirement
IanMagnusson Jul 22, 2022
7138a13
fix likelihood averaging
IanMagnusson Jul 22, 2022
7a4ddd2
fix cached transformer override_weights_file
IanMagnusson Jul 23, 2022
0c5dcb1
optional random_subsample_seed for PredictStep
IanMagnusson Jul 25, 2022
39b4f07
update changelog
IanMagnusson Jul 25, 2022
9e00fbc
allow different metric averaging
IanMagnusson Jul 25, 2022
054847f
expose override_weights_file in RC models
IanMagnusson Jul 25, 2022
d19deb4
expose override_weights_file in MetaICLModel
IanMagnusson Jul 25, 2022
f4057ad
Merge branch 'main' into prefix-caching
IanMagnusson Jul 25, 2022
5dd423d
oops deleted import by accident
IanMagnusson Jul 25, 2022
8fbf359
move cache data out of model attributes
IanMagnusson Jul 25, 2022
da2c641
Merge branch 'main' into prefix-caching
IanMagnusson Jul 25, 2022
451ad34
use consistent arg order
IanMagnusson Jul 25, 2022
c542493
wrestle with mypy
IanMagnusson Jul 25, 2022
5052e0f
docs and better names
IanMagnusson Jul 26, 2022
e584320
test for PrefixTrie
IanMagnusson Jul 26, 2022
8c02abe
handle full overlap corner case
IanMagnusson Jul 26, 2022
c7b2b7e
Merge branch 'fix-rc-likelihood-avg' into prefix-caching
IanMagnusson Jul 26, 2022
d0caa5a
Merge branch 'main' into prefix-caching
IanMagnusson Jul 26, 2022
3b8a370
Merge branch 'add-ia3' into prefix-caching
IanMagnusson Jul 28, 2022
f7a84d0
Merge branch 'main' into prefix-caching
IanMagnusson Jul 28, 2022
827b350
expose prefix_caching in ia3
IanMagnusson Jul 28, 2022
4e98c01
batch processing of cached prefixes
IanMagnusson Jul 29, 2022
796f447
Merge branch 'expand-metaicl' into prefix-caching
IanMagnusson Aug 9, 2022
fda8433
gpu logit processing for rc models
IanMagnusson Aug 19, 2022
4acdd95
add 2 new tasks
IanMagnusson Aug 19, 2022
975ade2
Merge branch 'main' into prefix-caching
IanMagnusson Aug 19, 2022
6c4f69e
Merge branch 'faster-rc-logit-processing' into prefix-caching
IanMagnusson Aug 19, 2022
3579282
clean up bad merge
IanMagnusson Aug 19, 2022
50220c1
Merge branch 'metaicl-race-high-and-numer-sense' into prefix-caching
IanMagnusson Aug 19, 2022
1cd38f4
Merge branch 'main' into prefix-caching
IanMagnusson Aug 30, 2022
3cab644
An example usage of prefix_caching
IanMagnusson Aug 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- MetaICLTask now supports fewshots less than 16 and only support getting the test split
- set default logging level to `"WARNING"` instead of `"ERROR"` when invoking `python -m catwalk`
- changed MetaICLModel formatting to always preserve whitespace, to reproduce MetaICL results
- improved speed of rank classification models by aggregating sequence logits on GPU rather than on CPU

### Added

Expand All @@ -23,8 +24,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Adds a new MetaICLModel that replicates the formatting and truncation used by MetaICL for few shot evaluation
- Optional `random_subsample_seed` for PredictStep
- An option for rank classification to average log likelihoods by token length
- Adds support for inference with IA3 adapters loaded from a file on decoder only ranked classification models
- Adds support for inference with IA3 adaptors loaded from a file on decoder only ranked classification models
- Added the ability to train `HFAutoModel`
- Add support for MetaICL's race-high and numer_sense tasks
- Prefix caching for DecoderOnlyRCModel that reuses overlapping prefixes between instances rather than recomputing them

### Fixed

Expand Down
16 changes: 14 additions & 2 deletions catwalk/models/ia3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@

class DecoderOnlyIA3Mixin:
@classmethod
def _make_model(self, pretrained_model_name_or_path: str, *, ia3_weights_file: str = None, **kwargs) -> GPT2LMHeadModel:
model = cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, True)
def _make_model(

self,
pretrained_model_name_or_path: str,
*,
override_weights_file: str = None,
ia3_weights_file: str = None,
**kwargs
) -> GPT2LMHeadModel:
model = cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, True, override_weights_file=override_weights_file)
isinstance(model, GPT2LMHeadModel)
config = IA3ForGPT2Config()
model = modify_with_ia3(model, config)
Expand All @@ -27,6 +35,8 @@ def __init__(
pretrained_model_name_or_path: str,
*,
likelihood_averaging: str = 'token',
override_weights_file: str = None,
prefix_caching: bool = True,
max_length_per_example: int = 256,
continuation_seperator: str = '\n',
example_seperator: str = '\n\n\n',
Expand All @@ -36,6 +46,8 @@ def __init__(
super().__init__(
pretrained_model_name_or_path,
likelihood_averaging=likelihood_averaging,
override_weights_file=override_weights_file,
prefix_caching=prefix_caching,
max_length_per_example=max_length_per_example,
continuation_seperator=continuation_seperator,
example_seperator=example_seperator,
Expand Down
4 changes: 4 additions & 0 deletions catwalk/models/metaicl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(
pretrained_model_name_or_path: str,
*,
likelihood_averaging: str = 'token',
override_weights_file: str = None,
prefix_caching: bool = True,
max_length_per_example: int = 256,
continuation_seperator: str = '\n',
example_seperator: str = '\n\n\n',
Expand All @@ -25,6 +27,8 @@ def __init__(
super().__init__(
pretrained_model_name_or_path,
likelihood_averaging=likelihood_averaging,
override_weights_file=override_weights_file,
prefix_caching=prefix_caching,
**model_kwargs
)
self.max_length_per_example = max_length_per_example
Expand Down
293 changes: 253 additions & 40 deletions catwalk/models/rank_classification.py

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions catwalk/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@
"squad_metrics": torchmetrics.SQuAD,
}


def classification_metrics(num_classes: int):
def classification_metrics(num_classes: int, *, average = None):
return {
"acc": torchmetrics.Accuracy,
"f1": partial(torchmetrics.F1Score, num_classes=num_classes, average=None),
"precision": partial(torchmetrics.Precision, num_classes=num_classes, average=None),
"recall": partial(torchmetrics.Recall, num_classes=num_classes, average=None)
"f1": partial(torchmetrics.F1Score, num_classes=num_classes, average=average),
"precision": partial(torchmetrics.Precision, num_classes=num_classes, average=average),
"recall": partial(torchmetrics.Recall, num_classes=num_classes, average=average)
}


Expand Down
5 changes: 4 additions & 1 deletion catwalk/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,11 @@
"metaicl::unifiedqa:openbookqa_with_ir": MetaICLTask("unifiedqa:openbookqa_with_ir").add_metrics(MC_METRICS),
"metaicl::unifiedqa:mctest": MetaICLTask("unifiedqa:mctest").add_metrics(MC_METRICS),
"metaicl::unifiedqa:ai2_science_middle": MetaICLTask("unifiedqa:ai2_science_middle").add_metrics(MC_METRICS),

"metaicl::commonsense_qa": MetaICLTask("commonsense_qa").add_metrics(MC_METRICS),

"metaicl::numer_sense": MetaICLTask("numer_sense").add_metrics(classification_metrics(12)),
"metaicl::race-high": MetaICLTask("race-high").add_metrics(MC_METRICS),
}

for config in datasets.get_dataset_config_names("bigscience/P3"):
Expand Down
1 change: 1 addition & 0 deletions catwalk/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from catwalk.utils.prefix_trie import PrefixTrie
85 changes: 85 additions & 0 deletions catwalk/utils/prefix_trie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import List, Optional, Sequence, Tuple, Dict
from tango.common import Tqdm

class PrefixTrie():
def __init__(self, sequences: Sequence[Sequence[int]], track_after_depth: int = 10):
"""
Returns a PrefixTrie for ordering examples by common prefixes

# Parameters

sequences : `Sequence[Sequence[int]]`
Sequences of tokens to add to the add to the Trie
track_after_depth : `int`
Only record sequence indices in nodes at or below this depth. This allows distinct
sequences that coincidentally start with the first few tokens as another sequence
not to be dropped from this barely overlapping prefix. Sequences shorter than the
minimum depth will only have their index recorded in their final node.
"""
self.root = PrefixTrieNode()
self.track_after_depth = track_after_depth
self.nodes: List['PrefixTrieNode'] = []
for i, sequence in Tqdm.tqdm(enumerate(sequences), desc="Building PrefixTrie for caching", total=len(sequences)):
self._add_sequence(sequence=sequence, index=i)
# only need to track sequences at forks and terminations
for node in self.nodes:
if len(node.children) == 1:
node.subsequences_on_this_path = node.subsequences_ending_here

def _add_sequence(self, sequence: Sequence[int], index: int):
seq_len = len(sequence)
current_node = self.root
for token_idx, token in enumerate(sequence):
if token not in current_node.children:
current_node.children[token] = PrefixTrieNode(parent=current_node, token=token)
self.nodes.append(current_node.children[token])
current_node = current_node.children[token]
if (token_idx + 1 >= self.track_after_depth) or (token_idx + 1 >= seq_len):
current_node.subsequences_on_this_path[index] = token_idx + 1
current_node.subsequences_ending_here[index] = len(sequence)

def get_leaf_nodes(self) -> List['PrefixTrieNode']:
return [node for node in self.nodes if len(node.children) == 0]

class PrefixTrieNode():
def __init__(self, parent: 'PrefixTrieNode' = None, token: int = None):
self.parent = parent
self.token = token
self.subsequences_on_this_path: Dict[int,int] = {}
self.subsequences_ending_here: Dict[int,int] = {}
self.lengths_covered: List[int] = []
self.children: Dict[int,'PrefixTrieNode'] = {}

def get_sequence(self) -> List[Optional[int]]:
"""Returns the sequence associated with a node"""
current_node = self
sequence = []
while current_node.parent is not None:
sequence.append(current_node.token)
current_node = current_node.parent
return sequence[::-1]

def get_subsequences(self) -> Tuple[List[int], int]:
"""
Returns a tuple of:
- a list of all indices for subsequences of the current node including itself
starting with longest and decreasing
- an int, the total number of tokens covered in all subsequences by this prefix

Note when a PrefixTrie with track_after_depth > 0, some subsequences will be intentionally
ignored here as their indices are not registered in low depth nodes.
"""
current_node = self
indices = []
already_found = set()
total_lengths_covered = 0
while current_node.parent is not None:
new_indices = []
for index, length_covered in current_node.subsequences_on_this_path.items():
if index not in already_found:
new_indices.append(index)
total_lengths_covered += length_covered
already_found.update(new_indices)
indices.extend(new_indices)
current_node = current_node.parent
return indices, total_lengths_covered
86 changes: 86 additions & 0 deletions experiments/prefix_cache_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse
import os
from tango.common.logging import initialize_logging
import time

from catwalk.models import MetaICLModel
from catwalk.steps import CalculateMetricsStep, PredictStep
from catwalk.tasks import TASK_SETS

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--zeroshot', action='store_true')
parser.add_argument('--no_prefix_caching', action='store_true')
parser.add_argument('--first_n_tasks', type=int, default=20)
args = parser.parse_args()

start = time.time()
initialize_logging(log_level="ERROR")
os.environ['TOKENIZERS_PARALLELISM'] = "false"

tasks = TASK_SETS['metaicl-classification-eval']
tasks = sorted(tasks)[:args.first_n_tasks]

num_shots = 0 if args.zeroshot else 16
if args.zeroshot:
batch_size = 64
elif args.no_prefix_caching:
batch_size = 16 # to account for larger input sizes with ICL
# CACHING with batching does not work close to the max model size as
# the largest prefix + largest continuation in a batch must be <= max model size
else:
batch_size = 1
limit = 1000
random_subsample_seed=42
seeds = [100] if args.zeroshot else [100, 13, 21, 42, 87]

model = MetaICLModel('gpt2-large', continuation_seperator = ' ' if args.zeroshot else '\n', prefix_caching = not args.no_prefix_caching)

seed2metrics = {}
for fewshot_seed in seeds:
metric_task_dict = {}
for task in tasks:

predictions = PredictStep(
model=model,
task=task,
batch_size=batch_size,
limit=limit,
random_subsample_seed=random_subsample_seed,
num_shots=num_shots,
fewshot_seed=fewshot_seed,
)
metrics = CalculateMetricsStep(
model=model,
task=task,
predictions=predictions)
metric_task_dict[task] = metrics
seed2metrics[fewshot_seed] = metric_task_dict

avg_f1_per_seed = []
avg_acc_per_seed = []
for seed, metric_task_dict in seed2metrics.items():
total_sum_f1 = 0.0
total_sum_acc = 0.0
for task, metrics in metric_task_dict.items():
for metric, result in metrics.result().items():
avg_result = result.mean()
if metric == 'f1':
total_sum_f1 += avg_result.item()
elif metric == 'acc':
total_sum_acc += avg_result.item()
print(f"{task}\t{seed}\t{metric}\t{avg_result}")
avg_f1_per_seed.append(total_sum_f1 / len(tasks))
avg_acc_per_seed.append(total_sum_acc / len(tasks))

print(f"avg macro f1 over seeds {sum(avg_f1_per_seed) / len(seeds)}")
print(f"min macro f1 over seeds {min(avg_f1_per_seed)}")
print(f"avg macro acc over seeds {sum(avg_acc_per_seed) / len(seeds)}")
print(f"min macro acc over seeds {min(avg_acc_per_seed)}")

end = time.time()
print(f"total seconds elapsed: {end - start}")

if __name__ == "__main__":

main()
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from catwalk.utils import PrefixTrie

def test_prefix_trie():
sequences = [[1,2,3],[2,3,4],[1,2,3,4]]
trie = PrefixTrie(sequences, track_after_depth=1)
leaves = trie.get_leaf_nodes()
assert leaves[0].get_sequence() == [2,3,4]
assert leaves[1].get_sequence() == [1,2,3,4]
assert leaves[0].get_subsequences() == ([1], 3)
assert leaves[1].get_subsequences() == ([2,0], 7)