From c89c2d2464bd4e87c6c281e6c0173ca49a03ea94 Mon Sep 17 00:00:00 2001 From: Jannis Klinkenberg <j.klinkenberg@itc.rwth-aachen.de> Date: Thu, 14 Nov 2024 09:04:42 +0100 Subject: [PATCH] added initial horovod version for TensorFlow --- tensorflow/cifar10_distributed/README.md | 10 + .../submit_job_container_horovod.sh | 48 +++++ .../train_model_horovod.py | 173 ++++++++++++++++++ 3 files changed, 231 insertions(+) create mode 100644 tensorflow/cifar10_distributed/README.md create mode 100644 tensorflow/cifar10_distributed/submit_job_container_horovod.sh create mode 100644 tensorflow/cifar10_distributed/train_model_horovod.py diff --git a/tensorflow/cifar10_distributed/README.md b/tensorflow/cifar10_distributed/README.md new file mode 100644 index 0000000..296a18d --- /dev/null +++ b/tensorflow/cifar10_distributed/README.md @@ -0,0 +1,10 @@ +# TensorFlow - Distributed Training + +This folder contains the following 2 example versions for distributed training: +- **Version 1:** A TensorFlow native version, that requires a bit more preparation +- **Version 2:** A version that is using Horovod ontop of TensorFlow + +More information and examples concerning Horovod can be found under: +- https://horovod.readthedocs.io/en/stable/tensorflow.html +- https://horovod.readthedocs.io/en/stable/keras.html +- https://github.com/horovod/horovod/tree/master/examples/ diff --git a/tensorflow/cifar10_distributed/submit_job_container_horovod.sh b/tensorflow/cifar10_distributed/submit_job_container_horovod.sh new file mode 100644 index 0000000..b400560 --- /dev/null +++ b/tensorflow/cifar10_distributed/submit_job_container_horovod.sh @@ -0,0 +1,48 @@ +#!/usr/bin/zsh +############################################################ +### Slurm flags +############################################################ + +#SBATCH --time=00:15:00 +#SBATCH --partition=c23g +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --cpus-per-task=24 +#SBATCH --gres=gpu:2 + +############################################################ +### Load modules or software +############################################################ + +# load module for TensorFlow container +module load TensorFlow/nvcr-24.01-tf2-py3 +module list + +############################################################ +### Parameters and Settings +############################################################ + +# print some information about current system +echo "Job nodes: ${SLURM_JOB_NODELIST}" +echo "Current machine: $(hostname)" +nvidia-smi + +export NCCL_DEBUG=INFO +export TF_CPP_MIN_LOG_LEVEL=1 # disable info messages +export TF_GPU_THREAD_MODE='gpu_private' +export NCCL_SOCKET_NTHREADS=8 # multi-threading for NCCL communication + +############################################################ +### Execution (Model Training) +############################################################ + +# TensorFlow in container often needs a tmp directory +NEWTMP=$(pwd)/tmp +mkdir -p ${NEWTMP} + +# each process sets required environment variables and +# runs the python script inside the container +srun zsh -c '\ + source set_vars.sh && \ + apptainer exec -e --nv -B ${NEWTMP}:/tmp ${TENSORFLOW_IMAGE} \ + bash -c "python -W ignore train_model_horovod.py --distributed"' diff --git a/tensorflow/cifar10_distributed/train_model_horovod.py b/tensorflow/cifar10_distributed/train_model_horovod.py new file mode 100644 index 0000000..c354c86 --- /dev/null +++ b/tensorflow/cifar10_distributed/train_model_horovod.py @@ -0,0 +1,173 @@ +from __future__ import print_function +import numpy as np +import os, sys +import argparse +import datetime +import tensorflow as tf +from tensorflow.keras.optimizers import Adam +from tensorflow.keras import backend as K +from tensorflow.keras.datasets import cifar10 +import tensorflow.keras.applications as applications +import horovod.keras as hvd + +def parse_command_line(): + parser = argparse.ArgumentParser() + parser.add_argument("--device", required=False, type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--num_epochs", required=False, type=int, default=5) + parser.add_argument("--batch_size", required=False, type=int, default=128) + parser.add_argument("--distributed", required=False, action="store_true", default=False) + parser.add_argument("--verbosity", required=False, help="Keras verbosity level for training/evaluation", type=int, default=2) + parser.add_argument("--num_intraop_threads", required=False, help="Number of intra-op threads", type=int, default=None) + parser.add_argument("--num_interop_threads", required=False, help="Number of inter-op threads", type=int, default=None) + parser.add_argument("--tensorboard", required=False, help="Whether to use tensorboard callback", action="store_true", default=False) + parser.add_argument("--profile_batches", required=False, help='Batches to profile with for tensorboard. Format "batch_start,batch_end"', type=str, default="2,5") + args = parser.parse_args() + + # default args for distributed + args.global_batches = args.batch_size + + if args.distributed: + args.global_batches = args.batch_size * hvd.size() + + # only use verbose for master process + if hvd.rank() != 0: + args.verbosity = 0 + + # specific to cifar 10 dataset + args.num_classes = 10 + + if hvd.rank() == 0: + print("Settings:") + settings_map = vars(args) + for name in sorted(settings_map.keys()): + print("--" + str(name) + ": " + str(settings_map[name])) + print("") + sys.stdout.flush() + + return args + +def load_dataset(args): + K.set_image_data_format("channels_last") + + # load the cifar10 data + (x_train, y_train), (x_test, y_test) = cifar10.load_data() + + # convert class vectors to binary class matrices. + y_train = tf.keras.utils.to_categorical(y_train, args.num_classes) + y_test = tf.keras.utils.to_categorical(y_test, args.num_classes) + + # normalize base data + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + x_train_mean = np.mean(x_train, axis=0) + x_train -= x_train_mean + x_test -= x_train_mean + + # dimensions + if hvd.rank() == 0: + print(f"original train_shape: {x_train.shape}") + print(f"original test_shape: {x_test.shape}") + n_train, n_test = x_train.shape[0], x_test.shape[0] + resize_size = 224 # use bigger images with ResNet + + # Generating input pipelines + ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train)) + .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label)) + .shuffle(n_train).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE) + ) + ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test)) + .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label)) + .shuffle(n_test).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE) + ) + + # get updated shapes + train_shape, test_shape = ds_train.element_spec[0].shape, ds_test.element_spec[0].shape + if hvd.rank() == 0: + print(f"final train_shape: {train_shape}") + print(f"final test_shape: {test_shape}") + + return ds_train, ds_test, train_shape + +def setup(args): + if args.num_intraop_threads: + tf.config.threading.set_intra_op_parallelism_threads(args.num_intraop_threads) + if args.num_interop_threads: + tf.config.threading.set_inter_op_parallelism_threads(args.num_interop_threads) + + l_gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU") + + if hvd.rank() == 0: + print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}") + print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}") + + print("List of GPU devices found:") + for dev in l_gpu_devices: + print(str(dev.device_type) + ": " + dev.name) + print("") + sys.stdout.flush() + + tf.config.set_visible_devices(l_gpu_devices[hvd.local_rank()], "GPU") + tf.keras.backend.clear_session() + tf.config.optimizer.set_jit(True) + +def main(): + # Horovod: initialize Horovod. + hvd.init() + + # parse command line arguments + args = parse_command_line() + + # run setup (e.g., create distributed environment if desired) + setup(args) + + # loading desired dataset + ds_train, ds_test, train_shape = load_dataset(args) + + # Horovod: (optional) compression algorithm. + compression = hvd.Compression.none # or hvd.Compression.fp16 + + # callbacks to register + callbacks = [ + # Horovod: broadcast initial variable states from rank 0 to all other processes. + # This is necessary to ensure consistent initialization of all workers when + # training is started with random weights or restored from a checkpoint. + hvd.callbacks.BroadcastGlobalVariablesCallback(0), + ] + + # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them. + if hvd.rank() == 0: + callbacks.append(tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5')) + + model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes) + # model.summary() # display the model architecture + + # Horovod: add Horovod Distributed Optimizer. + cur_optimizer = Adam(0.001) + opt = hvd.DistributedOptimizer(cur_optimizer, compression=compression) + + model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"]) + + # callbacks to register + if args.tensorboard: + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")), + histogram_freq=1, + profile_batch=args.profile_batches, + ) + callbacks.append(tensorboard_callback) + + # train the model + model.fit(ds_train, epochs=args.num_epochs, verbose=args.verbosity, callbacks=callbacks) + + # evaluate model + scores = model.evaluate(ds_test, verbose=args.verbosity) + if hvd.rank() == 0: + print(f"Test Evaluation: Accuracy: {scores[1]}") + sys.stdout.flush() + + # Horovod: synchronize at the end (replacement for barrier) + cur_rank = hvd.rank() + avg_rank = hvd.allreduce(cur_rank, average=True) + +if __name__ == "__main__": + main() -- GitLab