Skip to content
Snippets Groups Projects
Select Git revision
  • 683b7a21551afb3ca2fefeb6c9a25df8969c7a55
  • main default protected
  • nour2
  • aleks4
  • geno2
  • petri-net-output
  • nour
  • deep-rl-1
  • geno3
  • paula
  • aleks2
  • aleks3
  • Nour
  • geno
  • aleks
15 results

draft.py

Blame
  • draft.py 1.57 KiB
    import numpy as np
    import gymnasium as gym
    
    import environment 
    
    from stable_baselines3 import PPO, DQN
    import os
    import time
    
    def main():
        process = [] 
        num_s = 1
        process.append(num_s+1)
        num_ot = 5
        process.append(num_ot+1)
        num_sh_a = 3
        process.append(num_sh_a+1)
        num_sh_b = 3
        process.append(num_sh_b+1)
        num_sh_c = 3
        process.append(num_sh_c+1)
        num_m_a = 3
        process.append(num_m_a+1)
        num_m_b = 2
        process.append(num_m_b+1)
        num_p_a = 4
        process.append(num_p_a+1)
        num_p_b = 5
        process.append(num_p_b+1)
        num_p_c = 4
        process.append(num_p_c+1)
        num_ds_a = 7
        process.append(num_ds_a+1)
        num_ds_b = 7
        process.append(num_ds_b+1)
        num_ds_c = 7
        process.append(num_ds_c+1)
    
        case = []
        for i in range(15):
            case.append(2)
            
        space = [process, case]
        activities = 16
    
        env = environment.BusinessProcessEnv(space, activities)
        env.reset()
    
        models_dir = f"models/{int(time.time())}/"
        logdir = f"logs/{int(time.time())}/"
    
        if not os.path.exists(models_dir):
            os.makedirs(models_dir)
    
        if not os.path.exists(logdir):
            os.makedirs(logdir)
    
        # model = PPO('MultiInputPolicy', env, verbose=1, tensorboard_log=logdir)
        model = DQN('MultiInputPolicy', env, verbose=1, tensorboard_log=logdir)
    
        TIMESTEPS = 500000
        iters = 0
        while True:
            iters += 1
            model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"DQN")
            model.save(f"{models_dir}/{TIMESTEPS*iters}")
    
        
    
    
    if __name__ == "__main__":
        main()