Skip to content

Commit

Permalink
Fix issues 402 update example mnist (#422)
Browse files Browse the repository at this point in the history
* fix #402
Modify the example mnist to make it meet the new usage requirements of "_get_features_and_labels_from_input_fn" and update docs

* remove debug code

* add mnist test.sh and update quick_start.md
  • Loading branch information
errord authored Nov 26, 2020
1 parent fa57e44 commit d306eb2
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 23 deletions.
19 changes: 16 additions & 3 deletions docs/tutorials/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,25 @@ apt-get install libgmp-dev libmpc-dev libmpfr-dev

## Run Example

To quickly run a simple training example locally:
There are two ways to run a simple training example locally:

* run test.sh

```
cd example/mnist
./test.sh
```

* run it manually and view summary from TensorBoard

```
cd example/mnist
python leader.py --local-addr=localhost:50051 --peer-addr=localhost:50052 --data-path=data/leader &
python follower.py --local-addr=localhost:50052 --peer-addr=localhost:50051 --data-path=data/follower/ &
python make_data.py
python leader.py --local-addr=localhost:50051 --peer-addr=localhost:50052 --data-path=data/leader --checkpoint-path=log/checkpoint --save-checkpoint-steps=10 --summary-path=log/summary --summary-save-steps=10 &
python follower.py --local-addr=localhost:50052 --peer-addr=localhost:50051 --data-path=data/follower/ --checkpoint-path=log/checkpoint --save-checkpoint-steps=10 --summary-path=log/summary --summary-save-steps=10
tensorboard --logdir=log
```

For better display, run the last two commands in two different terminals.
28 changes: 19 additions & 9 deletions example/mnist/follower.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,27 @@
import fedlearner.trainer as flt


ROLE = 'follower'
parser = flt.trainer_worker.create_argument_parser()
parser.add_argument('--batch-size', type=int, default=256,
help='Training batch size.')
args = parser.parse_args()


def input_fn(bridge, trainer_master):
dataset = flt.data.DataBlockLoader(256, 'follower', bridge, trainer_master)
feature_map = {
"example_id": tf.FixedLenFeature([], tf.string),
"x": tf.FixedLenFeature([28 * 28 // 2], tf.float32),
}
record_batch = dataset.make_batch_iterator().get_next()
features = tf.parse_example(record_batch, features=feature_map)
return features, {}
dataset = flt.data.DataBlockLoader(args.batch_size, ROLE,
bridge, trainer_master).make_dataset()

def parse_fn(example):
feature_map = dict()
feature_map["example_id"] = tf.FixedLenFeature([], tf.string)
feature_map["x"] = tf.FixedLenFeature([28 * 28 // 2], tf.float32)
features = tf.parse_example(example, features=feature_map)
return features, dict(y=tf.constant(0))

dataset = dataset.map(map_func=parse_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset

def serving_input_receiver_fn():
feature_map = {
Expand Down Expand Up @@ -83,5 +93,5 @@ def model_fn(model, features, labels, mode):
logging.basicConfig(level=logging.INFO)
parser = flt.trainer_worker.create_argument_parser()
flt.trainer_worker.train(
'follower', parser.parse_args(), input_fn,
ROLE, parser.parse_args(), input_fn,
model_fn, serving_input_receiver_fn)
30 changes: 19 additions & 11 deletions example/mnist/leader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,28 @@
import tensorflow.compat.v1 as tf
import fedlearner.trainer as flt

ROLE = 'leader'

parser = flt.trainer_worker.create_argument_parser()
parser.add_argument('--batch-size', type=int, default=256,
help='Training batch size.')
args = parser.parse_args()

def input_fn(bridge, trainer_master):
dataset = flt.data.DataBlockLoader(256, 'leader', bridge, trainer_master)
feature_map = {
"example_id": tf.FixedLenFeature([], tf.string),
"x": tf.FixedLenFeature([28 * 28 // 2], tf.float32),
"y": tf.FixedLenFeature([], tf.int64)
}
record_batch = dataset.make_batch_iterator().get_next()
features = tf.parse_example(record_batch, features=feature_map)
labels = {'y': features.pop('y')}
return features, labels
dataset = flt.data.DataBlockLoader(args.batch_size, ROLE,
bridge, trainer_master).make_dataset()

def parse_fn(example):
feature_map = dict()
feature_map['example_id'] = tf.FixedLenFeature([], tf.string)
feature_map['x'] = tf.FixedLenFeature([28 * 28 // 2], tf.float32)
feature_map['y'] = tf.FixedLenFeature([], tf.int64)
features = tf.parse_example(example, features=feature_map)
return features, dict(y=features.pop('y'))

dataset = dataset.map(map_func=parse_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset


def serving_input_receiver_fn():
Expand Down Expand Up @@ -109,5 +117,5 @@ def model_fn(model, features, labels, mode):
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
flt.trainer_worker.train(
'leader', args, input_fn,
ROLE, args, input_fn,
model_fn, serving_input_receiver_fn)
21 changes: 21 additions & 0 deletions example/mnist/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=""

python make_data.py

python leader.py --local-addr=localhost:50051 \
--peer-addr=localhost:50052 \
--data-path=data/leader \
--checkpoint-path=log/checkpoint \
--save-checkpoint-steps=10 &

python follower.py --local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--data-path=data/follower/ \
--checkpoint-path=log/checkpoint \
--save-checkpoint-steps=10

wait

rm -rf data log
echo "test done"

0 comments on commit d306eb2

Please sign in to comment.