Skip to content

Commit

Permalink
Add different random seed for I/O vs training(#27)
Browse files Browse the repository at this point in the history
You can now specify a seed in the `iotool.sampler` configuration
and a different seed in `training` configuration.
The RandomSequenceSampler should now be using a separate RNG than
PyTorch/Numpy RNG.
  • Loading branch information
Temigo committed Aug 15, 2019
1 parent 9b7f074 commit 3a74478
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ before_script:

script:
- singularity exec local.simg pip3 install pytest --user
- singularity exec local.simg CUDA_VISIBLE_DEVICES='' pytest -rxXs
- singularity exec local.simg pytest -rxXs
2 changes: 1 addition & 1 deletion contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This repository contains a framework to define, train, run, and evaluate machine

Obviously, you should test your code. Ideally, we would have a unit testing framework that would make it easy for you to prove to others that you at least didn't break something.

Use the command `pytest -rxXs` to run all the tests that are currently available (still work in progress).
Use the command `CUDA_VISBLE_DEVICES='' pytest -rxXs` to run all the tests that are currently available (still work in progress).

## Documentation

Expand Down
27 changes: 14 additions & 13 deletions mlreco/iotools/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,46 @@
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
from torch.utils.data import Sampler


class AbstractBatchSampler(Sampler):
"""
Samplers that inherit from this class should work out of the box.
Just define the __iter__ function
__init__ defines self._data_size and self._batch_size
as well as self._random RNG if needed
"""
def __init__(self,data_size,batch_size):
self._data_size = int(data_size)
def __init__(self, data_size, batch_size, seed=0):
self._data_size = int(data_size)
if self._data_size < 0:
raise ValueError('%s received negative data size %s', (self.__class__.__name__, str(data_size)))

self._batch_size = int(batch_size)
if self._batch_size < 0 or self._batch_size > self._data_size:
raise ValueError('%s received invalid batch size %d for data size %s', (self.__class__.__name__, batch_size, str(self._data_size)))
# Use an independent random number generator for random sampling
self._random = np.random.RandomState(seed=seed)

def __len__(self):
return self._data_size


class RandomSequenceSampler(AbstractBatchSampler):
def __iter__(self):
starts = torch.randint(high=self._data_size - self._batch_size,
size=(len(self),))
return iter(np.concatenate([np.arange(start,start+self._batch_size) for start in starts]))
starts = self._random.randint(low=0, high=self._data_size - self._batch_size, size=(len(self),))
return iter(np.concatenate([np.arange(start, start+self._batch_size) for start in starts]))

@staticmethod
def create(ds,cfg):
return RandomSequenceSampler(len(ds),cfg['batch_size'])


def create(ds, cfg):
return RandomSequenceSampler(len(ds), cfg['batch_size'], seed=cfg['seed'])


class SequentialBatchSampler(AbstractBatchSampler):
def __iter__(self):
starts = np.arange(0, self._data_size - self._batch_size, self._batch_size)
return iter(np.concatenate([np.arange(start,start+self._batch_size) for start in starts]))
return iter(np.concatenate([np.arange(start, start+self._batch_size) for start in starts]))

@staticmethod
def create(ds,cfg):
return SequentialBatchSampler(len(ds),cfg['batch_size'])
def create(ds, cfg):
return SequentialBatchSampler(len(ds), cfg['batch_size'])
6 changes: 6 additions & 0 deletions mlreco/main_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def process_config(cfg):
cfg['training']['seed'] = int(time.time())
else:
cfg['training']['seed'] = int(cfg['training']['seed'])
# Update IO seed
if 'seed' not in cfg['iotool']['sampler'] or cfg['iotool']['sampler']['seed'] < 0:
import time
cfg['iotool']['sampler']['seed'] = int(time.time())
else:
cfg['iotool']['sample']['seed'] = int(cfg['iotool']['sampler']['seed'])

# Batch size checker
if cfg['iotool']['batch_size'] < 0 and cfg['training']['minibatch_size'] < 0:
Expand Down

0 comments on commit 3a74478

Please sign in to comment.