Skip to content

Commit

Permalink
Merge pull request #1465 from hansoli68/MultiGPU
Browse files Browse the repository at this point in the history
Multi GPU arguments fixes
  • Loading branch information
hgaiser authored Oct 1, 2020
2 parents a51bcff + c05643d commit 1d4e8ac
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion keras_retinanet/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def parse_args(args):
parser.add_argument('model', help='Path to RetinaNet model.')
parser.add_argument('--convert-model', help='Convert the model to an inference model (ie. the input is a training model).', action='store_true')
parser.add_argument('--backbone', help='The backbone of the model.', default='resnet50')
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).', type=int)
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).')
parser.add_argument('--score-threshold', help='Threshold on score to filter detections with (defaults to 0.05).', default=0.05, type=float)
parser.add_argument('--iou-threshold', help='IoU Threshold to count for a positive detection (defaults to 0.5).', default=0.5, type=float)
parser.add_argument('--max-detections', help='Max Detections per image (defaults to 100).', default=100, type=int)
Expand Down
2 changes: 1 addition & 1 deletion keras_retinanet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def csv_list(string):
group.add_argument('--no-weights', help='Don\'t initialize the model with any weights.', dest='imagenet_weights', action='store_const', const=False)
parser.add_argument('--backbone', help='Backbone model used by retinanet.', default='resnet50', type=str)
parser.add_argument('--batch-size', help='Size of the batches.', default=1, type=int)
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).', type=int)
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).')
parser.add_argument('--multi-gpu', help='Number of GPUs to use for parallel processing.', type=int, default=0)
parser.add_argument('--multi-gpu-force', help='Extra flag needed to enable (experimental) multi-gpu support.', action='store_true')
parser.add_argument('--initial-epoch', help='Epoch from which to begin the train, useful if resuming from snapshot.', type=int, default=0)
Expand Down
41 changes: 22 additions & 19 deletions keras_retinanet/utils/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@


def setup_gpu(gpu_id):
if gpu_id == 'cpu' or gpu_id == -1:
try:
visible_gpu_indices = [int(id) for id in gpu_id.split(',')]
available_gpus = tf.config.list_physical_devices('GPU')
visible_gpus = [gpu for idx, gpu in enumerate(available_gpus) if idx in visible_gpu_indices]

if visible_gpus:
try:
# Currently, memory growth needs to be the same across GPUs.
for gpu in available_gpus:
tf.config.experimental.set_memory_growth(gpu, True)

# Use only the selcted gpu.
tf.config.set_visible_devices(visible_gpus, 'GPU')
except RuntimeError as e:
# Visible devices must be set before GPUs have been initialized.
print(e)

logical_gpus = tf.config.list_logical_devices('GPU')
print(len(available_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
else:
tf.config.set_visible_devices([], 'GPU')
except ValueError:
tf.config.set_visible_devices([], 'GPU')
return

gpus = tf.config.list_physical_devices('GPU')
if gpus:
# Restrict TensorFlow to only use the first GPU.
try:
# Currently, memory growth needs to be the same across GPUs.
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)

# Use only the selcted gpu.
tf.config.set_visible_devices(gpus[int(gpu_id)], 'GPU')
except RuntimeError as e:
# Visible devices must be set before GPUs have been initialized.
print(e)

logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")

0 comments on commit 1d4e8ac

Please sign in to comment.