diff --git a/tensorflow/cifar10_distributed/train_model.py b/tensorflow/cifar10_distributed/train_model.py
index 293921bcb764a7649bc745e063936f6dff8291c6..8012cc3cd9c5eb34bb9ba1f2c5024507ffa31225 100644
--- a/tensorflow/cifar10_distributed/train_model.py
+++ b/tensorflow/cifar10_distributed/train_model.py
@@ -66,19 +66,15 @@ def load_dataset(args):
     n_train, n_test = x_train.shape[0], x_test.shape[0]
     resize_size = 224 # use bigger images with ResNet
 
-    # disable any automatic data sharding in TensorFlow as we handle that manually here
-    # options = tf.data.Options()
-    # options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
-
     # Generating input pipelines
     ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
                 .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
-                .shuffle(n_train) # .shard(num_shards=args.world_size, index=args.world_rank)
-                .cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options)
+                .shuffle(n_train)
+                .cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE)
     )
     ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
                .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
-               .shuffle(n_test).cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options)
+               .shuffle(n_test).cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE)
     )
 
     # get updated shapes
@@ -107,7 +103,7 @@ def setup(args):
         print("")
         sys.stdout.flush()
 
-    tf.config.set_visible_devices(gpu_devices[args.local_rank], "GPU")
+    tf.config.set_visible_devices(gpu_devices[0], "GPU")
     tf.keras.backend.clear_session()
     tf.config.optimizer.set_jit(True)
 
@@ -118,6 +114,9 @@ def setup(args):
         )
     )
 
+    print("MultiWorkerMirroredStrategy.num_replicas_in_sync", strategy.num_replicas_in_sync)
+    print("MultiWorkerMirroredStrategy.worker_index", strategy.cluster_resolver.task_id)
+
     return strategy
 
 def main():