From fabbd911d297cad57f18494aed81cdc1fe1efda9 Mon Sep 17 00:00:00 2001 From: Jannis Klinkenberg <j.klinkenberg@itc.rwth-aachen.de> Date: Fri, 15 Nov 2024 23:43:12 +0100 Subject: [PATCH] updated scripts --- .../cifar10_distributed/execution_wrapper.sh | 7 +- tensorflow/cifar10_distributed/set_vars.sh | 3 + .../submit_job_container_horovod.sh | 2 +- .../cifar10_distributed/submit_job_venv.sh | 4 +- .../submit_job_venv_horovod.sh | 41 +++++++ tensorflow/cifar10_distributed/train_model.py | 106 +++++++----------- .../train_model_horovod.py | 16 ++- 7 files changed, 101 insertions(+), 78 deletions(-) create mode 100644 tensorflow/cifar10_distributed/submit_job_venv_horovod.sh diff --git a/tensorflow/cifar10_distributed/execution_wrapper.sh b/tensorflow/cifar10_distributed/execution_wrapper.sh index 997737c..c7899b3 100644 --- a/tensorflow/cifar10_distributed/execution_wrapper.sh +++ b/tensorflow/cifar10_distributed/execution_wrapper.sh @@ -4,12 +4,15 @@ ############################################################ export TF_CONFIG=$(python -W ignore create_tfconfig.py) +# limit visible devices to ensure correct number of replicas in TensorFlow MultiWorkerMirroredStrategy +export CUDA_VISIBLE_DEVICES=${SLURM_LOCALID} + ############################################################ ### Execution ############################################################ # start model training -python -W ignore train_model.py --distributed +python -W ignore train_model.py # execute with XLA JIT -# TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" python -W ignore train_model.py --distributed \ No newline at end of file +# TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" python -W ignore train_model.py \ No newline at end of file diff --git a/tensorflow/cifar10_distributed/set_vars.sh b/tensorflow/cifar10_distributed/set_vars.sh index 6333adb..74ffc02 100644 --- a/tensorflow/cifar10_distributed/set_vars.sh +++ b/tensorflow/cifar10_distributed/set_vars.sh @@ -16,6 +16,9 @@ export APPTAINERENV_NCCL_SOCKET_NTHREADS=${NCCL_SOCKET_NTHREADS} export APPTAINERENV_NCCL_DEBUG=${NCCL_DEBUG} # make additional SLURM variables available inside container +export APPTAINERENV_SLURM_PROCID=${SLURM_PROCID} +export APPTAINERENV_SLURM_LOCALID=${SLURM_LOCALID} +export APPTAINERENV_SLURM_NTASKS=${SLURM_NTASKS} export APPTAINERENV_SLURM_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK} export APPTAINERENV_SLURM_NTASKS_PER_NODE=${SLURM_NTASKS_PER_NODE} export APPTAINERENV_SLURM_NNODES=${SLURM_NNODES} diff --git a/tensorflow/cifar10_distributed/submit_job_container_horovod.sh b/tensorflow/cifar10_distributed/submit_job_container_horovod.sh index b400560..0b62a97 100644 --- a/tensorflow/cifar10_distributed/submit_job_container_horovod.sh +++ b/tensorflow/cifar10_distributed/submit_job_container_horovod.sh @@ -45,4 +45,4 @@ mkdir -p ${NEWTMP} srun zsh -c '\ source set_vars.sh && \ apptainer exec -e --nv -B ${NEWTMP}:/tmp ${TENSORFLOW_IMAGE} \ - bash -c "python -W ignore train_model_horovod.py --distributed"' + bash -c "python -W ignore train_model_horovod.py"' diff --git a/tensorflow/cifar10_distributed/submit_job_venv.sh b/tensorflow/cifar10_distributed/submit_job_venv.sh index 14e2209..c42ac3f 100644 --- a/tensorflow/cifar10_distributed/submit_job_venv.sh +++ b/tensorflow/cifar10_distributed/submit_job_venv.sh @@ -36,6 +36,6 @@ export NCCL_SOCKET_NTHREADS=8 # multi-threading for NCCL communication # each process sets required environment variables and # runs the python script -srun zsh -c '\ +srun zsh -c "\ source set_vars.sh && \ - zsh ./execution_wrapper.sh' \ No newline at end of file + zsh ./execution_wrapper.sh" \ No newline at end of file diff --git a/tensorflow/cifar10_distributed/submit_job_venv_horovod.sh b/tensorflow/cifar10_distributed/submit_job_venv_horovod.sh new file mode 100644 index 0000000..90b41fc --- /dev/null +++ b/tensorflow/cifar10_distributed/submit_job_venv_horovod.sh @@ -0,0 +1,41 @@ +#!/usr/bin/zsh +############################################################ +### Slurm flags +############################################################ + +#SBATCH --time=00:15:00 +#SBATCH --partition=c23g +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --cpus-per-task=24 +#SBATCH --gres=gpu:2 + +############################################################ +### Load modules or software +############################################################ + +# TODO: load/activate your desired modules and virtual environment + +############################################################ +### Parameters and Settings +############################################################ + +# print some information about current system +echo "Job nodes: ${SLURM_JOB_NODELIST}" +echo "Current machine: $(hostname)" +nvidia-smi + +export NCCL_DEBUG=INFO +export TF_CPP_MIN_LOG_LEVEL=1 # disable info messages +export TF_GPU_THREAD_MODE='gpu_private' +export NCCL_SOCKET_NTHREADS=8 # multi-threading for NCCL communication + +############################################################ +### Execution (Model Training) +############################################################ + +# each process sets required environment variables and +# runs the python script +srun zsh -c "\ + source set_vars.sh && \ + python -W ignore train_model_horovod.py" diff --git a/tensorflow/cifar10_distributed/train_model.py b/tensorflow/cifar10_distributed/train_model.py index 1169159..293921b 100644 --- a/tensorflow/cifar10_distributed/train_model.py +++ b/tensorflow/cifar10_distributed/train_model.py @@ -10,24 +10,10 @@ from tensorflow.keras import backend as K from tensorflow.keras.datasets import cifar10 import tensorflow.keras.applications as applications -class TrainLoggerModel(tf.keras.Model): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def train_step(self, data): - # # if hvd.rank() == 0: - # x, y = data - # tf.print('new batch:') - # #tf.print(x,summarize=-1) - # tf.print(y,summarize=-1) - - # Return a dict mapping metric names to current value - return {m.name: m.result() for m in self.metrics} - def parse_command_line(): parser = argparse.ArgumentParser() parser.add_argument("--device", required=False, type=str, choices=["cpu", "cuda"], default="cuda") - parser.add_argument("--num_epochs", required=False, type=int, default=3) + parser.add_argument("--num_epochs", required=False, type=int, default=5) parser.add_argument("--batch_size", required=False, type=int, default=128) parser.add_argument("--verbosity", required=False, help="Keras verbosity level for training/evaluation", type=int, default=2) parser.add_argument("--num_intraop_threads", required=False, help="Number of intra-op threads", type=int, default=None) @@ -46,13 +32,13 @@ def parse_command_line(): # specific to cifar 10 dataset args.num_classes = 10 - # if args.world_rank == 0: - print("Settings:") - settings_map = vars(args) - for name in sorted(settings_map.keys()): - print("--" + str(name) + ": " + str(settings_map[name])) - print("") - sys.stdout.flush() + if args.world_rank == 0: + print("Settings:") + settings_map = vars(args) + for name in sorted(settings_map.keys()): + print("--" + str(name) + ": " + str(settings_map[name])) + print("") + sys.stdout.flush() return args @@ -74,28 +60,32 @@ def load_dataset(args): x_test -= x_train_mean # dimensions - # if args.world_rank == 0: - print(f"original train_shape: {x_train.shape}") - print(f"original test_shape: {x_test.shape}") + if args.world_rank == 0: + print(f"original train_shape: {x_train.shape}") + print(f"original test_shape: {x_test.shape}") n_train, n_test = x_train.shape[0], x_test.shape[0] resize_size = 224 # use bigger images with ResNet + # disable any automatic data sharding in TensorFlow as we handle that manually here + # options = tf.data.Options() + # options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF + # Generating input pipelines ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train)) .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label)) - .shuffle(n_train) - .cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) + .shuffle(n_train) # .shard(num_shards=args.world_size, index=args.world_rank) + .cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options) ) ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test)) .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label)) - .shuffle(n_test).cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) + .shuffle(n_test).cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options) ) # get updated shapes train_shape, test_shape = ds_train.element_spec[0].shape, ds_test.element_spec[0].shape - # if args.world_rank == 0: - print(f"final train_shape: {train_shape}") - print(f"final test_shape: {test_shape}") + if args.world_rank == 0: + print(f"final train_shape: {train_shape}") + print(f"final test_shape: {test_shape}") return ds_train, ds_test, train_shape @@ -105,19 +95,19 @@ def setup(args): if args.num_interop_threads: tf.config.threading.set_inter_op_parallelism_threads(args.num_interop_threads) - l_gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU") + gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU") - # if args.world_rank == 0: - print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}") - print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}") + if args.world_rank == 0: + print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}") + print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}") - print("List of GPU devices found:") - for dev in l_gpu_devices: - print(str(dev.device_type) + ": " + dev.name) - print("") - sys.stdout.flush() + print("List of GPU devices found:") + for dev in gpu_devices: + print(str(dev.device_type) + ": " + dev.name) + print("") + sys.stdout.flush() - tf.config.set_visible_devices(l_gpu_devices[args.local_rank], "GPU") + tf.config.set_visible_devices(gpu_devices[args.local_rank], "GPU") tf.keras.backend.clear_session() tf.config.optimizer.set_jit(True) @@ -145,20 +135,14 @@ def main(): # loading desired dataset ds_train, ds_test, train_shape = load_dataset(args) - # options = tf.data.Options() - # options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA - # ds_train = ds_train.with_options(options) - # callbacks to register callbacks = [] with strategy.scope(): - # ds_train = strategy.experimental_distribute_dataset(ds_train) - - # model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes) - model = TrainLoggerModel() - + model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes) # model.summary() # display the model architecture + + # create optimizer and scale learning rate with number of workers cur_optimizer = Adam(learning_rate=0.001 * args.world_size) model.compile(loss="categorical_crossentropy", optimizer=cur_optimizer, metrics=["accuracy"]) @@ -171,26 +155,14 @@ def main(): ) callbacks.append(tensorboard_callback) - class PrintLabelsCallback(tf.keras.callbacks.Callback): - def on_train_batch_begin(self, batch, logs=None): - # Use strategy.run to access labels data on each worker - def print_labels(features, labels): - # Print the actual labels processed by each worker - tf.print(f"Worker labels for batch {batch}:", labels, summarize=-1) - - # Iterate through dataset and extract labels only - strategy.run(lambda x: print_labels(*x), args=(next(iter(ds_train)),)) - # train the model - model.fit(ds_train, epochs=args.num_epochs, verbose=args.verbosity, callbacks=[PrintLabelsCallback()]) + model.fit(ds_train, epochs=args.num_epochs, verbose=args.verbosity, callbacks=callbacks) # evaluate model - # scores = model.evaluate(ds_test, verbose=args.verbosity) - # if args.world_rank == 0: - # print(f"Test Evaluation: Accuracy: {scores[1]}") - # sys.stdout.flush() - - + scores = model.evaluate(ds_test, verbose=args.verbosity) + if args.world_rank == 0: + print(f"Test Evaluation: Accuracy: {scores[1]}") + sys.stdout.flush() if __name__ == "__main__": main() diff --git a/tensorflow/cifar10_distributed/train_model_horovod.py b/tensorflow/cifar10_distributed/train_model_horovod.py index 344e40f..ebb9abf 100644 --- a/tensorflow/cifar10_distributed/train_model_horovod.py +++ b/tensorflow/cifar10_distributed/train_model_horovod.py @@ -67,16 +67,20 @@ def load_dataset(args): n_train, n_test = x_train.shape[0], x_test.shape[0] resize_size = 224 # use bigger images with ResNet + # disable any automatic data sharding in TensorFlow as we handle that manually here + # options = tf.data.Options() + # options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF + # Generating input pipelines ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train)) .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label)) .shuffle(n_train).shard(num_shards=hvd.size(), index=hvd.rank()) # Horovod: need to manually shard dataset - .cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE) + .cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options) ) # Horovod: dont use sharding for test here. Otherwise reduction of results is necessary ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test)) .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label)) - .shuffle(n_test).cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE) + .shuffle(n_test).cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE) #.with_options(options) ) # get updated shapes @@ -93,19 +97,19 @@ def setup(args): if args.num_interop_threads: tf.config.threading.set_inter_op_parallelism_threads(args.num_interop_threads) - l_gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU") + gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU") if hvd.rank() == 0: print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}") print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}") print("List of GPU devices found:") - for dev in l_gpu_devices: + for dev in gpu_devices: print(str(dev.device_type) + ": " + dev.name) print("") sys.stdout.flush() - tf.config.set_visible_devices(l_gpu_devices[hvd.local_rank()], "GPU") + tf.config.set_visible_devices(gpu_devices[hvd.local_rank()], "GPU") tf.keras.backend.clear_session() tf.config.optimizer.set_jit(True) @@ -145,7 +149,7 @@ def main(): model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes) # model.summary() # display the model architecture - # Horovod: add Horovod Distributed Optimizer and scale learning rate with number of workers + # Horovod: create Horovod DistributedOptimizer and scale learning rate with number of workers cur_optimizer = Adam(learning_rate=0.001 * hvd.size()) opt = hvd.DistributedOptimizer(cur_optimizer, compression=compression) -- GitLab