Skip to content
Snippets Groups Projects
Commit 7be20fb3 authored by Jannis Klinkenberg's avatar Jannis Klinkenberg
Browse files

unified scripts for single GPU usage

parent d046cc6e
Branches
No related tags found
No related merge requests found
from __future__ import print_function
import numpy as np
import os, sys
import random
import argparse
import datetime
import tensorflow as tf
......@@ -79,12 +80,14 @@ def setup(args):
if args.num_interop_threads:
tf.config.threading.set_inter_op_parallelism_threads(args.num_interop_threads)
gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU")
print(f"Tensorflow get_intra_op_parallelism_threads: {tf.config.threading.get_intra_op_parallelism_threads()}")
print(f"Tensorflow get_inter_op_parallelism_threads: {tf.config.threading.get_inter_op_parallelism_threads()}")
l_gpu_devices = [] if args.device == "cpu" else tf.config.list_physical_devices("GPU")
print("List of GPU devices found:")
for dev in l_gpu_devices:
for dev in gpu_devices:
print(str(dev.device_type) + ": " + dev.name)
print("")
sys.stdout.flush()
......@@ -93,6 +96,11 @@ def setup(args):
tf.config.optimizer.set_jit(True)
def main():
# always use the same seed
random.seed(42)
tf.random.set_seed(42)
np.random.seed(42)
# parse command line arguments
args = parse_command_line()
......@@ -107,6 +115,7 @@ def main():
model = applications.ResNet50(weights=None, input_shape=train_shape[1:], classes=args.num_classes)
# model.summary() # display the model architecture
cur_optimizer = Adam(0.001)
model.compile(loss="categorical_crossentropy", optimizer=cur_optimizer, metrics=["accuracy"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment