diff --git a/tensorflow/cifar10_distributed/train_model.py b/tensorflow/cifar10_distributed/train_model.py
index 42932f2a28e7ad331542f32f89f403fb2e6fb5fc..11691599639fcbc9fd5199dee9a1ad7e8a94b7fd 100644
--- a/tensorflow/cifar10_distributed/train_model.py
+++ b/tensorflow/cifar10_distributed/train_model.py
@@ -1,6 +1,7 @@
 from __future__ import print_function
 import numpy as np
 import os, sys
+import random
 import argparse
 import datetime
 import tensorflow as tf
@@ -9,12 +10,25 @@ from tensorflow.keras import backend as K
 from tensorflow.keras.datasets import cifar10
 import tensorflow.keras.applications as applications
 
+class TrainLoggerModel(tf.keras.Model):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def train_step(self, data):
+        # # if hvd.rank() == 0:
+        # x, y = data
+        # tf.print('new batch:')
+        # #tf.print(x,summarize=-1)
+        # tf.print(y,summarize=-1)
+
+        # Return a dict mapping metric names to current value
+        return {m.name: m.result() for m in self.metrics}
+
 def parse_command_line():
     parser = argparse.ArgumentParser()
     parser.add_argument("--device", required=False, type=str, choices=["cpu", "cuda"], default="cuda")
-    parser.add_argument("--num_epochs", required=False, type=int, default=5)
+    parser.add_argument("--num_epochs", required=False, type=int, default=3)
     parser.add_argument("--batch_size", required=False, type=int, default=128)
-    parser.add_argument("--distributed", required=False, action="store_true", default=False)
     parser.add_argument("--verbosity", required=False, help="Keras verbosity level for training/evaluation", type=int, default=2)
     parser.add_argument("--num_intraop_threads", required=False, help="Number of intra-op threads", type=int, default=None)
     parser.add_argument("--num_interop_threads", required=False, help="Number of inter-op threads", type=int, default=None)
@@ -23,31 +37,22 @@ def parse_command_line():
     args = parser.parse_args()
 
     # default args for distributed
-    args.world_size = 1
-    args.world_rank = 0
-    args.local_rank = 0
-    args.global_batches = args.batch_size
-
-    if args.distributed:
-        args.world_size = int(os.environ["WORLD_SIZE"])
-        args.world_rank = int(os.environ["RANK"])
-        args.local_rank = int(os.environ["LOCAL_RANK"])
-        args.global_batches = args.batch_size * args.world_size
-
-    # only use verbose for master process
-    if args.world_rank != 0:
-        args.verbosity = 0
+    args.world_size = int(os.environ["WORLD_SIZE"])
+    args.world_rank = int(os.environ["RANK"])
+    args.local_rank = int(os.environ["LOCAL_RANK"])
+    args.global_batch_size = args.batch_size * args.world_size
+    args.verbosity = 0 if args.world_rank != 0 else args.verbosity # only use verbose for master process
 
     # specific to cifar 10 dataset
     args.num_classes = 10
 
-    if args.world_rank == 0:
-        print("Settings:")
-        settings_map = vars(args)
-        for name in sorted(settings_map.keys()):
-            print("--" + str(name) + ": " + str(settings_map[name]))
-        print("")
-        sys.stdout.flush()
+    # if args.world_rank == 0:
+    print("Settings:")
+    settings_map = vars(args)
+    for name in sorted(settings_map.keys()):
+        print("--" + str(name) + ": " + str(settings_map[name]))
+    print("")
+    sys.stdout.flush()
 
     return args
 
@@ -69,27 +74,28 @@ def load_dataset(args):
     x_test -= x_train_mean
 
     # dimensions
