-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathremove_lvis_rare.py
74 lines (64 loc) · 2.74 KB
/
remove_lvis_rare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import json
import torch
from detic.modeling.clip import clip as clip_model
from detic.prompt_engineering import get_prompt_templates
import pickle
@torch.no_grad()
def get_custom_text_feat(class_names):
clip, _ = clip_model.load('RN50')
clip = clip.cuda()
def extract_mean_emb(text):
tokens = clip_model.tokenize(text).cuda()
if len(text) > 10000:
text_features = torch.cat([
clip.encode_text(text[:len(text) // 2]),
clip.encode_text(text[len(text) // 2:])],
dim=0)
else:
text_features = clip.encode_text(tokens)
text_features = torch.mean(text_features, 0, keepdims=True)
return text_features[0]
templates = get_prompt_templates()
clss_embeddings = []
for clss in class_names:
txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]
clss_embeddings.append(extract_mean_emb(txts))
txts = ['background']
clss_embeddings.append(extract_mean_emb(txts))
text_emb = torch.stack(clss_embeddings, dim=0)
text_emb /= text_emb.norm(dim=-1, keepdim=True)
return text_emb
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ann', default='datasets/lvis/lvis_v1_train.json')
args = parser.parse_args()
print('Loading', args.ann)
data = json.load(open(args.ann, 'r'))
catid2freq = {x['id']: x['frequency'] for x in data['categories']}# [id: freq]
exclude = ['r']
data['annotations'] = [x for x in data['annotations']
if catid2freq[x['category_id']] not in exclude]
# remove rare categories
data['categories'] = [x for x in data['categories']
if catid2freq[x['id']] not in exclude]
print('filtered #anns', len(data['annotations']))
out_path = args.ann[:-5] + '_seen.json'
print('Saving to', out_path)
json.dump(data, open(out_path, 'w'))
print('done')
# exclude = ['r']
# data['base_class_name'] = [x['name'] for x in data['categories'] if x['frequency'] not in exclude]
# data['name'] = [x['name'] for x in data['categories']]
# base_text_feats = get_custom_text_feat(data['base_class_name'])
# all_text_feats = get_custom_text_feat(data['name'])
# print('class_name', len(data['name']))
# # print(data['base_class_name'])
# with open('datasets/lvis/lvis_base_cls.pkl', 'wb') as f:
# pickle.dump(base_text_feats, f)
# with open('datasets/lvis/lvis_cls.pkl', 'wb') as f:
# pickle.dump(all_text_feats, f)
# loading categories
cat_info = data['categories']
print(len(cat_info))