diff --git a/tensorflow/.gitkeep b/tensorflow/.gitkeep
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/tensorflow/cifar10_distributed/create_tfconfig.py b/tensorflow/cifar10_distributed/create_tfconfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..f423956d8e481961e45aa81b0effdf518efa0c3d
--- /dev/null
+++ b/tensorflow/cifar10_distributed/create_tfconfig.py
@@ -0,0 +1,45 @@
+import numpy as np
+import os, sys
+import json
+
+def get_job_node_list_slurm():
+    host_list_str = os.environ["SLURM_JOB_NODELIST"]
+    host_list = []
+    # TODO: parsing string based on SLURM_JOB_NODELIST is more complex
+    pass
+
+
+def get_job_node_list_slurm_rwth():
+    host_list_val = eval(os.environ["R_WLM_ABAQUSHOSTLIST"])
+    host_list = []
+    for x in host_list_val:
+        host_list.append(x[0])
+    host_list = list(set(host_list))
+    return host_list
+
+
+def build_tf_config():
+    # general settings
+    port_range_start = 23456
+    tasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"])
+
+    # create worker list
+    list_hosts = sorted(get_job_node_list_slurm_rwth())
+    list_workers = []
+    for host in list_hosts:
+        for i in range(tasks_per_node):
+            list_workers.append(f"{host}:{port_range_start+i}")
+
+    # create config and set environment variable
+    tf_config = {
+        "cluster": {"worker": list_workers},
+        "task": {"type": "worker", "index": int(os.environ["RANK"])},
+    }
+
+    str_dump = json.dumps(tf_config)
+    print(str_dump)
+
+
+if __name__ == "__main__":
+    # actual building the config
+    build_tf_config()
diff --git a/tensorflow/cifar10_distributed/execution_wrapper.sh b/tensorflow/cifar10_distributed/execution_wrapper.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ae1df9b286ff00dc41705b3a12f4e3c960581d54
--- /dev/null
+++ b/tensorflow/cifar10_distributed/execution_wrapper.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/zsh
+############################################################
+### Parameters & Directories
+############################################################
+
+export TF_CPP_MIN_LOG_LEVEL=1 # disable info messages
+export TF_GPU_THREAD_MODE='gpu_private'
+export NCCL_SOCKET_NTHREADS=8 # multi-threading for NCCL communication
+
+############################################################
+### Set TF_CONFIG
+############################################################
+export TF_CONFIG=$(python -W ignore create_tfconfig.py)
+
+############################################################
+### Execution
+############################################################
+
+# start model training
+python -W ignore train_model.py --distributed
+
+# execute with XLA JIT
+# TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" python -W ignore train_model.py --distributed
\ No newline at end of file
diff --git a/tensorflow/cifar10_distributed/set_vars.sh b/tensorflow/cifar10_distributed/set_vars.sh
new file mode 100644
index 0000000000000000000000000000000000000000..19405a4e7631dae37f0ad1603bd933949126701c
--- /dev/null
+++ b/tensorflow/cifar10_distributed/set_vars.sh
@@ -0,0 +1,18 @@
+#!/usr/local_rwth/bin/zsh
+
+export RANK=${SLURM_PROCID}
+export LOCAL_RANK=${SLURM_LOCALID}
+export WORLD_SIZE=${SLURM_NTASKS}
+
+# make variables also available inside singularity container
+export APPTAINERENV_RANK=${RANK}
+export APPTAINERENV_LOCAL_RANK=${LOCAL_RANK}
+export APPTAINERENV_WORLD_SIZE=${WORLD_SIZE}
+export APPTAINERENV_TMP="/tmp"
+
+# make additional SLURM variables available to container
+export APPTAINERENV_SLURM_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK}
+export APPTAINERENV_SLURM_NTASKS_PER_NODE=${SLURM_NTASKS_PER_NODE}
+export APPTAINERENV_SLURM_NNODES=${SLURM_NNODES}
+export APPTAINERENV_SLURM_JOB_NODELIST=${SLURM_JOB_NODELIST}
+export APPTAINERENV_R_WLM_ABAQUSHOSTLIST="${R_WLM_ABAQUSHOSTLIST}"
diff --git a/tensorflow/cifar10_distributed/submit_job_container.sh b/tensorflow/cifar10_distributed/submit_job_container.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3ffe0d1046cc79881682d7926365e19a68880747
--- /dev/null
+++ b/tensorflow/cifar10_distributed/submit_job_container.sh
@@ -0,0 +1,45 @@
+#!/usr/bin/zsh
+############################################################
+### Slurm flags
+############################################################
+
+#SBATCH --time=00:15:00
+#SBATCH --partition=c23g
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=2
+#SBATCH --cpus-per-task=24
+#SBATCH --gres=gpu:2
+
+############################################################
+### Load modules or software
+############################################################
+
+# load module for PyTorch container
+module load TensorFlow/nvcr-24.01-tf2-py3
+module list
+
+############################################################
+### Parameters and Settings
+############################################################
+
+# print some information about current system
+echo "Job nodes: ${SLURM_JOB_NODELIST}"
+echo "Current machine: $(hostname)"
+nvidia-smi
+
+export NCCL_DEBUG=INFO
+
+############################################################
+### Execution (Model Training)
+############################################################
+
+# tensorflow in container often needs a tmp directory
+NEWTMP=$(pwd)/tmp
+mkdir -p ${NEWTMP}
+
+# each process sets required environment variables and
+# runs the python script inside the container
+srun zsh -c '\
+    source set_vars.sh && \
+    apptainer exec -e --nv -B ${NEWTMP}:/tmp ${TENSORFLOW_IMAGE} \
+        bash -c "bash ./execution_wrapper.sh"'
diff --git a/tensorflow/cifar10_distributed/train_model.py b/tensorflow/cifar10_distributed/train_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..41c29e80220f93eae4f5d613ad6b3fe3da12899b
--- /dev/null
+++ b/tensorflow/cifar10_distributed/train_model.py
@@ -0,0 +1,154 @@
+from __future__ import print_function
+import numpy as np
+import os, sys
+import argparse
+import datetime
+import tensorflow as tf
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras import backend as K
+from tensorflow.keras.datasets import cifar10
+import tensorflow.keras.applications as applications
+
+def parse_command_line():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--device", required=False, type=str, choices=["cpu", "cuda"], default="cuda")
+    parser.add_argument("--num_epochs", required=False, type=int, default=5)
+    parser.add_argument("--batch_size", required=False, type=int, default=128)
+    parser.add_argument("--num_workers", required=False, type=int, default=1)
+    parser.add_argument("--distributed", required=False, action="store_true", default=False)
+    parser.add_argument("--verbosity", required=False, help="Keras verbosity level for training/evaluation", type=int, default=2)
+    parser.add_argument("--num_intraop_threads", required=False, help="Number of intra-op threads", type=int, default=None)
+    parser.add_argument("--num_interop_threads", required=False, help="Number of inter-op threads", type=int, default=None)
+    parser.add_argument("--tensorboard", required=False, help="Whether to use tensorboard callback", action="store_true", default=False)
+    parser.add_argument("--profile_batches", required=False, help='Batches to profile with for tensorboard. Format "batch_start,batch_end"', type=str, default="2,5")
+    args = parser.parse_args()
+
+    # default args for distributed
+    args.world_size = 1
+    args.world_rank = 0
+    args.local_rank = 0
+    args.global_batches = args.batch_size
+
+    if args.distributed:
+        args.world_size = int(os.environ["WORLD_SIZE"])
+        args.world_rank = int(os.environ["RANK"])
+        args.local_rank = int(os.environ["LOCAL_RANK"])
+        args.global_batches = args.batch_size * args.world_size
+
+    # only use verbose for master process
+    if args.world_rank != 0:
+        args.verbosity = 0
+
+    # specific to cifar 10 dataset
+    args.num_classes = 10
+
+    if args.world_rank == 0:
+        print("Settings:")
+        settings_map = vars(args)
+        for name in sorted(settings_map.keys()):
+            print("--" + str(name) + ": " + str(settings_map[name]))
+        print("")
+        sys.stdout.flush()
+
+    return args
+
+def load_dataset(args):
+    K.set_image_data_format("channels_last")
+
+    # load the cifar10 data
+    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+
+    # convert class vectors to binary class matrices.
+    y_train = tf.keras.utils.to_categorical(y_train, args.num_classes)
+    y_test = tf.keras.utils.to_categorical(y_test, args.num_classes)
+
+    # normalize base data
+    x_train = x_train.astype("float32") / 255
+    x_test = x_test.astype("float32") / 255
+    x_train_mean = np.mean(x_train, axis=0)
+    x_train -= x_train_mean
+    x_test -= x_train_mean
+
+    if args.world_rank == 0:
+        print("x_train shape:", x_train.shape)
+        print("y_train shape:", y_train.shape)
+        print(x_train.shape[0], "train samples")
+        print(x_test.shape[0], "test samples")
+        sys.stdout.flush()
+
+    return (x_train, y_train), (x_test, y_test)
+
+def setup(args) -> None:
+    if args.num_intraop_threads:
+        tf.config.threading.set_intra_op_parallelism_threads(args.num_intraop_threads)
+    if args.num_interop_threads:
+        tf.config.threading.set_inter_op_parallelism_threads(args.num_interop_threads)
+
+    if args.world_rank == 0:
+        print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}")
+        print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}")
+        sys.stdout.flush()
+
+    l_gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU")
+    if args.world_rank == 0:
+        print("List of GPU devices found:")
+        for dev in l_gpu_devices:
+            print(str(dev.device_type) + ": " + dev.name)
+        print("")
+        sys.stdout.flush()
+
+    tf.config.set_visible_devices(l_gpu_devices[args.local_rank], "GPU")
+    tf.keras.backend.clear_session()
+    tf.config.optimizer.set_jit(True)
+
+def main():
+    # parse command line arguments
+    args = parse_command_line()
+
+    # run setup (e.g., create distributed environment if desired)
+    setup(args)
+
+    # define data parallel strategy for distrbuted training
+    strategy = tf.distribute.MultiWorkerMirroredStrategy(
+        communication_options=tf.distribute.experimental.CommunicationOptions(
+            implementation=tf.distribute.experimental.CollectiveCommunication.NCCL
+        )
+    )
+
+    # data set loading
+    (x_train, y_train), (x_test, y_test) = load_dataset(args)
+    n_train, n_test = x_train.shape[0], x_test.shape[0]
+    input_shape = x_train.shape[1:]
+
+    # Generating input pipelines
+    ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(n_train).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+    ds_test = ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(n_test).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+
+    # callbacks to register
+    callbacks = []
+
+    with strategy.scope():
+        model = applications.ResNet50(weights=None, input_shape=input_shape, classes=args.num_classes)
+        # model.summary() # display the model architecture
+        cur_optimizer = Adam(0.001)
+        model.compile(loss="categorical_crossentropy", optimizer=cur_optimizer, metrics=["accuracy"])
+
+        # callbacks to register
+        if args.tensorboard:
+            tensorboard_callback = tf.keras.callbacks.TensorBoard(
+                log_dir=os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")),
+                histogram_freq=1,
+                profile_batch=args.profile_batches,
+            )
+            callbacks.append(tensorboard_callback)
+
+    # train the model
+    model.fit(ds_train, epochs=args.num_epochs, verbose=args.verbosity, callbacks=callbacks)
+
+    # evaluate model
+    scores = model.evaluate(ds_test, verbose=args.verbosity)
+    print(f"Test Evaluation: Accuracy: {scores[1]}")
+    sys.stdout.flush()
+
+if __name__ == "__main__":
+    main()