private void runRAVE() {
GameFactory factory = gameFactory.copy();
ArrayList<Node> visited = new ArrayList<>();
ArrayList<AMAFNode> siblingsOfVisited = new ArrayList<>();
ArrayList<int[]> sourceActionsTakenByDecisionPlayer = new ArrayList<>();
int decisionPlayer = tree.getRoot().getPlayer();
Node currentNode = tree.getRoot();
visited.add(currentNode);
int depth = 1;
while(!currentNode.isLeaf() && !currentNode.isTerminal() && currentNode.visited() && depth < mctsConfig.maxTreeDepth){
Node parentNode = currentNode;
if (currentNode.isChanceNode()){
Game game = factory.getGame(currentNode.getState());
game.gameTick(true);
currentNode = new Node(game.getState(), game.isTerminal(), game.getCurrentPlayer());
tree.putNodeIfAbsent(currentNode);
} else {
currentNode = RAVE.selectChild(currentNode, tree, mctsConfig.C, mctsConfig.minVisits, mctsConfig.k);
visited.add(currentNode);
ArrayList<Node> siblingsOfCurrentNode = parentNode.getChildren(tree);
if (parentNode.getPlayer() == decisionPlayer) {
sourceActionsTakenByDecisionPlayer.add(currentNode.getSourceAction());
for (Node child : siblingsOfCurrentNode) {
AMAFNode amafNode = (AMAFNode) child;
if (!Arrays.equals(currentNode.getState(), amafNode.getState()))
siblingsOfVisited.add(amafNode);
}
}
}
depth++;
}
if (currentNode.isLeaf() && currentNode.visited() && !currentNode.isTerminal()){
Node randomExpandedNode = tree.expand(currentNode, factory, mctsConfig.selectionPolicy);
Node parentNode = tree.getTreeNode(new NodeState(currentNode.getState(), currentNode.getSourceAction()));
ArrayList<Node> siblingsOfExpandedNode = parentNode.getChildren(tree);
if (parentNode.getPlayer() == decisionPlayer) {
sourceActionsTakenByDecisionPlayer.add(randomExpandedNode.getSourceAction());
for (Node child : siblingsOfExpandedNode) {
AMAFNode amafNode = (AMAFNode) child;
if (!Arrays.equals(randomExpandedNode.getState(), amafNode.getState()))
siblingsOfVisited.add(amafNode);
}
}
visited.add(randomExpandedNode);
currentNode = randomExpandedNode;
} else {
// currentNode not visited enough times, proceed with rollout
}
Game game = factory.getGame(currentNode.getState());
double[] wins = new double[GameFactory.nMaxPlayers()];
for (int i = 0; i < mctsConfig.nRolloutsPerIteration; i++) {
Game gameClone = game.copy();
while(!gameClone.isTerminal()){
int[] action = gameClone.getRandomLegalAction();
if (gameClone.getCurrentPlayer() == decisionPlayer)
sourceActionsTakenByDecisionPlayer.add(action);
gameClone.performAction(action, true);
}
if (gameClone.getWinner() != -1) {
wins[gameClone.getWinner()] += 1.0;
} else {
for (int pl = 0; pl < wins.length; pl++)
wins[pl] = 0.5;
}
}
final int nRollouts = mctsConfig.nRolloutsPerIteration;
HashSet<Node> v = new HashSet<>(visited.size());
v.addAll(visited);
v.forEach(n -> n.update(wins, nRollouts));
v.forEach(n -> n.updateVisitedStatus(mctsConfig.minVisits));
for (int j = 0; j < sourceActionsTakenByDecisionPlayer.size(); j++){
int[] actionEncounteredByDecisionPlayer = sourceActionsTakenByDecisionPlayer.get(j);
for (int k = 0; k < siblingsOfVisited.size(); k++){
AMAFNode sibling = siblingsOfVisited.get(k);
if (Arrays.equals(actionEncounteredByDecisionPlayer, sibling.getSourceAction())){
sibling.updateAMAF(wins, nRollouts);
// TicTacToe3D analysis:
// if Action was good for decision player 1 and 'contributed' to a win, wins: [0, 1, 0}
// When we decide between amafNode and other nodes, we want to see how many wins its action brought US - the selectionArgNode.getPlayer
// if we are decisionPlayer we want to access wins[1] which is wins[decisionPlayer] = wins[root.getCurrentPlayer()]
// if we analyze the node of the opponent we want to access wins[2] which is wins[node.getCurrentPlayer()] in the selection
}
}
}
}
Add a code snippet to your website: www.paste.org