From b2a1a55c151b687ac303d3379153ff6394266d10 Mon Sep 17 00:00:00 2001
From: Jannis Klinkenberg <j.klinkenberg@itc.rwth-aachen.de>
Date: Sat, 9 Nov 2024 18:42:43 +0100
Subject: [PATCH] implemented resizing

---
 tensorflow/cifar10/train_model.py             | 39 +++++++++++-------
 tensorflow/cifar10_distributed/train_model.py | 40 ++++++++++++-------
 2 files changed, 49 insertions(+), 30 deletions(-)

diff --git a/tensorflow/cifar10/train_model.py b/tensorflow/cifar10/train_model.py
index 1251e61..edf6ba4 100644
--- a/tensorflow/cifar10/train_model.py
+++ b/tensorflow/cifar10/train_model.py
@@ -50,13 +50,28 @@ def load_dataset(args):
     x_train -= x_train_mean
     x_test -= x_train_mean
 
-    print("x_train shape:", x_train.shape)
-    print("y_train shape:", y_train.shape)
-    print(x_train.shape[0], "train samples")
-    print(x_test.shape[0], "test samples")
-    sys.stdout.flush()
+    # dimensions
+    print(f"original train_shape: {x_train.shape}")
+    print(f"original test_shape: {x_test.shape}")
+    n_train, n_test = x_train.shape[0], x_test.shape[0]
+    resize_size = 224 # use bigger images with ResNet
 
-    return (x_train, y_train), (x_test, y_test)
+    # Generating input pipelines
+    ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
+                .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
+                .shuffle(n_train).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+    )
+    ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
+               .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
+               .shuffle(n_test).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+    )
+
+    # get updated shapes
+    train_shape, test_shape = ds_train.element_spec[0].shape, ds_test.element_spec[0].shape
+    print(f"final train_shape: {train_shape}")
+    print(f"final test_shape: {test_shape}")
+
+    return ds_train, ds_test, train_shape
 
 def setup(args):
     if args.num_intraop_threads:
@@ -84,19 +99,13 @@ def main():
     # run setup (e.g., create distributed environment if desired)
     setup(args)
 
-    # data set loading
-    (x_train, y_train), (x_test, y_test) = load_dataset(args)
-    n_train, n_test = x_train.shape[0], x_test.shape[0]
-    input_shape = x_train.shape[1:]
-
-    # Generating input pipelines
-    ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(n_train).cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE)
-    ds_test = ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(n_test).cache().batch(args.batch_size).prefetch(tf.data.AUTOTUNE)
+    # loading desired dataset
+    ds_train, ds_test, train_shape = load_dataset(args)
 
     # callbacks to register
     callbacks = []
 
-    model = applications.ResNet50(weights=None, input_shape=input_shape, classes=args.num_classes)
+    model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes)
     # model.summary() # display the model architecture
     cur_optimizer = Adam(0.001)
     model.compile(loss="categorical_crossentropy", optimizer=cur_optimizer, metrics=["accuracy"])
diff --git a/tensorflow/cifar10_distributed/train_model.py b/tensorflow/cifar10_distributed/train_model.py
index ab0c3c0..42932f2 100644
--- a/tensorflow/cifar10_distributed/train_model.py
+++ b/tensorflow/cifar10_distributed/train_model.py
@@ -68,14 +68,30 @@ def load_dataset(args):
     x_train -= x_train_mean
     x_test -= x_train_mean
 
+    # dimensions
     if args.world_rank == 0:
-        print("x_train shape:", x_train.shape)
-        print("y_train shape:", y_train.shape)
-        print(x_train.shape[0], "train samples")
-        print(x_test.shape[0], "test samples")
-        sys.stdout.flush()
+        print(f"original train_shape: {x_train.shape}")
+        print(f"original test_shape: {x_test.shape}")
+    n_train, n_test = x_train.shape[0], x_test.shape[0]
+    resize_size = 224 # use bigger images with ResNet
 
-    return (x_train, y_train), (x_test, y_test)
+    # Generating input pipelines
+    ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
+                .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
+                .shuffle(n_train).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+    )
+    ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
+               .map(lambda image, label: (tf.image.resize(image, [resize_size, resize_size]), label))
+               .shuffle(n_test).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+    )
+
+    # get updated shapes
+    train_shape, test_shape = ds_train.element_spec[0].shape, ds_test.element_spec[0].shape
+    if args.world_rank == 0:
+        print(f"final train_shape: {train_shape}")
+        print(f"final test_shape: {test_shape}")
+
+    return ds_train, ds_test, train_shape
 
 def setup(args):
     if args.num_intraop_threads:
@@ -115,20 +131,14 @@ def main():
     # run setup (e.g., create distributed environment if desired)
     strategy = setup(args)
 
-    # data set loading
-    (x_train, y_train), (x_test, y_test) = load_dataset(args)
-    n_train, n_test = x_train.shape[0], x_test.shape[0]
-    input_shape = x_train.shape[1:]
-
-    # Generating input pipelines
-    ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(n_train).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
-    ds_test = ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(n_test).cache().batch(args.global_batches).prefetch(tf.data.AUTOTUNE)
+    # loading desired dataset
+    ds_train, ds_test, train_shape = load_dataset(args)
 
     # callbacks to register
     callbacks = []
 
     with strategy.scope():
-        model = applications.ResNet50(weights=None, input_shape=input_shape, classes=args.num_classes)
+        model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes)
         # model.summary() # display the model architecture
         cur_optimizer = Adam(0.001)
         model.compile(loss="categorical_crossentropy", optimizer=cur_optimizer, metrics=["accuracy"])
-- 
GitLab