Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ybendou committed Dec 7, 2022
1 parent d216bc0 commit fd00c54
Show file tree
Hide file tree
Showing 6 changed files with 797 additions and 1 deletion.
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,35 @@
# few-shot-generalization
# A Statistical Model for Predicting Generalization in Few-Shot Classification

Official implementation for the AISTATS submission "A Statistical Model for Predicting Generalization in Few-Shot Classification".

To run the code for different datasets, first download the features. We use the features proposed in the article https://arxiv.org/pdf/2201.09699.pdf which can be downloaded from the following [link](https://drive.google.com/drive/folders/1fALYAfzStWXasI-DTl6qi9moNuWbFA-j) and can be put in the the folder "features".

Then, run the following Bash script:
```
./bash_scripts/run_mini.sh
```
If one is interested in changing elements of the runs, you can specify the parameters and run the following commands:

```
SAVE_PATH="results";
FEATURES_PATH="features"
# validation set
VALIDATION_DATASET="miniimagenet_validation";
VALIDATION_FEATURES="mini11miniimagenet_validation_features";
# Test set
TEST_DATASET="miniimagenet_test";
TEST_FEATURES="mini11miniimagenet_test_features";
N_RUNS=1000; #Number of few-shot problems
N_WAYS=5; #Number of classes
MAXK=50; #Max number of samples
UNBALANCED="False";
# First run the validation split
python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$VALIDATION_FEATURES.pt --dataset $VALIDATION_DATASET --validation --n-ways $N_WAYS --n-runs $N_RUNS;
# Run on the test set
python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$TEST_FEATURES.pt --dataset $TEST_DATASET --config-validation $SAVE_PATH/$VALIDATION_DATASET"/nruns"$N_RUNS"_c"$N_WAYS"_unbalanced"$UNBALANCED"_filename_"$VALIDATION_FEATURES.pt --n-ways $N_WAYS --n-runs $N_RUNS;
````
59 changes: 59 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import random
import argparse
import random

def process_arguments(parser=None, params=None):
if parser == None:
parser = argparse.ArgumentParser()

### pytorch options
parser.add_argument("--device", type=str, default="cuda:0", help="device(s) to use, for multiple GPUs try cuda:ijk, will not work with 10+ GPUs")
parser.add_argument("--dataset-path", type=str, default='test/test/', help="dataset path")
parser.add_argument("--features-path", type=str, default='', help="features directory path")

parser.add_argument("--dataset-device", type=str, default="", help="use a different device for storing the datasets (use 'cpu' if you are lacking VRAM)")
parser.add_argument("--deterministic", action="store_true", help="use desterministic randomness for reproducibility")

### run options
parser.add_argument("--dataset", type=str, default="miniimagenet", help="dataset to use")
parser.add_argument("--seed", type=int, default=-1, help="set random seed manually, and also use deterministic approach")

### few-shot parameters
parser.add_argument("--n-shots", type=str, default="[1,5]", help="how many shots per few-shot run, can be int or list of ints. In case of episodic training, use first item of list as number of shots.")
parser.add_argument("--n-runs", type=int, default=10000, help="number of few-shot runs")
parser.add_argument("--n-ways", type=int, default=5, help="number of few-shot ways")
parser.add_argument("--n-queries", type=int, default=15, help="number of few-shot queries")
parser.add_argument("--unbalanced-queries", action="store_true", help="Unbalanced queries")

args = parser.parse_args()

if params!=None:
for key, value in params.items():
args.__dict__[key]= value
### process arguments
if args.dataset_device == "":
args.dataset_device = args.device
if args.dataset_path[-1] != '/':
args.dataset_path += "/"

if args.device[:5] == "cuda:" and len(args.device) > 5:
args.devices = []
for i in range(len(args.device) - 5):
args.devices.append(int(args.device[i+5]))
args.device = args.device[:6]
else:
args.devices = [args.device]

if args.seed == -1:
args.seed = random.randint(0, 1000000000)

try:
n_shots = int(args.n_shots)
args.n_shots = [n_shots]
except:
args.n_shots = eval(args.n_shots)

if '[' in args.features_path :
args.features_path = eval(args.features_path)
print("args, ", end='')
return args
25 changes: 25 additions & 0 deletions bash_scripts/run_mini.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
cd ..;

SAVE_PATH="results";
FEATURES_PATH="features"

# validation set
VALIDATION_DATASET="miniimagenet_validation";
VALIDATION_FEATURES="mini11miniimagenet_validation_features";

# Test set
TEST_DATASET="miniimagenet_test";
TEST_FEATURES="mini11miniimagenet_test_features";

N_RUNS=1000; # Number of runs
N_WAYS=5; # Number of classes
MAXK=50; # Max number of samples for the few-shot problems.
UNBALANCED="False";

# Run the validation set (different classes)
echo python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$VALIDATION_FEATURES.pt --dataset $VALIDATION_DATASET --validation --n-ways $N_WAYS --n-runs $N_RUNS;
python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$VALIDATION_FEATURES.pt --dataset $VALIDATION_DATASET --validation --n-ways $N_WAYS --n-runs $N_RUNS;

# Run the test set
echo python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$TEST_FEATURES.pt --dataset $TEST_DATASET --config-validation $SAVE_PATH/$VALIDATION_DATASET"/nruns"$N_RUNS"_c"$N_WAYS"_unbalanced"$UNBALANCED"_filename_"$VALIDATION_FEATURES.pt --n-ways $N_WAYS --n-runs $N_RUNS;
python main_bias_estimate.py --save-folder $SAVE_PATH --maxK $MAXK --features-path $FEATURES_PATH/$TEST_FEATURES.pt --dataset $TEST_DATASET --config-validation $SAVE_PATH/$VALIDATION_DATASET"/nruns"$N_RUNS"_c"$N_WAYS"_unbalanced"$UNBALANCED"_filename_"$VALIDATION_FEATURES.pt --n-ways $N_WAYS --n-runs $N_RUNS;
82 changes: 82 additions & 0 deletions few_shot_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import random
import torch

class EpisodicGenerator():
def __init__(self, datasetName=None, dataset_path=None, max_classes=50, num_elements_per_class=None):
self.dataset = None
self.num_elements_per_class = num_elements_per_class
self.max_classes = min(len(self.num_elements_per_class), 50)

def select_classes(self, ways):
# number of ways for this episode
n_ways = ways if ways!=0 else random.randint(5, self.max_classes)

# get n_ways classes randomly
choices = torch.randperm(len(self.num_elements_per_class))[:n_ways]
return choices

def get_query_size(self, choice_classes, n_queries):
return n_queries

def get_support_size(self, choice_classes, query_size, n_shots):
support_size = len(choice_classes)*n_shots
return support_size

def get_number_of_shots(self, choice_classes, support_size, query_size, n_shots):
n_shots_per_class = [n_shots]*len(choice_classes)
return n_shots_per_class

def get_number_of_queries(self, choice_classes, query_size, unbalanced_queries):
n_queries_per_class = [query_size]*len(choice_classes)
return n_queries_per_class

def sample_indices(self, num_elements_per_chosen_classes, n_shots_per_class, n_queries_per_class):
shots_idx = []
queries_idx = []
for k, q, elements_per_class in zip(n_shots_per_class, n_queries_per_class, num_elements_per_chosen_classes):
choices = torch.randperm(elements_per_class)
shots_idx.append(choices[:k].tolist())
queries_idx.append(choices[k:k+q].tolist())
return shots_idx, queries_idx

def sample_episode(self, ways=0, n_shots=0, n_queries=0, unbalanced_queries=False, verbose=False):
"""
Sample an episode
"""
# get n_ways classes randomly
choice_classes = self.select_classes(ways=ways)

query_size = self.get_query_size(choice_classes, n_queries)
support_size = self.get_support_size(choice_classes, query_size, n_shots)

n_shots_per_class = self.get_number_of_shots(choice_classes, support_size, query_size, n_shots)
n_queries_per_class = self.get_number_of_queries(choice_classes, query_size, unbalanced_queries)
shots_idx, queries_idx = self.sample_indices([self.num_elements_per_class[c] for c in choice_classes], n_shots_per_class, n_queries_per_class)

if verbose:
print(f'chosen class: {choice_classes}')
print(f'n_ways={len(choice_classes)}, q={query_size}, S={support_size}, n_shots_per_class={n_shots_per_class}')
print(f'queries per class:{n_queries_per_class}')
print(f'shots_idx: {shots_idx}')
print(f'queries_idx: {queries_idx}')

return {'choice_classes':choice_classes, 'shots_idx':shots_idx, 'queries_idx':queries_idx}

def get_features_from_indices(self, features, episode, validation=False):
"""
Get features from a list of all features and from a dictonnary describing an episode
"""
choice_classes, shots_idx, queries_idx = episode['choice_classes'], episode['shots_idx'], episode['queries_idx']
if validation :
validation_idx = episode['validations_idx']
val = []
shots, queries = [], []
for i, c in enumerate(choice_classes):
shots.append(features[c]['features'][shots_idx[i]])
queries.append(features[c]['features'][queries_idx[i]])
if validation :
val.append(features[c]['features'][validation_idx[i]])
if validation:
return shots, queries, val
else:
return shots, queries
Loading

0 comments on commit fd00c54

Please sign in to comment.