From 258ba29013c435c6082d5f4e7930efd0377cb51c Mon Sep 17 00:00:00 2001 From: Damien Sileo Date: Wed, 10 Jul 2024 17:06:38 +0200 Subject: [PATCH] new tasks --- .../.ipynb_checkpoints/access-checkpoint.py | 108 +++++++ .../preprocess-checkpoint.py | 295 ++++++++++++++++++ .../.ipynb_checkpoints/recast-checkpoint.py | 115 +++++++ .../.ipynb_checkpoints/tasks-checkpoint.py | 53 +++- src/tasksource/preprocess.py | 38 ++- src/tasksource/recast.py | 6 +- src/tasksource/tasks.py | 53 +++- 7 files changed, 629 insertions(+), 39 deletions(-) create mode 100644 src/tasksource/.ipynb_checkpoints/access-checkpoint.py create mode 100755 src/tasksource/.ipynb_checkpoints/preprocess-checkpoint.py create mode 100644 src/tasksource/.ipynb_checkpoints/recast-checkpoint.py diff --git a/src/tasksource/.ipynb_checkpoints/access-checkpoint.py b/src/tasksource/.ipynb_checkpoints/access-checkpoint.py new file mode 100644 index 0000000..bb49a19 --- /dev/null +++ b/src/tasksource/.ipynb_checkpoints/access-checkpoint.py @@ -0,0 +1,108 @@ +from .preprocess import Preprocessing +import re +import pandas as pd +from . import tasks, recast +from .metadata import dataset_rank +from datasets import load_dataset +import funcy as fc +import os +import copy +from sorcery import dict_of +from functools import cache +import random + + +class lazy_mtasks: + def __getattr__(self, name): + from . import mtasks + return getattr(mtasks, name) + + def __dir__(self): + from . import mtasks + return dir(mtasks) +lmtasks=lazy_mtasks() + +def parse_var_name(s): + config_name,task_name = None,None + if '__' in s and '___' not in s: # dataset__task + dataset_name, task_name = s.split('__') + elif '__' not in s.replace('___','') and '___' in s: #dataset___config + dataset_name, config_name = s.split('___') + elif '___' in s and '__' in s.split('___')[1]: #dataset___config__task + dataset_name, config_task=s.split('___') + config_name,task_name = config_task.split('__') + else: # dataset + dataset_name = s + return dataset_name,config_name,task_name + +def pretty_name(x): + dn = x.dataset_name.split("/")[-1] + cn = x.config_name if x.config_name else "" + tn = x.task_name if x.task_name else "" + return f"{dn}/{cn}/{tn}".replace('//','/').rstrip('/') + +@cache +def list_tasks(tasks_path=f'{os.path.dirname(__file__)}/tasks.py',multilingual=False,instruct=False, excluded=[]): + if multilingual: + tasks_path=tasks_path.replace('/tasks.py','/mtasks.py') + task_order = open(tasks_path).readlines() + task_order = [x.split('=')[0].rstrip() for x in task_order if '=' in x] + task_order = [x for x in task_order if x.isidentifier()] + task_order = fc.flip(dict(enumerate(task_order))) + + l = [] + _tasks = (lmtasks if multilingual else tasks) + + for key in dir(_tasks): + if key not in task_order: + continue + value=getattr(_tasks, key) + if isinstance(value,Preprocessing): + dataset_name, config_name, task_name = parse_var_name(key) + dataset_name = (value.dataset_name if value.dataset_name else dataset_name) + config_name = (value.config_name if value.config_name else config_name) + hasattr(value,key) + l+=[{'dataset_name': dataset_name, + 'config_name' : config_name, + 'task_name': task_name, + 'preprocessing_name': key, + 'task_type': value.__class__.__name__,'mapping': value, + 'rank':task_order.get(key,None)}] + df=pd.DataFrame(l).explode('config_name') + df = df.sort_values('rank').reset_index(drop=True) + df['id'] = df.apply(lambda x: pretty_name(x), axis=1) + df.insert(0, 'id', df.pop('id')) + del df['rank'] + if instruct: + df=df[df.id.map(lambda x: not any(a in x for a in recast.improper_labels))] + df=df[df.id.map(lambda x: not any(x in a for a in excluded))] + return df + +#task_df =list_tasks() +#mtask_df =list_tasks(multilingual=True) + +def dict_to_query(d=dict(), **kwargs): + d={**d,**kwargs} + return '&'.join([f'`{k}`=="{v}"' for k,v in d.items()]) + +def load_preprocessing(tasks=tasks, **kwargs): + _tasks_df = list_tasks(multilingual=tasks==lmtasks) + y = _tasks_df.copy().query(dict_to_query(**kwargs)).iloc[0] + preprocessing= copy.copy(getattr(tasks, y.preprocessing_name)) + for c in 'dataset_name','config_name': + if not isinstance(getattr(preprocessing,c), str): + setattr(preprocessing,c,getattr(y,c)) + return preprocessing + +def load_task(id=None, dataset_name=None,config_name=None,task_name=None,preprocessing_name=None, + max_rows=None, max_rows_eval=None, multilingual=False, instruct=False, seed=0, **load_dataset_kwargs): + query = dict_of(id, dataset_name, config_name, task_name,preprocessing_name) + query = {k:v for k,v in query.items() if v} + _tasks = (lmtasks if multilingual else tasks) + preprocessing = load_preprocessing(_tasks, **query) + dataset = load_dataset(preprocessing.dataset_name, preprocessing.config_name, **load_dataset_kwargs) + dataset= preprocessing(dataset,max_rows, max_rows_eval) + dataset.task_type = preprocessing.__class__.__name__ + if instruct: + dataset=recast.recast_instruct(dataset) + return dataset \ No newline at end of file diff --git a/src/tasksource/.ipynb_checkpoints/preprocess-checkpoint.py b/src/tasksource/.ipynb_checkpoints/preprocess-checkpoint.py new file mode 100755 index 0000000..e737242 --- /dev/null +++ b/src/tasksource/.ipynb_checkpoints/preprocess-checkpoint.py @@ -0,0 +1,295 @@ +from collections.abc import Iterable +from dotwiz import DotWiz +from dataclasses import dataclass +from typing import Union +import itertools +import funcy as fc +import exrex +import magicattr +import numpy as np +import copy +import datasets +import time + +MAX_MC_OPTIONS = 4 + +def get_column_names(dataset): + cn = dataset.column_names + if type(cn)==dict: + return set(fc.flatten(cn.values())) + else: + return set(cn) + + +def sample_dataset(dataset,n=10000, n_eval=1000,seed=0): + for k in dataset: + n_k=(n if k=='train' else n_eval) + if n_k and len(dataset[k])>n_k: + dataset[k]=dataset[k].train_test_split(train_size=n_k,seed=seed)['train'] + return dataset + +class Preprocessing(DotWiz): + default_splits = ('train','validation','test') + _instances = [] + + def __post_init__(self): + Preprocessing._instances+=[self] + + @staticmethod + def __map_to_target(x,fn=lambda x:None, target=None): + x[target]=fn(x) + return x + + def load(self): + return self(datasets.load_dataset(self.dataset_name,self.config_name)) + + def __call__(self,dataset, max_rows=None, max_rows_eval=None,seed=0): + dataset = self.pre_process(dataset) + + # manage splits + for k,v in zip(self.default_splits, self.splits): + if v and k!=v: + dataset[k]=dataset[v] + del dataset[v] + if k in dataset and not v: # obfuscated label + del dataset[k] + dataset = fix_splits(dataset) + + for k in list(dataset.keys()): + if k not in self.default_splits: + del dataset[k] + dataset = sample_dataset(dataset, max_rows, max_rows_eval,seed=seed) + + # field annotated with a string + substitutions = {v:k for k,v in self.to_dict().items() + if (k and k not in {'splits','dataset_name','config_name'} + and type(v)==str and k!=v)} + + dataset=dataset.remove_columns([c for c in substitutions.values() if c in dataset['train'].features and c not in substitutions]) + dataset=dataset.rename_columns(substitutions) + + # field annotated with a function + for k in self.to_dict().keys(): + v=getattr(self, k) + if callable(v) and k not in {"post_process","pre_process","load"}: + dataset=dataset.map(self.__map_to_target, + fn_kwargs={'fn':v,'target':k}) + + dataset=dataset.remove_columns( + get_column_names(dataset)-set(self.to_dict().keys())) + dataset = fix_labels(dataset) + dataset = fix_splits(dataset) # again: label mapping changed + dataset = self.post_process(dataset) + return dataset + + +@dataclass +class cat(Preprocessing): + fields:Union[str,list]=None + separator:str=' ' + + def __call__(self, example=None): + y=[np.char.array(example[f]) + sep + for f,sep in zip(self.fields[::-1],itertools.repeat(self.separator))] + y=list(sum(*y)) + if len(y)==1: + y=y[0] + return y + + +def pretty(f): + class pretty_f(DotWiz): + def __init__(self,*args): + self.__f_arg = f(*args) + for a in args: + setattr(self,'value',a) + + def __call__(self, *args,**kwargs): + return self.__f_arg(*args,**kwargs) + + def __repr__(self): + return f"{self.__f_arg.__qualname__ .split('.')[0]}({self.value})" + return pretty_f + +class dotgetter: + def __init__(self, path=''): + self.path=path + + def __bool__(self): + return bool(self.path) + + def __getattr__(self, k): + return self.__class__(f'{self.path}.{k}'.lstrip('.')) + + def __getitem__(self, i): + return self.__class__(f'{self.path}[{i}]') + + def __call__(self, example=None): + return magicattr.get(DotWiz(example), self.path) + + def __hash__(self): + return hash(self.path) + + +@dataclass +class ClassificationFields(Preprocessing): + sentence1:str='sentence1' + sentence2:str='sentence2' + labels:str='labels' + +@dataclass +class Seq2SeqLMFields(Preprocessing): + prompt:str='prompt' + output:str='output' + +@dataclass +class TokenClassificationFields(Preprocessing): + tokens:str='tokens' + labels:str='labels' + +@dataclass +class MultipleChoiceFields(Preprocessing): + inputs:str='input' + choices:Iterable=tuple() + labels:str='labels' + choices_list:str=None + def __post_init__(self): + for i, c in enumerate(self.choices): + setattr(self,f'choice{i}',c) + delattr(self,'choices') + if not self.choices_list: + delattr(self,'choices_list') + + def __call__(self,dataset, *args, **kwargs): + dataset = super().__call__(dataset, *args, **kwargs) + if self.choices_list: + dataset = dataset.filter(lambda x: 1