Skip to content

Commit

Permalink
tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Jan 31, 2023
1 parent 1febc88 commit 40ad476
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 228 deletions.
2 changes: 1 addition & 1 deletion src/tasksource/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def sample_dataset(dataset,n=10000, n_eval=1000):

class Preprocessing(DotWiz):
default_splits = ('train','validation','test')

@staticmethod
def __map_to_target(x,fn=lambda x:None, target=None):
x[target]=fn(x)
Expand Down Expand Up @@ -170,6 +169,7 @@ class SharedFields:
config_name:str = None
pre_process: callable = lambda x:x
post_process: callable = lambda x:x
#language:str="en"

@dataclass
class Classification(SharedFields, ClassificationFields): pass
Expand Down
28 changes: 19 additions & 9 deletions src/tasksource/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .preprocess import cat, get, regen, constant, Classification, TokenClassification, MultipleChoice
from .metadata import bigbench_discriminative_english, blimp_hard, imppres_presupposition, imppres_implicature
from datasets import get_dataset_config_names, ClassLabel

from datasets import get_dataset_config_names, ClassLabel, Dataset, DatasetDict
# variable name: dataset___config__task

###################### NLI/paraphrase ###############################
Expand Down Expand Up @@ -649,12 +648,12 @@ def _split_choices(s):
dataset_name="lucasmccabe/logiqa"
)

proto_qa = MultipleChoice(
"question",
choices_list=lambda x:x['answer-clusters']['answers'],
labels=lambda x: x['answer-clusters']['count'].index(max(x['answer-clusters']['count'])),
config_name='proto_qa'
)
#proto_qa = MultipleChoice(
# "question",
# choices_list=lambda x:x['answer-clusters']['answers'],
# labels=lambda x: x['answer-clusters']['count'].index(max(x['answer-clusters']['count'])),
# config_name='proto_qa'
#)

wiki_qa = Classification("question","answer","label")

Expand Down Expand Up @@ -705,4 +704,15 @@ def _preprocess_chatgpt_detection(ex):

moral_stories = MultipleChoice(cat(["situation","intention"]),
choices=['moral_action',"immoral_action"],labels=constant(0),
dataset_name="demelin/moral_stories", config_name="full")
dataset_name="demelin/moral_stories", config_name="full")

prost = MultipleChoice(cat(["context","ex_question"]), choices=['A','B','C','D'],labels="label",
dataset_name="corypaik/prost")

dyna_hate = Classification("text",labels="label",dataset_name="aps/dynahate",splits=['train',None,None])

syntactic_augmentation_nli = Classification('sentence1',"sentence2","gold_label",dataset_name="metaeval/syntactic-augmentation-nli")


#autotnli = Classification("premises", "hypothesis", "label", dataset_name="metaeval/autotnli")
#equate = Classification("sentence1", "sentence2", "gold_label",dataset_name="metaeval/equate")
Loading

0 comments on commit 40ad476

Please sign in to comment.