-    if args.world_rank == 0:
-        print(f"original train_shape: {x_train.shape}")
-        print(f"original test_shape: {x_test.shape}")
+    # if args.world_rank == 0:
+    print(f"original train_shape: {x_train.shape}")
+    print(f"original test_shape: {x_test.shape}")
     n_train, n_test = x_train.shape[0], x_test.shape[0]
     resize_size = 224 # use bigger images with ResNet
 
     # Generating input pipelines
     ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
                 .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
-                .shuffle(n_train).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+                .shuffle(n_train)
+                .cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE)
     )
     ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
                .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
-               .shuffle(n_test).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+               .shuffle(n_test).cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE)
     )
 
     # get updated shapes
     train_shape, test_shape = ds_train.element_spec[0].shape, ds_test.element_spec[0].shape
-    if args.world_rank == 0:
-        print(f"final train_shape: {train_shape}")
-        print(f"final test_shape: {test_shape}")
+    # if args.world_rank == 0:
+    print(f"final train_shape: {train_shape}")
+    print(f"final test_shape: {test_shape}")
 
     return ds_train, ds_test, train_shape
 
@@ -101,15 +107,15 @@ def setup(args):
 
     l_gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU")
 
-    if args.world_rank == 0:
-        print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}")
-        print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}")
+    # if args.world_rank == 0:
+    print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}")
+    print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}")
 
-        print("List of GPU devices found:")
-        for dev in l_gpu_devices:
-            print(str(dev.device_type) + ": " + dev.name)
-        print("")
-        sys.stdout.flush()
+    print("List of GPU devices found:")
+    for dev in l_gpu_devices:
+        print(str(dev.device_type) + ": " + dev.name)
+    print("")
+    sys.stdout.flush()
 
     tf.config.set_visible_devices(l_gpu_devices[args.local_rank], "GPU")
     tf.keras.backend.clear_session()
@@ -125,6 +131,11 @@ def setup(args):
     return strategy
 
 def main():
+    # always use the same seed
+    random.seed(42)
+    tf.random.set_seed(42)
+    np.random.seed(42)
+
     # parse command line arguments
     args = parse_command_line()
 
@@ -134,13 +145,21 @@ def main():
     # loading desired dataset
     ds_train, ds_test, train_shape = load_dataset(args)
 
+    # options = tf.data.Options()
+    # options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
+    # ds_train = ds_train.with_options(options)
+
     # callbacks to register
     callbacks = []
 
     with strategy.scope():
-        model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes)
+        # ds_train = strategy.experimental_distribute_dataset(ds_train)
+
+        # model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes)
+        model = TrainLoggerModel()
+        
         # model.summary() # display the model architecture
-        cur_optimizer = Adam(0.001)
+        cur_optimizer = Adam(learning_rate=0.001 * args.world_size)
         model.compile(loss="categorical_crossentropy", optimizer=cur_optimizer, metrics=["accuracy"])
 
     # callbacks to register
@@ -152,14 +171,26 @@ def main():
         )
         callbacks.append(tensorboard_callback)
 
+    class PrintLabelsCallback(tf.keras.callbacks.Callback):
+        def on_train_batch_begin(self, batch, logs=None):
+            # Use strategy.run to access labels data on each worker
+            def print_labels(features, labels):
+                # Print the actual labels processed by each worker
+                tf.print(f"Worker labels for batch {batch}:", labels, summarize=-1)
+
+            # Iterate through dataset and extract labels only
+            strategy.run(lambda x: print_labels(*x), args=(next(iter(ds_train)),))
+
     # train the model
-    model.fit(ds_train, epochs=args.num_epochs, verbose=args.verbosity, callbacks=callbacks)
+    model.fit(ds_train, epochs=args.num_epochs, verbose=args.verbosity, callbacks=[PrintLabelsCallback()])
 
     # evaluate model
-    scores = model.evaluate(ds_test, verbose=args.verbosity)
-    if args.world_rank == 0:
-        print(f"Test Evaluation: Accuracy: {scores[1]}")
-        sys.stdout.flush()
+    # scores = model.evaluate(ds_test, verbose=args.verbosity)
+    # if args.world_rank == 0:
+    #     print(f"Test Evaluation: Accuracy: {scores[1]}")
+    #     sys.stdout.flush()
+
+
 
 if __name__ == "__main__":
     main()
