Select Git revision
draft.py 1.57 KiB
import numpy as np
import gymnasium as gym
import environment
from stable_baselines3 import PPO, DQN
import os
import time
def main():
process = []
num_s = 1
process.append(num_s+1)
num_ot = 5
process.append(num_ot+1)
num_sh_a = 3
process.append(num_sh_a+1)
num_sh_b = 3
process.append(num_sh_b+1)
num_sh_c = 3
process.append(num_sh_c+1)
num_m_a = 3
process.append(num_m_a+1)
num_m_b = 2
process.append(num_m_b+1)
num_p_a = 4
process.append(num_p_a+1)
num_p_b = 5
process.append(num_p_b+1)
num_p_c = 4
process.append(num_p_c+1)
num_ds_a = 7
process.append(num_ds_a+1)
num_ds_b = 7
process.append(num_ds_b+1)
num_ds_c = 7
process.append(num_ds_c+1)
case = []
for i in range(15):
case.append(2)
space = [process, case]
activities = 16
env = environment.BusinessProcessEnv(space, activities)
env.reset()
models_dir = f"models/{int(time.time())}/"
logdir = f"logs/{int(time.time())}/"
if not os.path.exists(models_dir):
os.makedirs(models_dir)
if not os.path.exists(logdir):
os.makedirs(logdir)
# model = PPO('MultiInputPolicy', env, verbose=1, tensorboard_log=logdir)
model = DQN('MultiInputPolicy', env, verbose=1, tensorboard_log=logdir)
TIMESTEPS = 500000
iters = 0
while True:
iters += 1
model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"DQN")
model.save(f"{models_dir}/{TIMESTEPS*iters}")
if __name__ == "__main__":
main()