Skip to content
Snippets Groups Projects
Select Git revision
  • 0727d8fa4c3dfb05aafc7119e60a1b8d8a816754
  • main default protected
2 results

simulation_reading.py

Blame
  • draft.py 1.57 KiB
    import numpy as np
    import gymnasium as gym
    
    import environment 
    
    from stable_baselines3 import PPO, DQN
    import os
    import time
    
    def main():
        process = [] 
        num_s = 1
        process.append(num_s+1)
        num_ot = 5
        process.append(num_ot+1)
        num_sh_a = 3
        process.append(num_sh_a+1)
        num_sh_b = 3
        process.append(num_sh_b+1)
        num_sh_c = 3
        process.append(num_sh_c+1)
        num_m_a = 3
        process.append(num_m_a+1)
        num_m_b = 2
        process.append(num_m_b+1)
        num_p_a = 4
        process.append(num_p_a+1)
        num_p_b = 5
        process.append(num_p_b+1)
        num_p_c = 4
        process.append(num_p_c+1)
        num_ds_a = 7
        process.append(num_ds_a+1)
        num_ds_b = 7
        process.append(num_ds_b+1)
        num_ds_c = 7
        process.append(num_ds_c+1)
    
        case = []
        for i in range(15):
            case.append(2)
            
        space = [process, case]
        activities = 16
    
        env = environment.BusinessProcessEnv(space, activities)
        env.reset()
    
        models_dir = f"models/{int(time.time())}/"
        logdir = f"logs/{int(time.time())}/"
    
        if not os.path.exists(models_dir):
            os.makedirs(models_dir)
    
        if not os.path.exists(logdir):
            os.makedirs(logdir)
    
        # model = PPO('MultiInputPolicy', env, verbose=1, tensorboard_log=logdir)
        model = DQN('MultiInputPolicy', env, verbose=1, tensorboard_log=logdir)
    
        TIMESTEPS = 500000
        iters = 0
        while True:
            iters += 1
            model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"DQN")
            model.save(f"{models_dir}/{TIMESTEPS*iters}")
    
        
    
    
    if __name__ == "__main__":
        main()