Select Git revision
foundation_model.py

Leon Michel Gorißen authored
foundation_model.py 3.50 KiB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright Leon Gorissen
# Released under MIT License
import warnings
from functools import partial
from dynamics_learning.environment import WANDB_API_TOKEN, WANDB_ENTITY, WANDB_PROJECT
import numpy as np
import tensorflow as tf
from pritty_logger import RichLogger
# import env variables and set tensorflow variables
import wandb
from dynamics_learning.data_retrieval import download_resource_content
from dynamics_learning.preprocessing.dataset_analysis import analyze
from dynamics_learning.preprocessing.trajectory_interpolation import interpolate
from dynamics_learning.sweep.setup import setup_sweep
# from dynamics_learning.data_retrieval import download_resource_content
from dynamics_learning.testing import Dataset, model_analysis
from dynamics_learning.training import train, upload
# Suppress FutureWarning for the specific deprecation warning in pandas
warnings.simplefilter(action="ignore", category=FutureWarning)
logger = RichLogger("dynamics_learning-foundation_model")
if __name__ == "__main__":
# download not existing data
local_resource_path = download_resource_content()
# local_resource_path = Path("/app/dynamics_learning/Trajectory Data")
# preprocess data
attained_data, command_data = analyze(local_resource_path / "train")
interpolated_command_data = interpolate(local_resource_path / "train")
# build input and cross validation tensors
interpolated_command_input = np.column_stack(
(
interpolated_command_data["q_interpolated_command"],
interpolated_command_data["qd_interpolated_command"],
interpolated_command_data["qdd_interpolated_command"],
)
)
q_qd_qdd_interpolated_command_input = tf.convert_to_tensor(
interpolated_command_input
)
tau_attained_input = tf.convert_to_tensor(attained_data["tau_attained"])
wandb.login(key=WANDB_API_TOKEN, relogin=True)
# check if sweep_id is set, if not create a new sweep
sweep_id, sweep_config = setup_sweep()
def train_save_upload(
q_qd_qdd_interpolated_command_input: tf.Tensor, tau_attained_input: tf.Tensor
):
model, history, _run, config = train(
q_qd_qdd_interpolated_command_input, tau_attained_input
)
logger.info(
"\n=====================================\nModel trained\n=====================================\n"
)
model = upload(model, history, config)
logger.info(
"\n=====================================\nModel uploaded\n=====================================\n"
)
# if a model was return it is the best performing model
if model:
logger.info("Evaluating current model with test data.")
test_dataset = Dataset(type="test")
test_dataset.download() # FIXME
interpolate("/app/dynamics_learning/Trajectory Data/test")
model_analysis(test_dataset, model) # FIXME
"\n=====================================\nModel evaluated\n=====================================\n"
else:
logger.info("Model is not being evaluated.")
return None
train_save_upload_with_args = partial(
train_save_upload,
q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input,
tau_attained_input=tau_attained_input,
)
wandb.agent(
sweep_id,
train_save_upload_with_args,
project=WANDB_PROJECT,
entity=WANDB_ENTITY,
# count=NUM_MODELS,
)