Skip to content
Snippets Groups Projects
Commit 7ecfb8bb authored by Jannis Klinkenberg's avatar Jannis Klinkenberg
Browse files

removed obsolete statements from Horovod example

parent fabbd911
Branches
No related tags found
No related merge requests found
......@@ -67,20 +67,16 @@ def load_dataset(args):
n_train, n_test = x_train.shape[0], x_test.shape[0]
resize_size = 224 # use bigger images with ResNet
# disable any automatic data sharding in TensorFlow as we handle that manually here
# options = tf.data.Options()
# options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
# Generating input pipelines
ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
.map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
.shuffle(n_train).shard(num_shards=hvd.size(), index=hvd.rank()) # Horovod: need to manually shard dataset
.cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options)
.cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE)
)
# Horovod: dont use sharding for test here. Otherwise reduction of results is necessary
ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
.map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
.shuffle(n_test).cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options)
.shuffle(n_test).cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE)
)
# get updated shapes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment