Skip to content
Snippets Groups Projects
Verified Commit c89c2d24 authored by Jannis Klinkenberg's avatar Jannis Klinkenberg
Browse files

added initial horovod version for TensorFlow

parent 34146740
No related branches found
No related tags found
No related merge requests found
# TensorFlow - Distributed Training
This folder contains the following 2 example versions for distributed training:
- **Version 1:** A TensorFlow native version, that requires a bit more preparation
- **Version 2:** A version that is using Horovod ontop of TensorFlow
More information and examples concerning Horovod can be found under:
- https://horovod.readthedocs.io/en/stable/tensorflow.html
- https://horovod.readthedocs.io/en/stable/keras.html
- https://github.com/horovod/horovod/tree/master/examples/
#!/usr/bin/zsh
############################################################
### Slurm flags
############################################################
#SBATCH --time=00:15:00
#SBATCH --partition=c23g
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=24
#SBATCH --gres=gpu:2
############################################################
### Load modules or software
############################################################
# load module for TensorFlow container
module load TensorFlow/nvcr-24.01-tf2-py3
module list
############################################################
### Parameters and Settings
############################################################
# print some information about current system
echo "Job nodes: ${SLURM_JOB_NODELIST}"
echo "Current machine: $(hostname)"
nvidia-smi
export NCCL_DEBUG=INFO
export TF_CPP_MIN_LOG_LEVEL=1 # disable info messages
export TF_GPU_THREAD_MODE='gpu_private'
export NCCL_SOCKET_NTHREADS=8 # multi-threading for NCCL communication
############################################################
### Execution (Model Training)
############################################################
# TensorFlow in container often needs a tmp directory
NEWTMP=$(pwd)/tmp
mkdir -p ${NEWTMP}
# each process sets required environment variables and
# runs the python script inside the container
srun zsh -c '\
source set_vars.sh && \
apptainer exec -e --nv -B ${NEWTMP}:/tmp ${TENSORFLOW_IMAGE} \
bash -c "python -W ignore train_model_horovod.py --distributed"'
from __future__ import print_function
import numpy as np
import os, sys
import argparse
import datetime
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import cifar10
import tensorflow.keras.applications as applications
import horovod.keras as hvd
def parse_command_line():
parser = argparse.ArgumentParser()
parser.add_argument("--device", required=False, type=str, choices=["cpu", "cuda"], default="cuda")
parser.add_argument("--num_epochs", required=False, type=int, default=5)
parser.add_argument("--batch_size", required=False, type=int, default=128)
parser.add_argument("--distributed", required=False, action="store_true", default=False)
parser.add_argument("--verbosity", required=False, help="Keras verbosity level for training/evaluation", type=int, default=2)
parser.add_argument("--num_intraop_threads", required=False, help="Number of intra-op threads", type=int, default=None)
parser.add_argument("--num_interop_threads", required=False, help="Number of inter-op threads", type=int, default=None)
parser.add_argument("--tensorboard", required=False, help="Whether to use tensorboard callback", action="store_true", default=False)
parser.add_argument("--profile_batches", required=False, help='Batches to profile with for tensorboard. Format "batch_start,batch_end"', type=str, default="2,5")
args = parser.parse_args()
# default args for distributed
args.global_batches = args.batch_size
if args.distributed:
args.global_batches = args.batch_size * hvd.size()
# only use verbose for master process
if hvd.rank() != 0:
args.verbosity = 0
# specific to cifar 10 dataset
args.num_classes = 10
if hvd.rank() == 0:
print("Settings:")
settings_map = vars(args)
for name in sorted(settings_map.keys()):
print("--" + str(name) + ": " + str(settings_map[name]))
print("")
sys.stdout.flush()
return args
def load_dataset(args):
K.set_image_data_format("channels_last")
# load the cifar10 data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, args.num_classes)
y_test = tf.keras.utils.to_categorical(y_test, args.num_classes)
# normalize base data
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
x_train_mean = np.mean(x_train, axis=0)
x_train -= x_train_mean
x_test -= x_train_mean
# dimensions
if hvd.rank() == 0:
print(f"original train_shape: {x_train.shape}")
print(f"original test_shape: {x_test.shape}")
n_train, n_test = x_train.shape[0], x_test.shape[0]
resize_size = 224 # use bigger images with ResNet
# Generating input pipelines
ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
.map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
.shuffle(n_train).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
)
ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
.map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
.shuffle(n_test).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
)
# get updated shapes
train_shape, test_shape = ds_train.element_spec[0].shape, ds_test.element_spec[0].shape
if hvd.rank() == 0:
print(f"final train_shape: {train_shape}")
print(f"final test_shape: {test_shape}")
return ds_train, ds_test, train_shape
def setup(args):
if args.num_intraop_threads:
tf.config.threading.set_intra_op_parallelism_threads(args.num_intraop_threads)
if args.num_interop_threads:
tf.config.threading.set_inter_op_parallelism_threads(args.num_interop_threads)
l_gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU")
if hvd.rank() == 0:
print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}")
print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}")
print("List of GPU devices found:")
for dev in l_gpu_devices:
print(str(dev.device_type) + ": " + dev.name)
print("")
sys.stdout.flush()
tf.config.set_visible_devices(l_gpu_devices[hvd.local_rank()], "GPU")
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True)
def main():
# Horovod: initialize Horovod.
hvd.init()
# parse command line arguments
args = parse_command_line()
# run setup (e.g., create distributed environment if desired)
setup(args)
# loading desired dataset
ds_train, ds_test, train_shape = load_dataset(args)
# Horovod: (optional) compression algorithm.
compression = hvd.Compression.none # or hvd.Compression.fp16
# callbacks to register
callbacks = [
# Horovod: broadcast initial variable states from rank 0 to all other processes.
# This is necessary to ensure consistent initialization of all workers when
# training is started with random weights or restored from a checkpoint.
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
]
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
if hvd.rank() == 0:
callbacks.append(tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))
model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes)
# model.summary() # display the model architecture
# Horovod: add Horovod Distributed Optimizer.
cur_optimizer = Adam(0.001)
opt = hvd.DistributedOptimizer(cur_optimizer, compression=compression)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
# callbacks to register
if args.tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")),
histogram_freq=1,
profile_batch=args.profile_batches,
)
callbacks.append(tensorboard_callback)
# train the model
model.fit(ds_train, epochs=args.num_epochs, verbose=args.verbosity, callbacks=callbacks)
# evaluate model
scores = model.evaluate(ds_test, verbose=args.verbosity)
if hvd.rank() == 0:
print(f"Test Evaluation: Accuracy: {scores[1]}")
sys.stdout.flush()
# Horovod: synchronize at the end (replacement for barrier)
cur_rank = hvd.rank()
avg_rank = hvd.allreduce(cur_rank, average=True)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment