diff --git a/dynamics_learning/benchmark_number_of_runs.py b/dynamics_learning/benchmark_number_of_runs.py index 3771ef16e20b9fa94d1f7fda5204928029fa56a8..950a7ac41309396d95b0397206f7e753aefd9144 100644 --- a/dynamics_learning/benchmark_number_of_runs.py +++ b/dynamics_learning/benchmark_number_of_runs.py @@ -1,31 +1,258 @@ -THRESHOLD = 50 +from dynamics_learning.data_retrieval import download_resource_content_into_uuid_folders +from dynamics_learning.preprocessing.dataset_analysis import analyze +from dynamics_learning.preprocessing.trajectory_interpolation import interpolate +from dynamics_learning.environment import ( + WANDB_API_TOKEN, + SWEEP_ID, + WANDB_PROJECT, + WANDB_ENTITY, +) +from dynamics_learning.sweep.setup import setup_sweep +from dynamics_learning.training import train +from dynamics_learning.model_io import ( + save_model_to_binary_file, + load_model_from_binary_file, +) +from pathlib import Path +from pritty_logger import RichLogger +import numpy as np +import tensorflow as tf +import wandb +from functools import partial +import os +import signal +from keras.models import Sequential + +logger = RichLogger("dynamics_learning-benchmark_number_of_runs") + +THRESHOLD = 160 NUMBER_OF_TRAJECTORIES = 100 +LLT_ROBOT_UUID = "f2e72889-c140-4397-809f-fba1b892f17a" +ITA_ROBOT_UUID = "c9ff52e1-1733-4829-a209-ebd1586a8697" +WZL_ROBOT_UUID = "2e60a671-dcc3-4a36-9734-a239c899b57d" +runs = 0 +val_loss = 1000 +download_file = False +model1 = False +model2 = True +model3 = False +model4 = False + + +def prepare_data(directory: Path) -> tuple: + attained_data, command_data = analyze(directory) + interpolated_command_data = interpolate(directory) + + # build input and cross validation tensors + interpolated_command_input = np.column_stack( + ( + interpolated_command_data["q_interpolated_command"], + interpolated_command_data["qd_interpolated_command"], + interpolated_command_data["qdd_interpolated_command"], + ) + ) + q_qd_qdd_interpolated_command_input = tf.convert_to_tensor( + interpolated_command_input + ) + tau_attained_input = tf.convert_to_tensor(attained_data["tau_attained"]) + return ( + attained_data, + command_data, + interpolated_command_data, + q_qd_qdd_interpolated_command_input, + tau_attained_input, + ) + + +def training_loop( + sweep_id: str, + robot_uuid: str, + q_qd_qdd_interpolated_command_input: tf.Tensor, + tau_attained_input: tf.Tensor, + model: Sequential = None, +): + global val_loss, runs + model, history, run, config = train( + q_qd_qdd_interpolated_command_input, tau_attained_input, model=model + ) + val_loss = float(history.history["val_loss"][-1]) + runs += 1 + # these values are logged, once a run i finished + logger.info(f"runs: {runs}") + logger.info(f"val_loss: {val_loss}") + + # Check if the threshold has been subceeded + if val_loss < THRESHOLD: + logger.info( + f"Stopping training as val_loss has subceeded the threshold: {THRESHOLD}. The current run will be finished." + ) + wandb.finish() # Finish the W&B run cleanly + + # Stop the sweep entirely + logger.info("Stopping the entire sweep.") + Path(f"/app/dynamics_learning/models/{robot_uuid}/{sweep_id}").mkdir( + parents=True, exist_ok=True + ) + + file_path = f"/app/dynamics_learning/models/{robot_uuid}/{sweep_id}/{run.id}.h5" + save_model_to_binary_file(model, file_path) + # send SIGTERM to ensure no further runs are started + os.kill(os.getpid(), signal.SIGINT) + return runs, val_loss + + +def train_until_threshold_val_loss( + sweep_id: str, + robot_uuid: str, + q_qd_qdd_interpolated_command_input: tf.Tensor, + tau_attained_input: tf.Tensor, + model: Sequential = None, +): + training = partial( + training_loop, + sweep_id=sweep_id, + robot_uuid=robot_uuid, + q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input, + tau_attained_input=tau_attained_input, + model=model, + ) + wandb.agent( + sweep_id, + training, + project=WANDB_PROJECT, + entity=WANDB_ENTITY, + ) + logger.info(f"""Training concluded. +Number of runs: {runs} +Validation Loss: {val_loss} +""") + return None + if __name__ == "__main__": + # Remember to set the sweep id as needed: + # g5qxvipa: LLT instance from scratch + # 5vlv6m3t: ITA instance from scratch (used to train LLT instance, not trained using this script) + # 7x5hkf35: Foundation model (used to train LLT instance) + # 42d8t40t: LLT instance based on ITA instance + # fe3gjovo: LLT instance based on ITA instance with known hyper parameters + # 7tglijx8: LLT instance based on foundation model + # Download Training Data from the server + if download_file: + download_resource_content_into_uuid_folders() + + wandb.login(key=WANDB_API_TOKEN, relogin=True) + + ############################################### + ####################Model 1#################### + ############################################### + if model1: + # LLT instance model trained from scratch + robot_uuid = LLT_ROBOT_UUID + directory = Path(f"/app/dynamics_learning/Trajectory Data/train/{robot_uuid}") + # Interpolate Training Data in UUID folders + ( + attained_data, + command_data, + interpolated_command_data, + q_qd_qdd_interpolated_command_input, + tau_attained_input, + ) = prepare_data(directory) + + # ensure that the sweep id is set correctly: g5qxvipa + # assert SWEEP_ID == "g5qxvipa", "Sweep ID is not set correctly. Ensure that the sweep id is set to g5qxvipa" + assert ( + robot_uuid == LLT_ROBOT_UUID + ), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT_UUID" + sweep_id, sweep_config = setup_sweep(create_sweep=True) + + # reset runs counter + runs = 0 + val_loss = 1000 + # Train the model until the threshold validation loss is reached + train_until_threshold_val_loss( + sweep_id=sweep_id, + robot_uuid=robot_uuid, + q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input, + tau_attained_input=tau_attained_input, + model=None, + ) + + runs_model1 = runs + val_loss_model1 = val_loss + + ############################################### + ####################Model 2#################### + ############################################### + if model2: + # LLT model based on ITA model without known hyperparameters + robot_uuid = LLT_ROBOT_UUID + directory = Path(f"/app/dynamics_learning/Trajectory Data/train/{robot_uuid}") + # Interpolate Training Data in UUID folders + ( + attained_data, + command_data, + interpolated_command_data, + q_qd_qdd_interpolated_command_input, + tau_attained_input, + ) = prepare_data(directory) + # assert SWEEP_ID == "42d8t40t", "Sweep ID is not set correctly. Ensure that the sweep id is set to 42d8t40t" + assert ( + robot_uuid == LLT_ROBOT_UUID + ), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT_UUID" + + sweep_id, sweep_config = setup_sweep(create_sweep=True) + + # reset runs counter + runs = 0 + val_loss = 1000 + + # TODO load ITA model instead of dummy model + model = load_model_from_binary_file( + "/app/dynamics_learning/models/99.99706268310547.h5" + ) + + # Train the model until the threshold validation loss is reached + train_until_threshold_val_loss( + sweep_id=sweep_id, + robot_uuid=robot_uuid, + q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input, + tau_attained_input=tau_attained_input, + model=model, + ) + + runs_model2 = runs + val_loss_model2 = val_loss + + ############################################### + ####################Model 3#################### + ############################################### + if model3: + # LLT model based on ITA model with known hyperparameters + assert ( + SWEEP_ID == "fe3gjovo" + ), "Sweep ID is not set correctly. Ensure that the sweep id is set to fe3gjovo" + assert ( + robot_uuid == LLT_ROBOT_UUID + ), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT" + + # TODO load LLT Data + # TODO load ITA model + # TODO set hyperparameters for LLT model based on ITA model with known hyperparameters + + ############################################### + ####################Model 4#################### + ############################################### + if model4: + # LLT model based on foundation model + assert ( + SWEEP_ID == "7tglijx8" + ), "Sweep ID is not set correctly. Ensure that the sweep id is set to 7tglijx8" + assert ( + robot_uuid == LLT_ROBOT_UUID + ), "Robot UUID is not set correctly. Ensure that the robot uuid is set to LLT_ROBOT" - # Download Test Data from the server - - runs_instance_model_from_scratch = 0 - val_loss = 100 - while val_loss > 50: - # Train model instance_model_from_scratch(count = 1) - runs_instance_model_from_scratch += 1 - - runs_instance_model_from_pretrained_instance_unknown = 0 - val_loss = 100 - while val_loss > 50: - # Train model instance_model_from_pretrained_instance_unknown(count = 1) - runs_instance_model_from_pretrained_instance_unknown += 1 - - runs_instance_model_from_pretrained_instance_known = 0 - val_loss = 100 - while val_loss > 50: - # Train model instance_model_from_pretrained_instance_known(count = 1) - runs_instance_model_from_pretrained_instance_known += 1 - - runs_instance_model_from_pretrained_foundation = 0 - val_loss = 100 - while val_loss > 50: - # Train model instance_model_from_pretrained_foundation(count = 1) - runs_instance_model_from_pretrained_foundation += 1 + # TODO load LLT Data + # TODO load foundation model + # TODO set hyperparameters for LLT model based on foundation model