Skip to content
Snippets Groups Projects
Select Git revision
  • b57c26adbf22b3b2a6d991dceb5dd3378d7497e1
  • master default protected
2 results

demo3-microservice-in-aws.py

Blame
  • train_model.py 6.60 KiB
    from __future__ import print_function
    import numpy as np
    import os, sys
    import random
    import argparse
    import datetime
    import tensorflow as tf
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras import backend as K
    from tensorflow.keras.datasets import cifar10
    import tensorflow.keras.applications as applications
    
    def parse_command_line():
        parser = argparse.ArgumentParser()
        parser.add_argument("--device", required=False, type=str, choices=["cpu", "cuda"], default="cuda")
        parser.add_argument("--num_epochs", required=False, type=int, default=5)
        parser.add_argument("--batch_size", required=False, type=int, default=128)
        parser.add_argument("--verbosity", required=False, help="Keras verbosity level for training/evaluation", type=int, default=2)
        parser.add_argument("--num_intraop_threads", required=False, help="Number of intra-op threads", type=int, default=None)
        parser.add_argument("--num_interop_threads", required=False, help="Number of inter-op threads", type=int, default=None)
        parser.add_argument("--tensorboard", required=False, help="Whether to use tensorboard callback", action="store_true", default=False)
        parser.add_argument("--profile_batches", required=False, help='Batches to profile with for tensorboard. Format "batch_start,batch_end"', type=str, default="2,5")
        args = parser.parse_args()
    
        # default args for distributed
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.world_rank = int(os.environ["RANK"])
        args.local_rank = int(os.environ["LOCAL_RANK"])
        args.global_batch_size = args.batch_size * args.world_size
        args.verbosity = 0 if args.world_rank != 0 else args.verbosity # only use verbose for master process
    
        # specific to cifar 10 dataset
        args.num_classes = 10
    
        if args.world_rank == 0:
            print("Settings:")
            settings_map = vars(args)
            for name in sorted(settings_map.keys()):
                print("--" + str(name) + ": " + str(settings_map[name]))
            print("")
            sys.stdout.flush()
    
        return args
    
    def load_dataset(args):
        K.set_image_data_format("channels_last")
    
        # load the cifar10 data
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    
        # convert class vectors to binary class matrices.
        y_train = tf.keras.utils.to_categorical(y_train, args.num_classes)
        y_test = tf.keras.utils.to_categorical(y_test, args.num_classes)
    
        # normalize base data
        x_train = x_train.astype("float32") / 255
        x_test = x_test.astype("float32") / 255
        x_train_mean = np.mean(x_train, axis=0)
        x_train -= x_train_mean
        x_test -= x_train_mean
    
        # dimensions
        if args.world_rank == 0:
            print(f"original train_shape: {x_train.shape}")
            print(f"original test_shape: {x_test.shape}")
        n_train, n_test = x_train.shape[0], x_test.shape[0]
        resize_size = 224 # use bigger images with ResNet
    
        # Generating input pipelines
        ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
                    .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
                    .shuffle(n_train)
                    .cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE)
        )
        ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
                   .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
                   .shuffle(n_test).cache().batch(args.global_batch_size).prefetch(tf.data.AUTOTUNE)
        )
    
        # get updated shapes
        train_shape, test_shape = ds_train.element_spec[0].shape, ds_test.element_spec[0].shape
        if args.world_rank == 0:
            print(f"final train_shape: {train_shape}")
            print(f"final test_shape: {test_shape}")
    
        return ds_train, ds_test, train_shape
    
    def setup(args):
        if args.num_intraop_threads:
            tf.config.threading.set_intra_op_parallelism_threads(args.num_intraop_threads)
        if args.num_interop_threads:
            tf.config.threading.set_inter_op_parallelism_threads(args.num_interop_threads)
    
        gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU")
    
        if args.world_rank == 0:
            print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}")
            print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}")
    
            print("List of GPU devices found:")
            for dev in gpu_devices:
                print(str(dev.device_type) + ": " + dev.name)
            print("")
            sys.stdout.flush()
    
        tf.config.set_visible_devices(gpu_devices[0], "GPU")
        tf.keras.backend.clear_session()
        tf.config.optimizer.set_jit(True)
    
        # define data parallel strategy for distrbuted training
        strategy = tf.distribute.MultiWorkerMirroredStrategy(
            communication_options=tf.distribute.experimental.CommunicationOptions(
                implementation=tf.distribute.experimental.CollectiveCommunication.NCCL
            )
        )
    
        print("MultiWorkerMirroredStrategy.num_replicas_in_sync", strategy.num_replicas_in_sync)
        print("MultiWorkerMirroredStrategy.worker_index", strategy.cluster_resolver.task_id)
    
        return strategy
    
    def main():
        # always use the same seed
        random.seed(42)
        tf.random.set_seed(42)
        np.random.seed(42)
    
        # parse command line arguments
        args = parse_command_line()
    
        # run setup (e.g., create distributed environment if desired)
        strategy = setup(args)
    
        # loading desired dataset
        ds_train, ds_test, train_shape = load_dataset(args)
    
        # callbacks to register
        callbacks = []
    
        with strategy.scope():
            model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes)
            # model.summary() # display the model architecture
    
            # create optimizer and scale learning rate with number of workers
            cur_optimizer = Adam(learning_rate=0.001 * args.world_size)
            model.compile(loss="categorical_crossentropy", optimizer=cur_optimizer, metrics=["accuracy"])
    
        # callbacks to register
        if args.tensorboard:
            tensorboard_callback = tf.keras.callbacks.TensorBoard(
                log_dir=os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")),
                histogram_freq=1,
                profile_batch=args.profile_batches,
            )
            callbacks.append(tensorboard_callback)
    
        # train the model
        model.fit(ds_train, epochs=args.num_epochs, verbose=args.verbosity, callbacks=callbacks)
    
        # evaluate model
        scores = model.evaluate(ds_test, verbose=args.verbosity)
        if args.world_rank == 0:
            print(f"Test Evaluation: Accuracy: {scores[1]}")
            sys.stdout.flush()
    
    if __name__ == "__main__":
        main()