From c89c2d2464bd4e87c6c281e6c0173ca49a03ea94 Mon Sep 17 00:00:00 2001
From: Jannis Klinkenberg <j.klinkenberg@itc.rwth-aachen.de>
Date: Thu, 14 Nov 2024 09:04:42 +0100
Subject: [PATCH] added initial horovod version for TensorFlow

---
 tensorflow/cifar10_distributed/README.md      |  10 +
 .../submit_job_container_horovod.sh           |  48 +++++
 .../train_model_horovod.py                    | 173 ++++++++++++++++++
 3 files changed, 231 insertions(+)
 create mode 100644 tensorflow/cifar10_distributed/README.md
 create mode 100644 tensorflow/cifar10_distributed/submit_job_container_horovod.sh
 create mode 100644 tensorflow/cifar10_distributed/train_model_horovod.py

diff --git a/tensorflow/cifar10_distributed/README.md b/tensorflow/cifar10_distributed/README.md
new file mode 100644
index 0000000..296a18d
--- /dev/null
+++ b/tensorflow/cifar10_distributed/README.md
@@ -0,0 +1,10 @@
+# TensorFlow - Distributed Training
+
+This folder contains the following 2 example versions for distributed training:
+- **Version 1:** A TensorFlow native version, that requires a bit more preparation
+- **Version 2:** A version that is using Horovod ontop of TensorFlow
+
+More information and examples concerning Horovod can be found under:
+- https://horovod.readthedocs.io/en/stable/tensorflow.html
+- https://horovod.readthedocs.io/en/stable/keras.html
+- https://github.com/horovod/horovod/tree/master/examples/
diff --git a/tensorflow/cifar10_distributed/submit_job_container_horovod.sh b/tensorflow/cifar10_distributed/submit_job_container_horovod.sh
new file mode 100644
index 0000000..b400560
--- /dev/null
+++ b/tensorflow/cifar10_distributed/submit_job_container_horovod.sh
@@ -0,0 +1,48 @@
+#!/usr/bin/zsh
+############################################################
+### Slurm flags
+############################################################
+
+#SBATCH --time=00:15:00
+#SBATCH --partition=c23g
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=2
+#SBATCH --cpus-per-task=24
+#SBATCH --gres=gpu:2
+
+############################################################
+### Load modules or software
+############################################################
+
+# load module for TensorFlow container
+module load TensorFlow/nvcr-24.01-tf2-py3
+module list
+
+############################################################
+### Parameters and Settings
+############################################################
+
+# print some information about current system
+echo "Job nodes: ${SLURM_JOB_NODELIST}"
+echo "Current machine: $(hostname)"
+nvidia-smi
+
+export NCCL_DEBUG=INFO
+export TF_CPP_MIN_LOG_LEVEL=1 # disable info messages
+export TF_GPU_THREAD_MODE='gpu_private'
+export NCCL_SOCKET_NTHREADS=8 # multi-threading for NCCL communication
+
+############################################################
+### Execution (Model Training)
+############################################################
+
+# TensorFlow in container often needs a tmp directory
+NEWTMP=$(pwd)/tmp
+mkdir -p ${NEWTMP}
+
+# each process sets required environment variables and
+# runs the python script inside the container
+srun zsh -c '\
+    source set_vars.sh && \
+    apptainer exec -e --nv -B ${NEWTMP}:/tmp ${TENSORFLOW_IMAGE} \
+        bash -c "python -W ignore train_model_horovod.py --distributed"'
diff --git a/tensorflow/cifar10_distributed/train_model_horovod.py b/tensorflow/cifar10_distributed/train_model_horovod.py
new file mode 100644
index 0000000..c354c86
--- /dev/null
+++ b/tensorflow/cifar10_distributed/train_model_horovod.py
@@ -0,0 +1,173 @@
+from __future__ import print_function
+import numpy as np
+import os, sys
+import argparse
+import datetime
+import tensorflow as tf
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras import backend as K
+from tensorflow.keras.datasets import cifar10
+import tensorflow.keras.applications as applications
+import horovod.keras as hvd
+
+def parse_command_line():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--device", required=False, type=str, choices=["cpu", "cuda"], default="cuda")
+    parser.add_argument("--num_epochs", required=False, type=int, default=5)
+    parser.add_argument("--batch_size", required=False, type=int, default=128)
+    parser.add_argument("--distributed", required=False, action="store_true", default=False)
+    parser.add_argument("--verbosity", required=False, help="Keras verbosity level for training/evaluation", type=int, default=2)
+    parser.add_argument("--num_intraop_threads", required=False, help="Number of intra-op threads", type=int, default=None)
+    parser.add_argument("--num_interop_threads", required=False, help="Number of inter-op threads", type=int, default=None)
+    parser.add_argument("--tensorboard", required=False, help="Whether to use tensorboard callback", action="store_true", default=False)
+    parser.add_argument("--profile_batches", required=False, help='Batches to profile with for tensorboard. Format "batch_start,batch_end"', type=str, default="2,5")
+    args = parser.parse_args()
+
+    # default args for distributed
+    args.global_batches = args.batch_size
+
+    if args.distributed:
+        args.global_batches = args.batch_size * hvd.size()
+
+    # only use verbose for master process
+    if hvd.rank() != 0:
+        args.verbosity = 0
+
+    # specific to cifar 10 dataset
+    args.num_classes = 10
+
+    if hvd.rank() == 0:
+        print("Settings:")
+        settings_map = vars(args)
+        for name in sorted(settings_map.keys()):
+            print("--" + str(name) + ": " + str(settings_map[name]))
+        print("")
+        sys.stdout.flush()
+
+    return args
+
+def load_dataset(args):
+    K.set_image_data_format("channels_last")
+
+    # load the cifar10 data
+    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+
+    # convert class vectors to binary class matrices.
+    y_train = tf.keras.utils.to_categorical(y_train, args.num_classes)
+    y_test = tf.keras.utils.to_categorical(y_test, args.num_classes)
+
+    # normalize base data
+    x_train = x_train.astype("float32") / 255
+    x_test = x_test.astype("float32") / 255
+    x_train_mean = np.mean(x_train, axis=0)
+    x_train -= x_train_mean
+    x_test -= x_train_mean
+
+    # dimensions
+    if hvd.rank() == 0:
+        print(f"original train_shape: {x_train.shape}")
+        print(f"original test_shape: {x_test.shape}")
+    n_train, n_test = x_train.shape[0], x_test.shape[0]
+    resize_size = 224 # use bigger images with ResNet
+
+    # Generating input pipelines
+    ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
+                .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
+                .shuffle(n_train).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+    )
+    ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
+               .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
+               .shuffle(n_test).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+    )
+
+    # get updated shapes
+    train_shape, test_shape = ds_train.element_spec[0].shape, ds_test.element_spec[0].shape
+    if hvd.rank() == 0:
+        print(f"final train_shape: {train_shape}")
+        print(f"final test_shape: {test_shape}")
+
+    return ds_train, ds_test, train_shape
+
+def setup(args):
+    if args.num_intraop_threads:
+        tf.config.threading.set_intra_op_parallelism_threads(args.num_intraop_threads)
+    if args.num_interop_threads:
+        tf.config.threading.set_inter_op_parallelism_threads(args.num_interop_threads)
+
+    l_gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU")
+
+    if hvd.rank() == 0:
+        print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}")
+        print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}")
+
+        print("List of GPU devices found:")
+        for dev in l_gpu_devices:
+            print(str(dev.device_type) + ": " + dev.name)
+        print("")
+        sys.stdout.flush()
+
+    tf.config.set_visible_devices(l_gpu_devices[hvd.local_rank()], "GPU")
+    tf.keras.backend.clear_session()
+    tf.config.optimizer.set_jit(True)
+
+def main():
+    # Horovod: initialize Horovod.
+    hvd.init()
+
+    # parse command line arguments
+    args = parse_command_line()
+
+    # run setup (e.g., create distributed environment if desired)
+    setup(args)
+
+    # loading desired dataset
+    ds_train, ds_test, train_shape = load_dataset(args)
+
+    # Horovod: (optional) compression algorithm.
+    compression = hvd.Compression.none # or hvd.Compression.fp16
+
+    # callbacks to register
+    callbacks = [
+        # Horovod: broadcast initial variable states from rank 0 to all other processes.
+        # This is necessary to ensure consistent initialization of all workers when
+        # training is started with random weights or restored from a checkpoint.
+        hvd.callbacks.BroadcastGlobalVariablesCallback(0),
+    ]
+
+    # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
+    if hvd.rank() == 0:
+        callbacks.append(tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))
+
+    model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes)
+    # model.summary() # display the model architecture
+    
+    # Horovod: add Horovod Distributed Optimizer.
+    cur_optimizer = Adam(0.001)
+    opt = hvd.DistributedOptimizer(cur_optimizer, compression=compression)
+
+    model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
+
+    # callbacks to register
+    if args.tensorboard:
+        tensorboard_callback = tf.keras.callbacks.TensorBoard(
+            log_dir=os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")),
+            histogram_freq=1,
+            profile_batch=args.profile_batches,
+        )
+        callbacks.append(tensorboard_callback)
+
+    # train the model
+    model.fit(ds_train, epochs=args.num_epochs, verbose=args.verbosity, callbacks=callbacks)
+
+    # evaluate model
+    scores = model.evaluate(ds_test, verbose=args.verbosity)
+    if hvd.rank() == 0:
+        print(f"Test Evaluation: Accuracy: {scores[1]}")
+        sys.stdout.flush()
+    
+    # Horovod: synchronize at the end (replacement for barrier)
+    cur_rank = hvd.rank()
+    avg_rank = hvd.allreduce(cur_rank, average=True)
+
+if __name__ == "__main__":
+    main()
-- 
GitLab