Skip to content
Snippets Groups Projects
Commit a14538ba authored by Alexander Nasuta's avatar Alexander Nasuta
Browse files

Merge remote-tracking branch 'github/master'

# Conflicts:
#	requirements_dev.txt
parents 903648f5 4ae3adad
No related branches found
No related tags found
No related merge requests found
...@@ -10,11 +10,10 @@ from gymcts.logger import log ...@@ -10,11 +10,10 @@ from gymcts.logger import log
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
class GraphMatrixJspGYMCTSWrapper(GymctsABC, gym.Wrapper): class GraphMatrixJspGYMCTSWrapper(GymctsABC, gym.Wrapper):
def __init__(self, env: DisjunctiveGraphJspEnv): def __init__(self, env: DisjunctiveGraphJspEnv):
gym.Wrapper.__init__(self, env) super().__init__(env)
def load_state(self, state: np.ndarray) -> None: def load_state(self, state: np.ndarray) -> None:
self.env.unwrapped.load_state(state) self.env.unwrapped.load_state(state)
...@@ -23,13 +22,13 @@ class GraphMatrixJspGYMCTSWrapper(GymctsABC, gym.Wrapper): ...@@ -23,13 +22,13 @@ class GraphMatrixJspGYMCTSWrapper(GymctsABC, gym.Wrapper):
return self.env.unwrapped.is_terminal_state() return self.env.unwrapped.is_terminal_state()
def get_valid_actions(self) -> list[int]: def get_valid_actions(self) -> list[int]:
return self.env.unwrapped.valid_action_mask() return self.env.unwrapped.valid_action_list()
def rollout(self) -> float: def rollout(self) -> float:
return self.env.unwrapped.random_rollout() return self.env.unwrapped.greedy_rollout()
def get_state(self) -> np.ndarray: def get_state(self) -> np.ndarray:
return env.unwrapped.get_state return self.env.unwrapped.get_state()
if __name__ == '__main__': if __name__ == '__main__':
...@@ -37,31 +36,21 @@ if __name__ == '__main__': ...@@ -37,31 +36,21 @@ if __name__ == '__main__':
env = DisjunctiveGraphJspEnv( env = DisjunctiveGraphJspEnv(
jsp_instance=ft06, jsp_instance=ft06,
reward_function="makespan", c_lb=ft06_makespan,
reward_function="mcts", # this reward is in range [-inf, 1]
) )
# map reward to [1, -inf]
# ideally you want the reward to be in the range of [-1, 1] for the UBC score
env = TransformReward(env, lambda r: r / ft06_makespan + 2 if r != 0 else 0.0)
env.reset()
def mask_fn(env: gym.Env) -> np.ndarray:
# Do whatever you'd like in this function to return the action mask
# for the current env. In this example, we assume the env has a
# helpful method we can rely on.
return env.unwrapped.valid_action_mask()
env.reset()
env = DeepCopyMCTSGymEnvWrapper( env = GraphMatrixJspGYMCTSWrapper(
env, env
action_mask_fn=mask_fn
) )
agent = GymctsAgent( agent = GymctsAgent(
env=env, env=env,
render_tree_after_step=True, render_tree_after_step=True,
exclude_unvisited_nodes_from_render=True, exclude_unvisited_nodes_from_render=True,
number_of_simulations_per_step=1000, number_of_simulations_per_step=25,
) )
root = agent.search_root_node.get_root() root = agent.search_root_node.get_root()
......
...@@ -15,7 +15,7 @@ import numpy as np ...@@ -15,7 +15,7 @@ import numpy as np
class GraphMatrixJspGYMCTSWrapper(GymctsABC, gym.Wrapper): class GraphMatrixJspGYMCTSWrapper(GymctsABC, gym.Wrapper):
def __init__(self, env: DisjunctiveGraphJspEnv): def __init__(self, env: DisjunctiveGraphJspEnv):
gym.Wrapper.__init__(self, env) super().__init__(env)
def load_state(self, state: np.ndarray) -> None: def load_state(self, state: np.ndarray) -> None:
self.env.unwrapped.load_state(state) self.env.unwrapped.load_state(state)
...@@ -24,13 +24,13 @@ class GraphMatrixJspGYMCTSWrapper(GymctsABC, gym.Wrapper): ...@@ -24,13 +24,13 @@ class GraphMatrixJspGYMCTSWrapper(GymctsABC, gym.Wrapper):
return self.env.unwrapped.is_terminal_state() return self.env.unwrapped.is_terminal_state()
def get_valid_actions(self) -> list[int]: def get_valid_actions(self) -> list[int]:
return self.env.unwrapped.valid_action_mask() return self.env.unwrapped.valid_action_list()
def rollout(self) -> float: def rollout(self) -> float:
return self.env.unwrapped.random_rollout() return self.env.unwrapped.greedy_rollout()
def get_state(self) -> np.ndarray: def get_state(self) -> np.ndarray:
return env.unwrapped.get_state return self.env.unwrapped.get_state()
if __name__ == '__main__': if __name__ == '__main__':
...@@ -38,31 +38,22 @@ if __name__ == '__main__': ...@@ -38,31 +38,22 @@ if __name__ == '__main__':
env = DisjunctiveGraphJspEnv( env = DisjunctiveGraphJspEnv(
jsp_instance=ft06, jsp_instance=ft06,
reward_function="makespan", c_lb=ft06_makespan,
reward_function="mcts", # this reward is in range [-inf, 1]
) )
# map reward to [1, -inf]
# ideally you want the reward to be in the range of [-1, 1] for the UBC score
env = TransformReward(env, lambda r: r / ft06_makespan + 2 if r != 0 else 0.0)
env.reset()
def mask_fn(env: gym.Env) -> np.ndarray:
# Do whatever you'd like in this function to return the action mask
# for the current env. In this example, we assume the env has a
# helpful method we can rely on.
return env.unwrapped.valid_action_mask()
env.reset()
env = DeepCopyMCTSGymEnvWrapper( env = GraphMatrixJspGYMCTSWrapper(
env, env
action_mask_fn=mask_fn
) )
agent = DistributedGymctsAgent( agent = DistributedGymctsAgent(
env=env, env=env,
render_tree_after_step=True, render_tree_after_step=True,
clear_mcts_tree_after_step=False,
exclude_unvisited_nodes_from_render=True, exclude_unvisited_nodes_from_render=True,
number_of_simulations_per_step=125, number_of_simulations_per_step=2,
num_parallel=4, num_parallel=4,
) )
...@@ -75,6 +66,10 @@ if __name__ == '__main__': ...@@ -75,6 +66,10 @@ if __name__ == '__main__':
for a in actions: for a in actions:
obs, rew, term, trun, info = env.step(a) obs, rew, term, trun, info = env.step(a)
agent.show_mcts_tree_from_root(
tree_max_depth=None
)
env.render() env.render()
makespan = env.unwrapped.get_makespan() makespan = env.unwrapped.get_makespan()
print(f"makespan: {makespan}") print(f"makespan: {makespan}")
...@@ -32,7 +32,7 @@ examples = [ ...@@ -32,7 +32,7 @@ examples = [
] ]
dev = [ dev = [
"jsp-instance-utils", "jsp-instance-utils",
"graph-matrix-jsp-env", "graph-matrix-jsp-env>=0.3.0",
"graph-jsp-env", "graph-jsp-env",
"JSSEnv", "JSSEnv",
......
# #
# This file is autogenerated by pip-compile with Python 3.10 # This file is autogenerated by pip-compile with Python 3.11
# by the following command: # by the following command:
# #
# pip-compile pyproject.toml # pip-compile pyproject.toml
...@@ -45,6 +45,4 @@ rich==13.9.4 ...@@ -45,6 +45,4 @@ rich==13.9.4
six==1.17.0 six==1.17.0
# via python-dateutil # via python-dateutil
typing-extensions==4.12.2 typing-extensions==4.12.2
# via # via gymnasium
# gymnasium
# rich
import copy import copy
import random
import gymnasium as gym import gymnasium as gym
from typing import TypeVar, Any, SupportsFloat, Callable from typing import TypeVar, Any, SupportsFloat, Callable
...@@ -63,7 +64,10 @@ class GymctsAgent: ...@@ -63,7 +64,10 @@ class GymctsAgent:
# NAVIGATION STRATEGY # NAVIGATION STRATEGY
# select child with highest UCB score # select child with highest UCB score
while not temp_node.is_leaf(): while not temp_node.is_leaf():
temp_node = max(temp_node.children.values(), key=lambda child: child.ucb_score()) children = list(temp_node.children.values())
max_ucb_score = max(child.ucb_score() for child in children)
best_children = [child for child in children if child.ucb_score() == max_ucb_score]
temp_node = random.choice(best_children)
log.debug(f"Selected leaf node: {temp_node}") log.debug(f"Selected leaf node: {temp_node}")
return temp_node return temp_node
......
...@@ -118,6 +118,7 @@ class DistributedGymctsAgent: ...@@ -118,6 +118,7 @@ class DistributedGymctsAgent:
render_tree_after_step: bool = False, render_tree_after_step: bool = False,
render_tree_max_depth: int = 2, render_tree_max_depth: int = 2,
num_parallel: int = 4, num_parallel: int = 4,
clear_mcts_tree_after_step: bool = False,
number_of_simulations_per_step: int = 25, number_of_simulations_per_step: int = 25,
exclude_unvisited_nodes_from_render: bool = False exclude_unvisited_nodes_from_render: bool = False
): ):
...@@ -134,6 +135,7 @@ class DistributedGymctsAgent: ...@@ -134,6 +135,7 @@ class DistributedGymctsAgent:
self.number_of_simulations_per_step = number_of_simulations_per_step self.number_of_simulations_per_step = number_of_simulations_per_step
self.env = env self.env = env
self.clear_mcts_tree_after_step = clear_mcts_tree_after_step
self.search_root_node = GymctsNode( self.search_root_node = GymctsNode(
action=None, action=None,
...@@ -206,6 +208,8 @@ class DistributedGymctsAgent: ...@@ -206,6 +208,8 @@ class DistributedGymctsAgent:
ready_node = ray.get(ready_node_ref) ready_node = ray.get(ready_node_ref)
# merge the tree # merge the tree
if not self.clear_mcts_tree_after_step:
self.backpropagation(search_start_node, ready_node.mean_value, ready_node.visit_count)
search_start_node = merge_nodes(search_start_node, ready_node) search_start_node = merge_nodes(search_start_node, ready_node)
action = search_start_node.get_best_action() action = search_start_node.get_best_action()
...@@ -217,22 +221,34 @@ class DistributedGymctsAgent: ...@@ -217,22 +221,34 @@ class DistributedGymctsAgent:
tree_max_depth=self.render_tree_max_depth tree_max_depth=self.render_tree_max_depth
) )
if self.clear_mcts_tree_after_step:
# to clear memory we need to remove all nodes except the current node # to clear memory we need to remove all nodes except the current node
# this is done by setting the root node to the current node # this is done by setting the root node to the current node
# and setting the parent of the current node to None # and setting the parent of the current node to None
# we also need to reset the children of the current node # we also need to reset the children of the current node
# this is done by calling the reset method # this is done by calling the reset method
#
# in a distributed setting we need we delete all previous nodes
# this is because backpropagation merging trees is already computationally expensive
# and backpropagating the whole tree would be even more expensive
next_node.reset() next_node.reset()
self.search_root_node = next_node self.search_root_node = next_node
return action, next_node return action, next_node
def backpropagation(self, node: GymctsNode, average_episode_return: float, num_episodes: int) -> None:
log.debug(f"performing backpropagation from leaf node: {node}")
while not node.is_root():
node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
node.visit_count + num_episodes)
node.visit_count += num_episodes
node.max_value = max(node.max_value, average_episode_return)
node.min_value = min(node.min_value, average_episode_return)
node = node.parent
# also update root node
node.mean_value = (node.mean_value * node.visit_count + average_episode_return * num_episodes) / (
node.visit_count + num_episodes)
node.visit_count += num_episodes
node.max_value = max(node.max_value, average_episode_return)
node.min_value = min(node.min_value, average_episode_return)
def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None: def show_mcts_tree(self, start_node: GymctsNode = None, tree_max_depth: int = None) -> None:
if start_node is None: if start_node is None:
...@@ -268,7 +284,7 @@ if __name__ == '__main__': ...@@ -268,7 +284,7 @@ if __name__ == '__main__':
agent1 = DistributedGymctsAgent( agent1 = DistributedGymctsAgent(
env=env, env=env,
render_tree_after_step=True, render_tree_after_step=True,
number_of_simulations_per_step=1000, number_of_simulations_per_step=10,
exclude_unvisited_nodes_from_render=True, exclude_unvisited_nodes_from_render=True,
num_parallel=1, num_parallel=1,
) )
...@@ -278,4 +294,6 @@ if __name__ == '__main__': ...@@ -278,4 +294,6 @@ if __name__ == '__main__':
actions = agent1.solve() actions = agent1.solve()
end_time = time.perf_counter() end_time = time.perf_counter()
agent1.show_mcts_tree_from_root()
print(f"solution time pro action: {end_time - start_time}/{len(actions)}") print(f"solution time pro action: {end_time - start_time}/{len(actions)}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment