Skip to content

Commit

Permalink
training on gqa
Browse files Browse the repository at this point in the history
  • Loading branch information
sebamenabar committed Jun 22, 2020
1 parent e3d7b5b commit a69027f
Show file tree
Hide file tree
Showing 6 changed files with 612 additions and 322 deletions.
2 changes: 1 addition & 1 deletion src/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def parse_args_and_set_config(__c=__C, blacklist=None):
if allowed_type is bool:

def allowed_type(x):
return bool(parsing.strtobool(x))
return bool(parsing.str_to_bool(x))

parser.add_argument(
f"--{arg}",
Expand Down
14 changes: 7 additions & 7 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from base_config import __C, parse_args_and_set_config, edict, _to_values_only


parse_bool = lambda x: bool(parsing.strtobool(x))
parse_bool = lambda x: bool(parsing.str_to_bool(x))

if torch.cuda.is_available():
__C.orig_dir = (
Expand All @@ -26,7 +26,7 @@

__C.train.num_plot_samples = (32, edict(type=int))
__C.train.augment = (False, edict(type=parse_bool))
__C.train.dataset = ("orig", edict(choices=["orig", "uni"]))
__C.train.dataset = ("orig", edict(choices=["orig", "uni", "gqa"]))
__C.train.gradient_clip_val = (8.0, edict(type=float))

__C.train.optimizers = edict()
Expand All @@ -37,7 +37,7 @@
__C.model.encoder = edict()
__C.model.encoder.type = (
"resnet50",
edict(choices=["scratch", "pretrained", "resnet50"]),
edict(choices=["none", "scratch", "pretrained", "resnet50", "resnet101"]),
)
__C.model.encoder.ckpt_fp = ("", edict(type=str))
__C.model.encoder.out_nc = 512
Expand All @@ -48,16 +48,16 @@
init_mem="random",
max_step=(12, edict(type=int)),
separate_syntax_semantics=False,
use_feats="spatial",
use_feats=("spatial", edict(type=str, choices=["spatial", "objects", "pixels"])),
num_gt_lobs=0,
common=edict(
module_dim=512,
module_dim=(512, edict(type=int)),
# use_feats='spatial',
),
input_unit=edict(
in_channels=512,
in_channels=(512, edict(type=int)),
wordvec_dim=300,
rnn_dim=512,
rnn_dim=(512, edict(type=int)),
bidirectional=True,
separate_syntax_semantics_embeddings=False,
stem_act="ELU",
Expand Down
Loading

0 comments on commit a69027f

Please sign in to comment.