Skip to content

Commit

Permalink
Merge pull request #30 from ynashed/set_softdtw_to_non_cuda_version
Browse files Browse the repository at this point in the history
remove the default use of the cuda version of soft dtw
  • Loading branch information
sgasioro authored Oct 10, 2024
2 parents f64bd04 + c93cff3 commit 203d526
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
6 changes: 5 additions & 1 deletion optimize/example_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main(config):
no_adc=config.no_adc, loss_fn=config.loss_fn, shift_no_fit=config.shift_no_fit,
link_vdrift_eField=config.link_vdrift_eField, batch_memory=config.batch_memory, skip_pixels=config.skip_pixels,
set_target_vals=config.set_target_vals, vary_init=config.vary_init, seed_init=config.seed_init,
config = config)
config = config, use_cuda=config.use_cuda, softdtw_gamma=config.softdtw_gamma)
param_fit.make_target_sim(seed=config.seed, fixed_range=config.fixed_range)
# run simulation
param_fit.fit(tracks_dataloader, epochs=config.epochs, iterations=iterations, shuffle=config.data_shuffle, save_freq=config.save_freq)
Expand Down Expand Up @@ -157,6 +157,10 @@ def main(config):
help="Number of iterations to run. Overrides epochs.")
parser.add_argument("--loss_fn", dest="loss_fn", default=None,
help="Loss function to use. Named options are SDTW and space_match.")
parser.add_argument("--use_cuda", dest="use_cuda", default=False, action="store_true",
help="If using the cuda implementation of softdtw. Warning: may contain a minor bug")
parser.add_argument("--softdtw_gamma", dest="softdtw_gamma", default=1, type=float,
help="Gamma of the soft DTW (loss function).")
parser.add_argument("--max_batch_len", dest="max_batch_len", default=None, type=float,
help="Max dx [cm] per batch. If passed, will add tracks to batch until overflow, splitting where needed")
parser.add_argument("--max_nbatch", dest="max_nbatch", default=None, type=int,
Expand Down
6 changes: 2 additions & 4 deletions optimize/fit_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, relevant_params, track_fields, track_chunk, pixel_chunk,
out_label="", norm_scheme="divide", max_clip_norm_val=None, fit_diffs=False, optimizer_fn="Adam",
no_adc=False, shift_no_fit=[], link_vdrift_eField=False, batch_memory=None, skip_pixels = False,
set_target_vals=[], vary_init=False, seed_init=30,
config = {}):
config = {}, use_cuda = False, softdtw_gamma = 1):

if optimizer_fn == "Adam":
self.optimizer_fn = torch.optim.Adam
Expand Down Expand Up @@ -185,13 +185,11 @@ def __init__(self, relevant_params, track_fields, track_chunk, pixel_chunk,
t_only = self.no_adc
adc_only = not t_only
# once cuda implementation in soft_dtw_cuda.py is set up
use_cuda = self.device == 'cuda'

self.loss_fn_kw = {
'use_cuda' : use_cuda,
'adc_only' : adc_only,
't_only' : t_only,
'gamma' : 1
'gamma' : softdtw_gamma
}
if t_only:
logger.info("Using Soft DTW loss on t only")
Expand Down
4 changes: 2 additions & 2 deletions optimize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@ def _abs_dist_func(x, y):
y = y.unsqueeze(1).expand(-1, n, m, d)
return torch.abs(x - y).sum(3)

def calc_soft_dtw_loss(embed_out, embed_targ, adc_only=True, t_only=True, gamma=1):
def calc_soft_dtw_loss(embed_out, embed_targ, adc_only=True, t_only=True, gamma=1, use_cuda=False):
# Unroll embedding
x_out_nz, y_out_nz, z_out_nz, time_list_out_nz, adc_out_nz = embed_out
x_targ_nz, y_targ_nz, z_targ_nz, time_list_targ_nz, adc_targ_nz = embed_targ

sdtw = SoftDTW(use_cuda=False, gamma=1, dist_func = _abs_dist_func)
sdtw = SoftDTW(use_cuda=use_cuda, gamma=gamma, dist_func = _abs_dist_func)

if adc_only:
return sdtw(adc_out_nz[None, :, None], adc_targ_nz[None, :, None])
Expand Down

0 comments on commit 203d526

Please sign in to comment.