diff --git a/tensorflow/cifar10_distributed/train_model.py b/tensorflow/cifar10_distributed/train_model.py index 293921bcb764a7649bc745e063936f6dff8291c6..8012cc3cd9c5eb34bb9ba1f2c5024507ffa31225 100644 --- a/tensorflow/cifar10_distributed/train_model.py +++ b/tensorflow/cifar10_distributed/train_model.py @@ -66,19 +66,15 @@ def load_dataset(args): n_train, n_test = x_train.shape[0], x_test.shape[0] resize_size = 224 # use bigger images with ResNet - # disable any automatic data sharding in TensorFlow as we handle that manually here - # options = tf.data.Options() - # options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF - # Generating input pipelines ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train)) .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label)) - .shuffle(n_train) # .shard(num_shards=args.world_size, index=args.world_rank) - .cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options) + .shuffle(n_train) + .cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) ) ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test)) .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label)) - .shuffle(n_test).cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options) + .shuffle(n_test).cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) ) # get updated shapes @@ -107,7 +103,7 @@ def setup(args): print("") sys.stdout.flush() - tf.config.set_visible_devices(gpu_devices[args.local_rank], "GPU") + tf.config.set_visible_devices(gpu_devices[0], "GPU") tf.keras.backend.clear_session() tf.config.optimizer.set_jit(True) @@ -118,6 +114,9 @@ def setup(args): ) ) + print("MultiWorkerMirroredStrategy.num_replicas_in_sync", strategy.num_replicas_in_sync) + print("MultiWorkerMirroredStrategy.worker_index", strategy.cluster_resolver.task_id) + return strategy def main():