-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
797 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,35 @@ | ||
# few-shot-generalization | ||
# A Statistical Model for Predicting Generalization in Few-Shot Classification | ||
|
||
Official implementation for the AISTATS submission "A Statistical Model for Predicting Generalization in Few-Shot Classification". | ||
|
||
To run the code for different datasets, first download the features. We use the features proposed in the article https://arxiv.org/pdf/2201.09699.pdf which can be downloaded from the following [link](https://drive.google.com/drive/folders/1fALYAfzStWXasI-DTl6qi9moNuWbFA-j) and can be put in the the folder "features". | ||
|
||
Then, run the following Bash script: | ||
``` | ||
./bash_scripts/run_mini.sh | ||
``` | ||
If one is interested in changing elements of the runs, you can specify the parameters and run the following commands: | ||
|
||
``` | ||
SAVE_PATH="results"; | ||
FEATURES_PATH="features" | ||
# validation set | ||
VALIDATION_DATASET="miniimagenet_validation"; | ||
VALIDATION_FEATURES="mini11miniimagenet_validation_features"; | ||
# Test set | ||
TEST_DATASET="miniimagenet_test"; | ||
TEST_FEATURES="mini11miniimagenet_test_features"; | ||
N_RUNS=1000; #Number of few-shot problems | ||
N_WAYS=5; #Number of classes | ||
MAXK=50; #Max number of samples | ||
UNBALANCED="False"; | ||
# First run the validation split | ||
python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$VALIDATION_FEATURES.pt --dataset $VALIDATION_DATASET --validation --n-ways $N_WAYS --n-runs $N_RUNS; | ||
# Run on the test set | ||
python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$TEST_FEATURES.pt --dataset $TEST_DATASET --config-validation $SAVE_PATH/$VALIDATION_DATASET"/nruns"$N_RUNS"_c"$N_WAYS"_unbalanced"$UNBALANCED"_filename_"$VALIDATION_FEATURES.pt --n-ways $N_WAYS --n-runs $N_RUNS; | ||
```` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import random | ||
import argparse | ||
import random | ||
|
||
def process_arguments(parser=None, params=None): | ||
if parser == None: | ||
parser = argparse.ArgumentParser() | ||
|
||
### pytorch options | ||
parser.add_argument("--device", type=str, default="cuda:0", help="device(s) to use, for multiple GPUs try cuda:ijk, will not work with 10+ GPUs") | ||
parser.add_argument("--dataset-path", type=str, default='test/test/', help="dataset path") | ||
parser.add_argument("--features-path", type=str, default='', help="features directory path") | ||
|
||
parser.add_argument("--dataset-device", type=str, default="", help="use a different device for storing the datasets (use 'cpu' if you are lacking VRAM)") | ||
parser.add_argument("--deterministic", action="store_true", help="use desterministic randomness for reproducibility") | ||
|
||
### run options | ||
parser.add_argument("--dataset", type=str, default="miniimagenet", help="dataset to use") | ||
parser.add_argument("--seed", type=int, default=-1, help="set random seed manually, and also use deterministic approach") | ||
|
||
### few-shot parameters | ||
parser.add_argument("--n-shots", type=str, default="[1,5]", help="how many shots per few-shot run, can be int or list of ints. In case of episodic training, use first item of list as number of shots.") | ||
parser.add_argument("--n-runs", type=int, default=10000, help="number of few-shot runs") | ||
parser.add_argument("--n-ways", type=int, default=5, help="number of few-shot ways") | ||
parser.add_argument("--n-queries", type=int, default=15, help="number of few-shot queries") | ||
parser.add_argument("--unbalanced-queries", action="store_true", help="Unbalanced queries") | ||
|
||
args = parser.parse_args() | ||
|
||
if params!=None: | ||
for key, value in params.items(): | ||
args.__dict__[key]= value | ||
### process arguments | ||
if args.dataset_device == "": | ||
args.dataset_device = args.device | ||
if args.dataset_path[-1] != '/': | ||
args.dataset_path += "/" | ||
|
||
if args.device[:5] == "cuda:" and len(args.device) > 5: | ||
args.devices = [] | ||
for i in range(len(args.device) - 5): | ||
args.devices.append(int(args.device[i+5])) | ||
args.device = args.device[:6] | ||
else: | ||
args.devices = [args.device] | ||
|
||
if args.seed == -1: | ||
args.seed = random.randint(0, 1000000000) | ||
|
||
try: | ||
n_shots = int(args.n_shots) | ||
args.n_shots = [n_shots] | ||
except: | ||
args.n_shots = eval(args.n_shots) | ||
|
||
if '[' in args.features_path : | ||
args.features_path = eval(args.features_path) | ||
print("args, ", end='') | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
cd ..; | ||
|
||
SAVE_PATH="results"; | ||
FEATURES_PATH="features" | ||
|
||
# validation set | ||
VALIDATION_DATASET="miniimagenet_validation"; | ||
VALIDATION_FEATURES="mini11miniimagenet_validation_features"; | ||
|
||
# Test set | ||
TEST_DATASET="miniimagenet_test"; | ||
TEST_FEATURES="mini11miniimagenet_test_features"; | ||
|
||
N_RUNS=1000; # Number of runs | ||
N_WAYS=5; # Number of classes | ||
MAXK=50; # Max number of samples for the few-shot problems. | ||
UNBALANCED="False"; | ||
|
||
# Run the validation set (different classes) | ||
echo python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$VALIDATION_FEATURES.pt --dataset $VALIDATION_DATASET --validation --n-ways $N_WAYS --n-runs $N_RUNS; | ||
python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$VALIDATION_FEATURES.pt --dataset $VALIDATION_DATASET --validation --n-ways $N_WAYS --n-runs $N_RUNS; | ||
|
||
# Run the test set | ||
echo python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$TEST_FEATURES.pt --dataset $TEST_DATASET --config-validation $SAVE_PATH/$VALIDATION_DATASET"/nruns"$N_RUNS"_c"$N_WAYS"_unbalanced"$UNBALANCED"_filename_"$VALIDATION_FEATURES.pt --n-ways $N_WAYS --n-runs $N_RUNS; | ||
python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$TEST_FEATURES.pt --dataset $TEST_DATASET --config-validation $SAVE_PATH/$VALIDATION_DATASET"/nruns"$N_RUNS"_c"$N_WAYS"_unbalanced"$UNBALANCED"_filename_"$VALIDATION_FEATURES.pt --n-ways $N_WAYS --n-runs $N_RUNS; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import random | ||
import torch | ||
|
||
class EpisodicGenerator(): | ||
def __init__(self, datasetName=None, dataset_path=None, max_classes=50, num_elements_per_class=None): | ||
self.dataset = None | ||
self.num_elements_per_class = num_elements_per_class | ||
self.max_classes = min(len(self.num_elements_per_class), 50) | ||
|
||
def select_classes(self, ways): | ||
# number of ways for this episode | ||
n_ways = ways if ways!=0 else random.randint(5, self.max_classes) | ||
|
||
# get n_ways classes randomly | ||
choices = torch.randperm(len(self.num_elements_per_class))[:n_ways] | ||
return choices | ||
|
||
def get_query_size(self, choice_classes, n_queries): | ||
return n_queries | ||
|
||
def get_support_size(self, choice_classes, query_size, n_shots): | ||
support_size = len(choice_classes)*n_shots | ||
return support_size | ||
|
||
def get_number_of_shots(self, choice_classes, support_size, query_size, n_shots): | ||
n_shots_per_class = [n_shots]*len(choice_classes) | ||
return n_shots_per_class | ||
|
||
def get_number_of_queries(self, choice_classes, query_size, unbalanced_queries): | ||
n_queries_per_class = [query_size]*len(choice_classes) | ||
return n_queries_per_class | ||
|
||
def sample_indices(self, num_elements_per_chosen_classes, n_shots_per_class, n_queries_per_class): | ||
shots_idx = [] | ||
queries_idx = [] | ||
for k, q, elements_per_class in zip(n_shots_per_class, n_queries_per_class, num_elements_per_chosen_classes): | ||
choices = torch.randperm(elements_per_class) | ||
shots_idx.append(choices[:k].tolist()) | ||
queries_idx.append(choices[k:k+q].tolist()) | ||
return shots_idx, queries_idx | ||
|
||
def sample_episode(self, ways=0, n_shots=0, n_queries=0, unbalanced_queries=False, verbose=False): | ||
""" | ||
Sample an episode | ||
""" | ||
# get n_ways classes randomly | ||
choice_classes = self.select_classes(ways=ways) | ||
|
||
query_size = self.get_query_size(choice_classes, n_queries) | ||
support_size = self.get_support_size(choice_classes, query_size, n_shots) | ||
|
||
n_shots_per_class = self.get_number_of_shots(choice_classes, support_size, query_size, n_shots) | ||
n_queries_per_class = self.get_number_of_queries(choice_classes, query_size, unbalanced_queries) | ||
shots_idx, queries_idx = self.sample_indices([self.num_elements_per_class[c] for c in choice_classes], n_shots_per_class, n_queries_per_class) | ||
|
||
if verbose: | ||
print(f'chosen class: {choice_classes}') | ||
print(f'n_ways={len(choice_classes)}, q={query_size}, S={support_size}, n_shots_per_class={n_shots_per_class}') | ||
print(f'queries per class:{n_queries_per_class}') | ||
print(f'shots_idx: {shots_idx}') | ||
print(f'queries_idx: {queries_idx}') | ||
|
||
return {'choice_classes':choice_classes, 'shots_idx':shots_idx, 'queries_idx':queries_idx} | ||
|
||
def get_features_from_indices(self, features, episode, validation=False): | ||
""" | ||
Get features from a list of all features and from a dictonnary describing an episode | ||
""" | ||
choice_classes, shots_idx, queries_idx = episode['choice_classes'], episode['shots_idx'], episode['queries_idx'] | ||
if validation : | ||
validation_idx = episode['validations_idx'] | ||
val = [] | ||
shots, queries = [], [] | ||
for i, c in enumerate(choice_classes): | ||
shots.append(features[c]['features'][shots_idx[i]]) | ||
queries.append(features[c]['features'][queries_idx[i]]) | ||
if validation : | ||
val.append(features[c]['features'][validation_idx[i]]) | ||
if validation: | ||
return shots, queries, val | ||
else: | ||
return shots, queries |
Oops, something went wrong.