Skip to content
Snippets Groups Projects
Commit aba59654 authored by Leon Michel Gorißen's avatar Leon Michel Gorißen
Browse files

Refactor benchmark script to modularize training workflow for multiple models

- Added functionality to prepare data and execute training loops for various models (LLT, ITA).
- Implemented data analysis and interpolation for trajectory data.
- Integrated W&B sweeps to dynamically adjust training based on validation loss thresholds.
- Enabled saving models upon reaching the loss threshold.
- Improved logging and modular handling of multiple model configurations and UUIDs.
- Removed old hardcoded training loops and replaced with flexible, scalable partial functions for different models.
parent d35fbc75
Branches
Tags
No related merge requests found
THRESHOLD = 50
from dynamics_learning.data_retrieval import download_resource_content_into_uuid_folders
from dynamics_learning.preprocessing.dataset_analysis import analyze
from dynamics_learning.preprocessing.trajectory_interpolation import interpolate
from dynamics_learning.environment import (
WANDB_API_TOKEN,
SWEEP_ID,
WANDB_PROJECT,
WANDB_ENTITY,
)
from dynamics_learning.sweep.setup import setup_sweep
from dynamics_learning.training import train
from dynamics_learning.model_io import (
save_model_to_binary_file,
load_model_from_binary_file,
)
from pathlib import Path
from pritty_logger import RichLogger
import numpy as np
import tensorflow as tf
import wandb
from functools import partial
import os
import signal
from keras.models import Sequential
logger = RichLogger("dynamics_learning-benchmark_number_of_runs")
THRESHOLD = 160
NUMBER_OF_TRAJECTORIES = 100
LLT_ROBOT_UUID = "f2e72889-c140-4397-809f-fba1b892f17a"
ITA_ROBOT_UUID = "c9ff52e1-1733-4829-a209-ebd1586a8697"
WZL_ROBOT_UUID = "2e60a671-dcc3-4a36-9734-a239c899b57d"
runs = 0
val_loss = 1000
download_file = False
model1 = False
model2 = True
model3 = False
model4 = False
def prepare_data(directory: Path) -> tuple:
attained_data, command_data = analyze(directory)
interpolated_command_data = interpolate(directory)
# build input and cross validation tensors
interpolated_command_input = np.column_stack(
(
interpolated_command_data["q_interpolated_command"],
interpolated_command_data["qd_interpolated_command"],
interpolated_command_data["qdd_interpolated_command"],
)
)
q_qd_qdd_interpolated_command_input = tf.convert_to_tensor(
interpolated_command_input
)
tau_attained_input = tf.convert_to_tensor(attained_data["tau_attained"])
return (
attained_data,
command_data,
interpolated_command_data,
q_qd_qdd_interpolated_command_input,
tau_attained_input,
)
def training_loop(
sweep_id: str,
robot_uuid: str,
q_qd_qdd_interpolated_command_input: tf.Tensor,
tau_attained_input: tf.Tensor,
model: Sequential = None,
):
global val_loss, runs
model, history, run, config = train(
q_qd_qdd_interpolated_command_input, tau_attained_input, model=model
)
val_loss = float(history.history["val_loss"][-1])
runs += 1
# these values are logged, once a run i finished
logger.info(f"runs: {runs}")
logger.info(f"val_loss: {val_loss}")
# Check if the threshold has been subceeded
if val_loss < THRESHOLD:
logger.info(
f"Stopping training as val_loss has subceeded the threshold: {THRESHOLD}. The current run will be finished."
)
wandb.finish() # Finish the W&B run cleanly
# Stop the sweep entirely
logger.info("Stopping the entire sweep.")
Path(f"/app/dynamics_learning/models/{robot_uuid}/{sweep_id}").mkdir(
parents=True, exist_ok=True
)
file_path = f"/app/dynamics_learning/models/{robot_uuid}/{sweep_id}/{run.id}.h5"
save_model_to_binary_file(model, file_path)
# send SIGTERM to ensure no further runs are started
os.kill(os.getpid(), signal.SIGINT)
return runs, val_loss
def train_until_threshold_val_loss(
sweep_id: str,
robot_uuid: str,
q_qd_qdd_interpolated_command_input: tf.Tensor,
tau_attained_input: tf.Tensor,
model: Sequential = None,
):
training = partial(
training_loop,
sweep_id=sweep_id,
robot_uuid=robot_uuid,
q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input,
tau_attained_input=tau_attained_input,
model=model,
)
wandb.agent(
sweep_id,
training,
project=WANDB_PROJECT,
entity=WANDB_ENTITY,
)
logger.info(f"""Training concluded.
Number of runs: {runs}
Validation Loss: {val_loss}
""")
return None
if __name__ == "__main__":
# Remember to set the sweep id as needed:
# g5qxvipa: LLT instance from scratch
# 5vlv6m3t: ITA instance from scratch (used to train LLT instance, not trained using this script)
# 7x5hkf35: Foundation model (used to train LLT instance)
# 42d8t40t: LLT instance based on ITA instance
# fe3gjovo: LLT instance based on ITA instance with known hyper parameters
# 7tglijx8: LLT instance based on foundation model
# Download Training Data from the server
if download_file:
download_resource_content_into_uuid_folders()
wandb.login(key=WANDB_API_TOKEN, relogin=True)
###############################################
####################Model 1####################
###############################################
if model1:
# LLT instance model trained from scratch
robot_uuid = LLT_ROBOT_UUID
directory = Path(f"/app/dynamics_learning/Trajectory Data/train/{robot_uuid}")
# Interpolate Training Data in UUID folders
(
attained_data,
command_data,
interpolated_command_data,
q_qd_qdd_interpolated_command_input,
tau_attained_input,
) = prepare_data(directory)
# ensure that the sweep id is set correctly: g5qxvipa
# assert SWEEP_ID == "g5qxvipa", "Sweep ID is not set correctly. Ensure that the sweep id is set to g5qxvipa"
assert (
robot_uuid == LLT_ROBOT_UUID
), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT_UUID"
sweep_id, sweep_config = setup_sweep(create_sweep=True)
# reset runs counter
runs = 0
val_loss = 1000
# Train the model until the threshold validation loss is reached
train_until_threshold_val_loss(
sweep_id=sweep_id,
robot_uuid=robot_uuid,
q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input,
tau_attained_input=tau_attained_input,
model=None,
)
runs_model1 = runs
val_loss_model1 = val_loss
###############################################
####################Model 2####################
###############################################
if model2:
# LLT model based on ITA model without known hyperparameters
robot_uuid = LLT_ROBOT_UUID
directory = Path(f"/app/dynamics_learning/Trajectory Data/train/{robot_uuid}")
# Interpolate Training Data in UUID folders
(
attained_data,
command_data,
interpolated_command_data,
q_qd_qdd_interpolated_command_input,
tau_attained_input,
) = prepare_data(directory)
# assert SWEEP_ID == "42d8t40t", "Sweep ID is not set correctly. Ensure that the sweep id is set to 42d8t40t"
assert (
robot_uuid == LLT_ROBOT_UUID
), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT_UUID"
sweep_id, sweep_config = setup_sweep(create_sweep=True)
# reset runs counter
runs = 0
val_loss = 1000
# TODO load ITA model instead of dummy model
model = load_model_from_binary_file(
"/app/dynamics_learning/models/99.99706268310547.h5"
)
# Train the model until the threshold validation loss is reached
train_until_threshold_val_loss(
sweep_id=sweep_id,
robot_uuid=robot_uuid,
q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input,
tau_attained_input=tau_attained_input,
model=model,
)
runs_model2 = runs
val_loss_model2 = val_loss
###############################################
####################Model 3####################
###############################################
if model3:
# LLT model based on ITA model with known hyperparameters
assert (
SWEEP_ID == "fe3gjovo"
), "Sweep ID is not set correctly. Ensure that the sweep id is set to fe3gjovo"
assert (
robot_uuid == LLT_ROBOT_UUID
), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT"
# TODO load LLT Data
# TODO load ITA model
# TODO set hyperparameters for LLT model based on ITA model with known hyperparameters
###############################################
####################Model 4####################
###############################################
if model4:
# LLT model based on foundation model
assert (
SWEEP_ID == "7tglijx8"
), "Sweep ID is not set correctly. Ensure that the sweep id is set to 7tglijx8"
assert (
robot_uuid == LLT_ROBOT_UUID
), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT"
# Download Test Data from the server
runs_instance_model_from_scratch = 0
val_loss = 100
while val_loss > 50:
# Train model instance_model_from_scratch(count = 1)
runs_instance_model_from_scratch += 1
runs_instance_model_from_pretrained_instance_unknown = 0
val_loss = 100
while val_loss > 50:
# Train model instance_model_from_pretrained_instance_unknown(count = 1)
runs_instance_model_from_pretrained_instance_unknown += 1
runs_instance_model_from_pretrained_instance_known = 0
val_loss = 100
while val_loss > 50:
# Train model instance_model_from_pretrained_instance_known(count = 1)
runs_instance_model_from_pretrained_instance_known += 1
runs_instance_model_from_pretrained_foundation = 0
val_loss = 100
while val_loss > 50:
# Train model instance_model_from_pretrained_foundation(count = 1)
runs_instance_model_from_pretrained_foundation += 1
# TODO load LLT Data
# TODO load foundation model
# TODO set hyperparameters for LLT model based on foundation model
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment