From 57cc6c9456e773d33cc62b14cdb5d841e7b4996b Mon Sep 17 00:00:00 2001 From: iamakash Date: Tue, 23 Apr 2024 18:42:17 +0200 Subject: [PATCH 01/14] feat: Add commented observation class. The observation class defines the basic structure to implement the observation tree. It's still a WIP. Therefore, the entire class is commented out. --- .../railsim/observation/TreeObservation.java | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java new file mode 100644 index 00000000000..f755dce676a --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java @@ -0,0 +1,102 @@ +//package ch.sbb.matsim.contrib.railsim.qsimengine.observation; +//import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimCalc; +//import ch.sbb.matsim.contrib.railsim.qsimengine.TrainPosition; +//import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; +//import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResource; +//import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; +//import org.matsim.api.core.v01.network.Node; +//// TODO: What is the time resolution of the environment? +//// TODO: Is there any EVENT already created when the RL agent should be called? If not then we need to crate this event. +//// TODO: Need help with creating these additional APIs +//import java.util.List; +//import java.util.ArrayList; +//class OtherAgent{ +// // Distance of agent to other agent's tail position if the train is moving in the same direction +// // Otherwise, distance to BufferTip position if moving in opposite direction +// double distance; +// List bufferTipPosition; +// List tailPosition; +// double speed; +// +// public OtherAgent(double distance, List bufferTipPosition, List tailPosition, double speed, boolean sameDirection) { +// this.distance = distance; +// this.bufferTipPosition = bufferTipPosition; +// this.tailPosition = tailPosition; +// this.speed = speed; +// this.sameDirection = sameDirection; +// } +// +// boolean sameDirection; +// +//} +//class ObservationNode{ +// Node node; +// double distNodeAgent; +// double distNodeStop; +// boolean isSwitchable; +// OtherAgent sameDirAgent; +//// OtherAgent oppDirAgent; +//} +// +//public class TreeObservation { +// private TrainPosition position; +// private RailResourceManager resources; +// private List observation; +// public TreeObservation(TrainPosition position, RailResourceManager resources){ +// this.resources = resources; +// this.position = position; +// createTreeObs(); +// } +// +// public RailLink getBufferTip(){ +// double reserveDist = RailsimCalc.calcReservationDistance(position, resources.getLink(position.getHeadLink())); +// RailLink currentLink = resources.getLink(position.getHeadLink()); +// List reservedSegment = RailsimCalc.calcLinksToBlock(position, currentLink, reserveDist); +// // TODO Verify if the links are added in the list in the sequence of occurence. +// RailLink bufferTip = reservedSegment.get(reservedSegment.size() -1); +// +// return bufferTip; +// } +// +// private void createTreeObs(){ +// int depth =3; +// createTreeObs(depth); +// } +// +//// TODO: Handle cases when the agent should also see the next 2 halts. - But this may not be needed as the agent ca be penalised +//// if the agent is unable to reach the next station. +// private void createTreeObs(int depth){ +// // Get the link of the tip of the buffer +// RailLink bufferTipLink = getBufferTip(); +// +// // Get the toNode of the bufferTipLink +// Node toNode = getToNode(bufferTipLink); +// List exploreQueue = new ArrayList (); +// +// exploreQueue.addLast(toNode); +// for (int i=0; i< depth; i++){ +// // Depth traversal: Making an observation Tree of fixed depth +// int lenExploreQueue = exploreQueue.size(); +// while (lenExploreQueue > 0){ +// // Level traversal +// Node curNode = exploreQueue.getFirst(); +// exploreQueue.remove(0); +// +// // Look for switches/intersections/stops on the branches stemming out of the current switch +// List nextNodes = getNextNodes(curNode); +// +// for (Node nextNode: nextNodes){ +// exploreQueue.addLast(nextNode); +// TrainPosition trainF = getClosestTrainOnPathF(curNode, nextNode); +// TrainPosition trainR = getClosestTrainOnPathR(curNode, nextNode); +// observation.add(createObservatioNode(nextNode, trainF, trainR)); +// } +// lenExploreQueue -= 1; +// } +// } +// } +// +// public List getObservation(){ +// return observation; +// } +//} From c4261a8df3160e23a9738bce79fc7b664a4e08a4 Mon Sep 17 00:00:00 2001 From: iamakash Date: Fri, 26 Apr 2024 15:37:32 +0200 Subject: [PATCH 02/14] Added Doubts as TODOs --- .../contrib/railsim/observation/TreeObservation.java | 4 +--- .../contrib/railsim/qsimengine/RailsimEngine.java | 3 ++- .../qsimengine/disposition/SimpleDisposition.java | 10 ++++++++++ .../microTrackOppositeTrafficMany/trainNetwork.xml | 1 + 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java index f755dce676a..7aaeb572fc8 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java @@ -5,9 +5,7 @@ //import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResource; //import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; //import org.matsim.api.core.v01.network.Node; -//// TODO: What is the time resolution of the environment? -//// TODO: Is there any EVENT already created when the RL agent should be called? If not then we need to crate this event. -//// TODO: Need help with creating these additional APIs +//// TODO: What is the time resolution of the environment? 1 iteration is equal to home many seconds in rlsim? //import java.util.List; //import java.util.ArrayList; //class OtherAgent{ diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java index 3b20be8ca7c..cbd40d54e6f 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java @@ -253,6 +253,7 @@ private void checkTrackReservation(double time, UpdateEvent event) { } } +// TODO: More clarity needed on how stopTime is calculated private void updateDeparture(double time, UpdateEvent event) { TrainState state = event.state; @@ -602,7 +603,7 @@ private void decideNextUpdate(UpdateEvent event) { assert FuzzyUtils.greaterEqualThan(headDist, 0) : "Head distance must be positive"; // Find the earliest required update - + //TODO: The following if else statements are not clear double dist; if (FuzzyUtils.lessEqualThan(tailDist, decelDist) && FuzzyUtils.lessEqualThan(tailDist, reserveDist) && diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java index 3e1094fc776..277fedab883 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java @@ -58,6 +58,14 @@ public void onDeparture(double time, MobsimDriverAgent driver, List ro // Nothing to do. } + /** + * This method tries to first calculate the links needed by the train for moving the safety_distance. + * Then for each link of the segment, it check if it can blocked completely. + * Only when all the links of the segment (list of links) can be blocked, + * a Response with approved distance = length of the links is returned. + */ + //TODO: In the PPT it's mentioned that the requestNewSegment loops over all the links in the route, However, it just tries to block the segment (set of links) + // which are needed for safety distance. @Override public DispositionResponse requestNextSegment(double time, TrainPosition position, double dist) { @@ -114,6 +122,8 @@ public DispositionResponse requestNextSegment(double time, TrainPosition positio return new DispositionResponse(reserveDist, stop ? 0 : Double.POSITIVE_INFINITY, null); } + // TODO: I don't understand why detour is happening if the train is on the entrylink + // or if the train blocked links, one of which is an entry link ? private Detour checkDetour(double time, List segment, TrainPosition position) { if (position.getPt() != null && considerReRouting(segment, resources.getLink(position.getHeadLink()))) { diff --git a/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/trainNetwork.xml b/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/trainNetwork.xml index 46f01336767..1f3501ba109 100644 --- a/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/trainNetwork.xml +++ b/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/trainNetwork.xml @@ -30,6 +30,7 @@ + From 03f9bc073ba4046ea0c0fb1188dcda70a29edb4b Mon Sep 17 00:00:00 2001 From: iamakash Date: Mon, 13 May 2024 12:38:21 +0200 Subject: [PATCH 03/14] feat: add psuedocode for RL inference in Railsim --- .../railsim/observation/TreeObservation.java | 274 +++++++++++------- .../railsim/qsimengine/RailsimCalc.java | 50 ++++ .../railsim/qsimengine/RailsimEngine.java | 28 +- 3 files changed, 247 insertions(+), 105 deletions(-) diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java index 7aaeb572fc8..1e5fd22879a 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java @@ -1,100 +1,174 @@ -//package ch.sbb.matsim.contrib.railsim.qsimengine.observation; -//import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimCalc; -//import ch.sbb.matsim.contrib.railsim.qsimengine.TrainPosition; -//import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; -//import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResource; -//import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; -//import org.matsim.api.core.v01.network.Node; -//// TODO: What is the time resolution of the environment? 1 iteration is equal to home many seconds in rlsim? -//import java.util.List; -//import java.util.ArrayList; -//class OtherAgent{ -// // Distance of agent to other agent's tail position if the train is moving in the same direction -// // Otherwise, distance to BufferTip position if moving in opposite direction -// double distance; -// List bufferTipPosition; -// List tailPosition; -// double speed; -// -// public OtherAgent(double distance, List bufferTipPosition, List tailPosition, double speed, boolean sameDirection) { -// this.distance = distance; -// this.bufferTipPosition = bufferTipPosition; -// this.tailPosition = tailPosition; -// this.speed = speed; -// this.sameDirection = sameDirection; -// } -// -// boolean sameDirection; -// -//} -//class ObservationNode{ -// Node node; -// double distNodeAgent; -// double distNodeStop; -// boolean isSwitchable; -// OtherAgent sameDirAgent; -//// OtherAgent oppDirAgent; -//} -// -//public class TreeObservation { -// private TrainPosition position; -// private RailResourceManager resources; -// private List observation; -// public TreeObservation(TrainPosition position, RailResourceManager resources){ -// this.resources = resources; -// this.position = position; -// createTreeObs(); -// } -// -// public RailLink getBufferTip(){ -// double reserveDist = RailsimCalc.calcReservationDistance(position, resources.getLink(position.getHeadLink())); -// RailLink currentLink = resources.getLink(position.getHeadLink()); -// List reservedSegment = RailsimCalc.calcLinksToBlock(position, currentLink, reserveDist); -// // TODO Verify if the links are added in the list in the sequence of occurence. -// RailLink bufferTip = reservedSegment.get(reservedSegment.size() -1); -// -// return bufferTip; -// } -// -// private void createTreeObs(){ -// int depth =3; -// createTreeObs(depth); -// } -// -//// TODO: Handle cases when the agent should also see the next 2 halts. - But this may not be needed as the agent ca be penalised -//// if the agent is unable to reach the next station. -// private void createTreeObs(int depth){ -// // Get the link of the tip of the buffer -// RailLink bufferTipLink = getBufferTip(); -// -// // Get the toNode of the bufferTipLink -// Node toNode = getToNode(bufferTipLink); -// List exploreQueue = new ArrayList (); -// -// exploreQueue.addLast(toNode); -// for (int i=0; i< depth; i++){ -// // Depth traversal: Making an observation Tree of fixed depth -// int lenExploreQueue = exploreQueue.size(); -// while (lenExploreQueue > 0){ -// // Level traversal -// Node curNode = exploreQueue.getFirst(); -// exploreQueue.remove(0); -// -// // Look for switches/intersections/stops on the branches stemming out of the current switch -// List nextNodes = getNextNodes(curNode); -// -// for (Node nextNode: nextNodes){ -// exploreQueue.addLast(nextNode); -// TrainPosition trainF = getClosestTrainOnPathF(curNode, nextNode); -// TrainPosition trainR = getClosestTrainOnPathR(curNode, nextNode); -// observation.add(createObservatioNode(nextNode, trainF, trainR)); -// } -// lenExploreQueue -= 1; -// } -// } -// } -// -// public List getObservation(){ -// return observation; -// } -//} +package ch.sbb.matsim.contrib.railsim.observation; + +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimCalc; +import ch.sbb.matsim.contrib.railsim.qsimengine.TrainPosition; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResource; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.ResourceState; +import org.matsim.api.core.v01.network.Link; +import org.matsim.api.core.v01.network.Network; +import org.matsim.api.core.v01.network.Node; +// TODO: What is the time resolution of the environment? 1 iteration is equal to home many seconds in rlsim? +// The time resolution of matsim is seconds, same for railsim. But there are seconds where nothing happens, no events in update queue. +import java.util.List; +import java.util.ArrayList; + +class OtherAgent { + // Distance of agent to other agent's tail position if the train is moving in the same direction + // Otherwise, distance to BufferTip position if moving in opposite direction + double distance; + List bufferTipPosition; + List tailPosition; + double speed; + + public OtherAgent(double distance, List bufferTipPosition, List tailPosition, double speed, boolean sameDirection) { + this.distance = distance; + this.bufferTipPosition = bufferTipPosition; + this.tailPosition = tailPosition; + this.speed = speed; + this.sameDirection = sameDirection; + } + + boolean sameDirection; + +} + +class ObservationNode { + Node node; + double distNodeAgent; + double distNodeStop; + boolean isSwitchable; + OtherAgent sameDirAgent; + OtherAgent oppDirAgent; + int numParallelTracks; +} + +public class TreeObservation { + private final TrainPosition position; + private final RailResourceManager resources; + private final Network network; + private List observation; + + public TreeObservation(TrainPosition position, RailResourceManager resources, Network network) { + this.resources = resources; + this.position = position; + this.network = network; + createTreeObs(); + } + + public RailLink getBufferTip() { + double reserveDist = RailsimCalc.calcReservationDistance(position, resources.getLink(position.getHeadLink())); + RailLink currentLink = resources.getLink(position.getHeadLink()); + List reservedSegment = RailsimCalc.calcLinksToBlock(position, currentLink, reserveDist); + // TODO Verify if the links are added in the list in the sequence of occurence. + RailLink bufferTip = reservedSegment.get(reservedSegment.size() - 1); + + return bufferTip; + } + + private void createTreeObs() { + int depth = 3; + createTreeObs(depth); + } + + // TODO: Handle cases when the agent should also see the next 2 halts. - But this may not be needed as the agent can be penalised + // if the agent is unable to reach the next station. + private void createTreeObs(int depth) { +// TODO: This has to be fixed such that the next observation node is calculated from the current position + + // Get the link of the tip of the buffer + RailLink bufferTipLink = getBufferTip(); + + // Get the toNode of the bufferTipLink + Node toNode = getToNode(bufferTipLink); + List exploreQueue = new ArrayList(); + + List obsNode = getNextNodes(toNode); + exploreQueue.addLast(obsNode); + + for (int i = 0; i < depth; i++) { + // Depth traversal: Making an observation Tree of fixed depth + int lenExploreQueue = exploreQueue.size(); + while (lenExploreQueue > 0) { + // Level traversal + Node curNode = exploreQueue.getFirst(); + exploreQueue.remove(0); + + // Look for switches/intersections/stops on the branches stemming out of the current switch + List nextNodes = getNextNodes(curNode); + + for (Node nextNode : nextNodes) { + exploreQueue.addLast(nextNode); + TrainPosition trainF = getClosestTrainOnPathF(curNode, nextNode); + TrainPosition trainR = getClosestTrainOnPathR(curNode, nextNode); + observation.add(createObservatioNode(nextNode, trainF, trainR)); + } + lenExploreQueue -= 1; + } + } + } + + private TrainPosition getClosestTrainOnPathR(Node curNode, Node nextNode) { + // complete route to current position of the train from schedule, if needed? + List previousRoute = position.getRoute(0, position.getRouteIndex()); + + // TODO: If all opposite trains are needed, follow the inLinks of curNode until the nextNode is reached. Store the of the path links. + List path = null; + + // check for each link if there is capacity + for (Link link : path) { + RailLink railLink = resources.getLink(link.getId()); + RailResource resource = railLink.getResource(); + ResourceState state = resource.getState(railLink); + // TODO: Ask Christian how we get the train position of the nearest train. + } + + return null; + } + + private TrainPosition getClosestTrainOnPathF(Node curNode, Node nextNode) { + // complete route from current position of the train from schedule, if needed? + List upcomingRoute = position.getRoute(position.getRouteIndex(), position.getRouteSize()); + + // TODO: If all opposite trains are needed, follow the outLinks of curNode until the nextNode is reached. Store the of the path links. + List path = null; + + // same procedure as above... + + return null; + } + +// TODO: Add the logic to also consider halts/stops + private List getNextNodes(Node curNode) { + // next switches + List switchNodes = new ArrayList<>(); + + // check all outgoing links from the current node + for (Link outLink : curNode.getOutLinks().values()) { + Node nextNode = outLink.getToNode(); + + // follow nodes with only one outgoing link until a switch is reached + while (nextNode.getOutLinks().size() == 1) { + // get the single outgoing link and follow it + // TODO: Handle end of network, will throw NoSuchElementException at the moment + nextNode = nextNode.getOutLinks().values().iterator().next().getToNode(); + } + + // if the next node has more than one outgoing link, it's a switch + if (nextNode.getOutLinks().size() > 1) { + switchNodes.add(nextNode); + } + } + + return switchNodes; + } + + private Node getToNode(RailLink bufferTipLink) { + return network.getLinks().get(bufferTipLink.getLinkId()).getToNode(); + } + + public List getObservation() { + return observation; + } +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimCalc.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimCalc.java index 117947728d3..a4401286cda 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimCalc.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimCalc.java @@ -263,6 +263,56 @@ public static List calcLinksToBlock(TrainPosition position, RailLink c return result; } + +// public static List calcLinksToBlockRL(TrainPosition position, RailLink currentLink, double reserveDist) { +// +// // Assumption: The route for each train from start position to next switch/goal is calculated during train +// // initialisation. +// +// List result = new ArrayList<>(); +// +// // Assume current distance left on link is already reserved (only for fixed block) +// double dist = currentLink.length - position.getHeadPosition(); +// +// int idx = position.getRouteIndex(); +// +// Railink targetLink = getTargetLink(thisTrainId); +// // This function always needs to provide more reserve distance than requested (except when it will stop) +// while (FuzzyUtils.lessEqualThan(dist, reserveDist) && targetLink != currentLink) { +// RailLink nextLink = position.getRoute(idx++); +// +// // nextLink will be NULL only when the to-node of the link is a decision node +// if (nextLink == NULL){ +// Map obs= getObservationsOfAllTrains(); +// Map actions = rlModel(obs); +// +// // 3 possible actions: 0: do nothing; 1: switch; 2: STOP +// actionForThisTrain = actions[thisTrainId]; +// if (actionForThisTrain == STOP){ +// // update the event for this train to reduce velocity to zero +// // Assumption 1: The train is stopped write before entering the link (assuming fixed block). +// // Assumption 2: The Railsim engine automatically polls to reserve distance for this train after some +// // time period. +// } +// else{ +// // This method calculates the route of the given train +// // according to the chosen direction +// updateRouteForCurrentTrain(thisTrainId, actionForThisTrain, obs[thisTrainId]); +// } +// continue; +// } +// dist += nextLink.length; +// +// result.add(nextLink); +// +// // Don't block beyond stop +// if (position.isStop(nextLink.getLinkId())) +// break; +// } +// +// return result; +// } + /** * Calculate distance to the next stop. */ diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java index cbd40d54e6f..9c7ce2e695a 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java @@ -19,11 +19,7 @@ package ch.sbb.matsim.contrib.railsim.qsimengine; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.PriorityQueue; -import java.util.Queue; +import java.util.*; import java.util.stream.Collectors; import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.DispositionResponse; @@ -96,6 +92,26 @@ public void doSimStep(double time) { } } + public void doSimStepUsingRLInference(double time) { + + Map> obs = getObservationForAliveTrains(); + Map actions = rlModel(obsDict); + Set agent_ids = actions.keySet(); + for (String aid:agent_ids){ + int action = actions.get(aid); + if (action == 2){ + // Create an event for stop + } + else { + updateRouteForAgent(aid, action); + } + } + + doSimStep(time); + } + + + /** * Update the current state of all trains, even if no update would be needed. */ @@ -141,6 +157,8 @@ private void updateAllPositions(double time) { for (TrainState train : activeTrains) { if (train.timestamp < time) updateState(time, new UpdateEvent(train, UpdateEvent.Type.POSITION)); +// updateState(time, new UpdateEvent(train, UpdateEvent.Type.POSITION)); +// updateState(time, new UpdateEvent(train, UpdateEvent.Type.POSITION)); } } From 2a17a8eb2798575682e154fd04f27fa11589df37 Mon Sep 17 00:00:00 2001 From: iamakash Date: Mon, 27 May 2024 18:20:59 +0200 Subject: [PATCH 04/14] feat: create pipeline to integrate rl with railsim --- contribs/railsim/pom.xml | 54 +++++ .../railsim/EnvironmentFactoryServer.java | 120 +++++++++ .../matsim/contrib/railsim/RailsimEnv.java | 59 +++++ .../contrib/railsim/RunRailsimExample.java | 49 +++- .../railsim/observation/TreeObservation.java | 174 -------------- .../railsim/qsimengine/RailsimCalc.java | 2 +- .../railsim/qsimengine/RailsimEngine.java | 90 +++++-- .../railsim/qsimengine/RailsimQSimEngine.java | 17 +- .../railsim/qsimengine/TrainState.java | 4 +- .../matsim/contrib/railsim/rl/RLClient.java | 94 ++++++++ .../railsim/rl/observation/Observation.java | 76 ++++++ .../rl/observation/ObservationNode.java | 30 +++ .../railsim/rl/observation/StepOutput.java | 75 ++++++ .../rl/observation/TreeObservation.java | 227 ++++++++++++++++++ .../contrib/railsim/rl/utils/RLUtils.java | 150 ++++++++++++ contribs/railsim/src/main/proto/railsim.proto | 60 +++++ 16 files changed, 1083 insertions(+), 198 deletions(-) create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java delete mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/ObservationNode.java create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/StepOutput.java create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java create mode 100644 contribs/railsim/src/main/proto/railsim.proto diff --git a/contribs/railsim/pom.xml b/contribs/railsim/pom.xml index 6ff552a868c..f870ab1fcb6 100644 --- a/contribs/railsim/pom.xml +++ b/contribs/railsim/pom.xml @@ -29,5 +29,59 @@ test + + io.grpc + grpc-netty-shaded + 1.63.0 + runtime + + + io.grpc + grpc-protobuf + 1.63.0 + + + io.grpc + grpc-stub + 1.63.0 + + + org.apache.tomcat + annotations-api + 6.0.53 + provided + + + + + + + + kr.motd.maven + os-maven-plugin + 1.7.1 + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:3.25.1:exe:${os.detected.classifier} + grpc-java + io.grpc:protoc-gen-grpc-java:1.63.0:exe:${os.detected.classifier} + + + + + compile + compile-custom + + + + + + diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java new file mode 100644 index 00000000000..52cd5574ff7 --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java @@ -0,0 +1,120 @@ +package ch.sbb.matsim.contrib.railsim; +import ch.sbb.matsim.contrib.railsim.rl.RLClient; +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; +import io.grpc.Server; +import io.grpc.stub.StreamObserver; +import ch.sbb.matsim.contrib.railsim.grpc.ProtoConfirmationResponse; +import ch.sbb.matsim.contrib.railsim.grpc.ProtoGrpcPort; +import ch.sbb.matsim.contrib.railsim.grpc.RailsimFactoryGrpc; +import ch.sbb.matsim.contrib.railsim.grpc.ProtoAgentIDs; + + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; +public class EnvironmentFactoryServer { + + private static final Logger logger = Logger.getLogger(EnvironmentFactoryServer.class.getName()); + + private Server server; + + private void start() throws IOException { + /* The port on which the server should run */ + int factoryServerPort = 50051; + server = Grpc.newServerBuilderForPort(factoryServerPort, InsecureServerCredentials.create()) + .addService(new RailsimFactory()) + .build() + .start(); + logger.info("Server started, listening on " + factoryServerPort); + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + // Use stderr here since the logger may have been reset by its JVM shutdown hook. + System.err.println("*** shutting down gRPC server since JVM is shutting down"); + try { + EnvironmentFactoryServer.this.stop(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + System.err.println("*** server shut down"); + } + }); + } + + private void stop() throws InterruptedException { + if (server != null) { + server.shutdown().awaitTermination(30, TimeUnit.SECONDS); + } + } + + /** + * Await termination on the main thread since the grpc library uses daemon threads. + */ + private void blockUntilShutdown() throws InterruptedException { + if (server != null) { + server.awaitTermination(); + } + } + + /** + * Main launches the server from the command line. + */ + public static void main(String[] args) throws IOException, InterruptedException { + final EnvironmentFactoryServer server = new EnvironmentFactoryServer(); + server.start(); + server.blockUntilShutdown(); + } + + + // Implementation of the gRPC service on the server-side. + private class RailsimFactory extends RailsimFactoryGrpc.RailsimFactoryImplBase { + + Map envMap = new HashMap<>(); + + @Override + public void getEnvironment(ProtoGrpcPort grpcPort, StreamObserver responseObserver) { + // Create an instance of Railsim environment and store it in a map + System.out.println("getEnvironment() -> Create env with id: "+grpcPort); + + RLClient rlClient = new RLClient(grpcPort.getGrpcPort()); + RailsimEnv env = new RailsimEnv(rlClient); + + // Store the environment created with it's key being the port + this.envMap.put(grpcPort.getGrpcPort(), env); + + // Send the Ack message back to the client. + ProtoConfirmationResponse response = ProtoConfirmationResponse.newBuilder() + .setAck("OK") + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + + public void resetEnv(ProtoGrpcPort grpcPort, StreamObserver responseObserver) { + System.out.println("Reset env id: "+grpcPort); + + // fetch the object from map and reset it + RailsimEnv env = this.envMap.get(grpcPort.getGrpcPort()); + List agentIds = env.reset(); + + //Create response using agentIds + ProtoAgentIDs.Builder agentIDsBuilder = ProtoAgentIDs.newBuilder(); + agentIDsBuilder.addAllAgentId(agentIds); + ProtoAgentIDs response = agentIDsBuilder.build(); + + // Send the reply back to the client. + responseObserver.onNext(response); + // Indicate that no further messages will be sent to the client. + responseObserver.onCompleted(); + + // Start the simulation + env.startSimulation(); + + } + } + +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java new file mode 100644 index 00000000000..570927b5863 --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java @@ -0,0 +1,59 @@ +package ch.sbb.matsim.contrib.railsim; + +//import ch.sbb.matsim.contrib.railsim.rl.RLClient; + +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimQSimModule; +import ch.sbb.matsim.contrib.railsim.rl.RLClient; +import org.matsim.api.core.v01.Scenario; +import org.matsim.core.config.Config; +import org.matsim.core.config.ConfigUtils; +import org.matsim.core.controler.Controler; +import org.matsim.core.controler.OutputDirectoryHierarchy; +import org.matsim.core.scenario.ScenarioUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class RailsimEnv { + RLClient rlClient; // RLClient would be needed by RailsimEngine. + Controler controler; + public RailsimEnv(RLClient rlClient){ + +// TODO: How to pass this rlClient to the RailsimEngine from this class? + this.rlClient = rlClient; + + } + + public List reset(){ + + // start the simulation + // pass the observation to the RLClient + + String configFilename = "/Users/akashsinha/Documents/SBB/matsim-libs/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/config.xml"; + + Config config = ConfigUtils.loadConfig(configFilename); + config.controller().setOverwriteFileSetting(OutputDirectoryHierarchy.OverwriteFileSetting.deleteDirectoryIfExists); + + Scenario scenario = ScenarioUtils.loadScenario(config); + controler = new Controler(scenario); + + controler.addOverridingModule(new RailsimModule()); + + // if you have other extensions that provide QSim components, call their configure-method here + controler.configureQSimComponents(components -> new RailsimQSimModule().configure(components)); + + // get all train Ids in this scenario. + //TODO: Fix Me: implement the method getAllTrainIds() + List trainIds = new ArrayList<>(); //getAllTrainIds(); + trainIds.add("train0"); + + return trainIds; + } + + void startSimulation(){ + controler.run(); + } + + +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java index 5b5343ed8f6..11973b734cf 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java @@ -19,7 +19,10 @@ package ch.sbb.matsim.contrib.railsim; +import ch.sbb.matsim.contrib.railsim.qsimengine.TrainState; +import org.matsim.api.core.v01.Id; import org.matsim.api.core.v01.Scenario; +import org.matsim.api.core.v01.network.Network; import org.matsim.core.config.Config; import org.matsim.core.config.ConfigUtils; import org.matsim.core.controler.Controler; @@ -27,6 +30,12 @@ import org.matsim.core.scenario.ScenarioUtils; import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimQSimModule; +import org.matsim.pt.transitSchedule.api.TransitLine; +import org.matsim.pt.transitSchedule.api.TransitRoute; +import org.matsim.pt.transitSchedule.api.TransitStopFacility; + +import java.util.List; +import java.util.Map; /** * Example script that shows how to use railsim included in this contrib. @@ -42,7 +51,7 @@ public static void main(String[] args) { if (args.length != 0) { configFilename = args[0]; } else { - configFilename = "test/input/ch/sbb/matsim/contrib/railsim/integration/microOlten/config.xml"; + configFilename = "/Users/akashsinha/Documents/SBB/matsim-libs/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/config.xml"; } Config config = ConfigUtils.loadConfig(configFilename); @@ -57,6 +66,44 @@ public static void main(String[] args) { controler.configureQSimComponents(components -> new RailsimQSimModule().configure(components)); controler.run(); + + // Required: List of all agents + // Schedule of arrival and departure for all the halts for all trains + +// // Transit stops +// Map, TransitStopFacility> transitStops = scenario.getTransitSchedule().getFacilities(); +// +// // Transit lines +// Map, TransitLine> transitLineMap = scenario.getTransitSchedule().getTransitLines(); + + /* + Transit line contains: + - transitRoute + - route profile : sequence of stops with their arrival and departure times + - route: sequence of links + - departures: the train ids and their corresponding departure times. + + Each stop is essentially a link + + There can be multiple transit lines + */ + + /* + Output data structure + + */ +// List transitLines = scenario.getTransitSchedule().getTransitLines().values().stream().toList(); +// +// for (TransitLine tl: transitLines){ +// List transitRoutes = tl.getRoutes().values().stream().toList(); +// for (TransitRoute tr: transitRoutes){ +// tr.getDepartures(); +// tr.getStops().get(0).getStopFacility(); +// +// } +// } + + } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java deleted file mode 100644 index 1e5fd22879a..00000000000 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/observation/TreeObservation.java +++ /dev/null @@ -1,174 +0,0 @@ -package ch.sbb.matsim.contrib.railsim.observation; - -import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimCalc; -import ch.sbb.matsim.contrib.railsim.qsimengine.TrainPosition; -import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; -import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResource; -import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; -import ch.sbb.matsim.contrib.railsim.qsimengine.resources.ResourceState; -import org.matsim.api.core.v01.network.Link; -import org.matsim.api.core.v01.network.Network; -import org.matsim.api.core.v01.network.Node; -// TODO: What is the time resolution of the environment? 1 iteration is equal to home many seconds in rlsim? -// The time resolution of matsim is seconds, same for railsim. But there are seconds where nothing happens, no events in update queue. -import java.util.List; -import java.util.ArrayList; - -class OtherAgent { - // Distance of agent to other agent's tail position if the train is moving in the same direction - // Otherwise, distance to BufferTip position if moving in opposite direction - double distance; - List bufferTipPosition; - List tailPosition; - double speed; - - public OtherAgent(double distance, List bufferTipPosition, List tailPosition, double speed, boolean sameDirection) { - this.distance = distance; - this.bufferTipPosition = bufferTipPosition; - this.tailPosition = tailPosition; - this.speed = speed; - this.sameDirection = sameDirection; - } - - boolean sameDirection; - -} - -class ObservationNode { - Node node; - double distNodeAgent; - double distNodeStop; - boolean isSwitchable; - OtherAgent sameDirAgent; - OtherAgent oppDirAgent; - int numParallelTracks; -} - -public class TreeObservation { - private final TrainPosition position; - private final RailResourceManager resources; - private final Network network; - private List observation; - - public TreeObservation(TrainPosition position, RailResourceManager resources, Network network) { - this.resources = resources; - this.position = position; - this.network = network; - createTreeObs(); - } - - public RailLink getBufferTip() { - double reserveDist = RailsimCalc.calcReservationDistance(position, resources.getLink(position.getHeadLink())); - RailLink currentLink = resources.getLink(position.getHeadLink()); - List reservedSegment = RailsimCalc.calcLinksToBlock(position, currentLink, reserveDist); - // TODO Verify if the links are added in the list in the sequence of occurence. - RailLink bufferTip = reservedSegment.get(reservedSegment.size() - 1); - - return bufferTip; - } - - private void createTreeObs() { - int depth = 3; - createTreeObs(depth); - } - - // TODO: Handle cases when the agent should also see the next 2 halts. - But this may not be needed as the agent can be penalised - // if the agent is unable to reach the next station. - private void createTreeObs(int depth) { -// TODO: This has to be fixed such that the next observation node is calculated from the current position - - // Get the link of the tip of the buffer - RailLink bufferTipLink = getBufferTip(); - - // Get the toNode of the bufferTipLink - Node toNode = getToNode(bufferTipLink); - List exploreQueue = new ArrayList(); - - List obsNode = getNextNodes(toNode); - exploreQueue.addLast(obsNode); - - for (int i = 0; i < depth; i++) { - // Depth traversal: Making an observation Tree of fixed depth - int lenExploreQueue = exploreQueue.size(); - while (lenExploreQueue > 0) { - // Level traversal - Node curNode = exploreQueue.getFirst(); - exploreQueue.remove(0); - - // Look for switches/intersections/stops on the branches stemming out of the current switch - List nextNodes = getNextNodes(curNode); - - for (Node nextNode : nextNodes) { - exploreQueue.addLast(nextNode); - TrainPosition trainF = getClosestTrainOnPathF(curNode, nextNode); - TrainPosition trainR = getClosestTrainOnPathR(curNode, nextNode); - observation.add(createObservatioNode(nextNode, trainF, trainR)); - } - lenExploreQueue -= 1; - } - } - } - - private TrainPosition getClosestTrainOnPathR(Node curNode, Node nextNode) { - // complete route to current position of the train from schedule, if needed? - List previousRoute = position.getRoute(0, position.getRouteIndex()); - - // TODO: If all opposite trains are needed, follow the inLinks of curNode until the nextNode is reached. Store the of the path links. - List path = null; - - // check for each link if there is capacity - for (Link link : path) { - RailLink railLink = resources.getLink(link.getId()); - RailResource resource = railLink.getResource(); - ResourceState state = resource.getState(railLink); - // TODO: Ask Christian how we get the train position of the nearest train. - } - - return null; - } - - private TrainPosition getClosestTrainOnPathF(Node curNode, Node nextNode) { - // complete route from current position of the train from schedule, if needed? - List upcomingRoute = position.getRoute(position.getRouteIndex(), position.getRouteSize()); - - // TODO: If all opposite trains are needed, follow the outLinks of curNode until the nextNode is reached. Store the of the path links. - List path = null; - - // same procedure as above... - - return null; - } - -// TODO: Add the logic to also consider halts/stops - private List getNextNodes(Node curNode) { - // next switches - List switchNodes = new ArrayList<>(); - - // check all outgoing links from the current node - for (Link outLink : curNode.getOutLinks().values()) { - Node nextNode = outLink.getToNode(); - - // follow nodes with only one outgoing link until a switch is reached - while (nextNode.getOutLinks().size() == 1) { - // get the single outgoing link and follow it - // TODO: Handle end of network, will throw NoSuchElementException at the moment - nextNode = nextNode.getOutLinks().values().iterator().next().getToNode(); - } - - // if the next node has more than one outgoing link, it's a switch - if (nextNode.getOutLinks().size() > 1) { - switchNodes.add(nextNode); - } - } - - return switchNodes; - } - - private Node getToNode(RailLink bufferTipLink) { - return network.getLinks().get(bufferTipLink.getLinkId()).getToNode(); - } - - public List getObservation() { - return observation; - } -} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimCalc.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimCalc.java index a4401286cda..84e7b9add38 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimCalc.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimCalc.java @@ -173,7 +173,7 @@ static double calcTargetSpeedForStop(double dist, double acceleration, double de /** * Calculate the minimum distance that needs to be reserved for the train, such that it can stop safely. */ - static double calcReservationDistance(TrainState state, RailLink currentLink) { + public static double calcReservationDistance(TrainState state, RailLink currentLink) { double assumedSpeed = calcPossibleMaxSpeed(state); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java index 9c7ce2e695a..5b95889bb5e 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java @@ -25,11 +25,18 @@ import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.DispositionResponse; import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; +import ch.sbb.matsim.contrib.railsim.rl.RLClient; +import ch.sbb.matsim.contrib.railsim.rl.observation.Observation; +import ch.sbb.matsim.contrib.railsim.rl.observation.ObservationNode; +import ch.sbb.matsim.contrib.railsim.rl.observation.StepOutput; +import ch.sbb.matsim.contrib.railsim.rl.observation.TreeObservation; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.matsim.api.core.v01.Id; import org.matsim.api.core.v01.events.*; import org.matsim.api.core.v01.network.Link; +import org.matsim.api.core.v01.network.Network; +import org.matsim.api.core.v01.network.Node; import org.matsim.core.api.experimental.events.EventsManager; import org.matsim.core.mobsim.framework.MobsimDriverAgent; import org.matsim.core.mobsim.framework.Steppable; @@ -45,7 +52,7 @@ /** * Engine to simulate train movement. */ -final class RailsimEngine implements Steppable { +public class RailsimEngine implements Steppable { /** * Additional safety distance in meter that is added to the reservation distance. @@ -55,12 +62,24 @@ final class RailsimEngine implements Steppable { private static final Logger log = LogManager.getLogger(RailsimEngine.class); private final EventsManager eventsManager; private final RailsimConfigGroup config; - private final List activeTrains = new ArrayList<>(); + protected final List activeTrains = new ArrayList<>(); private final Queue updateQueue = new PriorityQueue<>(); private final RailResourceManager resources; private final TrainDisposition disposition; + private Network network; + private RLClient rlClient; - RailsimEngine(EventsManager eventsManager, RailsimConfigGroup config, RailResourceManager resources, TrainDisposition disposition) { + // Overloaded constructor to be used when using RL based inference + public RailsimEngine(EventsManager eventsManager, RailsimConfigGroup config, RailResourceManager resources, TrainDisposition disposition, Network network, RLClient rlClient) { + this.eventsManager = eventsManager; + this.config = config; + this.resources = resources; + this.disposition = disposition; + this.network = network; + this.rlClient=rlClient; + } + + public RailsimEngine(EventsManager eventsManager, RailsimConfigGroup config, RailResourceManager resources, TrainDisposition disposition) { this.eventsManager = eventsManager; this.config = config; this.resources = resources; @@ -92,24 +111,59 @@ public void doSimStep(double time) { } } - public void doSimStepUsingRLInference(double time) { - - Map> obs = getObservationForAliveTrains(); - Map actions = rlModel(obsDict); - Set agent_ids = actions.keySet(); - for (String aid:agent_ids){ - int action = actions.get(aid); - if (action == 2){ - // Create an event for stop - } - else { - updateRouteForAgent(aid, action); - } +// TODO: Implement getStepOutput() function - calculate reward, truncated and terminated flags + private Map getStepOutput(){ + + // set done before the train reaches the final goal - > when the blocked segment contains the goal link + for (TrainState train : activeTrains){ + // get observation for each train + TreeObservation treeObs = new TreeObservation(train, this.resources, this.network); + List treeObsFlattened = treeObs.getFlattenedObservationTree(); + List listObsNodes = treeObs.getObservationTree(); + // coordinates of train head + List headPosition; + // speed of train } + + return null; + } + + // TODO: Complete the following function + public void doSimStepRL(double time){ + + if (activeTrains.size() >0){ + // check if there is at least 1 active train + // send back Map + Map stepOutputMap = getStepOutput(); + rlClient.sendObservation(stepOutputMap); + + // get action corresponding to the initial observation sent + Map actionMap = rlClient.getAction(); + + // TODO: upadte route based on actionMap + } doSimStep(time); + } +// public void doSimStepUsingRLInference(double time) { +// +// Map> obs = getObservationForAliveTrains(); +// Map actions = rlModel(obsDict); +// Set agent_ids = actions.keySet(); +// for (String aid:agent_ids){ +// int action = actions.get(aid); +// if (action == 2){ +// // Create an event for stop +// } +// else { +// updateRouteForTrain(aid, action); +// } +// } +// +// doSimStep(time); +// } /** @@ -157,8 +211,6 @@ private void updateAllPositions(double time) { for (TrainState train : activeTrains) { if (train.timestamp < time) updateState(time, new UpdateEvent(train, UpdateEvent.Type.POSITION)); -// updateState(time, new UpdateEvent(train, UpdateEvent.Type.POSITION)); -// updateState(time, new UpdateEvent(train, UpdateEvent.Type.POSITION)); } } @@ -168,6 +220,7 @@ private void createEvent(Event event) { this.eventsManager.processEvent(event); } +// TODO: Where is UNBLOCK_LINK event being created? private void updateState(double time, UpdateEvent event) { // Do different updates depending on the type @@ -564,7 +617,6 @@ private double handleTransitStop(double time, TrainState state) { // Time needs to be rounded to current sim step double stopTime = state.pt.handleTransitStop(state.nextStop, Math.ceil(time)); state.nextStop = state.pt.getNextTransitStop(); - return stopTime; } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java index 2e65080b4a1..51e5f040d02 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java @@ -25,6 +25,7 @@ import com.google.inject.Inject; import org.matsim.api.core.v01.Id; import org.matsim.api.core.v01.network.Link; +import org.matsim.api.core.v01.network.Network; import org.matsim.api.core.v01.population.Leg; import org.matsim.api.core.v01.population.PlanElement; import org.matsim.api.core.v01.population.Route; @@ -58,6 +59,8 @@ public class RailsimQSimEngine implements DepartureHandler, MobsimEngine { private RailsimEngine engine; + private Network network; + @Inject public RailsimQSimEngine(QSim qsim, RailResourceManager res, TrainDisposition disposition, TransitStopAgentTracker agentTracker) { this.qsim = qsim; @@ -68,6 +71,17 @@ public RailsimQSimEngine(QSim qsim, RailResourceManager res, TrainDisposition di this.agentTracker = agentTracker; } + @Inject + public RailsimQSimEngine(QSim qsim, RailResourceManager res, TrainDisposition disposition, TransitStopAgentTracker agentTracker, Network network) { + this.qsim = qsim; + this.config = ConfigUtils.addOrGetModule(qsim.getScenario().getConfig(), RailsimConfigGroup.class); + this.res = res; + this.disposition = disposition; + this.modes = config.getNetworkModes(); + this.agentTracker = agentTracker; + this.network = network; + } + @Override public void setInternalInterface(InternalInterface internalInterface) { this.internalInterface = internalInterface; @@ -75,7 +89,8 @@ public void setInternalInterface(InternalInterface internalInterface) { @Override public void onPrepareSim() { - engine = new RailsimEngine(qsim.getEventsManager(), config, res, disposition); +// engine = new RailsimEngine(qsim.getEventsManager(), config, res, disposition); + engine = new RailsimEngine(qsim.getEventsManager(), config, res, disposition, this.network); } @Override diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/TrainState.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/TrainState.java index acd544c74e4..2b1ef43e969 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/TrainState.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/TrainState.java @@ -32,7 +32,7 @@ /** * Stores the mutable current state of a train. */ -final class TrainState implements TrainPosition { +public final class TrainState implements TrainPosition { /** * Driver of the train. @@ -59,7 +59,7 @@ final class TrainState implements TrainPosition { /** * Route of this train. */ - final List route; + public final List route; /** * Current index in the list of route links. diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java new file mode 100644 index 00000000000..441fd5c64c5 --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java @@ -0,0 +1,94 @@ +package ch.sbb.matsim.contrib.railsim.rl; + +import ch.sbb.matsim.contrib.railsim.grpc.RailsimConnecterGrpc; +import ch.sbb.matsim.contrib.railsim.grpc.ProtoObservationMap; +import ch.sbb.matsim.contrib.railsim.grpc.ProtoActionMap; +import ch.sbb.matsim.contrib.railsim.grpc.ProtoStepOutputMap; +import ch.sbb.matsim.contrib.railsim.grpc.ProtoObservation; +import ch.sbb.matsim.contrib.railsim.grpc.ProtoStepOutput; +//import ch.sbb.matsim.contrib.railsim.grpc.*; + +import ch.sbb.matsim.contrib.railsim.rl.observation.Observation; +import ch.sbb.matsim.contrib.railsim.rl.observation.StepOutput; +import io.grpc.*; + +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class RLClient { + private final RailsimConnecterGrpc.RailsimConnecterBlockingStub blockingStub; + private static final Logger logger = Logger.getLogger(RLClient.class.getName()); + public RLClient(int port){ + String target = "localhost:"+port; + ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) + .build(); + blockingStub = RailsimConnecterGrpc.newBlockingStub(channel); + } + + public Map getAction(){ + + ProtoActionMap actionMap; + + try { + // Call the original method on the server. + actionMap = blockingStub.getAction(null); + } catch (StatusRuntimeException e) { + // Log a warning if the RPC fails. + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return null; + } + + return actionMap.getDictActionMap(); + } + + public String sendObservation(Map stepOutputMap){ + + ch.sbb.matsim.contrib.railsim.grpc.ProtoConfirmationResponse msg; + + ProtoStepOutputMap.Builder protoStepOutputMapBuilder = ProtoStepOutputMap.newBuilder(); + for (Map.Entry entry: stepOutputMap.entrySet()){ + + StepOutput stepOutput = entry.getValue(); + + //Build the Observation object - this will be used inside the StepOutput class + ProtoObservation protoObservation = ProtoObservation.newBuilder() + .addAllObsTree(stepOutput.getObservation().getObsTree()) + .addAllPositionNextNode(stepOutput.getObservation().getPositionNextNode()) + .addAllTrainState(stepOutput.getObservation().getTrainState()) + .build(); + + // build StepOutput + ProtoStepOutput protoStepOutput = ProtoStepOutput.newBuilder() + .setObservation(protoObservation) + .setReward(stepOutput.getReward()) + .setTerminated(stepOutput.isTerminated()) + .setTruncated(stepOutput.isTruncated()) + .build(); + + // add stepOutput in the map + protoStepOutputMapBuilder.putDictStepOutput(entry.getKey(), protoStepOutput); + } + + ProtoStepOutputMap protoStepOutputMap = protoStepOutputMapBuilder.build(); + try { + // Call the original method on the server. + msg = blockingStub.updateState(protoStepOutputMap); + } catch (StatusRuntimeException e) { + // Log a warning if the RPC fails. + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return null; + } + + return msg.getAck(); + } + + public static void main(String args[]) throws InterruptedException { + // Access a service running on the local machine on port 50051 + RLClient client = new RLClient(50051); + Observation ob = new Observation(2, true); + Map actionMap = client.getAction(); + System.out.println(actionMap); + } +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java new file mode 100644 index 00000000000..46368d1e514 --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java @@ -0,0 +1,76 @@ +package ch.sbb.matsim.contrib.railsim.rl.observation; + +import java.util.*; + +public class Observation{ + //TODO: Add timestamp as well in the observation + List obsTree = new ArrayList<>(); + // + List trainState = new ArrayList(4); + // x and y coordinate of the node + List positionNextNode = new ArrayList(2); + + public int getTimestamp() { + return timestamp; + } + + public void setTimestamp(int timestamp) { + this.timestamp = timestamp; + } + + int timestamp; + public List getObsTree() { + return obsTree; + } + + public void setObsTree(List obsTree) { + this.obsTree = obsTree; + } + + public List getTrainState() { + return trainState; + } + + public void setTrainState(List trainState) { + this.trainState = trainState; + } + + public List getPositionNextNode() { + return positionNextNode; + } + + public void setPositionNextNode(List positionNextNode) { + this.positionNextNode = positionNextNode; + } + + + @Override + public String toString() { + return "Observation{" + + "obsTree=" + obsTree + + ", trainState=" + trainState + + ", positionNextNode=" + positionNextNode + + '}'; + } + + public void generateRandomObservation(double depthObservationTree){ + for (int i= 0; i<4; i++){ + this.trainState.add(i, Math.random()); + } + for (int i= 0; i<2; i++){ + this.positionNextNode.add(i, Math.random()); + } + int lenObsTree = (int)(Math.pow(2.0, depthObservationTree+1)-1)*17; + for (int i=0; i nodeId; + List position; + double distNodeAgent; + double distNodeStop; + + public ObservationNode(List position, double distNodeAgent, double distNodeStop, int isSwitchable, OtherAgent sameDirAgent, OtherAgent oppDirAgent, int numParallelTracks, Id nodeId) { + this.position = position; + this.distNodeAgent = distNodeAgent; + this.distNodeStop = distNodeStop; + this.isSwitchable = isSwitchable; + this.sameDirAgent = sameDirAgent; + this.oppDirAgent = oppDirAgent; + this.numParallelTracks = numParallelTracks; + this.nodeId = nodeId; + } + + int isSwitchable; + OtherAgent sameDirAgent; + OtherAgent oppDirAgent; + int numParallelTracks; +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/StepOutput.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/StepOutput.java new file mode 100644 index 00000000000..f31e7cc9a71 --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/StepOutput.java @@ -0,0 +1,75 @@ +package ch.sbb.matsim.contrib.railsim.rl.observation; + +import java.util.Collections; +import java.util.Map; + +public class StepOutput{ + Observation observation; + double reward; + boolean terminated; + + public Observation getObservation() { + return observation; + } + + public void setObservation(Observation observation) { + this.observation = observation; + } + + public double getReward() { + return reward; + } + + public void setReward(double reward) { + this.reward = reward; + } + + public boolean isTerminated() { + return terminated; + } + + public void setTerminated(boolean terminated) { + this.terminated = terminated; + } + + public boolean isTruncated() { + return truncated; + } + + public void setTruncated(boolean truncated) { + this.truncated = truncated; + } + + public Map getInfo() { + return info; + } + + public void setInfo(Map info) { + this.info = info; + } + + boolean truncated; + Map info; + + @Override + public String toString() { + return "StepOutput{" + + "observation=" + observation + + ", reward=" + reward + + ", terminated=" + terminated + + ", truncated=" + truncated + + ", info=" + info + + '}'; + } + + public StepOutput(int depthObservationTree, boolean random){ + if (random){ + this.observation = new Observation(depthObservationTree, random); + this.reward = 0; + this.terminated = false; + this.truncated = false; + this.info = Collections.emptyMap(); + } + } + +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java new file mode 100644 index 00000000000..1f4d04b4faf --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java @@ -0,0 +1,227 @@ +package ch.sbb.matsim.contrib.railsim.rl.observation; + +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimCalc; +import ch.sbb.matsim.contrib.railsim.qsimengine.TrainPosition; +import ch.sbb.matsim.contrib.railsim.qsimengine.TrainState; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; +import org.apache.commons.lang3.NotImplementedException; +import org.matsim.api.core.v01.network.Link; +import org.matsim.api.core.v01.network.Network; +import org.matsim.api.core.v01.network.Node; + +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.List; +import java.util.ArrayList; + +import ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils; + +import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.isSwitchable; + +class OtherAgent { + // Distance of agent to other agent's tail position if the train is moving in the same direction + // Otherwise, distance to BufferTip position if moving in opposite direction + double distance; + List bufferTipPosition; + List tailPosition; + double speed; + + public OtherAgent(double distance, List bufferTipPosition, List tailPosition, double speed, boolean sameDirection) { + this.distance = distance; + this.bufferTipPosition = bufferTipPosition; + this.tailPosition = tailPosition; + this.speed = speed; + this.sameDirection = sameDirection; + } + + boolean sameDirection; + +} + +public class TreeObservation { + private final TrainState position; + private final RailResourceManager resources; + private final Network network; + private List observationList; + private List flattenedObservation; + public TreeObservation(TrainState position, RailResourceManager resources, Network network) { + this.resources = resources; + this.position = position; + this.network = network; + this.observationList = new ArrayList<>(); + this.flattenedObservation= new ArrayList<>(); + createTreeObs(); + } + + private RailLink getBufferTip() { + double reserveDist = RailsimCalc.calcReservationDistance(position, resources.getLink(position.getHeadLink())); + RailLink currentLink = resources.getLink(position.getHeadLink()); + List reservedSegment = RailsimCalc.calcLinksToBlock(position, currentLink, reserveDist); + // TODO Verify if the links are added in the list in the sequence of occurence. + RailLink bufferTip = reservedSegment.get(reservedSegment.size() - 1); + + return bufferTip; + } + + private void createTreeObs() { + int depth = 3; + createTreeObs(depth); + } + + + private boolean isObsTreeNode(Node toNode){ + + // Check if the node is halt position of train + boolean isStop = false; + //TODO: This would fail if more than one halts lie in the observation tree. The code would recognise just one. + Node nextHaltToNode = network.getLinks().get(position.getPt().getNextTransitStop().getLinkId()).getToNode(); + if (nextHaltToNode.equals(toNode)) + isStop = true; + + // Check if the node is a switch node. Each node will have minimum two outgoing nodes. + boolean switch_node = toNode.getOutLinks().size() > 2; + if (isStop || switch_node){ + return true; + } + return false; + } + + private ObservationNode createObservatioNode(TrainPosition train, RailLink curLink, Node node, OtherAgent sameDirAgent, OtherAgent oppDirAgent){ + + // Get coordinates of the nextNode + List nodePosition = new ArrayList(Arrays.asList(node.getCoord().getX(), node.getCoord().getY())); + + // Get coordinates of the train head + Node toNodeCurLink = getToNode(curLink); + List toNodeCurLinkPosition = new ArrayList(Arrays.asList(toNodeCurLink.getCoord().getX(), toNodeCurLink.getCoord().getY())); + + //calculate distance of train to nextNode + double distNodeAgent = train.getHeadPosition()+RLUtils.calculateEuclideanDistance(nodePosition, toNodeCurLinkPosition); + + Node nextHaltToNode = network.getLinks().get(position.getPt().getNextTransitStop().getLinkId()).getToNode(); + List nextHaltToNodePosition = new ArrayList(Arrays.asList(nextHaltToNode.getCoord().getX(), nextHaltToNode.getCoord().getY())); + + double distNextHalt = RLUtils.calculateEuclideanDistance(nodePosition, nextHaltToNodePosition); + + int isSwitchable = isSwitchable(node, curLink, network) ? 1 : 0; + + int numParallelIncomingTracks = node.getInLinks().size(); + + return new ObservationNode(nodePosition,distNodeAgent, distNextHalt, isSwitchable, sameDirAgent, oppDirAgent, numParallelIncomingTracks, node.getId()); + + } + + private List flattenObservationNode(ObservationNode obsNode) { + + List flattenedNode = new ArrayList<>(); + + flattenedNode.addAll(obsNode.position); + flattenedNode.add((double) obsNode.numParallelTracks); + flattenedNode.add(obsNode.distNodeStop); + flattenedNode.add(obsNode.distNodeAgent); + flattenedNode.add((double)obsNode.isSwitchable); + + return flattenedNode; + } + + private void createTreeObs(int depth){ + // Get the link of the tip of the buffer + RailLink bufferTipLink = getBufferTip(); + + // Get the toNode of the bufferTipLink + Node toNode = getToNode(bufferTipLink); + List exploreQueue = new ArrayList(); + + + while(!isObsTreeNode(toNode)){ + toNode = toNode.getOutLinks().values().iterator().next().getToNode(); + } + exploreQueue.add(toNode); + + for (int i = 0; i < depth; i++) { + // Level Traversal algorithm + int lenExploreQueue = exploreQueue.size(); + while (lenExploreQueue > 0) { + // Level traversal + Node curNode = exploreQueue.get(0); + exploreQueue.remove(0); + + // Look for switches/intersections/stops on the branches stemming out of the current switch + List nextNodes = getNextNodes(curNode); + + for (Node nextNode : nextNodes) { + exploreQueue.add(nextNode); +// TrainPosition trainF = getClosestTrainOnPathF(curNode, nextNode); +// TrainPosition trainR = getClosestTrainOnPathR(curNode, nextNode); + ObservationNode obsNode = createObservatioNode(position, resources.getLink(position.getHeadLink()), nextNode, null, null); + this.observationList.add(obsNode); + this.flattenedObservation.addAll(flattenObservationNode(obsNode)); + } + lenExploreQueue -= 1; + } + } + } + + private TrainPosition getClosestTrainOnPathR(Node curNode, Node nextNode) { +// // complete route to current position of the train from schedule, if needed? +// List previousRoute = position.getRoute(0, position.getRouteIndex()); + throw new NotImplementedException(); +// // TODO: If all opposite trains are needed, follow the inLinks of curNode until the nextNode is reached. Store the of the path links. +// List path = null; +// +// // check for each link if there is capacity +// for (Link link : path) { +// RailLink railLink = resources.getLink(link.getId()); +// RailResource resource = railLink.getResource(); +// ResourceState state = resource.getState(railLink); +// // TODO: Ask Christian how we get the train position of the nearest train. +// } +// +// return null; + } + + private TrainPosition getClosestTrainOnPathF(Node curNode, Node nextNode) { + throw new NotImplementedException(); +// // complete route from current position of the train from schedule, if needed? +// List upcomingRoute = position.getRoute(position.getRouteIndex(), position.getRouteSize()); +// +// // TODO: If all opposite trains are needed, follow the outLinks of curNode until the nextNode is reached. Store the of the path links. +// List path = null; +// +// // same procedure as above... +// +// return null; + } + + private List getNextNodes(Node curNode) { + // next switches + List obsTreeNodes = new ArrayList<>(); + + // check all outgoing links from the current node + for (Link outLink : curNode.getOutLinks().values()) { + Node nextNode = outLink.getToNode(); + + // follow nodes with only one outgoing link until a switch or a halt is reached + while (!isObsTreeNode(nextNode)) { + // get the single outgoing link and follow it + // TODO: Handle end of network, will throw NoSuchElementException at the moment + nextNode = nextNode.getOutLinks().values().iterator().next().getToNode(); + } + obsTreeNodes.add(nextNode); + } + return obsTreeNodes; + } + + private Node getToNode(RailLink link) { + return network.getLinks().get(link.getLinkId()).getToNode(); + } + + public List getObservationTree() { + return observationList; + } + + public List getFlattenedObservationTree() { + return this.flattenedObservation; + } +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java new file mode 100644 index 00000000000..1f2241c123c --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java @@ -0,0 +1,150 @@ +package ch.sbb.matsim.contrib.railsim.rl.utils; + +import ch.sbb.matsim.contrib.railsim.qsimengine.TrainState; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; +import org.matsim.api.core.v01.events.Event; +import org.matsim.api.core.v01.network.Link; +import org.matsim.api.core.v01.network.Network; +import org.matsim.api.core.v01.network.Node; +import org.matsim.core.events.algorithms.EventWriterXML; +import org.matsim.core.events.handler.BasicEventHandler; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +public class RLUtils { + + public static Node getToNode(Network network, RailLink link) { + return network.getLinks().get(link.getLinkId()).getToNode(); + } + +// TODO: implement the functions: getSwitchNodeOnTrack and updateRoute + private boolean getSwitchNodeOnTrack(Network network, Node start, Node target, List path){ + + /* + get links from a node to the switch node on the same track + l0 St l1 l2 l3 l4 Sw + path = l1, l2, l3, l4 + */ + + while (!start.equals(target) || (start.getOutLinks().values().size() > 2)){ + + // Get outgoing node from the start node + List outLinks = start.getOutLinks().values().stream().collect(Collectors.toList()); + + // + for (Link link: outLinks){ + if new RailLink(link) in path + } + path.add(new RailLink(temp)); + start = temp.getToNode(); + if (start.equals(target)) + return true; + } + + } + public static void updateRoute(Network network, TrainState train, Node nextObsNode){ + + + // Calculate path until nextObsNode + + // Get the last link in the route + RailLink lastLinkInRoute = train.route.get(train.route.size() -1); + + // get the toNode of the lastLinkInRoute + Node nextNode = getToNode(network, lastLinkInRoute); + + // To store the path from lastLinkInRoute to node connecting the nextObsNode + List path = new ArrayList<>(); + + + if (isSwitchable(nextNode, lastLinkInRoute, network)){ + // the nextNode is not a switch Node + + } + else{ + // iterate through all possible directions searching for the nextObsNode + + } + + + // update route of the current train + train.route.addAll(path); + } + + public static double calculateEuclideanDistance(List point1, List point2) { + if (point1.size() != point2.size()) { + throw new IllegalArgumentException("Points must have the same number of dimensions"); + } + + double sumOfSquares = 0.0; + for (int i = 0; i < point1.size(); i++) { + double diff = point1.get(i) - point2.get(i); + sumOfSquares += diff * diff; + } + + return Math.sqrt(sumOfSquares); + } + + public static double calculateAngle(Node point1, Node point2, Node point3) { + double x1 = point1.getCoord().getX(); + double y1 = point1.getCoord().getY(); + double x2 = point2.getCoord().getX(); + double y2 = point2.getCoord().getY(); + double x3 = point3.getCoord().getX(); + double y3 = point3.getCoord().getY(); + + // Calculate the vectors between the points + double vector1x = x2 - x1; + double vector1y = y2 - y1; + double vector2x = x3 - x2; + double vector2y = y3 - y2; + + // Calculate the dot product + double dotProduct = (vector1x * vector2x) + (vector1y * vector2y); + + // Calculate the magnitudes of the vectors + double magnitude1 = Math.sqrt(vector1x * vector1x + vector1y * vector1y); + double magnitude2 = Math.sqrt(vector2x * vector2x + vector2y * vector2y); + + // Calculate the angle in radians + double angleInRadians = Math.acos(dotProduct / (magnitude1 * magnitude2)); + + // Convert radians to degrees + double angleInDegrees = Math.toDegrees(angleInRadians); + + return angleInDegrees; + } + + + + public static Boolean isSwitchable(Node switchNode, RailLink curLink, Network network){ + + int numOutGoingLinks = switchNode.getOutLinks().size(); + if (numOutGoingLinks <= 2){ + return false; + } + else{ + List outLinks = switchNode.getOutLinks().values().stream().toList(); + List toNodeOfOutLinkList = new ArrayList<>(); + for (Link link: outLinks){ + toNodeOfOutLinkList.add(link.getToNode()); + } + + Node fromNodeCurLink = network.getLinks().get(curLink.getLinkId()).getFromNode(); + + int possibleDirections = 0; + for (Node toNodeOfOutLink: toNodeOfOutLinkList){ + double angle = calculateAngle(fromNodeCurLink, switchNode, toNodeOfOutLink); + if (angle < 90.0 && angle > -90.0) + possibleDirections ++; + } + if (possibleDirections > 1) + return true; + else + return false; + } + } + +} diff --git a/contribs/railsim/src/main/proto/railsim.proto b/contribs/railsim/src/main/proto/railsim.proto new file mode 100644 index 00000000000..55055dcc8f2 --- /dev/null +++ b/contribs/railsim/src/main/proto/railsim.proto @@ -0,0 +1,60 @@ +syntax = "proto3"; +option java_multiple_files = true; +import "google/protobuf/empty.proto"; + +option java_package = "ch.sbb.matsim.contrib.railsim.grpc"; + + +message ProtoObservation { + repeated double obsTree = 1; + repeated double trainState = 2; + repeated double positionNextNode = 3; + int32 timestamp=4; +} + +message ProtoStepOutput { + ProtoObservation observation = 1; + double reward = 2; + bool terminated = 3; + bool truncated = 4; + map info = 5; +} + +message ProtoStepOutputMap { + map dictStepOutput= 1; +} + +message ProtoAgentIDs { + repeated string agentId = 1; +} + +message ProtoActionMap{ + map dictAction = 1; +} + +message ProtoObservationMap { + map dictObservation= 1; +} + +message ProtoConfirmationResponse{ + string ack = 1; +} + +service RailsimConnecter { + + rpc getAction (google.protobuf.Empty) returns (ProtoActionMap) {}; + + rpc updateState (ProtoStepOutputMap) returns (ProtoConfirmationResponse) {}; + +} + +message ProtoGrpcPort{ + int32 grpcPort = 1; +} + +service RailsimFactory{ + rpc getEnvironment (ProtoGrpcPort) returns (ProtoConfirmationResponse) {}; + + rpc resetEnv (ProtoGrpcPort) returns (ProtoAgentIDs) {}; +} + From 4cef09f786e53d186ca9d2331b9cdafa50909a80 Mon Sep 17 00:00:00 2001 From: iamakash Date: Tue, 28 May 2024 11:55:18 +0200 Subject: [PATCH 05/14] feat: create function to add route to train route --- .../contrib/railsim/rl/utils/RLUtils.java | 74 ++++++++++++------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java index 1f2241c123c..ecfc0688516 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java @@ -20,56 +20,78 @@ public static Node getToNode(Network network, RailLink link) { } // TODO: implement the functions: getSwitchNodeOnTrack and updateRoute - private boolean getSwitchNodeOnTrack(Network network, Node start, Node target, List path){ + private static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, List path){ /* - get links from a node to the switch node on the same track + get links from a node to the next switch node on the same track. l0 St l1 l2 l3 l4 Sw path = l1, l2, l3, l4 - */ - while (!start.equals(target) || (start.getOutLinks().values().size() > 2)){ + If target node found in the path, return true else false + */ + Node start = curLink.getToNode(); + while ( start.getOutLinks().values().size() <= 2){ + Link nextLink = null; - // Get outgoing node from the start node + // get outLinks from the start node List outLinks = start.getOutLinks().values().stream().collect(Collectors.toList()); - - // - for (Link link: outLinks){ - if new RailLink(link) in path + assert (outLinks.size() == 2); + + // get the fromNode of the curLink + Node fromNodeCurLink = curLink.getFromNode(); + + for (Link outLink : outLinks){ + if (outLink.getToNode().equals(fromNodeCurLink)){ + // Ignore the link where outLink(start) == curLink + continue; + } + else { + nextLink = outLink; + break; + } } - path.add(new RailLink(temp)); - start = temp.getToNode(); - if (start.equals(target)) + path.add(new RailLink(nextLink)); + + // update start node and curLink + start = nextLink.getToNode(); + curLink = nextLink; + if (start.equals(target)) { + // target found in the path return true; + } } - + // no path found to the target + return false; } public static void updateRoute(Network network, TrainState train, Node nextObsNode){ - - // Calculate path until nextObsNode - // Get the last link in the route RailLink lastLinkInRoute = train.route.get(train.route.size() -1); // get the toNode of the lastLinkInRoute - Node nextNode = getToNode(network, lastLinkInRoute); + Node toNodeLastLinkInRoute = getToNode(network, lastLinkInRoute); - // To store the path from lastLinkInRoute to node connecting the nextObsNode - List path = new ArrayList<>(); + // get the fromNode of the lastLinkInRoute + Node fromNodeLastLinkInRoute = network.getLinks().get(lastLinkInRoute.getLinkId()).getFromNode(); + // get the outLinks from the toNode of the lastLinkInRoute + List nextLinks = toNodeLastLinkInRoute.getOutLinks().values().stream().collect(Collectors.toList()); - if (isSwitchable(nextNode, lastLinkInRoute, network)){ - // the nextNode is not a switch Node + //path to the nextObsNode + List path = null; + for (Link nextLink : nextLinks){ - } - else{ - // iterate through all possible directions searching for the nextObsNode + // skip the link that takes back on the same track + if (nextLink.getToNode().equals(fromNodeLastLinkInRoute)) + continue; + // To store the path from lastLinkInRoute to node connecting the nextObsNode + path = new ArrayList<>(); + if (getPathToSwitchNodeOnTrack(network.getLinks().get(lastLinkInRoute.getLinkId()), nextObsNode, path)) + break; } - - // update route of the current train + assert path != null; train.route.addAll(path); } From 16a74ae8340e9a291c684cb0cd861db0fb154dc3 Mon Sep 17 00:00:00 2001 From: iamakash Date: Wed, 29 May 2024 21:51:44 +0200 Subject: [PATCH 06/14] feat: implement RLDisposition --- .../railsim/EnvironmentFactoryServer.java | 6 +- .../matsim/contrib/railsim/RailsimEnv.java | 38 ++- .../railsim/qsimengine/RailsimEngine.java | 81 +----- .../railsim/qsimengine/RailsimQSimEngine.java | 2 +- .../disposition/RLTrainDisposition.java | 237 ++++++++++++++++++ .../disposition/SimpleDisposition.java | 5 + .../disposition/TrainDisposition.java | 5 + .../matsim/contrib/railsim/rl/RLClient.java | 4 +- .../railsim/rl/observation/Observation.java | 38 ++- .../rl/observation/ObservationNode.java | 30 --- .../rl/observation/ObservationTreeNode.java | 93 +++++++ .../railsim/rl/observation/StepOutput.java | 5 + .../rl/observation/TreeObservation.java | 30 +-- .../contrib/railsim/rl/utils/RLUtils.java | 4 +- 14 files changed, 436 insertions(+), 142 deletions(-) create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java delete mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/ObservationNode.java create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/ObservationTreeNode.java diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java index 52cd5574ff7..f3b6508e3c4 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java @@ -102,9 +102,9 @@ public void resetEnv(ProtoGrpcPort grpcPort, StreamObserver respo List agentIds = env.reset(); //Create response using agentIds - ProtoAgentIDs.Builder agentIDsBuilder = ProtoAgentIDs.newBuilder(); - agentIDsBuilder.addAllAgentId(agentIds); - ProtoAgentIDs response = agentIDsBuilder.build(); + ProtoAgentIDs response = ProtoAgentIDs.newBuilder() + .addAllAgentId(agentIds) + .build(); // Send the reply back to the client. responseObserver.onNext(response); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java index 570927b5863..d0b5dfd2824 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java @@ -4,27 +4,51 @@ import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimQSimModule; import ch.sbb.matsim.contrib.railsim.rl.RLClient; +import org.matsim.api.core.v01.Id; import org.matsim.api.core.v01.Scenario; +import org.matsim.api.core.v01.population.Route; import org.matsim.core.config.Config; import org.matsim.core.config.ConfigUtils; import org.matsim.core.controler.Controler; import org.matsim.core.controler.OutputDirectoryHierarchy; import org.matsim.core.scenario.ScenarioUtils; - +import org.matsim.pt.transitSchedule.api.Departure; +import org.matsim.pt.transitSchedule.api.TransitRoute; +import org.matsim.vehicles.Vehicle; +import org.matsim.visum.VisumNetwork; +import org.matsim.pt.transitSchedule.api.TransitLine; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; public class RailsimEnv { RLClient rlClient; // RLClient would be needed by RailsimEngine. Controler controler; public RailsimEnv(RLClient rlClient){ - -// TODO: How to pass this rlClient to the RailsimEngine from this class? +// TODO: Pass this to RLDisposition this.rlClient = rlClient; } + private List getAllTrainIds(Scenario scenario){ + + List trainIds = new ArrayList<>(); + + List transitLines = scenario.getTransitSchedule().getTransitLines().values().stream().collect(Collectors.toList()); + for (TransitLine trainLine : transitLines){ + List transitRoutes = trainLine.getRoutes().values().stream().collect(Collectors.toList()); + for (TransitRoute transitRoute: transitRoutes){ + List departures= transitRoute.getDepartures().values().stream().collect(Collectors.toList()); + for(Departure departure: departures){ + Id ID = departure.getVehicleId(); + trainIds.add(ID.toString()); + } + } + } + return trainIds; + } + public List reset(){ // start the simulation @@ -42,13 +66,9 @@ public List reset(){ // if you have other extensions that provide QSim components, call their configure-method here controler.configureQSimComponents(components -> new RailsimQSimModule().configure(components)); - - // get all train Ids in this scenario. //TODO: Fix Me: implement the method getAllTrainIds() - List trainIds = new ArrayList<>(); //getAllTrainIds(); - trainIds.add("train0"); - - return trainIds; + // get all train Ids in this scenario. + return getAllTrainIds(scenario); } void startSimulation(){ diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java index 5b95889bb5e..83a85969e6e 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java @@ -26,8 +26,7 @@ import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; import ch.sbb.matsim.contrib.railsim.rl.RLClient; -import ch.sbb.matsim.contrib.railsim.rl.observation.Observation; -import ch.sbb.matsim.contrib.railsim.rl.observation.ObservationNode; +import ch.sbb.matsim.contrib.railsim.rl.observation.ObservationTreeNode; import ch.sbb.matsim.contrib.railsim.rl.observation.StepOutput; import ch.sbb.matsim.contrib.railsim.rl.observation.TreeObservation; import org.apache.logging.log4j.LogManager; @@ -36,7 +35,6 @@ import org.matsim.api.core.v01.events.*; import org.matsim.api.core.v01.network.Link; import org.matsim.api.core.v01.network.Network; -import org.matsim.api.core.v01.network.Node; import org.matsim.core.api.experimental.events.EventsManager; import org.matsim.core.mobsim.framework.MobsimDriverAgent; import org.matsim.core.mobsim.framework.Steppable; @@ -66,18 +64,16 @@ public class RailsimEngine implements Steppable { private final Queue updateQueue = new PriorityQueue<>(); private final RailResourceManager resources; private final TrainDisposition disposition; - private Network network; - private RLClient rlClient; // Overloaded constructor to be used when using RL based inference - public RailsimEngine(EventsManager eventsManager, RailsimConfigGroup config, RailResourceManager resources, TrainDisposition disposition, Network network, RLClient rlClient) { - this.eventsManager = eventsManager; - this.config = config; - this.resources = resources; - this.disposition = disposition; - this.network = network; - this.rlClient=rlClient; - } +// public RailsimEngine(EventsManager eventsManager, RailsimConfigGroup config, RailResourceManager resources, TrainDisposition disposition, Network network, RLClient rlClient) { +// this.eventsManager = eventsManager; +// this.config = config; +// this.resources = resources; +// this.disposition = disposition; +// this.network = network; +// this.rlClient=rlClient; +// } public RailsimEngine(EventsManager eventsManager, RailsimConfigGroup config, RailResourceManager resources, TrainDisposition disposition) { this.eventsManager = eventsManager; @@ -111,61 +107,6 @@ public void doSimStep(double time) { } } -// TODO: Implement getStepOutput() function - calculate reward, truncated and terminated flags - private Map getStepOutput(){ - - // set done before the train reaches the final goal - > when the blocked segment contains the goal link - for (TrainState train : activeTrains){ - // get observation for each train - TreeObservation treeObs = new TreeObservation(train, this.resources, this.network); - List treeObsFlattened = treeObs.getFlattenedObservationTree(); - List listObsNodes = treeObs.getObservationTree(); - // coordinates of train head - List headPosition; - // speed of train - } - - - return null; - } - - // TODO: Complete the following function - public void doSimStepRL(double time){ - - if (activeTrains.size() >0){ - // check if there is at least 1 active train - // send back Map - Map stepOutputMap = getStepOutput(); - rlClient.sendObservation(stepOutputMap); - - // get action corresponding to the initial observation sent - Map actionMap = rlClient.getAction(); - - // TODO: upadte route based on actionMap - } - doSimStep(time); - - } - -// public void doSimStepUsingRLInference(double time) { -// -// Map> obs = getObservationForAliveTrains(); -// Map actions = rlModel(obsDict); -// Set agent_ids = actions.keySet(); -// for (String aid:agent_ids){ -// int action = actions.get(aid); -// if (action == 2){ -// // Create an event for stop -// } -// else { -// updateRouteForTrain(aid, action); -// } -// } -// -// doSimStep(time); -// } - - /** * Update the current state of all trains, even if no update would be needed. */ @@ -442,6 +383,9 @@ private void enterLink(double time, UpdateEvent event) { // Arrival at destination if (!event.waitingForLink && state.isRouteAtEnd()) { + //call disposition + disposition.onArrival(time, event.state); + assert FuzzyUtils.equals(state.speed, 0) : "Speed must be 0 at end, but was " + state.speed; // Free all reservations @@ -687,6 +631,7 @@ private void decideNextUpdate(UpdateEvent event) { event.type = UpdateEvent.Type.LEAVE_LINK; } else if (reserveDist <= accelDist && reserveDist <= decelDist && reserveDist <= tailDist && reserveDist <= headDist) { +// TODO: I understand why reserveDist is compared with headDist but the other comparisions are not clear to me. dist = reserveDist; event.type = UpdateEvent.Type.BLOCK_TRACK; } else if (accelDist <= decelDist && accelDist <= reserveDist && accelDist <= tailDist && accelDist <= headDist) { diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java index 51e5f040d02..5bb0bf876d1 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java @@ -90,7 +90,7 @@ public void setInternalInterface(InternalInterface internalInterface) { @Override public void onPrepareSim() { // engine = new RailsimEngine(qsim.getEventsManager(), config, res, disposition); - engine = new RailsimEngine(qsim.getEventsManager(), config, res, disposition, this.network); + engine = new RailsimEngine(qsim.getEventsManager(), config, res, disposition); } @Override diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java new file mode 100644 index 00000000000..e5eea38514a --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java @@ -0,0 +1,237 @@ +package ch.sbb.matsim.contrib.railsim.qsimengine.disposition; + +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimCalc; +import ch.sbb.matsim.contrib.railsim.qsimengine.TrainPosition; +import ch.sbb.matsim.contrib.railsim.qsimengine.TrainState; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResource; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; +import ch.sbb.matsim.contrib.railsim.qsimengine.router.TrainRouter; +import ch.sbb.matsim.contrib.railsim.rl.RLClient; +import ch.sbb.matsim.contrib.railsim.rl.observation.Observation; +import ch.sbb.matsim.contrib.railsim.rl.observation.ObservationTreeNode; +import ch.sbb.matsim.contrib.railsim.rl.observation.StepOutput; +import ch.sbb.matsim.contrib.railsim.rl.observation.TreeObservation; +import org.apache.commons.jxpath.ri.compiler.Step; +import org.matsim.api.core.v01.Coord; +import org.matsim.api.core.v01.network.Link; +import org.matsim.api.core.v01.network.Network; +import org.matsim.api.core.v01.network.Node; +import org.matsim.core.mobsim.framework.MobsimDriverAgent; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.getPathToSwitchNodeOnTrack; +import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.updateRoute; + +public class RLTrainDisposition implements TrainDisposition { + RailResourceManager resources; + TrainRouter router; + Network network; + RLClient rlClient; + + Map bufferStepOutputMap; + public RLTrainDisposition(RailResourceManager resources, TrainRouter router, Network network, RLClient rlClient) { + this.resources = resources; + this.router = router; + this.network = network; + this.rlClient = rlClient; + this.bufferStepOutputMap = new HashMap<>(); + } + + private Double getReward(TrainState train){ + /** + * 1. check if the train is departing from one of the halts + * 2. Get the actual departure time + * 3. Get the scheduled departure time + * 4. calculate reward + * + * TODO: What happens when the train arrives late at a halt? Does it wait for fixed amount of time as scheduled or it leaves as per the scheduled departure time if possible? + * + * TODO: What happens if the train arrives at a halt later than it's departure time? Does the train stop at all? + */ + + + return -1.0; + } + + @Override + public void onDeparture(double time, MobsimDriverAgent driver, List route) { + // Update route for the train until the next switch node. + Link curLink = network.getLinks().get(driver.getCurrentLinkId()); + getPathToSwitchNodeOnTrack(curLink, null, route); + } + + @Override + public DispositionResponse requestNextSegment(double time, TrainPosition position, double dist) { + // calculate and send StepOutput to rl + Map stepOutputMap = getStepOutput(position, false, time); + + if (bufferStepOutputMap.size()==0){ + rlClient.sendObservation(stepOutputMap); + } + else{ + bufferStepOutputMap.putAll(stepOutputMap); + bufferStepOutputMap.clear(); + } + + // get action from rl + Map actionMap = rlClient.getAction(); + + //update route based on the action from rl + int action = actionMap.get(position.getTrain().id().toString()); + + StepOutput out = stepOutputMap.get(position.getTrain().id().toString()); + switch (action){ + case 0:{ + // update route in the query direction + Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(1).getNodeId()); + updateRoute(network, (TrainState) position, nextSwitchNodePos); + break; + } + case 1:{ + // update route in the other direction + Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(2).getNodeId()); + updateRoute(network, (TrainState) position, nextSwitchNodePos); + break; + } + case 2:{ + // stop the train + return new DispositionResponse(0, 0, null); + } + + } + + RailLink currentLink = resources.getLink(position.getHeadLink()); + List segment = RailsimCalc.calcLinksToBlock(position, currentLink, dist); + + // NOTE: Check for rerouting is omitted in this implementation + + double reserveDist = resources.tryBlockLink(time, currentLink, RailResourceManager.ANY_TRACK_NON_BLOCKING, position); + + if (reserveDist == RailResource.NO_RESERVATION) + return new DispositionResponse(0, 0, null); + + // current link only partial reserved + if (reserveDist < currentLink.length) { + return new DispositionResponse(reserveDist - position.getHeadPosition(), 0, null); + } + + // remove already used distance + reserveDist -= position.getHeadPosition(); + + boolean stop = false; + // Iterate all links that need to be blocked + for (RailLink link : segment) { + + // first link does not need to be blocked again + if (link == currentLink) + continue; + + dist = resources.tryBlockLink(time, link, RailResourceManager.ANY_TRACK_NON_BLOCKING, position); + + if (dist == RailResource.NO_RESERVATION) { + stop = true; + break; + } + + // partial reservation + reserveDist += dist; + + // If the link is not fully reserved then stop + // there might be a better advised speed (speed of train in-front) + if (dist < link.getLength()) { + stop = true; + break; + } + } + return new DispositionResponse(reserveDist, stop ? 0 : Double.POSITIVE_INFINITY, null); + + } + + @Override + public void unblockRailLink(double time, MobsimDriverAgent driver, RailLink link) { + // put resource handling into release track + resources.releaseLink(time, link, driver); + + } + + @Override + public void onArrival(double time, TrainPosition position) { + // Store the StepOutput in bufferStepOutput. + // bufferStepOutput is not sent to RL until there is an observation for a train whose done=false + bufferStepOutputMap.putAll(getStepOutput(position, true, time)); + + } + + + Observation getObservation(double time, TrainPosition train){ + + Observation ob = new Observation(); + + // get observation for each train + TreeObservation treeObs = new TreeObservation((TrainState) train, this.resources, this.network, 2); + List treeObsFlattened = treeObs.getFlattenedObservationTree(); + List listObsNodes = treeObs.getObservationTree(); + + // set ObsTree field of observation + ob.setObsTree(listObsNodes); + ob.setFlattenedObsTree(treeObsFlattened); + + // choose the left child of the root node as the nextNode + Node nextNode = network.getNodes().get(listObsNodes.get(1).getNodeId()); + List positionNextNode = new ArrayList<>(); + positionNextNode.add(nextNode.getCoord().getX()); + positionNextNode.add(nextNode.getCoord().getY()); + + // set the PositionNextNode of the observation + ob.setPositionNextNode(positionNextNode); + + // set the state of the train - headlink fromNode Coords, headPostion, speed + List extractedTrainState = new ArrayList<>(); + + // Add headlink fromNode Coords + List headLinkFromNodePosition = new ArrayList<>(); + Coord headLinkFromNodeCoord = network.getLinks().get(train.getHeadLink()).getFromNode().getCoord(); + headLinkFromNodePosition.add(headLinkFromNodeCoord.getX()); + headLinkFromNodePosition.add(headLinkFromNodeCoord.getY()); + + extractedTrainState.addAll(headLinkFromNodePosition); + + // TODO: Add train speed + extractedTrainState.add(train.getTrain().maxVelocity()); + + // Add headPosition + extractedTrainState.add(train.getHeadPosition()); + + ob.setTrainState(extractedTrainState); + + // Add railsim timestep + ob.setRailsim_timestamp(time); + + return ob; + } + + private Map getStepOutput(TrainPosition train, Boolean done, double time){ + + StepOutput stepOutput = new StepOutput(); + + stepOutput.setInfo(null); + stepOutput.setReward(getReward((TrainState) train)); + stepOutput.setTerminated(done); + stepOutput.setTruncated(done); + stepOutput.setObservation(getObservation(time, train)); + + Map stepOutputMap= new HashMap<>(); + stepOutputMap.put(train.getTrain().id().toString(), stepOutput); + + return stepOutputMap; + } +} + + + + diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java index 277fedab883..8369c7c0a0d 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java @@ -196,4 +196,9 @@ public void unblockRailLink(double time, MobsimDriverAgent driver, RailLink link // put resource handling into release track resources.releaseLink(time, link, driver); } + + @Override + public void onArrival(double time, TrainPosition position) { + + } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/TrainDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/TrainDisposition.java index 3cd2cdeb11a..cbd6ba2cff5 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/TrainDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/TrainDisposition.java @@ -49,4 +49,9 @@ public interface TrainDisposition { */ void unblockRailLink(double time, MobsimDriverAgent driver, RailLink link); + /** + * Method invoked when a train is arriving at rout end. + */ + void onArrival(double time, TrainPosition position); + } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java index 441fd5c64c5..e936a0fe900 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java @@ -1,7 +1,6 @@ package ch.sbb.matsim.contrib.railsim.rl; import ch.sbb.matsim.contrib.railsim.grpc.RailsimConnecterGrpc; -import ch.sbb.matsim.contrib.railsim.grpc.ProtoObservationMap; import ch.sbb.matsim.contrib.railsim.grpc.ProtoActionMap; import ch.sbb.matsim.contrib.railsim.grpc.ProtoStepOutputMap; import ch.sbb.matsim.contrib.railsim.grpc.ProtoObservation; @@ -13,7 +12,6 @@ import io.grpc.*; import java.util.Map; -import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; @@ -54,7 +52,7 @@ public String sendObservation(Map stepOutputMap){ //Build the Observation object - this will be used inside the StepOutput class ProtoObservation protoObservation = ProtoObservation.newBuilder() - .addAllObsTree(stepOutput.getObservation().getObsTree()) + .addAllObsTree(stepOutput.getObservation().getFlattenedObsTree()) .addAllPositionNextNode(stepOutput.getObservation().getPositionNextNode()) .addAllTrainState(stepOutput.getObservation().getTrainState()) .build(); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java index 46368d1e514..b26fe4a4bc3 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java @@ -4,27 +4,43 @@ public class Observation{ //TODO: Add timestamp as well in the observation - List obsTree = new ArrayList<>(); + double railsim_timestamp; + + List flattenedObsTree = new ArrayList<>(); + + public Observation() { + + } + + public List getObsTree() { + return obsTree; + } + + public void setObsTree(List obsTree) { + this.obsTree = obsTree; + } + + List obsTree= new ArrayList(); // List trainState = new ArrayList(4); // x and y coordinate of the node List positionNextNode = new ArrayList(2); - public int getTimestamp() { - return timestamp; + public double getRailsim_timestamp() { + return railsim_timestamp; } - public void setTimestamp(int timestamp) { - this.timestamp = timestamp; + public void setRailsim_timestamp(double railsim_timestamp) { + this.railsim_timestamp = railsim_timestamp; } int timestamp; - public List getObsTree() { - return obsTree; + public List getFlattenedObsTree() { + return flattenedObsTree; } - public void setObsTree(List obsTree) { - this.obsTree = obsTree; + public void setFlattenedObsTree(List flattenedObsTree) { + this.flattenedObsTree = flattenedObsTree; } public List getTrainState() { @@ -47,7 +63,7 @@ public void setPositionNextNode(List positionNextNode) { @Override public String toString() { return "Observation{" + - "obsTree=" + obsTree + + "obsTree=" + flattenedObsTree + ", trainState=" + trainState + ", positionNextNode=" + positionNextNode + '}'; @@ -62,7 +78,7 @@ public void generateRandomObservation(double depthObservationTree){ } int lenObsTree = (int)(Math.pow(2.0, depthObservationTree+1)-1)*17; for (int i=0; i nodeId; - List position; - double distNodeAgent; - double distNodeStop; - - public ObservationNode(List position, double distNodeAgent, double distNodeStop, int isSwitchable, OtherAgent sameDirAgent, OtherAgent oppDirAgent, int numParallelTracks, Id nodeId) { - this.position = position; - this.distNodeAgent = distNodeAgent; - this.distNodeStop = distNodeStop; - this.isSwitchable = isSwitchable; - this.sameDirAgent = sameDirAgent; - this.oppDirAgent = oppDirAgent; - this.numParallelTracks = numParallelTracks; - this.nodeId = nodeId; - } - - int isSwitchable; - OtherAgent sameDirAgent; - OtherAgent oppDirAgent; - int numParallelTracks; -} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/ObservationTreeNode.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/ObservationTreeNode.java new file mode 100644 index 00000000000..27b695c8be1 --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/ObservationTreeNode.java @@ -0,0 +1,93 @@ +package ch.sbb.matsim.contrib.railsim.rl.observation; + +import org.matsim.api.core.v01.Id; +import org.matsim.api.core.v01.network.Node; + +import java.util.List; + +public class ObservationTreeNode { + + Id nodeId; + List position; + double distNodeAgent; + double distNodeStop; + int isSwitchable; + OtherAgent sameDirAgent; + OtherAgent oppDirAgent; + int numParallelTracks; + + public Id getNodeId() { + return nodeId; + } + + public void setNodeId(Id nodeId) { + this.nodeId = nodeId; + } + + public List getPosition() { + return position; + } + + public void setPosition(List position) { + this.position = position; + } + + public double getDistNodeAgent() { + return distNodeAgent; + } + + public void setDistNodeAgent(double distNodeAgent) { + this.distNodeAgent = distNodeAgent; + } + + public double getDistNodeStop() { + return distNodeStop; + } + + public void setDistNodeStop(double distNodeStop) { + this.distNodeStop = distNodeStop; + } + + public int getIsSwitchable() { + return isSwitchable; + } + + public void setIsSwitchable(int isSwitchable) { + this.isSwitchable = isSwitchable; + } + + public OtherAgent getSameDirAgent() { + return sameDirAgent; + } + + public void setSameDirAgent(OtherAgent sameDirAgent) { + this.sameDirAgent = sameDirAgent; + } + + public OtherAgent getOppDirAgent() { + return oppDirAgent; + } + + public void setOppDirAgent(OtherAgent oppDirAgent) { + this.oppDirAgent = oppDirAgent; + } + + public int getNumParallelTracks() { + return numParallelTracks; + } + + public void setNumParallelTracks(int numParallelTracks) { + this.numParallelTracks = numParallelTracks; + } + + public ObservationTreeNode(List position, double distNodeAgent, double distNodeStop, int isSwitchable, OtherAgent sameDirAgent, OtherAgent oppDirAgent, int numParallelTracks, Id nodeId) { + this.position = position; + this.distNodeAgent = distNodeAgent; + this.distNodeStop = distNodeStop; + this.isSwitchable = isSwitchable; + this.sameDirAgent = sameDirAgent; + this.oppDirAgent = oppDirAgent; + this.numParallelTracks = numParallelTracks; + this.nodeId = nodeId; + } +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/StepOutput.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/StepOutput.java index f31e7cc9a71..afdd1790505 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/StepOutput.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/StepOutput.java @@ -72,4 +72,9 @@ public StepOutput(int depthObservationTree, boolean random){ } } + public StepOutput(){ + + } + + } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java index 1f4d04b4faf..bf82e115e7c 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java @@ -10,7 +10,6 @@ import org.matsim.api.core.v01.network.Network; import org.matsim.api.core.v01.network.Node; -import java.lang.reflect.Field; import java.util.Arrays; import java.util.List; import java.util.ArrayList; @@ -43,15 +42,18 @@ public class TreeObservation { private final TrainState position; private final RailResourceManager resources; private final Network network; - private List observationList; + private List observationList; private List flattenedObservation; - public TreeObservation(TrainState position, RailResourceManager resources, Network network) { + + int depth; + public TreeObservation(TrainState position, RailResourceManager resources, Network network, int depth) { this.resources = resources; this.position = position; this.network = network; this.observationList = new ArrayList<>(); this.flattenedObservation= new ArrayList<>(); - createTreeObs(); + this.depth = depth; + createTreeObs(this.depth); } private RailLink getBufferTip() { @@ -64,11 +66,6 @@ private RailLink getBufferTip() { return bufferTip; } - private void createTreeObs() { - int depth = 3; - createTreeObs(depth); - } - private boolean isObsTreeNode(Node toNode){ @@ -87,7 +84,7 @@ private boolean isObsTreeNode(Node toNode){ return false; } - private ObservationNode createObservatioNode(TrainPosition train, RailLink curLink, Node node, OtherAgent sameDirAgent, OtherAgent oppDirAgent){ + private ObservationTreeNode createObservatioNode(TrainPosition train, RailLink curLink, Node node, OtherAgent sameDirAgent, OtherAgent oppDirAgent){ // Get coordinates of the nextNode List nodePosition = new ArrayList(Arrays.asList(node.getCoord().getX(), node.getCoord().getY())); @@ -108,11 +105,11 @@ private ObservationNode createObservatioNode(TrainPosition train, RailLink curLi int numParallelIncomingTracks = node.getInLinks().size(); - return new ObservationNode(nodePosition,distNodeAgent, distNextHalt, isSwitchable, sameDirAgent, oppDirAgent, numParallelIncomingTracks, node.getId()); + return new ObservationTreeNode(nodePosition,distNodeAgent, distNextHalt, isSwitchable, sameDirAgent, oppDirAgent, numParallelIncomingTracks, node.getId()); } - private List flattenObservationNode(ObservationNode obsNode) { + private List flattenObservationNode(ObservationTreeNode obsNode) { List flattenedNode = new ArrayList<>(); @@ -154,7 +151,7 @@ private void createTreeObs(int depth){ exploreQueue.add(nextNode); // TrainPosition trainF = getClosestTrainOnPathF(curNode, nextNode); // TrainPosition trainR = getClosestTrainOnPathR(curNode, nextNode); - ObservationNode obsNode = createObservatioNode(position, resources.getLink(position.getHeadLink()), nextNode, null, null); + ObservationTreeNode obsNode = createObservatioNode(position, resources.getLink(position.getHeadLink()), nextNode, null, null); this.observationList.add(obsNode); this.flattenedObservation.addAll(flattenObservationNode(obsNode)); } @@ -177,7 +174,6 @@ private TrainPosition getClosestTrainOnPathR(Node curNode, Node nextNode) { // ResourceState state = resource.getState(railLink); // // TODO: Ask Christian how we get the train position of the nearest train. // } -// // return null; } @@ -194,6 +190,8 @@ private TrainPosition getClosestTrainOnPathF(Node curNode, Node nextNode) { // return null; } + + private List getNextNodes(Node curNode) { // next switches List obsTreeNodes = new ArrayList<>(); @@ -206,6 +204,8 @@ private List getNextNodes(Node curNode) { while (!isObsTreeNode(nextNode)) { // get the single outgoing link and follow it // TODO: Handle end of network, will throw NoSuchElementException at the moment + + //TODO: Fix me. This is wrong as they nextNode in the same track can have 2 outLinks. nextNode = nextNode.getOutLinks().values().iterator().next().getToNode(); } obsTreeNodes.add(nextNode); @@ -217,7 +217,7 @@ private Node getToNode(RailLink link) { return network.getLinks().get(link.getLinkId()).getToNode(); } - public List getObservationTree() { + public List getObservationTree() { return observationList; } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java index ecfc0688516..1869ac78ee1 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java @@ -20,7 +20,7 @@ public static Node getToNode(Network network, RailLink link) { } // TODO: implement the functions: getSwitchNodeOnTrack and updateRoute - private static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, List path){ + public static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, List path){ /* get links from a node to the next switch node on the same track. @@ -55,7 +55,7 @@ private static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, Li // update start node and curLink start = nextLink.getToNode(); curLink = nextLink; - if (start.equals(target)) { + if (target != null && start.equals(target)) { // target found in the path return true; } From 44a372de3e4584a10688ff82ae57c3c0da9759f3 Mon Sep 17 00:00:00 2001 From: iamakash Date: Thu, 30 May 2024 20:32:31 +0200 Subject: [PATCH 07/14] fix: fix bug in observation tree fixed bug in observation tree in getNextNodes method. Added the code to integrate RLDisposition with Railsim --- .../railsim/EnvironmentFactoryServer.java | 2 + .../matsim/contrib/railsim/RailsimEnv.java | 3 +- .../qsimengine/RailsimRLQSimModule.java | 43 ++++++++ .../disposition/RLTrainDisposition.java | 3 + .../matsim/contrib/railsim/rl/RLClient.java | 1 + .../railsim/rl/observation/Observation.java | 1 - .../rl/observation/TreeObservation.java | 102 ++++++++++-------- contribs/railsim/src/main/proto/railsim.proto | 2 +- 8 files changed, 110 insertions(+), 47 deletions(-) create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java index f3b6508e3c4..716f6d4e709 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java @@ -108,10 +108,12 @@ public void resetEnv(ProtoGrpcPort grpcPort, StreamObserver respo // Send the reply back to the client. responseObserver.onNext(response); + // Indicate that no further messages will be sent to the client. responseObserver.onCompleted(); // Start the simulation + //TODO: Should this be started on a different thread so that the endpoint is not blocked or is it automatically taken care of? env.startSimulation(); } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java index d0b5dfd2824..d1e902becd8 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java @@ -3,6 +3,7 @@ //import ch.sbb.matsim.contrib.railsim.rl.RLClient; import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimQSimModule; +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimRLQSimModule; import ch.sbb.matsim.contrib.railsim.rl.RLClient; import org.matsim.api.core.v01.Id; import org.matsim.api.core.v01.Scenario; @@ -65,7 +66,7 @@ public List reset(){ controler.addOverridingModule(new RailsimModule()); // if you have other extensions that provide QSim components, call their configure-method here - controler.configureQSimComponents(components -> new RailsimQSimModule().configure(components)); + controler.configureQSimComponents(components -> new RailsimRLQSimModule().configure(components)); //TODO: Fix Me: implement the method getAllTrainIds() // get all train Ids in this scenario. return getAllTrainIds(scenario); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java new file mode 100644 index 00000000000..0e6758055c8 --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java @@ -0,0 +1,43 @@ +package ch.sbb.matsim.contrib.railsim.qsimengine; + +import ch.sbb.matsim.contrib.railsim.qsimengine.deadlocks.DeadlockAvoidance; +import ch.sbb.matsim.contrib.railsim.qsimengine.deadlocks.SimpleDeadlockAvoidance; +import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.RLTrainDisposition; +import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.SimpleDisposition; +import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.TrainDisposition; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; +import ch.sbb.matsim.contrib.railsim.qsimengine.router.TrainRouter; +import com.google.inject.multibindings.OptionalBinder; +import org.matsim.core.mobsim.qsim.AbstractQSimModule; +import org.matsim.core.mobsim.qsim.components.QSimComponentsConfig; +import org.matsim.core.mobsim.qsim.components.QSimComponentsConfigurator; +import org.matsim.core.mobsim.qsim.pt.TransitDriverAgentFactory; + +public class RailsimRLQSimModule extends AbstractQSimModule implements QSimComponentsConfigurator{ + + public static final String COMPONENT_NAME = "Railsim"; + + @Override + public void configure(QSimComponentsConfig components) { + components.addNamedComponent(COMPONENT_NAME); + } + + @Override + protected void configureQSim() { + bind(RailsimQSimEngine.class).asEagerSingleton(); + + bind(TrainRouter.class).asEagerSingleton(); + bind(RailResourceManager.class).asEagerSingleton(); + + // These interfaces might be replaced with other implementations + bind(TrainDisposition.class).to(RLTrainDisposition.class).asEagerSingleton(); + bind(DeadlockAvoidance.class).to(SimpleDeadlockAvoidance.class).asEagerSingleton(); + + addQSimComponentBinding(COMPONENT_NAME).to(RailsimQSimEngine.class); + + OptionalBinder.newOptionalBinder(binder(), TransitDriverAgentFactory.class) + .setBinding().to(RailsimDriverAgentFactory.class); + } + + +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java index e5eea38514a..64c0a9907eb 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java @@ -12,6 +12,7 @@ import ch.sbb.matsim.contrib.railsim.rl.observation.ObservationTreeNode; import ch.sbb.matsim.contrib.railsim.rl.observation.StepOutput; import ch.sbb.matsim.contrib.railsim.rl.observation.TreeObservation; +import jakarta.inject.Inject; import org.apache.commons.jxpath.ri.compiler.Step; import org.matsim.api.core.v01.Coord; import org.matsim.api.core.v01.network.Link; @@ -34,6 +35,8 @@ public class RLTrainDisposition implements TrainDisposition { RLClient rlClient; Map bufferStepOutputMap; + + @Inject public RLTrainDisposition(RailResourceManager resources, TrainRouter router, Network network, RLClient rlClient) { this.resources = resources; this.router = router; diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java index e936a0fe900..136e226d83a 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java @@ -55,6 +55,7 @@ public String sendObservation(Map stepOutputMap){ .addAllObsTree(stepOutput.getObservation().getFlattenedObsTree()) .addAllPositionNextNode(stepOutput.getObservation().getPositionNextNode()) .addAllTrainState(stepOutput.getObservation().getTrainState()) + .setTimestamp(stepOutput.getObservation().getRailsim_timestamp()) .build(); // build StepOutput diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java index b26fe4a4bc3..c52111727de 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/Observation.java @@ -3,7 +3,6 @@ import java.util.*; public class Observation{ - //TODO: Add timestamp as well in the observation double railsim_timestamp; List flattenedObsTree = new ArrayList<>(); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java index bf82e115e7c..b4da093bc11 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java @@ -13,6 +13,7 @@ import java.util.Arrays; import java.util.List; import java.util.ArrayList; +import java.util.stream.Collectors; import ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils; @@ -127,14 +128,16 @@ private void createTreeObs(int depth){ RailLink bufferTipLink = getBufferTip(); // Get the toNode of the bufferTipLink - Node toNode = getToNode(bufferTipLink); - List exploreQueue = new ArrayList(); + Node toNodeBufferTipLink = getToNode(bufferTipLink); + List exploreQueue = new ArrayList<>(); + // store a list of visitedNodes to avoid infinite loop in case of cycles in the network. + List visitedNodes = new ArrayList<>(); - while(!isObsTreeNode(toNode)){ - toNode = toNode.getOutLinks().values().iterator().next().getToNode(); + while(!isObsTreeNode(toNodeBufferTipLink)){ + toNodeBufferTipLink = toNodeBufferTipLink.getOutLinks().values().iterator().next().getToNode(); } - exploreQueue.add(toNode); + exploreQueue.add(toNodeBufferTipLink); for (int i = 0; i < depth; i++) { // Level Traversal algorithm @@ -142,71 +145,82 @@ private void createTreeObs(int depth){ while (lenExploreQueue > 0) { // Level traversal Node curNode = exploreQueue.get(0); + + // Create observationTreeNode from the curNode +// TrainPosition trainF = getClosestTrainOnPathF(curNode, nextNode); +// TrainPosition trainR = getClosestTrainOnPathR(curNode, nextNode); + ObservationTreeNode obsNode = createObservatioNode(position, resources.getLink(position.getHeadLink()), curNode, null, null); + this.observationList.add(obsNode); + this.flattenedObservation.addAll(flattenObservationNode(obsNode)); + exploreQueue.remove(0); + visitedNodes.add(curNode); // Look for switches/intersections/stops on the branches stemming out of the current switch - List nextNodes = getNextNodes(curNode); + List nextNodes = getNextNodes(toNodeBufferTipLink, curNode); + // Add nextNode only if it's not already visited for (Node nextNode : nextNodes) { - exploreQueue.add(nextNode); -// TrainPosition trainF = getClosestTrainOnPathF(curNode, nextNode); -// TrainPosition trainR = getClosestTrainOnPathR(curNode, nextNode); - ObservationTreeNode obsNode = createObservatioNode(position, resources.getLink(position.getHeadLink()), nextNode, null, null); - this.observationList.add(obsNode); - this.flattenedObservation.addAll(flattenObservationNode(obsNode)); + if (!visitedNodes.contains(nextNode)){ + exploreQueue.add(nextNode); + } } lenExploreQueue -= 1; } } } - private TrainPosition getClosestTrainOnPathR(Node curNode, Node nextNode) { -// // complete route to current position of the train from schedule, if needed? -// List previousRoute = position.getRoute(0, position.getRouteIndex()); + private TrainPosition getClosestTrainOnPathR(Node curNode, Node nextNode) throws NotImplementedException{ throw new NotImplementedException(); -// // TODO: If all opposite trains are needed, follow the inLinks of curNode until the nextNode is reached. Store the of the path links. -// List path = null; -// -// // check for each link if there is capacity -// for (Link link : path) { -// RailLink railLink = resources.getLink(link.getId()); -// RailResource resource = railLink.getResource(); -// ResourceState state = resource.getState(railLink); -// // TODO: Ask Christian how we get the train position of the nearest train. -// } -// return null; } - - private TrainPosition getClosestTrainOnPathF(Node curNode, Node nextNode) { + private TrainPosition getClosestTrainOnPathF(Node curNode, Node nextNode) throws NotImplementedException { throw new NotImplementedException(); -// // complete route from current position of the train from schedule, if needed? -// List upcomingRoute = position.getRoute(position.getRouteIndex(), position.getRouteSize()); -// -// // TODO: If all opposite trains are needed, follow the outLinks of curNode until the nextNode is reached. Store the of the path links. -// List path = null; -// -// // same procedure as above... -// -// return null; } - - - private List getNextNodes(Node curNode) { + private List getNextNodes(Node toNodeBufferTipLink, Node obsNode) { // next switches List obsTreeNodes = new ArrayList<>(); - // check all outgoing links from the current node - for (Link outLink : curNode.getOutLinks().values()) { + // check all outgoing links from the current obsNode + for (Link outLink : obsNode.getOutLinks().values()) { + + boolean reverseDirection = false; Node nextNode = outLink.getToNode(); + Node prevNode = obsNode; // follow nodes with only one outgoing link until a switch or a halt is reached while (!isObsTreeNode(nextNode)) { // get the single outgoing link and follow it // TODO: Handle end of network, will throw NoSuchElementException at the moment - //TODO: Fix me. This is wrong as they nextNode in the same track can have 2 outLinks. - nextNode = nextNode.getOutLinks().values().iterator().next().getToNode(); + List outLinks = nextNode.getOutLinks().values().stream().collect(Collectors.toList()); + Link nextLink = null; + for(Link link : outLinks){ + // skip the link that leads to prevNode to avoid an infinite loop + if (link.getToNode().equals(prevNode)){ + continue; + } + else{ + nextLink = link; + break; + } + } + + // update prevNode + prevNode = nextNode; + + // update nextNode + nextNode = nextLink.getToNode(); + + if (nextNode.equals(toNodeBufferTipLink)){ + reverseDirection =true; + break; + } + } + + if (reverseDirection){ + // skip the current outLink as this link from the switchNode leads to the observing train + continue; } obsTreeNodes.add(nextNode); } diff --git a/contribs/railsim/src/main/proto/railsim.proto b/contribs/railsim/src/main/proto/railsim.proto index 55055dcc8f2..b7d680c506a 100644 --- a/contribs/railsim/src/main/proto/railsim.proto +++ b/contribs/railsim/src/main/proto/railsim.proto @@ -9,7 +9,7 @@ message ProtoObservation { repeated double obsTree = 1; repeated double trainState = 2; repeated double positionNextNode = 3; - int32 timestamp=4; + double timestamp=4; } message ProtoStepOutput { From 5ca5b4be9b80b89e7b38893463943c442bf7b6a6 Mon Sep 17 00:00:00 2001 From: rakow Date: Tue, 4 Jun 2024 09:38:10 +0200 Subject: [PATCH 08/14] separate rl disposition into own module --- .../matsim/contrib/railsim/RailsimEnv.java | 12 +++--- .../RailsimRLDispositionModule.java | 31 +++++++++++++ .../qsimengine/RailsimRLQSimModule.java | 43 ------------------- 3 files changed, 37 insertions(+), 49 deletions(-) create mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLDispositionModule.java delete mode 100644 contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java index d1e902becd8..b7b19097a7b 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java @@ -3,23 +3,21 @@ //import ch.sbb.matsim.contrib.railsim.rl.RLClient; import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimQSimModule; -import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimRLQSimModule; +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimRLDispositionModule; import ch.sbb.matsim.contrib.railsim.rl.RLClient; import org.matsim.api.core.v01.Id; import org.matsim.api.core.v01.Scenario; -import org.matsim.api.core.v01.population.Route; import org.matsim.core.config.Config; import org.matsim.core.config.ConfigUtils; import org.matsim.core.controler.Controler; import org.matsim.core.controler.OutputDirectoryHierarchy; import org.matsim.core.scenario.ScenarioUtils; import org.matsim.pt.transitSchedule.api.Departure; +import org.matsim.pt.transitSchedule.api.TransitLine; import org.matsim.pt.transitSchedule.api.TransitRoute; import org.matsim.vehicles.Vehicle; -import org.matsim.visum.VisumNetwork; -import org.matsim.pt.transitSchedule.api.TransitLine; + import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -64,9 +62,11 @@ public List reset(){ controler = new Controler(scenario); controler.addOverridingModule(new RailsimModule()); + controler.addOverridingQSimModule(new RailsimRLDispositionModule(rlClient)); // if you have other extensions that provide QSim components, call their configure-method here - controler.configureQSimComponents(components -> new RailsimRLQSimModule().configure(components)); + controler.configureQSimComponents(components -> new RailsimQSimModule().configure(components)); + //TODO: Fix Me: implement the method getAllTrainIds() // get all train Ids in this scenario. return getAllTrainIds(scenario); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLDispositionModule.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLDispositionModule.java new file mode 100644 index 00000000000..4c655035988 --- /dev/null +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLDispositionModule.java @@ -0,0 +1,31 @@ +package ch.sbb.matsim.contrib.railsim.qsimengine; + +import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.RLTrainDisposition; +import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.TrainDisposition; +import ch.sbb.matsim.contrib.railsim.rl.RLClient; +import org.matsim.core.mobsim.qsim.AbstractQSimModule; +import org.matsim.core.mobsim.qsim.components.QSimComponentsConfig; +import org.matsim.core.mobsim.qsim.components.QSimComponentsConfigurator; + +public class RailsimRLDispositionModule extends AbstractQSimModule implements QSimComponentsConfigurator { + + private final RLClient rlClient; + + public RailsimRLDispositionModule(RLClient rlClient) { + + this.rlClient = rlClient; + } + + @Override + public void configure(QSimComponentsConfig components) { + } + + @Override + protected void configureQSim() { + + bind(RLClient.class).toInstance(rlClient); + bind(TrainDisposition.class).to(RLTrainDisposition.class).asEagerSingleton(); + + + } +} diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java deleted file mode 100644 index 0e6758055c8..00000000000 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java +++ /dev/null @@ -1,43 +0,0 @@ -package ch.sbb.matsim.contrib.railsim.qsimengine; - -import ch.sbb.matsim.contrib.railsim.qsimengine.deadlocks.DeadlockAvoidance; -import ch.sbb.matsim.contrib.railsim.qsimengine.deadlocks.SimpleDeadlockAvoidance; -import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.RLTrainDisposition; -import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.SimpleDisposition; -import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.TrainDisposition; -import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; -import ch.sbb.matsim.contrib.railsim.qsimengine.router.TrainRouter; -import com.google.inject.multibindings.OptionalBinder; -import org.matsim.core.mobsim.qsim.AbstractQSimModule; -import org.matsim.core.mobsim.qsim.components.QSimComponentsConfig; -import org.matsim.core.mobsim.qsim.components.QSimComponentsConfigurator; -import org.matsim.core.mobsim.qsim.pt.TransitDriverAgentFactory; - -public class RailsimRLQSimModule extends AbstractQSimModule implements QSimComponentsConfigurator{ - - public static final String COMPONENT_NAME = "Railsim"; - - @Override - public void configure(QSimComponentsConfig components) { - components.addNamedComponent(COMPONENT_NAME); - } - - @Override - protected void configureQSim() { - bind(RailsimQSimEngine.class).asEagerSingleton(); - - bind(TrainRouter.class).asEagerSingleton(); - bind(RailResourceManager.class).asEagerSingleton(); - - // These interfaces might be replaced with other implementations - bind(TrainDisposition.class).to(RLTrainDisposition.class).asEagerSingleton(); - bind(DeadlockAvoidance.class).to(SimpleDeadlockAvoidance.class).asEagerSingleton(); - - addQSimComponentBinding(COMPONENT_NAME).to(RailsimQSimEngine.class); - - OptionalBinder.newOptionalBinder(binder(), TransitDriverAgentFactory.class) - .setBinding().to(RailsimDriverAgentFactory.class); - } - - -} From 38468318b67c78dfdf9c86fa80dc57470dae1690 Mon Sep 17 00:00:00 2001 From: iamakash Date: Tue, 4 Jun 2024 09:50:24 +0200 Subject: [PATCH 09/14] feat: implement reward function --- .../railsim/EnvironmentFactoryServer.java | 2 + .../matsim/contrib/railsim/RailsimEnv.java | 6 +- .../matsim/contrib/railsim/RailsimModule.java | 3 +- .../contrib/railsim/RunRailsimExample.java | 46 +---- .../railsim/qsimengine/RailsimEngine.java | 15 +- .../railsim/qsimengine/RailsimQSimEngine.java | 11 -- .../qsimengine/RailsimRLQSimModule.java | 15 +- .../disposition/RLTrainDisposition.java | 166 ++++++++++++++---- .../disposition/SimpleDisposition.java | 16 +- .../disposition/TrainDisposition.java | 16 +- .../contrib/railsim/rl/utils/RLUtils.java | 2 +- 11 files changed, 199 insertions(+), 99 deletions(-) diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java index 716f6d4e709..c5552bb6c7e 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java @@ -75,6 +75,8 @@ private class RailsimFactory extends RailsimFactoryGrpc.RailsimFactoryImplBase { Map envMap = new HashMap<>(); + + @Override public void getEnvironment(ProtoGrpcPort grpcPort, StreamObserver responseObserver) { // Create an instance of Railsim environment and store it in a map diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java index d1e902becd8..12f42f70dbe 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java @@ -18,6 +18,8 @@ import org.matsim.vehicles.Vehicle; import org.matsim.visum.VisumNetwork; import org.matsim.pt.transitSchedule.api.TransitLine; + +import javax.inject.Inject; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -26,6 +28,8 @@ public class RailsimEnv { RLClient rlClient; // RLClient would be needed by RailsimEngine. Controler controler; + + @Inject public RailsimEnv(RLClient rlClient){ // TODO: Pass this to RLDisposition this.rlClient = rlClient; @@ -66,7 +70,7 @@ public List reset(){ controler.addOverridingModule(new RailsimModule()); // if you have other extensions that provide QSim components, call their configure-method here - controler.configureQSimComponents(components -> new RailsimRLQSimModule().configure(components)); + controler.configureQSimComponents(components -> new RailsimRLQSimModule(rlClient).configure(components)); //TODO: Fix Me: implement the method getAllTrainIds() // get all train Ids in this scenario. return getAllTrainIds(scenario); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimModule.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimModule.java index 9d25bc4a9a9..96dc6839a3f 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimModule.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimModule.java @@ -23,6 +23,7 @@ import ch.sbb.matsim.contrib.railsim.analysis.trainstates.RailsimTrainStateControlerListener; import ch.sbb.matsim.contrib.railsim.config.RailsimConfigGroup; import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimQSimModule; +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimRLQSimModule; import com.google.inject.Singleton; import org.matsim.core.config.ConfigUtils; import org.matsim.core.controler.AbstractModule; @@ -34,7 +35,7 @@ public class RailsimModule extends AbstractModule { @Override public void install() { - installQSimModule(new RailsimQSimModule()); +// installQSimModule(new RailsimRLQSimModule(null)); ConfigUtils.addOrGetModule(getConfig(), RailsimConfigGroup.class); bind(RailsimLinkStateControlerListener.class).in(Singleton.class); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java index 11973b734cf..9793fb6fd4a 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java @@ -19,7 +19,9 @@ package ch.sbb.matsim.contrib.railsim; +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimRLQSimModule; import ch.sbb.matsim.contrib.railsim.qsimengine.TrainState; +import ch.sbb.matsim.contrib.railsim.rl.RLClient; import org.matsim.api.core.v01.Id; import org.matsim.api.core.v01.Scenario; import org.matsim.api.core.v01.network.Network; @@ -56,6 +58,9 @@ public static void main(String[] args) { Config config = ConfigUtils.loadConfig(configFilename); config.controller().setOverwriteFileSetting(OutputDirectoryHierarchy.OverwriteFileSetting.deleteDirectoryIfExists); + config.controller().setLastIteration(0); + config.controller().setDumpDataAtEnd(true); + config.controller().setCreateGraphs(false); Scenario scenario = ScenarioUtils.loadScenario(config); Controler controler = new Controler(scenario); @@ -63,47 +68,10 @@ public static void main(String[] args) { controler.addOverridingModule(new RailsimModule()); // if you have other extensions that provide QSim components, call their configure-method here - controler.configureQSimComponents(components -> new RailsimQSimModule().configure(components)); +// controler.configureQSimComponents(components -> new RailsimQSimModule().configure(components)); + controler.configureQSimComponents(components -> new RailsimRLQSimModule(new RLClient(55422)).configure(components)); controler.run(); - - // Required: List of all agents - // Schedule of arrival and departure for all the halts for all trains - -// // Transit stops -// Map, TransitStopFacility> transitStops = scenario.getTransitSchedule().getFacilities(); -// -// // Transit lines -// Map, TransitLine> transitLineMap = scenario.getTransitSchedule().getTransitLines(); - - /* - Transit line contains: - - transitRoute - - route profile : sequence of stops with their arrival and departure times - - route: sequence of links - - departures: the train ids and their corresponding departure times. - - Each stop is essentially a link - - There can be multiple transit lines - */ - - /* - Output data structure - - */ -// List transitLines = scenario.getTransitSchedule().getTransitLines().values().stream().toList(); -// -// for (TransitLine tl: transitLines){ -// List transitRoutes = tl.getRoutes().values().stream().toList(); -// for (TransitRoute tr: transitRoutes){ -// tr.getDepartures(); -// tr.getStops().get(0).getStopFacility(); -// -// } -// } - - } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java index 83a85969e6e..ae10a5f7d9c 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java @@ -141,7 +141,7 @@ public boolean handleDeparture(double now, MobsimDriverAgent agent, Id lin activeTrains.add(state); - disposition.onDeparture(now, state.driver, state.route); + disposition.onDeparture(now, state, state.route); updateQueue.add(new UpdateEvent(state, UpdateEvent.Type.DEPARTURE)); @@ -377,6 +377,9 @@ private void enterLink(double time, UpdateEvent event) { // Same event is re-scheduled after stopping, event.plannedTime = time + stopTime; + // call disposition +// disposition.onHaltDeparture(event.plannedTime, event.state); + return; } @@ -384,7 +387,7 @@ private void enterLink(double time, UpdateEvent event) { if (!event.waitingForLink && state.isRouteAtEnd()) { //call disposition - disposition.onArrival(time, event.state); + disposition.onArrival(time, event.state, true); assert FuzzyUtils.equals(state.speed, 0) : "Speed must be 0 at end, but was " + state.speed; @@ -470,6 +473,12 @@ private void leaveLink(double time, UpdateEvent event) { unblockTrack(time, state, tailLink); else updateQueue.add(new UpdateEvent(state, tailLink, time)); + + // Call disposition to calculate delay between scheduled departure and actual departure for the current stop. + if (state.isStop(state.getTailLink())){ + disposition.onStopDeparture(time,state); + } + } /** @@ -805,5 +814,7 @@ void clearTrains(double now) { eventsManager.processEvent(new VehicleAbortsEvent(now, train.driver.getVehicle().getId(), train.headLink)); eventsManager.processEvent(new PersonStuckEvent(now, train.driver.getId(), train.headLink, train.driver.getMode())); } + + disposition.onSimulationEnd(now); } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java index 5bb0bf876d1..151dcb2f2f3 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimQSimEngine.java @@ -71,16 +71,6 @@ public RailsimQSimEngine(QSim qsim, RailResourceManager res, TrainDisposition di this.agentTracker = agentTracker; } - @Inject - public RailsimQSimEngine(QSim qsim, RailResourceManager res, TrainDisposition disposition, TransitStopAgentTracker agentTracker, Network network) { - this.qsim = qsim; - this.config = ConfigUtils.addOrGetModule(qsim.getScenario().getConfig(), RailsimConfigGroup.class); - this.res = res; - this.disposition = disposition; - this.modes = config.getNetworkModes(); - this.agentTracker = agentTracker; - this.network = network; - } @Override public void setInternalInterface(InternalInterface internalInterface) { @@ -89,7 +79,6 @@ public void setInternalInterface(InternalInterface internalInterface) { @Override public void onPrepareSim() { -// engine = new RailsimEngine(qsim.getEventsManager(), config, res, disposition); engine = new RailsimEngine(qsim.getEventsManager(), config, res, disposition); } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java index 0e6758055c8..5e2ac9f14b3 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimRLQSimModule.java @@ -7,6 +7,8 @@ import ch.sbb.matsim.contrib.railsim.qsimengine.disposition.TrainDisposition; import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; import ch.sbb.matsim.contrib.railsim.qsimengine.router.TrainRouter; +import ch.sbb.matsim.contrib.railsim.rl.RLClient; +import com.google.inject.Provides; import com.google.inject.multibindings.OptionalBinder; import org.matsim.core.mobsim.qsim.AbstractQSimModule; import org.matsim.core.mobsim.qsim.components.QSimComponentsConfig; @@ -17,6 +19,10 @@ public class RailsimRLQSimModule extends AbstractQSimModule implements QSimCompo public static final String COMPONENT_NAME = "Railsim"; + RLClient rlClient; + public RailsimRLQSimModule(RLClient rlClient){ + this.rlClient = rlClient; + } @Override public void configure(QSimComponentsConfig components) { components.addNamedComponent(COMPONENT_NAME); @@ -29,8 +35,11 @@ protected void configureQSim() { bind(TrainRouter.class).asEagerSingleton(); bind(RailResourceManager.class).asEagerSingleton(); + bind(RLClient.class).toInstance(this.rlClient); + // These interfaces might be replaced with other implementations bind(TrainDisposition.class).to(RLTrainDisposition.class).asEagerSingleton(); +// bind(TrainDisposition.class).to(SimpleDisposition.class).asEagerSingleton(); bind(DeadlockAvoidance.class).to(SimpleDeadlockAvoidance.class).asEagerSingleton(); addQSimComponentBinding(COMPONENT_NAME).to(RailsimQSimEngine.class); @@ -39,5 +48,9 @@ protected void configureQSim() { .setBinding().to(RailsimDriverAgentFactory.class); } - +// @Provides +// RLClient provideRLClient() { +// // Create and return the instance of RLClient +// return this.rlClient; +// } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java index 64c0a9907eb..efd722866ed 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java @@ -13,17 +13,19 @@ import ch.sbb.matsim.contrib.railsim.rl.observation.StepOutput; import ch.sbb.matsim.contrib.railsim.rl.observation.TreeObservation; import jakarta.inject.Inject; -import org.apache.commons.jxpath.ri.compiler.Step; import org.matsim.api.core.v01.Coord; +import org.matsim.api.core.v01.Id; import org.matsim.api.core.v01.network.Link; import org.matsim.api.core.v01.network.Network; import org.matsim.api.core.v01.network.Node; import org.matsim.core.mobsim.framework.MobsimDriverAgent; +import org.matsim.pt.transitSchedule.api.Departure; +import org.matsim.pt.transitSchedule.api.TransitRoute; +import org.matsim.pt.transitSchedule.api.TransitRouteStop; +import org.matsim.vehicles.Vehicle; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; +import java.util.stream.Collectors; import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.getPathToSwitchNodeOnTrack; import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.updateRoute; @@ -36,6 +38,12 @@ public class RLTrainDisposition implements TrainDisposition { Map bufferStepOutputMap; + Map, Double> delays; + + Map, Map, List>> departureSchedule; + + Map, TrainPosition> activeTrains; + @Inject public RLTrainDisposition(RailResourceManager resources, TrainRouter router, Network network, RLClient rlClient) { this.resources = resources; @@ -43,43 +51,80 @@ public RLTrainDisposition(RailResourceManager resources, TrainRouter router, Net this.network = network; this.rlClient = rlClient; this.bufferStepOutputMap = new HashMap<>(); + this.departureSchedule = new HashMap<>(); + this.activeTrains = new HashMap<>(); + System.out.println("RLTrainDisposition created"); } - private Double getReward(TrainState train){ + private void calculateScheduleOnDeparture(TrainState train){ + assert train.getPt() != null; + + // get vehicle Id + Id trainId = train.getPt().getPlannedVehicleId(); + + // use trainSit route to get departures for the train + TransitRoute transitRoute = train.getPt().getTransitRoute(); + List departures = transitRoute.getDepartures().values().stream().collect(Collectors.toList()); + + for (Departure departure: departures){ + double routeDepartureTime = departure.getDepartureTime(); + for (TransitRouteStop stop : transitRoute.getStops()){ + Id stopLinkId = stop.getStopFacility().getLinkId(); + double offset = stop.getDepartureOffset().seconds(); + double scheduledDepartureTime = routeDepartureTime + offset; + + // store time corresponding to the train and halt + if (!this.departureSchedule.containsKey(trainId)){ + Map, List> mapLinkDepartureTime = new HashMap<>(); + departureSchedule.put(trainId, mapLinkDepartureTime); + } + if (!departureSchedule.get(trainId).containsKey(stopLinkId)){ + List listDepartureTimes = new ArrayList<>(); + departureSchedule.get(trainId).put(stopLinkId, listDepartureTimes); + } + departureSchedule.get(trainId).get(stopLinkId).add(scheduledDepartureTime); + } + } + } + + private Double getReward(){ /** - * 1. check if the train is departing from one of the halts - * 2. Get the actual departure time - * 3. Get the scheduled departure time - * 4. calculate reward - * - * TODO: What happens when the train arrives late at a halt? Does it wait for fixed amount of time as scheduled or it leaves as per the scheduled departure time if possible? - * - * TODO: What happens if the train arrives at a halt later than it's departure time? Does the train stop at all? - */ - - - return -1.0; + * 1. Calculate the sum of delays incurred + * 2. Clear the delays + * */ + + double reward = 0.0; + List departureDelays= delays.values().stream().collect(Collectors.toList()); + for (double t : departureDelays){ + reward -= t; + } + + this.delays.clear(); + return reward; } @Override - public void onDeparture(double time, MobsimDriverAgent driver, List route) { + public void onDeparture(double time, TrainPosition train, List route) { + // Update route for the train until the next switch node. - Link curLink = network.getLinks().get(driver.getCurrentLinkId()); + Link curLink = network.getLinks().get(train.getHeadLink()); getPathToSwitchNodeOnTrack(curLink, null, route); + + // calculate the scheduled departure times of the train + calculateScheduleOnDeparture((TrainState) train); + + // update the active departure list + activeTrains.put(train.getPt().getPlannedVehicleId(), train); } @Override public DispositionResponse requestNextSegment(double time, TrainPosition position, double dist) { // calculate and send StepOutput to rl - Map stepOutputMap = getStepOutput(position, false, time); + Map stepOutputMap = getStepOutput(position, getObservation(time, position), getReward(), false); + bufferStepOutputMap.putAll(stepOutputMap); + rlClient.sendObservation(bufferStepOutputMap); + bufferStepOutputMap.clear(); - if (bufferStepOutputMap.size()==0){ - rlClient.sendObservation(stepOutputMap); - } - else{ - bufferStepOutputMap.putAll(stepOutputMap); - bufferStepOutputMap.clear(); - } // get action from rl Map actionMap = rlClient.getAction(); @@ -105,6 +150,9 @@ public DispositionResponse requestNextSegment(double time, TrainPosition positio // stop the train return new DispositionResponse(0, 0, null); } + default:{ + System.out.println("Illegal action"); + } } @@ -162,20 +210,63 @@ public void unblockRailLink(double time, MobsimDriverAgent driver, RailLink link } - @Override - public void onArrival(double time, TrainPosition position) { + public void onTermination(double time, TrainPosition position){ // Store the StepOutput in bufferStepOutput. // bufferStepOutput is not sent to RL until there is an observation for a train whose done=false - bufferStepOutputMap.putAll(getStepOutput(position, true, time)); + bufferStepOutputMap.putAll(getStepOutput(position, getObservation(time, position), getReward(), true)); + } + @Override + public void onArrival(double time, TrainPosition position, Boolean terminated) { + if (Boolean.TRUE.equals(terminated)){ + onTermination(time, position); + // remove the train from active trains + activeTrains.remove(position.getPt().getPlannedVehicleId()); + } } + @Override + public void onStopDeparture(double time, TrainPosition position){ + // Calculate delays incurred by the system + // this function should be called in leaveLink() method when the train departs from a halt + Id trainId = position.getPt().getPlannedVehicleId(); + assert position.isStop(position.getHeadLink()); + List scheduledDepartureTimes = departureSchedule.get(trainId).get(position.getHeadLink()); + + Double delay = 0.0; + for (Double depTime : scheduledDepartureTimes){ + if (time - depTime > 0) { + delay = Math.min(delay, time - depTime); + } + } + + if (delay<0){ + delay = 0.0; + } + + delays.put(trainId, delay); + } + + @Override + public void onSimulationEnd(double now){ + + // get observation for all active trains and + // set all the trains having heavy negative reward + + for(TrainPosition train: activeTrains.values()){ + Map stepOutputMap = getStepOutput(train, getObservation(now, train), -100.0, true); + bufferStepOutputMap.putAll(stepOutputMap); + } + rlClient.sendObservation(bufferStepOutputMap); + bufferStepOutputMap.clear(); + + } Observation getObservation(double time, TrainPosition train){ Observation ob = new Observation(); - // get observation for each train + // get observation for the train TreeObservation treeObs = new TreeObservation((TrainState) train, this.resources, this.network, 2); List treeObsFlattened = treeObs.getFlattenedObservationTree(); List listObsNodes = treeObs.getObservationTree(); @@ -184,7 +275,7 @@ Observation getObservation(double time, TrainPosition train){ ob.setObsTree(listObsNodes); ob.setFlattenedObsTree(treeObsFlattened); - // choose the left child of the root node as the nextNode + // choose the 1st child of the root node as the nextNode for query direction Node nextNode = network.getNodes().get(listObsNodes.get(1).getNodeId()); List positionNextNode = new ArrayList<>(); positionNextNode.add(nextNode.getCoord().getX()); @@ -193,7 +284,7 @@ Observation getObservation(double time, TrainPosition train){ // set the PositionNextNode of the observation ob.setPositionNextNode(positionNextNode); - // set the state of the train - headlink fromNode Coords, headPostion, speed + // set the state of the train - headlink fromNode Coords, headPosition, speed List extractedTrainState = new ArrayList<>(); // Add headlink fromNode Coords @@ -218,15 +309,16 @@ Observation getObservation(double time, TrainPosition train){ return ob; } - private Map getStepOutput(TrainPosition train, Boolean done, double time){ + private Map getStepOutput + (TrainPosition train, Observation observation, double reward, Boolean done){ StepOutput stepOutput = new StepOutput(); stepOutput.setInfo(null); - stepOutput.setReward(getReward((TrainState) train)); + stepOutput.setReward(reward); stepOutput.setTerminated(done); stepOutput.setTruncated(done); - stepOutput.setObservation(getObservation(time, train)); + stepOutput.setObservation(observation); Map stepOutputMap= new HashMap<>(); stepOutputMap.put(train.getTrain().id().toString(), stepOutput); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java index 8369c7c0a0d..ce4d0c43730 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/SimpleDisposition.java @@ -54,13 +54,23 @@ public SimpleDisposition(RailResourceManager resources, TrainRouter router) { } @Override - public void onDeparture(double time, MobsimDriverAgent driver, List route) { + public void onDeparture(double time, TrainPosition driver, List route) { // Nothing to do. } + @Override + public void onStopDeparture(double time, TrainPosition position){ + //Nothing to do + } + + @Override + public void onSimulationEnd(double now){ + // Nothing to do + } + /** * This method tries to first calculate the links needed by the train for moving the safety_distance. - * Then for each link of the segment, it check if it can blocked completely. + * Then for each link of the segment, it checks if it can blocked completely. * Only when all the links of the segment (list of links) can be blocked, * a Response with approved distance = length of the links is returned. */ @@ -198,7 +208,7 @@ public void unblockRailLink(double time, MobsimDriverAgent driver, RailLink link } @Override - public void onArrival(double time, TrainPosition position) { + public void onArrival(double time, TrainPosition position, Boolean terminated) { } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/TrainDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/TrainDisposition.java index cbd6ba2cff5..0c3d3bdc321 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/TrainDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/TrainDisposition.java @@ -33,7 +33,7 @@ public interface TrainDisposition { /** * Method invoked when a train is departing. */ - void onDeparture(double time, MobsimDriverAgent driver, List route); + void onDeparture(double time, TrainPosition position, List route); /** * Request the next segment to be reserved. @@ -50,8 +50,18 @@ public interface TrainDisposition { void unblockRailLink(double time, MobsimDriverAgent driver, RailLink link); /** - * Method invoked when a train is arriving at rout end. + * Method invoked when a train is arriving at a stop. + * @param terminated a flag to indicate if the stop is the route end */ - void onArrival(double time, TrainPosition position); + void onArrival(double time, TrainPosition position, Boolean terminated); + + /** + * Method invoked when the train departs from stops + * @param time current time + * @param position position information + */ + public void onStopDeparture(double time, TrainPosition position); + + public void onSimulationEnd(double now); } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java index 1869ac78ee1..a3806a34f22 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java @@ -148,7 +148,7 @@ public static Boolean isSwitchable(Node switchNode, RailLink curLink, Network ne return false; } else{ - List outLinks = switchNode.getOutLinks().values().stream().toList(); + List outLinks = switchNode.getOutLinks().values().stream().collect(Collectors.toList()); List toNodeOfOutLinkList = new ArrayList<>(); for (Link link: outLinks){ toNodeOfOutLinkList.add(link.getToNode()); From c47ac932d52dfbb40e939f1ac70e914f4a0cd736 Mon Sep 17 00:00:00 2001 From: rakow Date: Tue, 4 Jun 2024 15:10:07 +0200 Subject: [PATCH 10/14] re enable railsim qsim module --- .../main/java/ch/sbb/matsim/contrib/railsim/RailsimModule.java | 3 ++- .../java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimModule.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimModule.java index af16fbb9e8d..9d25bc4a9a9 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimModule.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimModule.java @@ -22,6 +22,7 @@ import ch.sbb.matsim.contrib.railsim.analysis.linkstates.RailsimLinkStateControlerListener; import ch.sbb.matsim.contrib.railsim.analysis.trainstates.RailsimTrainStateControlerListener; import ch.sbb.matsim.contrib.railsim.config.RailsimConfigGroup; +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimQSimModule; import com.google.inject.Singleton; import org.matsim.core.config.ConfigUtils; import org.matsim.core.controler.AbstractModule; @@ -33,7 +34,7 @@ public class RailsimModule extends AbstractModule { @Override public void install() { -// installQSimModule(new RailsimRLQSimModule(null)); + installQSimModule(new RailsimQSimModule()); ConfigUtils.addOrGetModule(getConfig(), RailsimConfigGroup.class); bind(RailsimLinkStateControlerListener.class).in(Singleton.class); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java index 506ee1e8632..d3ef247fbb2 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java @@ -19,6 +19,8 @@ package ch.sbb.matsim.contrib.railsim; +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimRLDispositionModule; +import ch.sbb.matsim.contrib.railsim.rl.RLClient; import org.matsim.api.core.v01.Scenario; import org.matsim.core.config.Config; import org.matsim.core.config.ConfigUtils; @@ -55,6 +57,7 @@ public static void main(String[] args) { Controler controler = new Controler(scenario); controler.addOverridingModule(new RailsimModule()); + controler.addOverridingQSimModule(new RailsimRLDispositionModule(new RLClient(9000))); // if you have other extensions that provide QSim components, call their configure-method here controler.configureQSimComponents(components -> new RailsimQSimModule().configure(components)); From e6fe992e81d87e9b825a8b8592259754d72a2ecc Mon Sep 17 00:00:00 2001 From: iamakash Date: Tue, 4 Jun 2024 15:17:17 +0200 Subject: [PATCH 11/14] refactor: code cleanup --- .../matsim/contrib/railsim/RunRailsimExample.java | 1 + .../contrib/railsim/qsimengine/RailsimEngine.java | 4 ---- .../disposition/RLTrainDisposition.java | 15 ++++++--------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java index d3ef247fbb2..dc7db10c0bc 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RunRailsimExample.java @@ -59,6 +59,7 @@ public static void main(String[] args) { controler.addOverridingModule(new RailsimModule()); controler.addOverridingQSimModule(new RailsimRLDispositionModule(new RLClient(9000))); + // if you have other extensions that provide QSim components, call their configure-method here controler.configureQSimComponents(components -> new RailsimQSimModule().configure(components)); controler.run(); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java index ae10a5f7d9c..1e38900d546 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java @@ -376,10 +376,6 @@ private void enterLink(double time, UpdateEvent event) { // Same event is re-scheduled after stopping, event.plannedTime = time + stopTime; - - // call disposition -// disposition.onHaltDeparture(event.plannedTime, event.state); - return; } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java index efd722866ed..f369d4ef8e3 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java @@ -62,7 +62,7 @@ private void calculateScheduleOnDeparture(TrainState train){ // get vehicle Id Id trainId = train.getPt().getPlannedVehicleId(); - // use trainSit route to get departures for the train + // use trainsit route to get departures for the train TransitRoute transitRoute = train.getPt().getTransitRoute(); List departures = transitRoute.getDepartures().values().stream().collect(Collectors.toList()); @@ -123,6 +123,8 @@ public DispositionResponse requestNextSegment(double time, TrainPosition positio Map stepOutputMap = getStepOutput(position, getObservation(time, position), getReward(), false); bufferStepOutputMap.putAll(stepOutputMap); rlClient.sendObservation(bufferStepOutputMap); + + // clear the buffer after sending the stepOutput bufferStepOutputMap.clear(); @@ -214,14 +216,14 @@ public void onTermination(double time, TrainPosition position){ // Store the StepOutput in bufferStepOutput. // bufferStepOutput is not sent to RL until there is an observation for a train whose done=false bufferStepOutputMap.putAll(getStepOutput(position, getObservation(time, position), getReward(), true)); + + // remove the train from active trains + activeTrains.remove(position.getPt().getPlannedVehicleId()); } @Override public void onArrival(double time, TrainPosition position, Boolean terminated) { if (Boolean.TRUE.equals(terminated)){ onTermination(time, position); - - // remove the train from active trains - activeTrains.remove(position.getPt().getPlannedVehicleId()); } } @@ -239,11 +241,6 @@ public void onStopDeparture(double time, TrainPosition position){ delay = Math.min(delay, time - depTime); } } - - if (delay<0){ - delay = 0.0; - } - delays.put(trainId, delay); } From 7576be90d3ba201a3357788859a18f567465b155 Mon Sep 17 00:00:00 2001 From: iamakash Date: Wed, 5 Jun 2024 18:52:29 +0200 Subject: [PATCH 12/14] fix: bug fixes after integrating railsim with rl --- .../matsim/contrib/railsim/RailsimEnv.java | 3 + .../railsim/qsimengine/RailsimEngine.java | 13 +- .../disposition/RLTrainDisposition.java | 144 ++++++++------ .../matsim/contrib/railsim/rl/RLClient.java | 8 +- .../rl/observation/TreeObservation.java | 180 ++++++++++++------ .../contrib/railsim/rl/utils/RLUtils.java | 54 ++++-- .../trainNetwork.xml | 1 - .../transitSchedule.xml | 18 +- 8 files changed, 270 insertions(+), 151 deletions(-) diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java index 317d13e6e79..b352d99ef4e 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java @@ -58,6 +58,9 @@ public List reset(){ Config config = ConfigUtils.loadConfig(configFilename); config.controller().setOverwriteFileSetting(OutputDirectoryHierarchy.OverwriteFileSetting.deleteDirectoryIfExists); + config.controller().setLastIteration(0); + config.controller().setDumpDataAtEnd(true); + config.controller().setCreateGraphs(false); Scenario scenario = ScenarioUtils.loadScenario(config); controler = new Controler(scenario); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java index 1e38900d546..803a1a69438 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java @@ -65,16 +65,6 @@ public class RailsimEngine implements Steppable { private final RailResourceManager resources; private final TrainDisposition disposition; - // Overloaded constructor to be used when using RL based inference -// public RailsimEngine(EventsManager eventsManager, RailsimConfigGroup config, RailResourceManager resources, TrainDisposition disposition, Network network, RLClient rlClient) { -// this.eventsManager = eventsManager; -// this.config = config; -// this.resources = resources; -// this.disposition = disposition; -// this.network = network; -// this.rlClient=rlClient; -// } - public RailsimEngine(EventsManager eventsManager, RailsimConfigGroup config, RailResourceManager resources, TrainDisposition disposition) { this.eventsManager = eventsManager; this.config = config; @@ -161,7 +151,6 @@ private void createEvent(Event event) { this.eventsManager.processEvent(event); } -// TODO: Where is UNBLOCK_LINK event being created? private void updateState(double time, UpdateEvent event) { // Do different updates depending on the type @@ -265,7 +254,6 @@ private void checkTrackReservation(double time, UpdateEvent event) { } } -// TODO: More clarity needed on how stopTime is calculated private void updateDeparture(double time, UpdateEvent event) { TrainState state = event.state; @@ -811,6 +799,7 @@ void clearTrains(double now) { eventsManager.processEvent(new PersonStuckEvent(now, train.driver.getId(), train.headLink, train.driver.getMode())); } + System.out.println("end simulation called"); disposition.onSimulationEnd(now); } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java index f369d4ef8e3..08eb4b0567b 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java @@ -27,8 +27,7 @@ import java.util.*; import java.util.stream.Collectors; -import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.getPathToSwitchNodeOnTrack; -import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.updateRoute; +import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.*; public class RLTrainDisposition implements TrainDisposition { RailResourceManager resources; @@ -53,6 +52,7 @@ public RLTrainDisposition(RailResourceManager resources, TrainRouter router, Net this.bufferStepOutputMap = new HashMap<>(); this.departureSchedule = new HashMap<>(); this.activeTrains = new HashMap<>(); + this.delays = new HashMap<>(); System.out.println("RLTrainDisposition created"); } @@ -67,22 +67,38 @@ private void calculateScheduleOnDeparture(TrainState train){ List departures = transitRoute.getDepartures().values().stream().collect(Collectors.toList()); for (Departure departure: departures){ - double routeDepartureTime = departure.getDepartureTime(); - for (TransitRouteStop stop : transitRoute.getStops()){ - Id stopLinkId = stop.getStopFacility().getLinkId(); - double offset = stop.getDepartureOffset().seconds(); - double scheduledDepartureTime = routeDepartureTime + offset; - - // store time corresponding to the train and halt - if (!this.departureSchedule.containsKey(trainId)){ - Map, List> mapLinkDepartureTime = new HashMap<>(); - departureSchedule.put(trainId, mapLinkDepartureTime); + if (departure.getVehicleId().equals(trainId)){ + double routeDepartureTime = departure.getDepartureTime(); + + for (int i=0; i stopLinkId = stop.getStopFacility().getLinkId(); + + double offset = 0.0; + if (i==transitRoute.getStops().size()-1){ + // for the last stop, use arrival offset + offset = stop.getArrivalOffset().seconds(); + }else{ + // this would work fine all stops except the last stop + // as the last stop does not have any departure offset + offset = stop.getDepartureOffset().seconds(); + } + // calculate scheduled departure time + double scheduledDepartureTime = routeDepartureTime + offset; + + // store time scheduled departure time corresponding to the train and halt + if (!this.departureSchedule.containsKey(trainId)){ + Map, List> mapLinkDepartureTime = new HashMap<>(); + departureSchedule.put(trainId, mapLinkDepartureTime); + } + if (!departureSchedule.get(trainId).containsKey(stopLinkId)){ + List listDepartureTimes = new ArrayList<>(); + departureSchedule.get(trainId).put(stopLinkId, listDepartureTimes); + } + departureSchedule.get(trainId).get(stopLinkId).add(scheduledDepartureTime); } - if (!departureSchedule.get(trainId).containsKey(stopLinkId)){ - List listDepartureTimes = new ArrayList<>(); - departureSchedule.get(trainId).put(stopLinkId, listDepartureTimes); - } - departureSchedule.get(trainId).get(stopLinkId).add(scheduledDepartureTime); + } } } @@ -108,7 +124,8 @@ public void onDeparture(double time, TrainPosition train, List route) // Update route for the train until the next switch node. Link curLink = network.getLinks().get(train.getHeadLink()); - getPathToSwitchNodeOnTrack(curLink, null, route); + + getPathToSwitchNodeOnTrack(curLink, null, route, resources); // calculate the scheduled departure times of the train calculateScheduleOnDeparture((TrainState) train); @@ -119,43 +136,48 @@ public void onDeparture(double time, TrainPosition train, List route) @Override public DispositionResponse requestNextSegment(double time, TrainPosition position, double dist) { - // calculate and send StepOutput to rl - Map stepOutputMap = getStepOutput(position, getObservation(time, position), getReward(), false); - bufferStepOutputMap.putAll(stepOutputMap); - rlClient.sendObservation(bufferStepOutputMap); - - // clear the buffer after sending the stepOutput - bufferStepOutputMap.clear(); - - - // get action from rl - Map actionMap = rlClient.getAction(); - - //update route based on the action from rl - int action = actionMap.get(position.getTrain().id().toString()); +// // calculate and send StepOutput to rl + Observation ob = getObservation(time, position); + + RailLink bufferTipLink = getBufferTip(resources, (TrainState) position); + Node toNodeBufferTipLink = network.getLinks().get(bufferTipLink.getLinkId()).getToNode(); + if(toNodeBufferTipLink.getOutLinks().size()>=3){ + // RL should be called only when the buffer is about to reach the decision node + Map stepOutputMap = getStepOutput(position, ob, getReward(), false); + bufferStepOutputMap.putAll(stepOutputMap); + rlClient.sendObservation(bufferStepOutputMap); + // clear the buffer after sending the stepOutput + bufferStepOutputMap.clear(); + + // get action from rl + Map actionMap = rlClient.getAction(); + + //update route based on the action from rl + int action = actionMap.get(position.getPt().getPlannedVehicleId().toString()); + + StepOutput out = stepOutputMap.get(position.getPt().getPlannedVehicleId().toString()); + switch (action){ + case 0:{ + // update route in the query direction + Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(1).getNodeId()); + updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); + break; + } + case 1:{ + // update route in the other direction + Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(2).getNodeId()); + updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); + break; + } + case 2:{ + // stop the train + return new DispositionResponse(0, 0, null); + } + default:{ + System.out.println("Illegal action"); + } - StepOutput out = stepOutputMap.get(position.getTrain().id().toString()); - switch (action){ - case 0:{ - // update route in the query direction - Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(1).getNodeId()); - updateRoute(network, (TrainState) position, nextSwitchNodePos); - break; - } - case 1:{ - // update route in the other direction - Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(2).getNodeId()); - updateRoute(network, (TrainState) position, nextSwitchNodePos); - break; - } - case 2:{ - // stop the train - return new DispositionResponse(0, 0, null); } - default:{ - System.out.println("Illegal action"); - } - } RailLink currentLink = resources.getLink(position.getHeadLink()); @@ -246,7 +268,6 @@ public void onStopDeparture(double time, TrainPosition position){ @Override public void onSimulationEnd(double now){ - // get observation for all active trains and // set all the trains having heavy negative reward @@ -272,8 +293,17 @@ Observation getObservation(double time, TrainPosition train){ ob.setObsTree(listObsNodes); ob.setFlattenedObsTree(treeObsFlattened); - // choose the 1st child of the root node as the nextNode for query direction - Node nextNode = network.getNodes().get(listObsNodes.get(1).getNodeId()); + + Node nextNode = null; + if (listObsNodes.size()>1){ + // choose the 1st child of the root node as the nextNode for query direction + nextNode = network.getNodes().get(listObsNodes.get(1).getNodeId()); + }else{ + // if there are no children of the root node then make root as the nextNode for query + nextNode = network.getNodes().get(listObsNodes.get(0).getNodeId()); + } + + // get coords of the query node List positionNextNode = new ArrayList<>(); positionNextNode.add(nextNode.getCoord().getX()); positionNextNode.add(nextNode.getCoord().getY()); @@ -318,7 +348,7 @@ Observation getObservation(double time, TrainPosition train){ stepOutput.setObservation(observation); Map stepOutputMap= new HashMap<>(); - stepOutputMap.put(train.getTrain().id().toString(), stepOutput); + stepOutputMap.put(train.getPt().getPlannedVehicleId().toString(), stepOutput); return stepOutputMap; } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java index 136e226d83a..c7154733556 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java @@ -15,8 +15,11 @@ import java.util.logging.Level; import java.util.logging.Logger; +import com.google.protobuf.Empty; + + public class RLClient { - private final RailsimConnecterGrpc.RailsimConnecterBlockingStub blockingStub; + private RailsimConnecterGrpc.RailsimConnecterBlockingStub blockingStub=null; private static final Logger logger = Logger.getLogger(RLClient.class.getName()); public RLClient(int port){ String target = "localhost:"+port; @@ -31,7 +34,8 @@ public Map getAction(){ try { // Call the original method on the server. - actionMap = blockingStub.getAction(null); + Empty request = Empty.newBuilder().build(); + actionMap = blockingStub.getAction(request); } catch (StatusRuntimeException e) { // Log a warning if the RPC fails. logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java index b4da093bc11..aee72ac5b02 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java @@ -6,17 +6,22 @@ import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; import org.apache.commons.lang3.NotImplementedException; +import org.matsim.api.core.v01.Id; import org.matsim.api.core.v01.network.Link; import org.matsim.api.core.v01.network.Network; import org.matsim.api.core.v01.network.Node; -import java.util.Arrays; -import java.util.List; -import java.util.ArrayList; +import java.util.*; import java.util.stream.Collectors; import ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils; +import org.matsim.pt.transitSchedule.api.Departure; +import org.matsim.pt.transitSchedule.api.TransitLine; +import org.matsim.pt.transitSchedule.api.TransitRoute; +import org.matsim.pt.transitSchedule.api.TransitRouteStop; +import org.matsim.vehicles.Vehicle; +import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.getBufferTip; import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.isSwitchable; class OtherAgent { @@ -45,8 +50,10 @@ public class TreeObservation { private final Network network; private List observationList; private List flattenedObservation; - int depth; + Map, List>> transitRouteStops; + + public TreeObservation(TrainState position, RailResourceManager resources, Network network, int depth) { this.resources = resources; this.position = position; @@ -54,30 +61,61 @@ public TreeObservation(TrainState position, RailResourceManager resources, Netwo this.observationList = new ArrayList<>(); this.flattenedObservation= new ArrayList<>(); this.depth = depth; + this.transitRouteStops = new HashMap<>(); createTreeObs(this.depth); - } - - private RailLink getBufferTip() { - double reserveDist = RailsimCalc.calcReservationDistance(position, resources.getLink(position.getHeadLink())); - RailLink currentLink = resources.getLink(position.getHeadLink()); - List reservedSegment = RailsimCalc.calcLinksToBlock(position, currentLink, reserveDist); - // TODO Verify if the links are added in the list in the sequence of occurence. - RailLink bufferTip = reservedSegment.get(reservedSegment.size() - 1); - return bufferTip; } + private void getStopsOnTransitLine(TransitRoute transitRoute){ + // Calculate the stops on a particular transit line and store it in a map + List> stops = new ArrayList<>(); + // get any departure on the transitline + Departure departure = transitRoute.getDepartures().values().stream().collect(Collectors.toList()).get(0); + for (TransitRouteStop stop: transitRoute.getStops() ){ + stops.add(stop.getStopFacility().getLinkId()); + } + this.transitRouteStops.put(transitRoute.getId(), stops); + } private boolean isObsTreeNode(Node toNode){ - // Check if the node is halt position of train + // Boolean to check if the toNode is halt position of train boolean isStop = false; - //TODO: This would fail if more than one halts lie in the observation tree. The code would recognise just one. - Node nextHaltToNode = network.getLinks().get(position.getPt().getNextTransitStop().getLinkId()).getToNode(); - if (nextHaltToNode.equals(toNode)) - isStop = true; - // Check if the node is a switch node. Each node will have minimum two outgoing nodes. + // get vehicle Id + Id trainId = position.getPt().getPlannedVehicleId(); + + // use transit route to get departures for the train + TransitRoute transitRoute = position.getPt().getTransitRoute(); + + // calculate stops in this route + if (!transitRouteStops.containsKey(transitRoute.getId())){ + getStopsOnTransitLine(transitRoute); + } + List> stopsList = transitRouteStops.get(transitRoute.getId()); + + // Iterate through all the stops of the transitLine + for (Id linkId : stopsList){ + // if toNode(stopLink) == toNode, mark the toNode to be a halt + if (network.getLinks().get(linkId).getToNode().equals(toNode)){ + isStop = true; + break; + } + } + +// if (position.getPt().getNextTransitStop() != null){ +// // ensure that there is a nextTransitStop. On the final destination +// // nextTransitStop is null +// Node nextHaltToNode = network.getLinks().get(position.getPt().getNextTransitStop().getLinkId()).getToNode(); +// if (nextHaltToNode.equals(toNode)) +// isStop = true; +// } +// +// // toNode is the end of the track with just one outLink in opposite direction +// if (toNode.getOutLinks().size()==1) +// isStop = true; + + // Check if the node is a switch node. A switch node must have more than . boolean switch_node = toNode.getOutLinks().size() > 2; if (isStop || switch_node){ return true; @@ -95,12 +133,16 @@ private ObservationTreeNode createObservatioNode(TrainPosition train, RailLink c List toNodeCurLinkPosition = new ArrayList(Arrays.asList(toNodeCurLink.getCoord().getX(), toNodeCurLink.getCoord().getY())); //calculate distance of train to nextNode - double distNodeAgent = train.getHeadPosition()+RLUtils.calculateEuclideanDistance(nodePosition, toNodeCurLinkPosition); - - Node nextHaltToNode = network.getLinks().get(position.getPt().getNextTransitStop().getLinkId()).getToNode(); - List nextHaltToNodePosition = new ArrayList(Arrays.asList(nextHaltToNode.getCoord().getX(), nextHaltToNode.getCoord().getY())); + double distNodeAgent = curLink.length - train.getHeadPosition()+RLUtils.calculateEuclideanDistance(nodePosition, toNodeCurLinkPosition); + + double distNextHalt = 0; + if (position.getPt().getNextTransitStop() != null){ + // next transit stop will be null at the end of the route + Node nextHaltToNode = network.getLinks().get(position.getPt().getNextTransitStop().getLinkId()).getToNode(); + List nextHaltToNodePosition = new ArrayList(Arrays.asList(nextHaltToNode.getCoord().getX(), nextHaltToNode.getCoord().getY())); + distNextHalt = RLUtils.calculateEuclideanDistance(nodePosition, nextHaltToNodePosition); + } - double distNextHalt = RLUtils.calculateEuclideanDistance(nodePosition, nextHaltToNodePosition); int isSwitchable = isSwitchable(node, curLink, network) ? 1 : 0; @@ -125,39 +167,49 @@ private List flattenObservationNode(ObservationTreeNode obsNode) { private void createTreeObs(int depth){ // Get the link of the tip of the buffer - RailLink bufferTipLink = getBufferTip(); + RailLink bufferTipLink = getBufferTip(resources, position); + if (bufferTipLink == null){ + // if there is no buffer because there is no reserved distance then + // assume bufferTipLink to be the headLink + bufferTipLink = resources.getLink(position.getHeadLink()); + } - // Get the toNode of the bufferTipLink + // Get the toNode and fromNode of the bufferTipLink Node toNodeBufferTipLink = getToNode(bufferTipLink); + Node fromNodeBufferTipLink = network.getLinks().get(bufferTipLink.getLinkId()).getFromNode(); + List exploreQueue = new ArrayList<>(); // store a list of visitedNodes to avoid infinite loop in case of cycles in the network. List visitedNodes = new ArrayList<>(); - while(!isObsTreeNode(toNodeBufferTipLink)){ - toNodeBufferTipLink = toNodeBufferTipLink.getOutLinks().values().iterator().next().getToNode(); - } + // the added node may not necessarily be an observation tree node exploreQueue.add(toNodeBufferTipLink); + // If the "toNodeBufferTipLink" is not an observation tree node then increase the depth by 1 + if(!isObsTreeNode(toNodeBufferTipLink)) + depth +=1; + for (int i = 0; i < depth; i++) { // Level Traversal algorithm int lenExploreQueue = exploreQueue.size(); while (lenExploreQueue > 0) { // Level traversal Node curNode = exploreQueue.get(0); + exploreQueue.remove(0); // Create observationTreeNode from the curNode -// TrainPosition trainF = getClosestTrainOnPathF(curNode, nextNode); -// TrainPosition trainR = getClosestTrainOnPathR(curNode, nextNode); - ObservationTreeNode obsNode = createObservatioNode(position, resources.getLink(position.getHeadLink()), curNode, null, null); - this.observationList.add(obsNode); - this.flattenedObservation.addAll(flattenObservationNode(obsNode)); - - exploreQueue.remove(0); + if(isObsTreeNode(curNode)){ + // if the exploredNode fits the criteria of observation tree node, then + // add the node in the ObservationTree list + ObservationTreeNode obsNode = createObservatioNode(position, resources.getLink(position.getHeadLink()), curNode, null, null); + this.observationList.add(obsNode); + this.flattenedObservation.addAll(flattenObservationNode(obsNode)); + } visitedNodes.add(curNode); // Look for switches/intersections/stops on the branches stemming out of the current switch - List nextNodes = getNextNodes(toNodeBufferTipLink, curNode); + List nextNodes = getNextNodes(fromNodeBufferTipLink, curNode); // Add nextNode only if it's not already visited for (Node nextNode : nextNodes) { @@ -177,7 +229,7 @@ private TrainPosition getClosestTrainOnPathF(Node curNode, Node nextNode) throws throw new NotImplementedException(); } - private List getNextNodes(Node toNodeBufferTipLink, Node obsNode) { + private List getNextNodes(Node fromNodeBufferTipLink, Node obsNode) { // next switches List obsTreeNodes = new ArrayList<>(); @@ -188,37 +240,49 @@ private List getNextNodes(Node toNodeBufferTipLink, Node obsNode) { Node nextNode = outLink.getToNode(); Node prevNode = obsNode; - // follow nodes with only one outgoing link until a switch or a halt is reached + + // traverse the track until a halt/final destination/switch is reached while (!isObsTreeNode(nextNode)) { // get the single outgoing link and follow it // TODO: Handle end of network, will throw NoSuchElementException at the moment - - List outLinks = nextNode.getOutLinks().values().stream().collect(Collectors.toList()); - Link nextLink = null; - for(Link link : outLinks){ - // skip the link that leads to prevNode to avoid an infinite loop - if (link.getToNode().equals(prevNode)){ - continue; - } - else{ - nextLink = link; - break; - } + if (nextNode.equals(fromNodeBufferTipLink)){ + reverseDirection =true; + break; } + List outLinks = nextNode.getOutLinks().values().stream().collect(Collectors.toList()); - // update prevNode - prevNode = nextNode; + // At the final stop, there will be just 1 outlink + Link nextLink = outLinks.get(0); + + if (outLinks.size()>1){ + // update the nextLink to point the right direction if there are more than 1 outlink + for(Link link : outLinks){ + // skip the link that leads to prevNode to avoid an infinite loop + if (link.getToNode().equals(prevNode)){ + continue; + } + else{ + nextLink = link; + break; + } + } + // update prevNode + prevNode = nextNode; - // update nextNode - nextNode = nextLink.getToNode(); + // update nextNode + nextNode = nextLink.getToNode(); - if (nextNode.equals(toNodeBufferTipLink)){ - reverseDirection =true; + if (nextNode.equals(fromNodeBufferTipLink)){ + reverseDirection =true; + break; + } + }else{ + // nextNode is the final stop with just 1 outgoing link break; } } - if (reverseDirection){ + if (reverseDirection || nextNode.equals(fromNodeBufferTipLink)){ // skip the current outLink as this link from the switchNode leads to the observing train continue; } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java index a3806a34f22..f162fa2db27 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java @@ -1,7 +1,10 @@ package ch.sbb.matsim.contrib.railsim.rl.utils; +import ch.sbb.matsim.contrib.railsim.qsimengine.RailsimCalc; +import ch.sbb.matsim.contrib.railsim.qsimengine.TrainPosition; import ch.sbb.matsim.contrib.railsim.qsimengine.TrainState; import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailLink; +import ch.sbb.matsim.contrib.railsim.qsimengine.resources.RailResourceManager; import org.matsim.api.core.v01.events.Event; import org.matsim.api.core.v01.network.Link; import org.matsim.api.core.v01.network.Network; @@ -19,18 +22,29 @@ public static Node getToNode(Network network, RailLink link) { return network.getLinks().get(link.getLinkId()).getToNode(); } -// TODO: implement the functions: getSwitchNodeOnTrack and updateRoute - public static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, List path){ - /* - get links from a node to the next switch node on the same track. - l0 St l1 l2 l3 l4 Sw - path = l1, l2, l3, l4 + public static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, List path, RailResourceManager resources){ + + // In the beginning there will be just one link (corresponding to the start link) + // specified in the route of each train. However, the route of the train has a duplicated entry + // of the entry link. Therefore, min(path.size()) = 2 at the time of start of departure. + if (path.size()==2 && path.get(0).equals(path.get(1))){ + // remove the duplicate entry link for the train + path.remove(1); + } + +// get links from a node to the next switch node on the same track. +// l0 St l1 l2 l3 l4 Sw +// path = l1, l2, l3, l4 +// +// If target node found in the path, return true else false - If target node found in the path, return true else false - */ Node start = curLink.getToNode(); - while ( start.getOutLinks().values().size() <= 2){ +// Node start = network.getLinks().get(curLink.getLinkId()).getToNode(); + + // at all points in the track start node will have two outgoing links except at the final goal point + // where outgoing link will be 1 or at a switch node where the outgoing links will be more than 2. + while ( start.getOutLinks().values().size() == 2){ Link nextLink = null; // get outLinks from the start node @@ -50,7 +64,10 @@ public static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, Lis break; } } - path.add(new RailLink(nextLink)); + + // convert Link to Raillink + RailLink convertedNextLink = resources.getLink(nextLink.getId()); + path.add(convertedNextLink); // update start node and curLink start = nextLink.getToNode(); @@ -63,7 +80,7 @@ public static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, Lis // no path found to the target return false; } - public static void updateRoute(Network network, TrainState train, Node nextObsNode){ + public static void updateRoute(Network network, TrainState train, Node nextObsNode, RailResourceManager resources){ // Get the last link in the route RailLink lastLinkInRoute = train.route.get(train.route.size() -1); @@ -87,7 +104,7 @@ public static void updateRoute(Network network, TrainState train, Node nextObsNo // To store the path from lastLinkInRoute to node connecting the nextObsNode path = new ArrayList<>(); - if (getPathToSwitchNodeOnTrack(network.getLinks().get(lastLinkInRoute.getLinkId()), nextObsNode, path)) + if (getPathToSwitchNodeOnTrack(network.getLinks().get(lastLinkInRoute.getLinkId()), nextObsNode, path, resources)) break; } @@ -169,4 +186,17 @@ public static Boolean isSwitchable(Node switchNode, RailLink curLink, Network ne } } + public static RailLink getBufferTip(RailResourceManager resources, TrainState position ) { + double reserveDist = RailsimCalc.calcReservationDistance(position, resources.getLink(position.getHeadLink())); + RailLink currentLink = resources.getLink(position.getHeadLink()); + List reservedSegment = RailsimCalc.calcLinksToBlock(position, currentLink, reserveDist); + // if no track could be reserved return null + if (reservedSegment.size()==0) + return null; + else{ + RailLink bufferTip = reservedSegment.get(reservedSegment.size() - 1); + return bufferTip; + } + } + } diff --git a/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/trainNetwork.xml b/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/trainNetwork.xml index 1f3501ba109..46f01336767 100644 --- a/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/trainNetwork.xml +++ b/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/trainNetwork.xml @@ -30,7 +30,6 @@ - diff --git a/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/transitSchedule.xml b/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/transitSchedule.xml index 5d12af79465..cd323479a1e 100644 --- a/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/transitSchedule.xml +++ b/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/transitSchedule.xml @@ -27,10 +27,10 @@ - - - - + + + + @@ -49,10 +49,10 @@ - - - - + + + + @@ -61,4 +61,4 @@ - \ No newline at end of file + From 198e1e3b81d03cc0ae53ac65318f0a255f43eade Mon Sep 17 00:00:00 2001 From: iamakash Date: Sat, 8 Jun 2024 02:04:43 +0200 Subject: [PATCH 13/14] fix: observation tree bugs --- .../railsim/EnvironmentFactoryServer.java | 28 ++- .../matsim/contrib/railsim/RailsimEnv.java | 9 +- .../disposition/RLTrainDisposition.java | 132 +++++++--- .../matsim/contrib/railsim/rl/RLClient.java | 24 +- .../rl/observation/TreeObservation.java | 238 +++++++----------- .../contrib/railsim/rl/utils/RLUtils.java | 120 ++++----- contribs/railsim/src/main/proto/railsim.proto | 2 + .../microJunctionY/transitSchedule.xml | 20 +- 8 files changed, 302 insertions(+), 271 deletions(-) diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java index c5552bb6c7e..3707a5a3d04 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java @@ -8,6 +8,10 @@ import ch.sbb.matsim.contrib.railsim.grpc.ProtoGrpcPort; import ch.sbb.matsim.contrib.railsim.grpc.RailsimFactoryGrpc; import ch.sbb.matsim.contrib.railsim.grpc.ProtoAgentIDs; +import org.matsim.api.core.v01.Scenario; +import org.matsim.core.config.Config; +import org.matsim.core.config.ConfigUtils; +import org.matsim.core.scenario.ScenarioUtils; import java.io.IOException; @@ -99,9 +103,14 @@ public void getEnvironment(ProtoGrpcPort grpcPort, StreamObserver responseObserver) { System.out.println("Reset env id: "+grpcPort); + String configFilename = "/Users/akashsinha/Documents/SBB/matsim-libs/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/config.xml"; +// String configFilename = "/Users/akashsinha/Documents/SBB/matsim-libs/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microJunctionY/config.xml"; +// String configFilename = "/Users/akashsinha/Documents/SBB/matsim-libs/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microStationRerouting/config.xml"; + + // fetch the object from map and reset it RailsimEnv env = this.envMap.get(grpcPort.getGrpcPort()); - List agentIds = env.reset(); + List agentIds = env.reset(configFilename); //Create response using agentIds ProtoAgentIDs response = ProtoAgentIDs.newBuilder() @@ -119,6 +128,23 @@ public void resetEnv(ProtoGrpcPort grpcPort, StreamObserver respo env.startSimulation(); } + + public void getAgentIds(ProtoGrpcPort grpcPort, StreamObserver responseObserver){ + String configFilename = "/Users/akashsinha/Documents/SBB/matsim-libs/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/config.xml"; + Config config = ConfigUtils.loadConfig(configFilename); + Scenario scenario = ScenarioUtils.loadScenario(config); + RailsimEnv env = this.envMap.get(grpcPort.getGrpcPort()); + + List agentIds = env.getAllTrainIds(scenario); + //Create response using agentIds + ProtoAgentIDs response = ProtoAgentIDs.newBuilder() + .addAllAgentId(agentIds) + .build(); + + // Send the reply back to the client. + responseObserver.onNext(response); + responseObserver.onCompleted(); + } } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java index b352d99ef4e..654c271b668 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java @@ -31,7 +31,7 @@ public RailsimEnv(RLClient rlClient){ this.rlClient = rlClient; } - private List getAllTrainIds(Scenario scenario){ + public List getAllTrainIds(Scenario scenario){ List trainIds = new ArrayList<>(); @@ -49,12 +49,7 @@ private List getAllTrainIds(Scenario scenario){ return trainIds; } - public List reset(){ - - // start the simulation - // pass the observation to the RLClient - - String configFilename = "/Users/akashsinha/Documents/SBB/matsim-libs/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microTrackOppositeTrafficMany/config.xml"; + public List reset(String configFilename){ Config config = ConfigUtils.loadConfig(configFilename); config.controller().setOverwriteFileSetting(OutputDirectoryHierarchy.OverwriteFileSetting.deleteDirectoryIfExists); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java index 08eb4b0567b..4699b19d5d1 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java @@ -25,9 +25,11 @@ import org.matsim.vehicles.Vehicle; import java.util.*; +import java.util.concurrent.*; import java.util.stream.Collectors; import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.*; +import static java.util.concurrent.Executors.*; public class RLTrainDisposition implements TrainDisposition { RailResourceManager resources; @@ -125,7 +127,7 @@ public void onDeparture(double time, TrainPosition train, List route) // Update route for the train until the next switch node. Link curLink = network.getLinks().get(train.getHeadLink()); - getPathToSwitchNodeOnTrack(curLink, null, route, resources); + getPathToSwitchNodeOnTrack(curLink, null, route, resources, network); // calculate the scheduled departure times of the train calculateScheduleOnDeparture((TrainState) train); @@ -136,47 +138,103 @@ public void onDeparture(double time, TrainPosition train, List route) @Override public DispositionResponse requestNextSegment(double time, TrainPosition position, double dist) { -// // calculate and send StepOutput to rl - Observation ob = getObservation(time, position); + Id trainId = position.getPt().getPlannedVehicleId(); RailLink bufferTipLink = getBufferTip(resources, (TrainState) position); - Node toNodeBufferTipLink = network.getLinks().get(bufferTipLink.getLinkId()).getToNode(); - if(toNodeBufferTipLink.getOutLinks().size()>=3){ - // RL should be called only when the buffer is about to reach the decision node - Map stepOutputMap = getStepOutput(position, ob, getReward(), false); - bufferStepOutputMap.putAll(stepOutputMap); - rlClient.sendObservation(bufferStepOutputMap); - // clear the buffer after sending the stepOutput - bufferStepOutputMap.clear(); - // get action from rl - Map actionMap = rlClient.getAction(); + // outTracks contain list of outLinks from bufferTipLink. The list does not contain the reverse link of bufferTipLink + List outTracks = getOutTracks(network.getLinks().get(bufferTipLink.getLinkId()), network); - //update route based on the action from rl - int action = actionMap.get(position.getPt().getPlannedVehicleId().toString()); - - StepOutput out = stepOutputMap.get(position.getPt().getPlannedVehicleId().toString()); - switch (action){ - case 0:{ - // update route in the query direction - Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(1).getNodeId()); - updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); - break; - } - case 1:{ - // update route in the other direction - Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(2).getNodeId()); - updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); - break; - } - case 2:{ - // stop the train - return new DispositionResponse(0, 0, null); - } - default:{ - System.out.println("Illegal action"); - } +// if(outTracks.size()>=2){ + if(true){ + // RL should be called only when the train's safety buffer is about to reach the decision node + // calculate and send StepOutput to rl + Observation ob = getObservation(time, position); + Map stepOutputMap = getStepOutput(position, ob, getReward(), false); + if (activeTrains.containsKey(trainId)){ + bufferStepOutputMap.putAll(stepOutputMap); + rlClient.sendObservation(bufferStepOutputMap); + + + // start a new thread and send observation through rlClient +// Thread rlSendObservationThread = new Thread(() -> rlClient.sendObservation(bufferStepOutputMap)); +// rlSendObservationThread.start(); +// try { +// // Wait for the rlSendObservationThread to finish +// rlSendObservationThread.join(); +// } catch (InterruptedException e) { +// e.printStackTrace(); +// } + // clear the buffer after sending the stepOutput + bufferStepOutputMap.clear(); + + // start a new thread ot get action from rl + Map actionMap = new HashMap<>(); + rlClient.getAction(actionMap); +// Thread rlgetActionThread = new Thread(() -> rlClient.getAction(actionMap)); +// rlgetActionThread.start(); +// try { +// // Wait for the rlgetActionThread to finish +// rlgetActionThread.join(); +// } catch (InterruptedException e) { +// e.printStackTrace(); +// } + +// rlClient.getAction(actionMap); + int action = actionMap.get(trainId); +// int action = actionMap.values().iterator().next(); + + +// // Create a single-threaded executor +// ExecutorService executorService = newSingleThreadExecutor(); +// +// // Create a Callable task +// Callable> task = () -> rlClient.getAction(); +// +// // Submit the task to the executor service and get a Future object +// Future> futureResult = executorService.submit(task); +// Map result = null; +// try { +// // Wait for the task to complete and get the result +// result = futureResult.get(); +// System.out.println("task returned: " + result); +// } catch (InterruptedException | ExecutionException e) { +// e.printStackTrace(); +//// executorService.shutdown(); +// } +// finally { +// // Shutdown the executor service +// executorService.shutdown(); +// } + + //update route based on the action from rl +// int action = result.get(trainId); + +// StepOutput out = stepOutputMap.get(position.getPt().getPlannedVehicleId().toString()); + StepOutput out = stepOutputMap.values().iterator().next(); +// switch (action){ +// case 0:{ +// // update route in the query direction +// Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(1).getNodeId()); +// updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); +// break; +// } +// case 1:{ +// // update route in the other direction +// Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(2).getNodeId()); +// updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); +// break; +// } +// case 2:{ +// // stop the train +// return new DispositionResponse(0, 0, null); +// } +// default:{ +// System.out.println("Illegal action"); +// } +// +// } } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java index c7154733556..e6a7116e64f 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java @@ -28,9 +28,9 @@ public RLClient(int port){ blockingStub = RailsimConnecterGrpc.newBlockingStub(channel); } - public Map getAction(){ + public void getAction(Map resultActionMap){ - ProtoActionMap actionMap; + ProtoActionMap actionMap=null; try { // Call the original method on the server. @@ -39,10 +39,10 @@ public Map getAction(){ } catch (StatusRuntimeException e) { // Log a warning if the RPC fails. logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); - return null; +// return null; } - - return actionMap.getDictActionMap(); + resultActionMap.putAll(actionMap.getDictActionMap()); +// return actionMap.getDictActionMap(); } public String sendObservation(Map stepOutputMap){ @@ -87,11 +87,11 @@ public String sendObservation(Map stepOutputMap){ return msg.getAck(); } - public static void main(String args[]) throws InterruptedException { - // Access a service running on the local machine on port 50051 - RLClient client = new RLClient(50051); - Observation ob = new Observation(2, true); - Map actionMap = client.getAction(); - System.out.println(actionMap); - } +// public static void main(String args[]) throws InterruptedException { +// // Access a service running on the local machine on port 50051 +// RLClient client = new RLClient(50051); +// Observation ob = new Observation(2, true); +// Map actionMap = client.getAction(); +// System.out.println(actionMap); +// } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java index aee72ac5b02..273b286042a 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/observation/TreeObservation.java @@ -21,8 +21,7 @@ import org.matsim.pt.transitSchedule.api.TransitRouteStop; import org.matsim.vehicles.Vehicle; -import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.getBufferTip; -import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.isSwitchable; +import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.*; class OtherAgent { // Distance of agent to other agent's tail position if the train is moving in the same direction @@ -77,51 +76,6 @@ private void getStopsOnTransitLine(TransitRoute transitRoute){ } this.transitRouteStops.put(transitRoute.getId(), stops); } - private boolean isObsTreeNode(Node toNode){ - - // Boolean to check if the toNode is halt position of train - boolean isStop = false; - - // get vehicle Id - Id trainId = position.getPt().getPlannedVehicleId(); - - // use transit route to get departures for the train - TransitRoute transitRoute = position.getPt().getTransitRoute(); - - // calculate stops in this route - if (!transitRouteStops.containsKey(transitRoute.getId())){ - getStopsOnTransitLine(transitRoute); - } - List> stopsList = transitRouteStops.get(transitRoute.getId()); - - // Iterate through all the stops of the transitLine - for (Id linkId : stopsList){ - // if toNode(stopLink) == toNode, mark the toNode to be a halt - if (network.getLinks().get(linkId).getToNode().equals(toNode)){ - isStop = true; - break; - } - } - -// if (position.getPt().getNextTransitStop() != null){ -// // ensure that there is a nextTransitStop. On the final destination -// // nextTransitStop is null -// Node nextHaltToNode = network.getLinks().get(position.getPt().getNextTransitStop().getLinkId()).getToNode(); -// if (nextHaltToNode.equals(toNode)) -// isStop = true; -// } -// -// // toNode is the end of the track with just one outLink in opposite direction -// if (toNode.getOutLinks().size()==1) -// isStop = true; - - // Check if the node is a switch node. A switch node must have more than . - boolean switch_node = toNode.getOutLinks().size() > 2; - if (isStop || switch_node){ - return true; - } - return false; - } private ObservationTreeNode createObservatioNode(TrainPosition train, RailLink curLink, Node node, OtherAgent sameDirAgent, OtherAgent oppDirAgent){ @@ -165,6 +119,80 @@ private List flattenObservationNode(ObservationTreeNode obsNode) { return flattenedNode; } + private TrainPosition getClosestTrainOnPathR(Node curNode, Node nextNode) throws NotImplementedException{ + throw new NotImplementedException(); + } + private TrainPosition getClosestTrainOnPathF(Node curNode, Node nextNode) throws NotImplementedException { + throw new NotImplementedException(); + } + + private Node getToNode(RailLink link) { + return network.getLinks().get(link.getLinkId()).getToNode(); + } + + public List getObservationTree() { + return observationList; + } + + public List getFlattenedObservationTree() { + return this.flattenedObservation; + } + + private boolean linkContainsObsTreeNode(Link curLink){ + + Node toNode = curLink.getToNode(); + + // Boolean to check if the toNode is halt position of train + boolean isStop = false; + + // get vehicle Id + Id trainId = position.getPt().getPlannedVehicleId(); + + // use transit route to get departures for the train + TransitRoute transitRoute = position.getPt().getTransitRoute(); + + // calculate stops in this route + if (!transitRouteStops.containsKey(transitRoute.getId())){ + getStopsOnTransitLine(transitRoute); + } + List> stopsList = transitRouteStops.get(transitRoute.getId()); + + // Iterate through all the stops of the transitLine + for (Id linkId : stopsList){ + // if toNode(stopLink) == toNode, mark the toNode to be a halt + if (network.getLinks().get(linkId).getToNode().equals(toNode)){ + isStop = true; + break; + } + } + + // Check if the node is a switch node. A switch node must have more than one outgoing track + boolean switch_node = getOutTracks(curLink, network).size() >= 2; + if (isStop || switch_node){ + return true; + } + return false; + } + + private List getNextLinks(Link curLink) { + // links contains ObsNode + List linksContainingObsNode = new ArrayList<>(); + + // check all outgoing links from the current obsNode + for (Link track : getOutTracks(curLink, network)) { + + // traverse the track until a halt/final destination/switch is reached + Link outLink = track; + while (!linkContainsObsTreeNode(outLink)) { + // get the single outgoing link and follow it + // TODO: Handle end of network, will throw NoSuchElementException at the moment + outLink = getOutTracks(track, network).get(0); + } + linksContainingObsNode.add(outLink); + } + return linksContainingObsNode; + } + private void createTreeObs(int depth){ // Get the link of the tip of the buffer RailLink bufferTipLink = getBufferTip(resources, position); @@ -174,20 +202,16 @@ private void createTreeObs(int depth){ bufferTipLink = resources.getLink(position.getHeadLink()); } - // Get the toNode and fromNode of the bufferTipLink - Node toNodeBufferTipLink = getToNode(bufferTipLink); - Node fromNodeBufferTipLink = network.getLinks().get(bufferTipLink.getLinkId()).getFromNode(); - - List exploreQueue = new ArrayList<>(); + List exploreQueue = new ArrayList<>(); // store a list of visitedNodes to avoid infinite loop in case of cycles in the network. - List visitedNodes = new ArrayList<>(); + List visitedLinks = new ArrayList<>(); - // the added node may not necessarily be an observation tree node - exploreQueue.add(toNodeBufferTipLink); + // the added bufferTipLink may not necessarily be the link having observation tree node + exploreQueue.add(network.getLinks().get(bufferTipLink.getLinkId())); // If the "toNodeBufferTipLink" is not an observation tree node then increase the depth by 1 - if(!isObsTreeNode(toNodeBufferTipLink)) + if(!linkContainsObsTreeNode(network.getLinks().get(bufferTipLink.getLinkId()))) depth +=1; for (int i = 0; i < depth; i++) { @@ -195,26 +219,26 @@ private void createTreeObs(int depth){ int lenExploreQueue = exploreQueue.size(); while (lenExploreQueue > 0) { // Level traversal - Node curNode = exploreQueue.get(0); + Link exploredLink = exploreQueue.get(0); exploreQueue.remove(0); // Create observationTreeNode from the curNode - if(isObsTreeNode(curNode)){ + if(linkContainsObsTreeNode(exploredLink)){ // if the exploredNode fits the criteria of observation tree node, then // add the node in the ObservationTree list - ObservationTreeNode obsNode = createObservatioNode(position, resources.getLink(position.getHeadLink()), curNode, null, null); + Node node = exploredLink.getToNode(); + ObservationTreeNode obsNode = createObservatioNode(position, resources.getLink(position.getHeadLink()), node, null, null); this.observationList.add(obsNode); this.flattenedObservation.addAll(flattenObservationNode(obsNode)); + visitedLinks.add(exploredLink); } - visitedNodes.add(curNode); - - // Look for switches/intersections/stops on the branches stemming out of the current switch - List nextNodes = getNextNodes(fromNodeBufferTipLink, curNode); + // Look for links having switches/intersections/stops on the branches stemming out of the current explored link + List nextLinks = getNextLinks(exploredLink); - // Add nextNode only if it's not already visited - for (Node nextNode : nextNodes) { - if (!visitedNodes.contains(nextNode)){ - exploreQueue.add(nextNode); + // Add next links to the exploreQueue only if it's not already visited + for (Link nextLink : nextLinks) { + if (!visitedLinks.contains(nextLink)){ + exploreQueue.add(nextLink); } } lenExploreQueue -= 1; @@ -222,84 +246,4 @@ private void createTreeObs(int depth){ } } - private TrainPosition getClosestTrainOnPathR(Node curNode, Node nextNode) throws NotImplementedException{ - throw new NotImplementedException(); - } - private TrainPosition getClosestTrainOnPathF(Node curNode, Node nextNode) throws NotImplementedException { - throw new NotImplementedException(); - } - - private List getNextNodes(Node fromNodeBufferTipLink, Node obsNode) { - // next switches - List obsTreeNodes = new ArrayList<>(); - - // check all outgoing links from the current obsNode - for (Link outLink : obsNode.getOutLinks().values()) { - - boolean reverseDirection = false; - Node nextNode = outLink.getToNode(); - - Node prevNode = obsNode; - - // traverse the track until a halt/final destination/switch is reached - while (!isObsTreeNode(nextNode)) { - // get the single outgoing link and follow it - // TODO: Handle end of network, will throw NoSuchElementException at the moment - if (nextNode.equals(fromNodeBufferTipLink)){ - reverseDirection =true; - break; - } - List outLinks = nextNode.getOutLinks().values().stream().collect(Collectors.toList()); - - // At the final stop, there will be just 1 outlink - Link nextLink = outLinks.get(0); - - if (outLinks.size()>1){ - // update the nextLink to point the right direction if there are more than 1 outlink - for(Link link : outLinks){ - // skip the link that leads to prevNode to avoid an infinite loop - if (link.getToNode().equals(prevNode)){ - continue; - } - else{ - nextLink = link; - break; - } - } - // update prevNode - prevNode = nextNode; - - // update nextNode - nextNode = nextLink.getToNode(); - - if (nextNode.equals(fromNodeBufferTipLink)){ - reverseDirection =true; - break; - } - }else{ - // nextNode is the final stop with just 1 outgoing link - break; - } - } - - if (reverseDirection || nextNode.equals(fromNodeBufferTipLink)){ - // skip the current outLink as this link from the switchNode leads to the observing train - continue; - } - obsTreeNodes.add(nextNode); - } - return obsTreeNodes; - } - - private Node getToNode(RailLink link) { - return network.getLinks().get(link.getLinkId()).getToNode(); - } - - public List getObservationTree() { - return observationList; - } - - public List getFlattenedObservationTree() { - return this.flattenedObservation; - } } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java index f162fa2db27..54d3ada283c 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/utils/RLUtils.java @@ -23,88 +23,59 @@ public static Node getToNode(Network network, RailLink link) { } - public static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, List path, RailResourceManager resources){ + public static boolean getPathToSwitchNodeOnTrack(Link curLink, Node target, List path, RailResourceManager resources, Network network){ + /** + get links from a node to the next switch node on the same track. + l0 St l1 l2 l3 l4 Sw + path = l1, l2, l3, l4 + + If target node found in the path, return true else false + + **/ // In the beginning there will be just one link (corresponding to the start link) // specified in the route of each train. However, the route of the train has a duplicated entry // of the entry link. Therefore, min(path.size()) = 2 at the time of start of departure. + + if (curLink.getToNode().equals(target)) + return true; + if (path.size()==2 && path.get(0).equals(path.get(1))){ // remove the duplicate entry link for the train path.remove(1); } -// get links from a node to the next switch node on the same track. -// l0 St l1 l2 l3 l4 Sw -// path = l1, l2, l3, l4 -// -// If target node found in the path, return true else false - - Node start = curLink.getToNode(); -// Node start = network.getLinks().get(curLink.getLinkId()).getToNode(); + List outTracks = getOutTracks(curLink, network); + while (outTracks.size()==1){ + Link nextTrackLink = outTracks.get(0); - // at all points in the track start node will have two outgoing links except at the final goal point - // where outgoing link will be 1 or at a switch node where the outgoing links will be more than 2. - while ( start.getOutLinks().values().size() == 2){ - Link nextLink = null; - - // get outLinks from the start node - List outLinks = start.getOutLinks().values().stream().collect(Collectors.toList()); - assert (outLinks.size() == 2); - - // get the fromNode of the curLink - Node fromNodeCurLink = curLink.getFromNode(); - - for (Link outLink : outLinks){ - if (outLink.getToNode().equals(fromNodeCurLink)){ - // Ignore the link where outLink(start) == curLink - continue; - } - else { - nextLink = outLink; - break; - } - } + // convert Link to Raillink and append the link to path + RailLink convertedTrackLink = resources.getLink(nextTrackLink.getId()); + path.add(convertedTrackLink); - // convert Link to Raillink - RailLink convertedNextLink = resources.getLink(nextLink.getId()); - path.add(convertedNextLink); + outTracks = getOutTracks(nextTrackLink, network); - // update start node and curLink - start = nextLink.getToNode(); - curLink = nextLink; - if (target != null && start.equals(target)) { - // target found in the path + if (nextTrackLink.getToNode().equals(target)) return true; - } + } - // no path found to the target return false; } + public static void updateRoute(Network network, TrainState train, Node nextObsNode, RailResourceManager resources){ // Get the last link in the route RailLink lastLinkInRoute = train.route.get(train.route.size() -1); - // get the toNode of the lastLinkInRoute - Node toNodeLastLinkInRoute = getToNode(network, lastLinkInRoute); - - // get the fromNode of the lastLinkInRoute - Node fromNodeLastLinkInRoute = network.getLinks().get(lastLinkInRoute.getLinkId()).getFromNode(); - // get the outLinks from the toNode of the lastLinkInRoute - List nextLinks = toNodeLastLinkInRoute.getOutLinks().values().stream().collect(Collectors.toList()); + List nextTracks = getOutTracks(network.getLinks().get(lastLinkInRoute.getLinkId()), network); //path to the nextObsNode List path = null; - for (Link nextLink : nextLinks){ - - // skip the link that takes back on the same track - if (nextLink.getToNode().equals(fromNodeLastLinkInRoute)) - continue; - + for (Link nextLink : nextTracks){ // To store the path from lastLinkInRoute to node connecting the nextObsNode path = new ArrayList<>(); - if (getPathToSwitchNodeOnTrack(network.getLinks().get(lastLinkInRoute.getLinkId()), nextObsNode, path, resources)) + if (getPathToSwitchNodeOnTrack(network.getLinks().get(lastLinkInRoute.getLinkId()), nextObsNode, path, resources, network)) break; } @@ -160,8 +131,8 @@ public static double calculateAngle(Node point1, Node point2, Node point3) { public static Boolean isSwitchable(Node switchNode, RailLink curLink, Network network){ - int numOutGoingLinks = switchNode.getOutLinks().size(); - if (numOutGoingLinks <= 2){ + int numOutGoingTracks = getOutTracks(network.getLinks().get(curLink.getLinkId()), network).size(); + if (numOutGoingTracks <= 1){ return false; } else{ @@ -199,4 +170,39 @@ public static RailLink getBufferTip(RailResourceManager resources, TrainState po } } + public static List getOutTracks(Link link, Network network){ + + List outTracks = new ArrayList<>(); + + Node toNode = link.getToNode(); + if (toNode.getOutLinks() != null){ + for (Link l : toNode.getOutLinks().values()){ + if (l.getToNode().equals(link.getFromNode())){ + // skip the outlink that leads backward in a bi-directional network + continue; + }else{ + outTracks.add(l); + } + } + } + return outTracks; + } + + public static List getMergingTracks(Link link){ + List inTracks = new ArrayList<>(); + + Node toNode = link.getToNode(); + if (toNode.getInLinks() != null){ + for(Link l: toNode.getInLinks().values()){ + if (l.equals(link)){ + continue; + }else{ + inTracks.add(l); + } + } + } + + return inTracks; + } + } diff --git a/contribs/railsim/src/main/proto/railsim.proto b/contribs/railsim/src/main/proto/railsim.proto index b7d680c506a..1d0ba85e9aa 100644 --- a/contribs/railsim/src/main/proto/railsim.proto +++ b/contribs/railsim/src/main/proto/railsim.proto @@ -56,5 +56,7 @@ service RailsimFactory{ rpc getEnvironment (ProtoGrpcPort) returns (ProtoConfirmationResponse) {}; rpc resetEnv (ProtoGrpcPort) returns (ProtoAgentIDs) {}; + + rpc getAgentIds (ProtoGrpcPort) returns (ProtoAgentIDs) {}; } diff --git a/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microJunctionY/transitSchedule.xml b/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microJunctionY/transitSchedule.xml index 8642858618f..fce483e6722 100644 --- a/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microJunctionY/transitSchedule.xml +++ b/contribs/railsim/test/input/ch/sbb/matsim/contrib/railsim/integration/microJunctionY/transitSchedule.xml @@ -22,11 +22,11 @@ - - - - - + + + + + @@ -42,11 +42,11 @@ - - - - - + + + + + From b72ddcbf99b9ac22c302ba6e0b01814ef220a2eb Mon Sep 17 00:00:00 2001 From: iamakash Date: Wed, 12 Jun 2024 10:15:36 +0200 Subject: [PATCH 14/14] fix: fix bugs with the synchronisation issues --- contribs/railsim/pom.xml | 8 +- .../railsim/EnvironmentFactoryServer.java | 13 +- .../matsim/contrib/railsim/RailsimEnv.java | 4 + .../railsim/qsimengine/RailsimEngine.java | 2 +- .../disposition/RLTrainDisposition.java | 144 ++++++------------ .../matsim/contrib/railsim/rl/RLClient.java | 61 ++++++-- contribs/railsim/src/main/proto/railsim.proto | 7 +- 7 files changed, 120 insertions(+), 119 deletions(-) diff --git a/contribs/railsim/pom.xml b/contribs/railsim/pom.xml index f870ab1fcb6..1a0da60ef68 100644 --- a/contribs/railsim/pom.xml +++ b/contribs/railsim/pom.xml @@ -51,9 +51,15 @@ 6.0.53 provided + + io.grpc + grpc-netty + 1.63.0 + compile + - + diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java index 3707a5a3d04..b15f668c13c 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/EnvironmentFactoryServer.java @@ -78,9 +78,6 @@ public static void main(String[] args) throws IOException, InterruptedException private class RailsimFactory extends RailsimFactoryGrpc.RailsimFactoryImplBase { Map envMap = new HashMap<>(); - - - @Override public void getEnvironment(ProtoGrpcPort grpcPort, StreamObserver responseObserver) { // Create an instance of Railsim environment and store it in a map @@ -89,7 +86,7 @@ public void getEnvironment(ProtoGrpcPort grpcPort, StreamObserver responseObserver) { System.out.println("Reset env id: "+grpcPort); @@ -125,8 +123,11 @@ public void resetEnv(ProtoGrpcPort grpcPort, StreamObserver respo // Start the simulation //TODO: Should this be started on a different thread so that the endpoint is not blocked or is it automatically taken care of? - env.startSimulation(); - + try{ + env.startSimulation(); + }finally{ + env.shutdown(); + } } public void getAgentIds(ProtoGrpcPort grpcPort, StreamObserver responseObserver){ diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java index 654c271b668..fed5ec9ea47 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/RailsimEnv.java @@ -74,5 +74,9 @@ void startSimulation(){ controler.run(); } + void shutdown(){ + this.rlClient.shutdown(); + } + } diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java index 803a1a69438..468a294bf2a 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/RailsimEngine.java @@ -74,7 +74,7 @@ public RailsimEngine(EventsManager eventsManager, RailsimConfigGroup config, Rai @Override public void doSimStep(double time) { - + log.info("Thread name:: "+Thread.currentThread().getName() + " timestep: "+time); UpdateEvent update = updateQueue.peek(); // Update loop over all required state updates diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java index 4699b19d5d1..e9c8d01b9c4 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/qsimengine/disposition/RLTrainDisposition.java @@ -12,6 +12,7 @@ import ch.sbb.matsim.contrib.railsim.rl.observation.ObservationTreeNode; import ch.sbb.matsim.contrib.railsim.rl.observation.StepOutput; import ch.sbb.matsim.contrib.railsim.rl.observation.TreeObservation; +import ch.sbb.matsim.routing.pt.raptor.SwissRailRaptor; import jakarta.inject.Inject; import org.matsim.api.core.v01.Coord; import org.matsim.api.core.v01.Id; @@ -25,13 +26,15 @@ import org.matsim.vehicles.Vehicle; import java.util.*; -import java.util.concurrent.*; import java.util.stream.Collectors; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import static ch.sbb.matsim.contrib.railsim.rl.utils.RLUtils.*; -import static java.util.concurrent.Executors.*; public class RLTrainDisposition implements TrainDisposition { + + private static final Logger log = LogManager.getLogger(RLTrainDisposition.class);; RailResourceManager resources; TrainRouter router; Network network; @@ -145,96 +148,49 @@ public DispositionResponse requestNextSegment(double time, TrainPosition positio // outTracks contain list of outLinks from bufferTipLink. The list does not contain the reverse link of bufferTipLink List outTracks = getOutTracks(network.getLinks().get(bufferTipLink.getLinkId()), network); -// if(outTracks.size()>=2){ - if(true){ - // RL should be called only when the train's safety buffer is about to reach the decision node - - // calculate and send StepOutput to rl + // RL should be called only when the train's safety buffer is about to reach the decision node + if(outTracks.size()>=2){ +// if(true){ + // calculate StepOutput Observation ob = getObservation(time, position); - Map stepOutputMap = getStepOutput(position, ob, getReward(), false); - if (activeTrains.containsKey(trainId)){ - bufferStepOutputMap.putAll(stepOutputMap); - rlClient.sendObservation(bufferStepOutputMap); - - - // start a new thread and send observation through rlClient -// Thread rlSendObservationThread = new Thread(() -> rlClient.sendObservation(bufferStepOutputMap)); -// rlSendObservationThread.start(); -// try { -// // Wait for the rlSendObservationThread to finish -// rlSendObservationThread.join(); -// } catch (InterruptedException e) { -// e.printStackTrace(); -// } - // clear the buffer after sending the stepOutput - bufferStepOutputMap.clear(); - - // start a new thread ot get action from rl - Map actionMap = new HashMap<>(); - rlClient.getAction(actionMap); -// Thread rlgetActionThread = new Thread(() -> rlClient.getAction(actionMap)); -// rlgetActionThread.start(); -// try { -// // Wait for the rlgetActionThread to finish -// rlgetActionThread.join(); -// } catch (InterruptedException e) { -// e.printStackTrace(); -// } - -// rlClient.getAction(actionMap); - int action = actionMap.get(trainId); -// int action = actionMap.values().iterator().next(); - - -// // Create a single-threaded executor -// ExecutorService executorService = newSingleThreadExecutor(); -// -// // Create a Callable task -// Callable> task = () -> rlClient.getAction(); -// -// // Submit the task to the executor service and get a Future object -// Future> futureResult = executorService.submit(task); -// Map result = null; -// try { -// // Wait for the task to complete and get the result -// result = futureResult.get(); -// System.out.println("task returned: " + result); -// } catch (InterruptedException | ExecutionException e) { -// e.printStackTrace(); -//// executorService.shutdown(); -// } -// finally { -// // Shutdown the executor service -// executorService.shutdown(); -// } + Map stepOutputMap = getStepOutput(position, ob, getReward(), false, false); + bufferStepOutputMap.putAll(stepOutputMap); + + log.info("Thread: "+ Thread.currentThread().getName() + " Send observation for train: "+trainId + " at timestep: " + (time)); + rlClient.sendObservation(bufferStepOutputMap); + + // clear the buffer after sending the stepOutput + bufferStepOutputMap.clear(); + // start a new thread to get action from rl + log.info("Thread: "+ Thread.currentThread().getName() + " Get action for train: "+trainId + " at timestep: " + (time)); + Map actionMap = new HashMap<>(); + rlClient.getAction(time, trainId.toString(), actionMap); + log.info("Thread: "+ Thread.currentThread().getName() + " Recd. action for train: "+trainId + " at timestep" + Math.ceil(time) + " action: "+actionMap); //update route based on the action from rl -// int action = result.get(trainId); - -// StepOutput out = stepOutputMap.get(position.getPt().getPlannedVehicleId().toString()); - StepOutput out = stepOutputMap.values().iterator().next(); -// switch (action){ -// case 0:{ -// // update route in the query direction -// Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(1).getNodeId()); -// updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); -// break; -// } -// case 1:{ -// // update route in the other direction -// Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(2).getNodeId()); -// updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); -// break; -// } -// case 2:{ -// // stop the train -// return new DispositionResponse(0, 0, null); -// } -// default:{ -// System.out.println("Illegal action"); -// } -// -// } + int action = actionMap.get(trainId.toString()); + StepOutput out = stepOutputMap.values().iterator().next(); + switch (action){ + case 0:{ + // update route in the query direction + Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(1).getNodeId()); + updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); + break; + } + case 1:{ + // update route in the other direction + Node nextSwitchNodePos = network.getNodes().get(out.getObservation().getObsTree().get(2).getNodeId()); + updateRoute(network, (TrainState) position, nextSwitchNodePos, resources); + break; + } + case 2:{ + // stop the train + return new DispositionResponse(0, 0, null); + } + default:{ + System.out.println("Illegal action"); + } + } } @@ -295,7 +251,7 @@ public void unblockRailLink(double time, MobsimDriverAgent driver, RailLink link public void onTermination(double time, TrainPosition position){ // Store the StepOutput in bufferStepOutput. // bufferStepOutput is not sent to RL until there is an observation for a train whose done=false - bufferStepOutputMap.putAll(getStepOutput(position, getObservation(time, position), getReward(), true)); + bufferStepOutputMap.putAll(getStepOutput(position, getObservation(time, position), getReward(), true, false)); // remove the train from active trains activeTrains.remove(position.getPt().getPlannedVehicleId()); @@ -330,7 +286,7 @@ public void onSimulationEnd(double now){ // set all the trains having heavy negative reward for(TrainPosition train: activeTrains.values()){ - Map stepOutputMap = getStepOutput(train, getObservation(now, train), -100.0, true); + Map stepOutputMap = getStepOutput(train, getObservation(now, train), -100.0, false, true); bufferStepOutputMap.putAll(stepOutputMap); } rlClient.sendObservation(bufferStepOutputMap); @@ -395,14 +351,14 @@ Observation getObservation(double time, TrainPosition train){ } private Map getStepOutput - (TrainPosition train, Observation observation, double reward, Boolean done){ + (TrainPosition train, Observation observation, double reward, Boolean terminated, Boolean truncated){ StepOutput stepOutput = new StepOutput(); stepOutput.setInfo(null); stepOutput.setReward(reward); - stepOutput.setTerminated(done); - stepOutput.setTruncated(done); + stepOutput.setTerminated(terminated); + stepOutput.setTruncated(truncated); stepOutput.setObservation(observation); Map stepOutputMap= new HashMap<>(); diff --git a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java index e6a7116e64f..eac58762134 100644 --- a/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java +++ b/contribs/railsim/src/main/java/ch/sbb/matsim/contrib/railsim/rl/RLClient.java @@ -5,44 +5,72 @@ import ch.sbb.matsim.contrib.railsim.grpc.ProtoStepOutputMap; import ch.sbb.matsim.contrib.railsim.grpc.ProtoObservation; import ch.sbb.matsim.contrib.railsim.grpc.ProtoStepOutput; +import ch.sbb.matsim.contrib.railsim.grpc.ProtoGetActionRequest; //import ch.sbb.matsim.contrib.railsim.grpc.*; import ch.sbb.matsim.contrib.railsim.rl.observation.Observation; import ch.sbb.matsim.contrib.railsim.rl.observation.StepOutput; + import io.grpc.*; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.logging.Level; import java.util.logging.Logger; import com.google.protobuf.Empty; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; public class RLClient { private RailsimConnecterGrpc.RailsimConnecterBlockingStub blockingStub=null; private static final Logger logger = Logger.getLogger(RLClient.class.getName()); - public RLClient(int port){ - String target = "localhost:"+port; - ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) + + ManagedChannel channel; + ExecutorService executorService; + private final EventLoopGroup eventLoopGroup; + + + public RLClient(int port){ + // Create a single-threaded executor + executorService = Executors.newSingleThreadExecutor(); + + // Create a single-threaded executor + this.executorService = Executors.newSingleThreadExecutor(); + + // Create a single-threaded event loop group + this.eventLoopGroup = new NioEventLoopGroup(1); + + // Build the channel with the custom executor and event loop group + this.channel = NettyChannelBuilder.forAddress("localhost", port) + .usePlaintext() // Use plaintext for simplicity; switch to TLS in production + .executor(executorService) + .eventLoopGroup(eventLoopGroup) + .channelType(NioSocketChannel.class) .build(); - blockingStub = RailsimConnecterGrpc.newBlockingStub(channel); + + blockingStub = RailsimConnecterGrpc.newBlockingStub(channel); } - public void getAction(Map resultActionMap){ + public void getAction(double time, String trainId, Map resultActionMap){ ProtoActionMap actionMap=null; - + ProtoGetActionRequest request = ProtoGetActionRequest.newBuilder() + .setTimestamp((time)) + .setTrainId(trainId) + .build(); try { // Call the original method on the server. - Empty request = Empty.newBuilder().build(); actionMap = blockingStub.getAction(request); } catch (StatusRuntimeException e) { // Log a warning if the RPC fails. logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); -// return null; } resultActionMap.putAll(actionMap.getDictActionMap()); -// return actionMap.getDictActionMap(); } public String sendObservation(Map stepOutputMap){ @@ -87,11 +115,12 @@ public String sendObservation(Map stepOutputMap){ return msg.getAck(); } -// public static void main(String args[]) throws InterruptedException { -// // Access a service running on the local machine on port 50051 -// RLClient client = new RLClient(50051); -// Observation ob = new Observation(2, true); -// Map actionMap = client.getAction(); -// System.out.println(actionMap); -// } + public void shutdown() { + channel.shutdown(); + executorService.shutdown(); + eventLoopGroup.shutdownGracefully(); + } + + + } diff --git a/contribs/railsim/src/main/proto/railsim.proto b/contribs/railsim/src/main/proto/railsim.proto index 1d0ba85e9aa..73a64c84216 100644 --- a/contribs/railsim/src/main/proto/railsim.proto +++ b/contribs/railsim/src/main/proto/railsim.proto @@ -40,9 +40,14 @@ message ProtoConfirmationResponse{ string ack = 1; } +message ProtoGetActionRequest{ + double timestamp = 1; + string trainId = 2; +} + service RailsimConnecter { - rpc getAction (google.protobuf.Empty) returns (ProtoActionMap) {}; + rpc getAction (ProtoGetActionRequest) returns (ProtoActionMap) {}; rpc updateState (ProtoStepOutputMap) returns (ProtoConfirmationResponse) {};