diff --git a/tensorflow/cifar10_distributed/train_model_horovod.py b/tensorflow/cifar10_distributed/train_model_horovod.py
index c354c86cb24c4b56561d6ca67c5a04a18b949238..344e40f61969eac0f1160f2fbc8a5adca28e6de5 100644
--- a/tensorflow/cifar10_distributed/train_model_horovod.py
+++ b/tensorflow/cifar10_distributed/train_model_horovod.py
@@ -1,6 +1,7 @@
 from __future__ import print_function
 import numpy as np
 import os, sys
+import random
 import argparse
 import datetime
 import tensorflow as tf
@@ -15,7 +16,6 @@ def parse_command_line():
     parser.add_argument("--device", required=False, type=str, choices=["cpu", "cuda"], default="cuda")
     parser.add_argument("--num_epochs", required=False, type=int, default=5)
     parser.add_argument("--batch_size", required=False, type=int, default=128)
-    parser.add_argument("--distributed", required=False, action="store_true", default=False)
     parser.add_argument("--verbosity", required=False, help="Keras verbosity level for training/evaluation", type=int, default=2)
     parser.add_argument("--num_intraop_threads", required=False, help="Number of intra-op threads", type=int, default=None)
     parser.add_argument("--num_interop_threads", required=False, help="Number of inter-op threads", type=int, default=None)
@@ -24,14 +24,11 @@ def parse_command_line():
     args = parser.parse_args()
 
     # default args for distributed
-    args.global_batches = args.batch_size
-
-    if args.distributed:
-        args.global_batches = args.batch_size * hvd.size()
-
-    # only use verbose for master process
-    if hvd.rank() != 0:
-        args.verbosity = 0
+    args.world_size = hvd.size()
+    args.world_rank = hvd.rank()
+    args.local_rank = hvd.local_rank()
+    args.global_batch_size = args.batch_size * hvd.size()
+    args.verbosity = 0 if hvd.rank() != 0 else args.verbosity # only use verbose for master process
 
     # specific to cifar 10 dataset
     args.num_classes = 10
@@ -73,11 +70,13 @@ def load_dataset(args):
     # Generating input pipelines
     ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
                 .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
-                .shuffle(n_train).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+                .shuffle(n_train).shard(num_shards=hvd.size(), index=hvd.rank()) # Horovod: need to manually shard dataset
+                .cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE)
     )
+    # Horovod: dont use sharding for test here. Otherwise reduction of results is necessary
     ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
                .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
-               .shuffle(n_test).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+               .shuffle(n_test).cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE)
     )
 
     # get updated shapes
@@ -111,6 +110,11 @@ def setup(args):
     tf.config.optimizer.set_jit(True)
 
 def main():
+    # always use the same seed
+    random.seed(42)
+    tf.random.set_seed(42)
+    np.random.seed(42)
+
     # Horovod: initialize Horovod.
     hvd.init()
 
@@ -134,15 +138,15 @@ def main():
         hvd.callbacks.BroadcastGlobalVariablesCallback(0),
     ]
 
-    # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
-    if hvd.rank() == 0:
-        callbacks.append(tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))
+    # If desired: save checkpoints only on worker 0 to prevent other workers from corrupting them.
+    # if hvd.rank() == 0:
+    #     callbacks.append(tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))
 
     model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes)
     # model.summary() # display the model architecture
     
-    # Horovod: add Horovod Distributed Optimizer.
-    cur_optimizer = Adam(0.001)
+    # Horovod: add Horovod Distributed Optimizer and scale learning rate with number of workers
+    cur_optimizer = Adam(learning_rate=0.001 * hvd.size())
     opt = hvd.DistributedOptimizer(cur_optimizer, compression=compression)
 
     model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])