Skip to content
Snippets Groups Projects
Commit fcc8943e authored by Nour's avatar Nour
Browse files

train the agent with a new reward

parent 528e7d85
No related branches found
No related tags found
No related merge requests found
Showing
with 11890 additions and 13 deletions
......@@ -13,6 +13,7 @@ The following are the import for the backend. Just write the name of the script
Then you can immediately use the function from the script directly in the app.
"""
sys.path.append(os.path.join(os.path.dirname(sys.path[0]),'backend'))
import input
#define the app
......
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
File added
No preview for this file type
No preview for this file type
......@@ -23,7 +23,7 @@ def train(space, activities):
# model = PPO('MultiInputPolicy', env, verbose=1, tensorboard_log=logdir)
model = DQN('MultiInputPolicy', env, verbose=1, exploration_fraction = 0.33, learning_starts = 10000, tensorboard_log=logdir)
TIMESTEPS = 2000000
TIMESTEPS = 1000000
iters = 0
while True:
iters += 1
......@@ -32,7 +32,7 @@ def train(space, activities):
def deploy(state):
model = DQN.load(r"C:\Users\nourm\OneDrive\Desktop\Nour\optis_app\models\1687192616\2000000")
model = DQN.load(r"C:\Users\nourm\OneDrive\Desktop\Nour\optis_app\models\1687769377\5000000")
action, _ = model.predict(state, deterministic=True)
return action
......
......@@ -40,6 +40,8 @@ class BusinessProcessEnv(gym.Env):
self.reward = 0
def get_current_state(self, caseid):
process, case, event = simmodel.get_current_state(self.process, caseid)
state = OrderedDict()
......@@ -48,6 +50,18 @@ class BusinessProcessEnv(gym.Env):
state['process'] = np.asarray(process)
return state
def get_ressources(self, caseid, action):
state = self.get_current_state(caseid)
if action == 1 or action == 15:
return state['process'][0]
elif action == 2 or action == 3:
return state['process'][1]
else:
ressours_position = action - 2
return state['process'][ressours_position]
def step(self, action):
self.process.next = action
......@@ -94,13 +108,31 @@ class BusinessProcessEnv(gym.Env):
if self.process.is_valid(self.current_state['event'], action, case_obj):
print(action)
print(self.current_state['process'])
print(self.get_ressources(case_obj, action))
while(self.process.flag):
self.model_env.step()
if self.get_ressources(case_obj, action ) == 0 :
reward = -10
self.reward += reward
next_state = self.current_state
done = False
truncated = False
info = {}
return next_state, reward, done, truncated, info
stop = self.process.env.now
# case_obj = self.process.case_objects[self.process.case_id]
# print(f"Agent did case {self.process.case_id} activity {action}.")
print(f"Agent did case {self.process.case_id} activity {action}.")
next_state = self.get_current_state(case_obj)
......@@ -110,7 +142,7 @@ class BusinessProcessEnv(gym.Env):
if time == 0:
reward = 0
else:
reward = - math.log(time, 10)
reward = 4.5 - math.log(time, 10)
self.reward += reward
done = True if (len(self.process.done_cases) == 10 or len(self.process.active_cases) == 0) else False
truncated = False
......@@ -118,7 +150,7 @@ class BusinessProcessEnv(gym.Env):
return next_state, reward, done, truncated, info # either self.reward or just reward ???
else:
reward = - 6
reward = -20
self.reward += reward
# next_state = self.flatten_observation_to_int(self.current_state)
next_state = self.current_state
......
......@@ -35,7 +35,7 @@ def export_to_xes(process, file_path):
pass
def get_active_cases():
event_log = pd.read_csv(r'D:\test\optis.csv')
event_log = pd.read_csv(r'C:\Users\nourm\OneDrive\Desktop\Nour\optis_app\backend\eventlog_test\optis.csv')
active_cases = event_log.groupby('CaseID').filter(lambda x: 'order completed' not in x['Activity'].values)['CaseID'].unique().tolist()
return active_cases
......@@ -90,7 +90,7 @@ def get_state(case_id,):
'order completed': 15,
}
event_log = pd.read_csv(r'D:\test\optis.csv')
event_log = pd.read_csv(r'C:\Users\nourm\OneDrive\Desktop\Nour\optis_app\backend\eventlog_test\optis.csv')
# Sort the event log by case ID and start timestamp
event_log.sort_values(by=['CaseID', 'StartTimestamp'], inplace=True)
......
Source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -54,14 +54,14 @@ def main():
# dqn.train(space, activities)
"""
''''''
# generate event log
env = simpy.Environment()
business_process = model.BusinessProcess(env, ressources)
business_process.event_log_flag = True
env.process(model.run_process(env, business_process))
env.run(until = 30000)
log.export_to_csv(business_process, r'D:\test\optis.csv')
log.export_to_csv(business_process, r'C:\Users\nourm\OneDrive\Desktop\Nour\optis_app\backend\eventlog_test\optis.csv')
# extract active cases from event log
active_cases = log.get_active_cases()
......@@ -74,7 +74,7 @@ def main():
state = log.get_state(caseid)
print(dqn.deploy(state))
"""
test.test_agent()
......
......@@ -3,9 +3,9 @@ import numpy as np
import dqn
def test_agent():
case = [1,0,1,0,0,0,1,0,0,0,0,0,0,0,0]
case = [1,1,0,1,0,0,0,0,0,0,0,0,0,0,0]
event = 4
process = [1,4,0,2,0,1,0,2,0,3,30,45,45]
process = [0,2,1,2,0,0,3,2,3,3,7,6,0]
state = OrderedDict()
state['case'] = np.asarray(case)
......
File added
File added
File added
File added
File added
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment