Skip to content
Snippets Groups Projects
Select Git revision
  • 6ef1ca816beab069502e841bb224cc52b21ee8ec
  • main default
  • with_hexapod_in_scene
  • ITA
  • moritzlennartz-main-patch-ec82
  • wzl
  • arxiv2412.12231v1
  • no_movement
8 results

foundation_model.py

  • foundation_model.py 3.50 KiB
    #!/usr/bin/env python3
    # -*- coding:utf-8 -*-
    # Copyright Leon Gorissen
    # Released under MIT License
    import warnings
    from functools import partial
    
    from dynamics_learning.environment import WANDB_API_TOKEN, WANDB_ENTITY, WANDB_PROJECT
    
    import numpy as np
    import tensorflow as tf
    from pritty_logger import RichLogger
    
    # import env variables and set tensorflow variables
    import wandb
    
    from dynamics_learning.data_retrieval import download_resource_content
    from dynamics_learning.preprocessing.dataset_analysis import analyze
    from dynamics_learning.preprocessing.trajectory_interpolation import interpolate
    from dynamics_learning.sweep.setup import setup_sweep
    
    # from dynamics_learning.data_retrieval import download_resource_content
    from dynamics_learning.testing import Dataset, model_analysis
    from dynamics_learning.training import train, upload
    
    # Suppress FutureWarning for the specific deprecation warning in pandas
    warnings.simplefilter(action="ignore", category=FutureWarning)
    
    logger = RichLogger("dynamics_learning-foundation_model")
    
    
    if __name__ == "__main__":
        # download not existing data
        local_resource_path = download_resource_content()
        # local_resource_path = Path("/app/dynamics_learning/Trajectory Data")
    
        # preprocess data
        attained_data, command_data = analyze(local_resource_path / "train")
        interpolated_command_data = interpolate(local_resource_path / "train")
    
        # build input and cross validation tensors
        interpolated_command_input = np.column_stack(
            (
                interpolated_command_data["q_interpolated_command"],
                interpolated_command_data["qd_interpolated_command"],
                interpolated_command_data["qdd_interpolated_command"],
            )
        )
        q_qd_qdd_interpolated_command_input = tf.convert_to_tensor(
            interpolated_command_input
        )
        tau_attained_input = tf.convert_to_tensor(attained_data["tau_attained"])
    
        wandb.login(key=WANDB_API_TOKEN, relogin=True)
        # check if sweep_id is set, if not create a new sweep
        sweep_id, sweep_config = setup_sweep()
    
        def train_save_upload(
            q_qd_qdd_interpolated_command_input: tf.Tensor, tau_attained_input: tf.Tensor
        ):
            model, history, _run, config = train(
                q_qd_qdd_interpolated_command_input, tau_attained_input
            )
            logger.info(
                "\n=====================================\nModel trained\n=====================================\n"
            )
            model = upload(model, history, config)
            logger.info(
                "\n=====================================\nModel uploaded\n=====================================\n"
            )
            # if a model was return it is the best performing model
            if model:
                logger.info("Evaluating current model with test data.")
                test_dataset = Dataset(type="test")
                test_dataset.download()  # FIXME
                interpolate("/app/dynamics_learning/Trajectory Data/test")
                model_analysis(test_dataset, model)  # FIXME
                "\n=====================================\nModel evaluated\n=====================================\n"
            else:
                logger.info("Model is not being evaluated.")
            return None
    
        train_save_upload_with_args = partial(
            train_save_upload,
            q_qd_qdd_interpolated_command_input=q_qd_qdd_interpolated_command_input,
            tau_attained_input=tau_attained_input,
        )
    
        wandb.agent(
            sweep_id,
            train_save_upload_with_args,
            project=WANDB_PROJECT,
            entity=WANDB_ENTITY,
            # count=NUM_MODELS,
        )