Skip to content
Snippets Groups Projects
Commit aba59654 authored by Leon Michel Gorißen's avatar Leon Michel Gorißen
Browse files

Refactor benchmark script to modularize training workflow for multiple models

- Added functionality to prepare data and execute training loops for various models (LLT, ITA).
- Implemented data analysis and interpolation for trajectory data.
- Integrated W&B sweeps to dynamically adjust training based on validation loss thresholds.
- Enabled saving models upon reaching the loss threshold.
- Improved logging and modular handling of multiple model configurations and UUIDs.
- Removed old hardcoded training loops and replaced with flexible, scalable partial functions for different models.
parent d35fbc75
No related branches found
No related tags found
No related merge requests found
THRESHOLD = 50 from dynamics_learning.data_retrieval import download_resource_content_into_uuid_folders
from dynamics_learning.preprocessing.dataset_analysis import analyze
from dynamics_learning.preprocessing.trajectory_interpolation import interpolate
from dynamics_learning.environment import (
WANDB_API_TOKEN,
SWEEP_ID,
WANDB_PROJECT,
WANDB_ENTITY,
)
from dynamics_learning.sweep.setup import setup_sweep
from dynamics_learning.training import train
from dynamics_learning.model_io import (
save_model_to_binary_file,
load_model_from_binary_file,
)
from pathlib import Path
from pritty_logger import RichLogger
import numpy as np
import tensorflow as tf
import wandb
from functools import partial
import os
import signal
from keras.models import Sequential
logger = RichLogger("dynamics_learning-benchmark_number_of_runs")
THRESHOLD = 160
NUMBER_OF_TRAJECTORIES = 100 NUMBER_OF_TRAJECTORIES = 100
LLT_ROBOT_UUID = "f2e72889-c140-4397-809f-fba1b892f17a"
ITA_ROBOT_UUID = "c9ff52e1-1733-4829-a209-ebd1586a8697"
WZL_ROBOT_UUID = "2e60a671-dcc3-4a36-9734-a239c899b57d"
runs = 0
val_loss = 1000
download_file = False
model1 = False
model2 = True
model3 = False
model4 = False
def prepare_data(directory: Path) -> tuple:
attained_data, command_data = analyze(directory)
interpolated_command_data = interpolate(directory)
# build input and cross validation tensors
interpolated_command_input = np.column_stack(
(
interpolated_command_data["q_interpolated_command"],
interpolated_command_data["qd_interpolated_command"],
interpolated_command_data["qdd_interpolated_command"],
)
)
q_qd_qdd_interpolated_command_input = tf.convert_to_tensor(
interpolated_command_input
)
tau_attained_input = tf.convert_to_tensor(attained_data["tau_attained"])
return (
attained_data,
command_data,
interpolated_command_data,
q_qd_qdd_interpolated_command_input,
tau_attained_input,
)
def training_loop(
sweep_id: str,
robot_uuid: str,
q_qd_qdd_interpolated_command_input: tf.Tensor,
tau_attained_input: tf.Tensor,
model: Sequential = None,
):
global val_loss, runs
model, history, run, config = train(
q_qd_qdd_interpolated_command_input, tau_attained_input, model=model
)
val_loss = float(history.history["val_loss"][-1])
runs += 1
# these values are logged, once a run i finished
logger.info(f"runs: {runs}")
logger.info(f"val_loss: {val_loss}")
# Check if the threshold has been subceeded
if val_loss < THRESHOLD:
logger.info(
f"Stopping training as val_loss has subceeded the threshold: {THRESHOLD}. The current run will be finished."
)
wandb.finish() # Finish the W&B run cleanly
# Stop the sweep entirely
logger.info("Stopping the entire sweep.")
Path(f"/app/dynamics_learning/models/{robot_uuid}/{sweep_id}").mkdir(
parents=True, exist_ok=True
)
file_path = f"/app/dynamics_learning/models/{robot_uuid}/{sweep_id}/{run.id}.h5"
save_model_to_binary_file(model, file_path)
# send SIGTERM to ensure no further runs are started
os.kill(os.getpid(), signal.SIGINT)
return runs, val_loss
def train_until_threshold_val_loss(
sweep_id: str,
robot_uuid: str,
q_qd_qdd_interpolated_command_input: tf.Tensor,
tau_attained_input: tf.Tensor,
model: Sequential = None,
):
training = partial(
training_loop,
sweep_id=sweep_id,
robot_uuid=robot_uuid,
q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input,
tau_attained_input=tau_attained_input,
model=model,
)
wandb.agent(
sweep_id,
training,
project=WANDB_PROJECT,
entity=WANDB_ENTITY,
)
logger.info(f"""Training concluded.
Number of runs: {runs}
Validation Loss: {val_loss}
""")
return None
if __name__ == "__main__": if __name__ == "__main__":
# Remember to set the sweep id as needed:
# g5qxvipa: LLT instance from scratch
# 5vlv6m3t: ITA instance from scratch (used to train LLT instance, not trained using this script)
# 7x5hkf35: Foundation model (used to train LLT instance)
# 42d8t40t: LLT instance based on ITA instance
# fe3gjovo: LLT instance based on ITA instance with known hyper parameters
# 7tglijx8: LLT instance based on foundation model
# Download Training Data from the server # Download Training Data from the server
if download_file:
download_resource_content_into_uuid_folders()
wandb.login(key=WANDB_API_TOKEN, relogin=True)
###############################################
####################Model 1####################
###############################################
if model1:
# LLT instance model trained from scratch
robot_uuid = LLT_ROBOT_UUID
directory = Path(f"/app/dynamics_learning/Trajectory Data/train/{robot_uuid}")
# Interpolate Training Data in UUID folders
(
attained_data,
command_data,
interpolated_command_data,
q_qd_qdd_interpolated_command_input,
tau_attained_input,
) = prepare_data(directory)
# ensure that the sweep id is set correctly: g5qxvipa
# assert SWEEP_ID == "g5qxvipa", "Sweep ID is not set correctly. Ensure that the sweep id is set to g5qxvipa"
assert (
robot_uuid == LLT_ROBOT_UUID
), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT_UUID"
sweep_id, sweep_config = setup_sweep(create_sweep=True)
# reset runs counter
runs = 0
val_loss = 1000
# Train the model until the threshold validation loss is reached
train_until_threshold_val_loss(
sweep_id=sweep_id,
robot_uuid=robot_uuid,
q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input,
tau_attained_input=tau_attained_input,
model=None,
)
runs_model1 = runs
val_loss_model1 = val_loss
###############################################
####################Model 2####################
###############################################
if model2:
# LLT model based on ITA model without known hyperparameters
robot_uuid = LLT_ROBOT_UUID
directory = Path(f"/app/dynamics_learning/Trajectory Data/train/{robot_uuid}")
# Interpolate Training Data in UUID folders
(
attained_data,
command_data,
interpolated_command_data,
q_qd_qdd_interpolated_command_input,
tau_attained_input,
) = prepare_data(directory)
# assert SWEEP_ID == "42d8t40t", "Sweep ID is not set correctly. Ensure that the sweep id is set to 42d8t40t"
assert (
robot_uuid == LLT_ROBOT_UUID
), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT_UUID"
sweep_id, sweep_config = setup_sweep(create_sweep=True)
# reset runs counter
runs = 0
val_loss = 1000
# TODO load ITA model instead of dummy model
model = load_model_from_binary_file(
"/app/dynamics_learning/models/99.99706268310547.h5"
)
# Train the model until the threshold validation loss is reached
train_until_threshold_val_loss(
sweep_id=sweep_id,
robot_uuid=robot_uuid,
q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input,
tau_attained_input=tau_attained_input,
model=model,
)
runs_model2 = runs
val_loss_model2 = val_loss
###############################################
####################Model 3####################
###############################################
if model3:
# LLT model based on ITA model with known hyperparameters
assert (
SWEEP_ID == "fe3gjovo"
), "Sweep ID is not set correctly. Ensure that the sweep id is set to fe3gjovo"
assert (
robot_uuid == LLT_ROBOT_UUID
), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT"
# TODO load LLT Data
# TODO load ITA model
# TODO set hyperparameters for LLT model based on ITA model with known hyperparameters
###############################################
####################Model 4####################
###############################################
if model4:
# LLT model based on foundation model
assert (
SWEEP_ID == "7tglijx8"
), "Sweep ID is not set correctly. Ensure that the sweep id is set to 7tglijx8"
assert (
robot_uuid == LLT_ROBOT_UUID
), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT"
# Download Test Data from the server # TODO load LLT Data
# TODO load foundation model
runs_instance_model_from_scratch = 0 # TODO set hyperparameters for LLT model based on foundation model
val_loss = 100
while val_loss > 50:
# Train model instance_model_from_scratch(count = 1)
runs_instance_model_from_scratch += 1
runs_instance_model_from_pretrained_instance_unknown = 0
val_loss = 100
while val_loss > 50:
# Train model instance_model_from_pretrained_instance_unknown(count = 1)
runs_instance_model_from_pretrained_instance_unknown += 1
runs_instance_model_from_pretrained_instance_known = 0
val_loss = 100
while val_loss > 50:
# Train model instance_model_from_pretrained_instance_known(count = 1)
runs_instance_model_from_pretrained_instance_known += 1
runs_instance_model_from_pretrained_foundation = 0
val_loss = 100
while val_loss > 50:
# Train model instance_model_from_pretrained_foundation(count = 1)
runs_instance_model_from_pretrained_foundation += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment