From 692c1958f071c39fdecbc665296fbca6a6a897fd Mon Sep 17 00:00:00 2001 From: jaydeepborkar Date: Fri, 25 Aug 2023 19:06:36 +0000 Subject: [PATCH] parallelize entropy --- inference.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/inference.py b/inference.py index 4701229..b82716a 100644 --- a/inference.py +++ b/inference.py @@ -15,6 +15,7 @@ from datetime import datetime import pandas as pd import numpy as np +import multiprocessing import torch import os @@ -178,6 +179,7 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l pile_dataset = PileDataset(pile_sequences, tokenizer) batch_size = get_batch_size(split_name) data_loader = DataLoader(pile_dataset, batch_size=batch_size) + pythia_model.to(device) with torch.no_grad(): desc = f"Collecting {dataset} inference responses for {split_name}" @@ -215,9 +217,17 @@ def gini(array): return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array))) +def process_attention_head(args): + e=1e-8 + head_index, head = args + attention_head = head.detach().cpu().numpy() + attention_head += e + head_entropy = -np.sum(attention_head * np.log(attention_head)) + return head_index, head_entropy + + def accumilate_inference_log( - batch_sequence_ids: list, labels: torch.Tensor, outputs: CausalLMOutputWithPast, features: list -): + batch_sequence_ids: list, labels: torch.Tensor, outputs: CausalLMOutputWithPast, features: list): """ Extract the desired data from the model response and save it to a CSV file. @@ -244,18 +254,30 @@ def accumilate_inference_log( inference_log["sequence_perplexity"] = perplexities[index][2] if "attn" in features: for layer_index, attention_layer in enumerate(outputs.attentions): - sequence_attention = attention_layer[index].detach() + sequence_attention = attention_layer[index].detach() head_e = [] gini_head = [] for head_index, head in enumerate(sequence_attention): attention_head = head.detach().cpu().numpy() - attention_head += e #adding 'e' to attention weights that are 0 to avoid log zero error while calculating entropy. Entropy = - ∑(w * log(w)) + attention_head += e gini_coefficient = gini(attention_head) gini_head.append(gini_coefficient) - head_entropy = -np.sum(attention_head * np.log(attention_head)) - head_e.append(head_entropy) + #head_entropy = -np.sum(attention_head * np.log(attention_head)) + #head_e.append(head_entropy) inference_log[f"gini_head{head_index+1}_layer{layer_index+1}"] = gini_coefficient + #inference_log[f"entropy_head{head_index+1}_layer{layer_index+1}"] = head_entropy + + attention_head_args = [(head_index, head) for head_index, head in enumerate(sequence_attention)] + num_processes = multiprocessing.cpu_count() + ctx = multiprocessing.get_context('spawn') + pool = ctx.Pool(processes=num_processes) + head_results = pool.map(process_attention_head, attention_head_args) + pool.close() + pool.join() + + for head_index, head_entropy in head_results: + head_e.append(head_entropy) inference_log[f"entropy_head{head_index+1}_layer{layer_index+1}"] = head_entropy avg_head = np.mean(head_e) @@ -269,7 +291,7 @@ def accumilate_inference_log( inference_log[f"avg gini"] = average_gini inference_logs.append(inference_log) - + return inference_logs def save_inference_log(split_name: str, run_id: str, dataset: pd.DataFrame, inference_logs: list):