Skip to content

Commit

Permalink
multi-proc on ts-guessing ok
Browse files Browse the repository at this point in the history
  • Loading branch information
Ravoxsg committed Aug 13, 2024
1 parent 8632fb6 commit 8fb4437
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 8 deletions.
13 changes: 10 additions & 3 deletions llmsanitize/closed_data_methods/ts_guessing_question_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
"""

import os
os.environ['CLASSPATH']="/data/mathieu/stanford-postagger-full-2020-11-17/stanford-postagger.jar"
os.environ["STANFORD_MODELS"]="/data/mathieu/stanford-postagger-full-2020-11-17/models"
import numpy as np
from tqdm import tqdm
from rouge_score import rouge_scorer
from nltk.tokenize import word_tokenize
from nltk.tag import StanfordPOSTagger
from functools import partial
from datasets import Dataset

from llmsanitize.utils.logger import get_child_logger, suspend_logging
from llmsanitize.utils.dataset_utils import get_answers_list
Expand Down Expand Up @@ -88,14 +91,14 @@ def inference(
tagger = get_stanford_tagger()

prompt, masked_word = build_prompt(
example,
data_point,
tagger,
eval_data_name,
type_hint,
category_hint,
url_hint
)
data_point["masked_wor"] = masked_word
data_point["masked_word"] = masked_word
if prompt == "failed":
data_point["response"] = "failed"
else:
Expand Down Expand Up @@ -188,6 +191,7 @@ def main_ts_guessing_question_based(
if n_eval_data_points > 0:
p = np.random.permutation(len(data_points))
data_points = [data_points[x] for x in p]
data_points = Dataset.from_list(data_points)

llm = LLM(
local_model_path=local_model_path,
Expand Down Expand Up @@ -216,7 +220,10 @@ def main_ts_guessing_question_based(
url_hint=False
)

ts_guessing_results = eval_data.map(process_fn, num_proc=num_proc)
ts_guessing_results = data_points.map(process_fn, num_proc=num_proc)
ts_guessing_results = [x for x in ts_guessing_results if x["response"] != "failed"]
ts_guessing_results = ts_guessing_results[:n_eval_data_points]

masked_words = [x["masked_word"].lower() for x in ts_guessing_results]
responses = [x["response"].lower() for x in ts_guessing_results]
em = len([i for i in range(len(responses)) if responses[i] == masked_words[i]]) / len(responses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
"""

import os
os.environ["CLASSPATH"]="/data/mathieu/stanford-postagger-full-2020-11-17/stanford-postagger.jar"
os.environ["STANFORD_MODELS"]="/data/mathieu/stanford-postagger-full-2020-11-17/models"
import numpy as np
from tqdm import tqdm
from rouge_score import rouge_scorer
from nltk.tokenize import word_tokenize, sent_tokenize
from functools import partial
from datasets import Dataset

from llmsanitize.utils.logger import get_child_logger, suspend_logging
from llmsanitize.utils.dataset_utils import get_answers_list, get_answer_index
Expand Down Expand Up @@ -61,14 +64,16 @@ def process_response(response, wrong_letter):
@suspend_logging
def inference(data_point, eval_data_name, llm):
prompt, answer, wrong_letter = build_prompt(
example,
data_point,
eval_data_name
)
response, cost = llm.query(prompt)
response = process_response(response, wrong_letter)
data_point["answer"] = answer
data_point["response"] = response

return data_point


def main_ts_guessing_question_multichoice(
eval_data: list = [],
Expand Down Expand Up @@ -102,7 +107,8 @@ def main_ts_guessing_question_multichoice(
data_points = [data_points[x] for x in p]
data_points = data_points[:n_eval_data_points]
logger.info(f"We are left with {len(data_points)} data points after subsampling")

data_points = Dataset.from_list(data_points)

llm = LLM(
local_model_path=local_model_path,
local_tokenizer_path=local_tokenizer_path,
Expand All @@ -127,7 +133,7 @@ def main_ts_guessing_question_multichoice(
llm=llm,
)

ts_guessing_results = eval_data.map(process_fn, num_proc=num_proc)
ts_guessing_results = data_points.map(process_fn, num_proc=num_proc)
answers = [x["answer"].lower() for x in ts_guessing_results]
responses = [x["response"].lower() for x in ts_guessing_results]
em = len([i for i in range(len(responses)) if responses[i] == answers[i]]) / len(responses)
Expand Down
5 changes: 4 additions & 1 deletion llmsanitize/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def get_answer_index(data_point, dataset_name):
alphabet = "abcdefghijklmnopqrstuvwxyz"
if dataset_name == "allenai/ai2_arc":
key = data_point["answerKey"].lower()
answer_index = alphabet.index(key)
if key in ["1", "2", "3", "4"]:
answer_index = int(key)-1
else:
answer_index = alphabet.index(key)
if dataset_name == "Rowan/hellaswag":
answer_index = int(data_point["label"])
if dataset_name == "cais/mmlu":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python main.py \
--text_key question \
--label_key answerKey \
--n_eval_data_points 100 \
--num_proc 16 \
--method ts-guessing-question-multichoice \
--local_port $port \
--model_name $model_name
--model_name $model_name
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ python main.py \
--text_key ctx \
--label_key activity_label \
--n_eval_data_points 100 \
--n_proc 16 \
--method ts-guessing-question-multichoice \
--local_port $port \
--model_name $model_name
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python main.py \
--text_key question \
--label_key answer_text \
--n_eval_data_points 100 \
--num_proc 16 \
--method ts-guessing-question-multichoice \
--local_port $port \
--model_name $model_name
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python main.py \
--text_key question \
--label_key category \
--n_eval_data_points 100 \
--num_proc 16 \
--method ts-guessing-question-multichoice \
--local_port $port \
--model_name $model_name
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python main.py \
--text_key sentence \
--label_key answer_token \
--n_eval_data_points 100 \
--num_proc 16 \
--method ts-guessing-question-multichoice \
--local_port $port \
--model_name $model_name

0 comments on commit 8fb4437

Please sign in to comment.