diff --git a/examples/Application/Traffic-Light-Control/README.md b/examples/Application/Traffic-Light-Control/README.md new file mode 100644 index 000000000..ad59667e7 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/README.md @@ -0,0 +1,72 @@ +## Baseline Algorithms For Traffic Light Control +Based on PARL, we use the DDQN algorithm of deep RL to reproduce some baselines of the Traffic Light Control(TLC), reaching the same level of indicators as the papers in TLC benchmarks. + +### Traffic Light Control Simulator Introduction + +Please see [sumo](https://github.com/eclipse/sumo) or [cityflow](https://github.com/cityflow-project/CityFlow) to know more about the TLC simulator. +And we use the cityflow simuator in the experiments, as for how to install the cityflow, please refer [here](https://cityflow.readthedocs.io/en/latest/index.html) for more informations. + +### Benchmark Result +Note that we set the yellow signal time to 5 seconds to clear the intersection, and the action intervals is set to 10 seconds as the papers, you can refer the `config.py` for more details. +You can download the data from [here](https://traffic-signal-control.github.io/) and [MPLight data](https://github.com/Chacha-Chen/MPLight/tree/master/data). +We use the average travel time of all vehicles to evaluate the performance of the signal control method in transportation. +Performances of presslight and FRAP on cityflow envrionments in training process after 300 episodes are shown below. + +| average travel time| hz_1x1_tms-
xy_18041608| hz_1x1_bc-
tyc_18041608|syn_1x3_
gaussian|syn_2x2_
gaussian|anon_4_4_
750_0.6| anon_4_4
_750_0.3| anon_4_4
_700_0.6|anon_4_4
_700_0.3| +| :-----| :----: | :----: |:----: | :----: |:----: | :----: |:----: | :----: | +| max_pressure | 284.02 | 445.62 | 240.08 |316.67|589.03 | 536.89 |545.29 | 483.08 | +| presslight |110.62 | 189.97| 127.83| 184.58| 437.86| 357.10 |410.34 | 434.33| +| FRAP | 113.79 | 135.88 | 123.97| 166.45| 374.73 | 331.43 | 343.79| 300.77 | +| presslight* | 236.29| 244.87 |149.40| 953.78| -- | --| --| -- | +| FRAP* | 130.53| 159.54| 750.68| 713.48|--| -- |-- | -- | + + +We also provide the implementation for that SOTL algorithm, but its performance heavily relies on the environment variables such as `t_min` and `min_green_vehicle`. We do not list its result here. + +And results of the last two rows of the table ,`presslight*` and `FRAP*`, they are the results of the code [tlc-baselines](https://github.com/gjzheng93/tlc-baselines) provided from the paper authors' team. We run the [code](https://github.com/gjzheng93/tlc-baselines) just changing the yellow signal time and the action intervals to keep them same as our config as the papers without changing any other parameters. `--` in the table means the origins code doesn't perform well in the last four `anon_4X4` datas, the average travel time results of it will be more than 1000, maybe it will perform better than the `max_pressure`if you modify the other hyperparameters of the DQN agents, such as the buffer size, update_model_freq, the gamma or others. + +## How to use +### Dependencies ++ [parl>=1.4.3](https://github.com/PaddlePaddle/PARL) ++ torch==1.8.1+cu102 ++ cityflow==0.1 + +### Training +First, download the data from [here](https://traffic-signal-control.github.io/) or [MPLight data](https://github.com/Chacha-Chen/MPLight/tree/master/data) and put them in the `data` directory. And run the training script. The `train_presslight.py `for the presslight, each intersection has its own model as default(you can also choose to train with that all the intersections share one model in the script, just as what the paper MPLight used, it is suggested when the number of the intersections is large, just setting the `--is_share_model` to `True`). +```bash +python train_presslight.py --is_share_model False +``` + +If you want the train the `FRAR`, you can run the script below: +```bash +python train_FRAP.py +``` + +If you want to compare the different results, you can load the right model path in the `config.py` and the right data path in the `config.json`, and then run: +```bash +python test.py +``` + +### Contents ++ agent + + `agent.py` + The agent that uses the PARL agent mode, it will be used when training the RL methods such as `presslight` or `FRAP` and so on. + + `max_pressure_agent.py` and `sotl_agnet.py`.The classic methods of the TLC. ++ data + + You can get the data of the from here or download other data and put them here. ++ example + + Put the `config.json` here, need to change the path of the roadnet the flow data in the `json` file. ++ model + + Different algorithms have different models. ++ obs_reward + + Different algorithms have different obs and rewards generators. + +### Reference ++ [Parl](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) ++ [Reinforcement Learning for Traffic Signal Control](https://traffic-signal-control.github.io/) ++ [Toward A Thousand Lights: Decentralized Deep Reinforcement Learning for Large-Scale Traffic Signal Control](https://chacha-chen.github.io/papers/chacha-AAAI2020.pdf) ++ [Traffic Light Control Baselines](https://github.com/zhc134/tlc-baselines) ++ [PressLight: Learning Max Pressure Control to Coordinate Traffic Signals in Arterial Network](http://personal.psu.edu/hzw77/publications/presslight-kdd19.pdf) ++ [PressLight](https://github.com/wingsweihua/presslight) ++ [Learning Phase Competition for Traffic Signal Control](http://www.personal.psu.edu/~gjz5038/paper/cikm2019_frap/cikm2019_frap_paper.pdf) ++ [frap-pub](https://github.com/gjzheng93/frap-pub) diff --git a/examples/Application/Traffic-Light-Control/agent/agent.py b/examples/Application/Traffic-Light-Control/agent/agent.py new file mode 100644 index 000000000..e24c4ba5d --- /dev/null +++ b/examples/Application/Traffic-Light-Control/agent/agent.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import parl +import numpy as np + + +class Agent(parl.Agent): + def __init__(self, algorithm, config): + super(Agent, self).__init__(algorithm) + + self.config = config + self.epsilon = self.config['epsilon'] + + def sample(self, obs): + # The epsilon-greedy action selector. + obs = paddle.to_tensor(obs, dtype='float32') + logits = self.alg.sample(obs) + act_dim = logits.shape[-1] + act_values = logits.numpy() + actions = np.argmax(act_values, axis=-1) + for i in range(obs.shape[0]): + if np.random.rand() <= self.epsilon: + actions[i] = np.random.randint(0, act_dim) + return actions + + def predict(self, obs): + + obs = paddle.to_tensor(obs, dtype='float32') + predict_actions = self.alg.predict(obs) + return predict_actions.numpy() + + def learn(self, obs, actions, dones, rewards, next_obs): + + obs = paddle.to_tensor(obs, dtype='float32') + actions = paddle.to_tensor(actions, dtype='float32') + dones = paddle.to_tensor(dones, dtype='float32') + next_obs = paddle.to_tensor(next_obs, dtype='float32') + rewards = paddle.to_tensor(rewards, dtype='float32') + + Q_loss, pred_values, target_values, max_v_show_values, train_count, lr, epsilon = self.alg.learn( + obs, actions, dones, rewards, next_obs) + + self.alg.sync_target(decay=self.config['decay']) + self.epsilon = epsilon + + return Q_loss.numpy(), pred_values.numpy(), target_values.numpy( + ), max_v_show_values.numpy(), train_count, lr, epsilon diff --git a/examples/Application/Traffic-Light-Control/agent/max_pressure_agent.py b/examples/Application/Traffic-Light-Control/agent/max_pressure_agent.py new file mode 100644 index 000000000..dd8c0d3af --- /dev/null +++ b/examples/Application/Traffic-Light-Control/agent/max_pressure_agent.py @@ -0,0 +1,32 @@ +# Third party code +# +# The following code is mainly referenced, modified and copied from: +# https://github.com/zhc134/tlc-baselines and https://github.com/gjzheng93/tlc-baseline + +import numpy as np + + +class MaxPressureAgent(object): + """ + Agent using MaxPressure method to control traffic light + """ + + def __init__(self, world): + self.world = world + + def predict(self, lane_vehicle_count): + actions = [] + for I_id, I in enumerate(self.world.intersections): + action = I.current_phase + max_pressure = None + action = -1 + for phase_id in range(len(I.phases)): + pressure = sum([ + lane_vehicle_count[start] - lane_vehicle_count[end] + for start, end in I.phase_available_lanelinks[phase_id] + ]) + if max_pressure is None or pressure > max_pressure: + action = phase_id + max_pressure = pressure + actions.append(action) + return np.array(actions) diff --git a/examples/Application/Traffic-Light-Control/agent/sotl_agent.py b/examples/Application/Traffic-Light-Control/agent/sotl_agent.py new file mode 100644 index 000000000..906af1e9d --- /dev/null +++ b/examples/Application/Traffic-Light-Control/agent/sotl_agent.py @@ -0,0 +1,45 @@ +# Third party code +# +# The following code is mainly referenced, modified and copied from: +# https://github.com/zhc134/tlc-baselines and https://github.com/gjzheng93/tlc-baseline + +import numpy as np + + +class SOTLAgent(object): + """ + Agent using Self-organizing Traffic Light(SOTL) Control method to control traffic light. + Note that different t_min, min_green_vehicle and max_red_vehicle may cause different results, which may not fair to compare to others. + """ + + def __init__(self, world, t_min=3, min_green_vehicle=20, + max_red_vehicle=5): + self.world = world + # the minimum duration of time of one phase + self.t_min = t_min + # some threshold to deal with phase requests + self.min_green_vehicle = min_green_vehicle # 10 + self.max_red_vehicle = max_red_vehicle # 30 + self.action_dims = [] + for i in self.world.intersections: + self.action_dims.append(len(i.phases)) + + def predict(self, lane_waiting_count): + actions = [] + for I_id, I in enumerate(self.world.intersections): + action = I.current_phase + if I.current_phase_time >= self.t_min: + num_green_vehicles = sum([ + lane_waiting_count[lane] + for lane in I.phase_available_startlanes[I.current_phase] + ]) + num_red_vehicles = sum( + [lane_waiting_count[lane] for lane in I.startlanes]) + num_red_vehicles -= num_green_vehicles + if num_green_vehicles <= self.min_green_vehicle and num_red_vehicles > self.max_red_vehicle: + action = (action + 1) % self.action_dims[I_id] + actions.append(action) + return np.array(actions) + + def get_reward(self): + return None diff --git a/examples/Application/Traffic-Light-Control/config.py b/examples/Application/Traffic-Light-Control/config.py new file mode 100644 index 000000000..057cdc0f8 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/config.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +config = { + + #========== env config ========== + 'config_path_name': + './scenarios/config_hz_1.json', # note that the path of the data can be modified in the json file. + 'thread_num': 8, + 'obs_fns': ['lane_count'], + 'reward_fns': ['pressure'], + 'is_only': False, + 'average': None, + 'action_interval': 10, + 'metric_period': 3600, #3600 + 'yellow_phase_time': 5, + + #========== learner config ========== + 'gamma': 0.85, # also can be set to 0.95 + 'epsilon': 0.9, + 'epsilon_min': 0.2, + 'epsilon_decay': 0.99, + 'start_lr': 0.00025, + 'episodes': 200 + 100, + 'algo': 'DQN', # DQN + 'max_train_steps': int(1e6), + 'lr_decay_interval': 100, + 'epsilon_decay_interval': 100, + 'sample_batch_size': + 2048, # also can be set to 32, which doesn't matter much. + 'learn_freq': 2, # update parameters every 2 or 5 steps + 'decay': 0.995, # soft update of double DQN + 'reward_normal_factor': 4, # rescale the rewards, also can be set to 20, + 'train_count_log': 5, # add to the tensorboard + 'is_show_log': False, # print in the screen + 'step_count_log': 1000, + + # save checkpoint frequent episode + 'save_rate': 100, + 'save_dir': './save_model/presslight', + 'train_log_dir': './train_log/presslight', + 'save_dir': './save_model/presslight4*4', + 'train_log_dir': './train_log/presslight4*4', + + # memory config + 'memory_size': 20000, + 'begin_train_mmeory_size': 3000 +} diff --git a/examples/Application/Traffic-Light-Control/ddqn.py b/examples/Application/Traffic-Light-Control/ddqn.py new file mode 100644 index 000000000..4f5369681 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/ddqn.py @@ -0,0 +1,95 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +import copy +import numpy as np +import parl +from parl.utils.scheduler import LinearDecayScheduler + + +class DDQN(parl.Algorithm): + def __init__(self, model, config): + + self.model = model + + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=40.0) + self.optimizer = paddle.optimizer.Adam( + learning_rate=config['start_lr'], + parameters=self.model.parameters(), + grad_clip=clip) + + self.mse_loss = nn.MSELoss(reduction='mean') + + self.config = config + self.lr_scheduler = LinearDecayScheduler(config['start_lr'], + config['max_train_steps']) + self.lr = config['start_lr'] + self.target_model = copy.deepcopy(model) + + self.train_count = 0 + + self.epsilon = self.config['epsilon'] + self.epsilon_min = self.config['epsilon_min'] + self.epsilon_decay = self.config['epsilon_decay'] + + def sample(self, obs): + logits = self.model(obs) + return logits + + def predict(self, obs): + logits = self.model(obs) + predict_actions = paddle.argmax(logits, axis=-1) + return predict_actions + + def sync_target(self, decay=0.995): + # soft update + self.model.sync_weights_to(self.target_model, decay) + + def learn(self, obs, actions, dones, rewards, next_obs): + # Update the Q network with the data sampled from the memory buffer. + if self.train_count > 0 and self.train_count % self.config[ + 'lr_decay_interval'] == 0: + self.lr = self.lr_scheduler.step( + step_num=self.config['lr_decay_interval']) + terminal = dones + actions_onehot = F.one_hot( + actions.astype('int'), num_classes=self.model.act_dim) + # shape of the pred_values: batch_size + pred_values = paddle.sum(self.model(obs) * actions_onehot, axis=-1) + greedy_action = self.model(next_obs).argmax(1) + with paddle.no_grad(): + # target_model for evaluation, using the double DQN, the max_v_show just used for showing in the tensorborad + max_v_show = paddle.max(self.target_model(next_obs), axis=-1) + greedy_actions_onehot = F.one_hot( + greedy_action, num_classes=self.model.act_dim) + max_v = paddle.sum( + self.target_model(next_obs) * greedy_actions_onehot, axis=-1) + assert max_v.shape == rewards.shape + target = rewards + (1 - terminal) * self.config['gamma'] * max_v + Q_loss = 0.5 * self.mse_loss(pred_values, target) + + # optimize + self.optimizer.clear_grad() + Q_loss.backward() + self.optimizer.step() + self.train_count += 1 + if self.epsilon > self.epsilon_min and self.train_count % self.config[ + 'epsilon_decay_interval'] == 0: + self.epsilon *= self.epsilon_decay + return Q_loss, pred_values.mean(), target.mean(), max_v_show.mean( + ), self.train_count, self.lr, self.epsilon diff --git a/examples/Application/Traffic-Light-Control/environment.py b/examples/Application/Traffic-Light-Control/environment.py new file mode 100644 index 000000000..6addec9b9 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/environment.py @@ -0,0 +1,49 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/gjzheng93/tlc-baseline and https://github.com/zhc134/tlc-baselines + +import gym +import numpy as np +import cityflow + + +class CityFlowEnv(gym.Env): + """ + Environment for Traffic Signal Control task. + + Parameters + ---------- + world: World object + obs_reward_generator(object): generator of the obs and rewards + """ + + def __init__(self, world, obs_reward_generator): + + self.world = world + self.n_agents = len(self.world.intersection_ids) + self.n = self.n_agents + # agents action space dim, each roadnet file may have different action dims + self.action_dims = [] + for i in self.world.intersections: + self.action_dims.append(len(i.phases)) + self.action_space = gym.spaces.MultiDiscrete(self.action_dims) + self.obs_reward_generator = obs_reward_generator + + def step(self, actions): + """ + actions: list + """ + assert len(actions) == self.n_agents + self.world.step(actions) + + obs = self.obs_reward_generator.generate_obs() + rewards = self.obs_reward_generator.generate_reward() + dones = [False] * self.n_agents + infos = {} + return obs, rewards, dones, infos + + def reset(self, seed=False): + self.world.reset(seed) + obs = self.obs_reward_generator.generate_obs() + return obs diff --git a/examples/Application/Traffic-Light-Control/model/FRAP_model.py b/examples/Application/Traffic-Light-Control/model/FRAP_model.py new file mode 100644 index 000000000..7c22441d1 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/model/FRAP_model.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import parl + + +class PressLightFRAPModel(parl.Model): + def __init__(self, + obs_dim, + act_dim, + embedding_size=4, + constant=None, + algo='DQN'): + + super(PressLightFRAPModel, self).__init__() + self.constant = constant + self.phase_lanes_dim = (obs_dim - act_dim) // act_dim + self.obs_dim = obs_dim + self.act_dim = act_dim + + # Assert the input of phase is one-hot. + self.current_phase_embedding = nn.Embedding(2, embedding_size) + self.relation_embedding = nn.Embedding(2, embedding_size) + + relation_dim = 10 + self.relation_conv = nn.Conv2D( + embedding_size, relation_dim, kernel_size=1, stride=1, padding=0) + + self.d_fc = nn.Linear(self.phase_lanes_dim, embedding_size) + + self.lane_dim = 16 + self.lane_fc = nn.Linear(embedding_size * 2, self.lane_dim) + self.lane_conv = nn.Conv2D( + self.lane_dim * 2, + relation_dim, + kernel_size=1, + stride=1, + padding=0) + hidden_size = 10 + self.hidden_conv = nn.Conv2D( + relation_dim, hidden_size, kernel_size=1, stride=1, padding=0) + self.before_merge = nn.Conv2D( + hidden_size, 1, kernel_size=1, stride=1, padding=0) + + self.algo = algo + + def forward(self, x): + + batch_size = x.shape[0] + # The cur_phase is one-hot vector and only contains 0/1. + cur_phase = x[:, self.obs_dim - self.act_dim:].astype('int') + # cur_phase_em shape:[batch, act_dim, embedding_size] + cur_phase_em = self.current_phase_embedding(cur_phase) + + # Constant and relation_embedding's shape:[batchsize, constant.shape[0], constant.shape[1], 4] + constant = paddle.tile(self.constant, (batch_size, 1, 1)) + relation_embedding = self.relation_embedding(constant) + # From NHWC to NCHW + relation_embedding = paddle.transpose( + relation_embedding, perm=[0, 3, 1, 2]) + relation_conv = self.relation_conv(relation_embedding) + + # The x_lane_phases contain lane vehicle nums of each phase, + # there may be two or more lanes can pass because the phase set the lanes to green, + # and the obs may sightly different to the origin paper, but it may be not affect the fianl result in our experiment. + x_lane_phases = paddle.reshape( + x[:, :self.obs_dim - self.act_dim], + [-1, self.act_dim, self.phase_lanes_dim]) + # x_lane_phases_feature shape: [batch_size, act_dim, embedding_size] + x_lane_phases_feature = nn.Sigmoid()(self.d_fc(x_lane_phases)) + list_phase_pressure = [] + for i in range(self.act_dim): + # concat the embedding features + p1_concat = paddle.concat( + (x_lane_phases_feature[:, i], cur_phase_em[:, i]), axis=-1) + add_feature = nn.Sigmoid()(self.lane_fc(p1_concat)) + list_phase_pressure.append(add_feature) + + list_phase_pressure_recomb = [] + for i in range(self.act_dim): + for j in range(self.act_dim): + if i != j: + list_phase_pressure_recomb.append( + paddle.concat( + (list_phase_pressure[i], list_phase_pressure[j]), + axis=-1)) + list_phase_pressure_recomb = paddle.stack(list_phase_pressure_recomb) + list_phase_pressure_recomb = paddle.transpose( + list_phase_pressure_recomb, perm=[1, 0, 2]) + # list_phase_pressure_recomb shape: [batch_size, self.act_dim*self.act_dim-1, 32] + list_phase_pressure_recomb = paddle.reshape( + list_phase_pressure_recomb, + (-1, self.act_dim, self.act_dim - 1, self.lane_dim * 2)) + list_phase_pressure_recomb = paddle.transpose( + list_phase_pressure_recomb, perm=[0, 3, 1, 2]) + lane_conv = self.lane_conv(list_phase_pressure_recomb) + + combine_feature = paddle.multiply(lane_conv, relation_conv) + + hidden_layer = self.hidden_conv(combine_feature) + before_merge = self.before_merge(hidden_layer) + before_merge = paddle.reshape(before_merge, + (-1, self.act_dim, self.act_dim - 1)) + q_values = paddle.sum(before_merge, axis=-1) + assert q_values.shape[-1] == self.act_dim + return q_values diff --git a/examples/Application/Traffic-Light-Control/model/presslight_model.py b/examples/Application/Traffic-Light-Control/model/presslight_model.py new file mode 100644 index 000000000..0b08dd1b4 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/model/presslight_model.py @@ -0,0 +1,69 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import parl + + +class PressLightModel(parl.Model): + def __init__(self, obs_dim, act_dim, algo='DQN'): + super(PressLightModel, self).__init__() + + hid1_size = 20 + hid2_size = 20 + self.obs_dim = obs_dim + self.act_dim = act_dim + embedding_size = 10 + self.current_phase_embedding = nn.Embedding(act_dim, embedding_size) + + self.algo = algo + if self.algo == 'Dueling': + self.fc1_adv = nn.Linear(obs_dim - 1, hid1_size) + self.fc1_val = nn.Linear(obs_dim - 1, hid1_size) + + self.fc2_adv = nn.Linear(hid1_size, hid2_size) + self.fc2_val = nn.Linear(hid1_size, hid2_size) + + self.fc3_adv = nn.Linear(hid2_size + embedding_size, self.act_dim) + self.fc3_val = nn.Linear(hid2_size + embedding_size, 1) + else: + self.fc1 = nn.Linear(obs_dim - 1, hid1_size) + self.fc2 = nn.Linear(hid1_size, hid2_size) + self.fc3 = nn.Linear(hid2_size + embedding_size, self.act_dim) + + def forward(self, x): + cur_phase = x[:, -1] + cur_phase = cur_phase.astype('int') + cur_phase_em = self.current_phase_embedding(cur_phase) + x = x[:, :-1] + if self.algo == 'Dueling': + fc1_a = nn.ReLU()(self.fc1_adv(x)) + fc1_v = nn.ReLU()(self.fc1_val(x)) + + fc2_a = nn.ReLU()(self.fc2_adv(fc1_a)) + fc2_v = nn.ReLU()(self.fc2_val(fc1_v)) + + fc2_a = paddle.concat((fc2_a, cur_phase_em), axis=-1) + fc2_v = paddle.concat((fc2_v, cur_phase_em), axis=-1) + As = self.fc3_adv(fc2_a) + V = self.fc3_val(fc2_v) + Q = As + (V - As.mean(axis=1, keepdim=True)) + else: + x1 = nn.ReLU()(self.fc1(x)) + x2 = nn.ReLU()(self.fc2(x1)) + x2 = paddle.concat((x2, cur_phase_em), axis=-1) + Q = self.fc3(x2) + return Q diff --git a/examples/Application/Traffic-Light-Control/obs_reward/max_pressure_agent.py b/examples/Application/Traffic-Light-Control/obs_reward/max_pressure_agent.py new file mode 100644 index 000000000..dd8c0d3af --- /dev/null +++ b/examples/Application/Traffic-Light-Control/obs_reward/max_pressure_agent.py @@ -0,0 +1,32 @@ +# Third party code +# +# The following code is mainly referenced, modified and copied from: +# https://github.com/zhc134/tlc-baselines and https://github.com/gjzheng93/tlc-baseline + +import numpy as np + + +class MaxPressureAgent(object): + """ + Agent using MaxPressure method to control traffic light + """ + + def __init__(self, world): + self.world = world + + def predict(self, lane_vehicle_count): + actions = [] + for I_id, I in enumerate(self.world.intersections): + action = I.current_phase + max_pressure = None + action = -1 + for phase_id in range(len(I.phases)): + pressure = sum([ + lane_vehicle_count[start] - lane_vehicle_count[end] + for start, end in I.phase_available_lanelinks[phase_id] + ]) + if max_pressure is None or pressure > max_pressure: + action = phase_id + max_pressure = pressure + actions.append(action) + return np.array(actions) diff --git a/examples/Application/Traffic-Light-Control/obs_reward/max_pressure_obs.py b/examples/Application/Traffic-Light-Control/obs_reward/max_pressure_obs.py new file mode 100644 index 000000000..4e2e8af88 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/obs_reward/max_pressure_obs.py @@ -0,0 +1,40 @@ +# Third party code +# +# The following code is mainly referenced, modified and copied from: +# https://github.com/zhc134/tlc-baselines and https://github.com/gjzheng93/tlc-baseline + +import numpy as np + + +class MaxPressureGenerator(object): + """ + Generate State based on statistics of lane vehicles. + Parameters + ---------- + world : World object + fns_obs/fns_reward : list of statistics to get, currently support "lane_count", "lane_waiting_count" , "lane_waiting_time_count", "lane_delay" and "pressure" + """ + + def __init__(self, world, fns_obs='lane_count', fns_reward=None): + + self.world = world + self.fns_obs = fns_obs + # subscribe functions for obs and reward + self.world.subscribe(self.fns_obs) + self.fns_reward = fns_reward + if self.fns_reward: + self.world.subscribe(fns_reward) + + def generate_obs(self): + """ + return: numpy array of all the intersections obs + """ + lane_waiting_count = self.world.get_info(self.fns_obs) + + return lane_waiting_count + + def generate_reward(self): + """ + getting the reward of each intersections + """ + return None diff --git a/examples/Application/Traffic-Light-Control/obs_reward/presslight_FRAP_obs_reward.py b/examples/Application/Traffic-Light-Control/obs_reward/presslight_FRAP_obs_reward.py new file mode 100644 index 000000000..721b453c5 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/obs_reward/presslight_FRAP_obs_reward.py @@ -0,0 +1,145 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The following code are referenced and modified from: +https://github.com/gjzheng93/frap-pub and https://github.com/zhc134/tlc-baselines +""" +import numpy as np + + +class PressureLightFRAPGenerator(object): + """PressureLightFRAPGenerator + + Args: + world (object): World used by this Generator. + fns_obs: functions to get the obs. + fns_reward: functions to get the rewards. + """ + + def __init__(self, world, fns_obs, fns_reward): + + self.world = world + self.fns_obs = fns_obs + # Get all the intersections, because that each intersection is one agent. + self.Is = self.world.intersections + # Get lanes of intersections, with the order of the list is same to the self.Is. + self.all_intersections_lanes = [] + # May be the dim of each intersection can be different? Assert the all the agents have the same dims here. + self.obs_dims = [] + for I in self.Is: + # each intersection's lane_ids is saved in the lanes, and the infos needed such as the lane vehicle num of obs can be got from the lane_ids here. + lanes = [] + roads = I.roads + # get the lane_ids from the road_ids + for road in roads: + from_zero = (road["startIntersection"] == I.id + ) if self.world.RIGHT else ( + road["endIntersection"] == I.id) + lanes.append([ + road["id"] + "_" + str(i) + for i in range(len(road["lanes"]))[::( + 1 if from_zero else -1)] + ]) + # all the lanes of the all the intersections are saved in the self.all_intersections_lanes + self.all_intersections_lanes.append(lanes) + + # calculate result dim of obs of each agents + available_lanelinks = I.phase_available_lanelinks[0] + # phase_available_lanelinks of each phase contains start_end_lanelink_pair, + # here we use the vehicle nums of the start_end_lanelink_pair as the feature, + # so the dim of the obs is :len(I.phases)*len(available_lanelinks)*2 also plus the len(I.phases). + # which may be slight different to the paper, but many other papers using the different feature and also get the better results, + # and you can modify the feature_dim here and below. + self.obs_dims.append( + len(I.phases) * len(available_lanelinks) * 2 + len(I.phases)) + + # subscribe functions for obs and reward + self.world.subscribe(self.fns_obs) + self.world.subscribe(fns_reward) + self.fns_reward = fns_reward + + def generate_relation(self): + """ + getting the confilt relation matrix, which can only use when the act_dim is 8 or 4. + """ + relations_all = [] + for I in self.Is: + relations = [] + num_phase = len(I.phases) + if num_phase == 8: + for p1 in I.phase_available_roadlinks: + zeros = [0, 0, 0, 0, 0, 0, 0] + count = 0 + for p2 in I.phase_available_roadlinks: + if p1 == p2: + # That means that the two phase have one same direction. + continue + if len(list(set(p1 + p2))) == 3: + zeros[count] = 1 + count += 1 + relations.append(zeros) + relations = np.array(relations).reshape((8, 7)) + elif num_phase == 4: + relations = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0]]).reshape((4, 3)) + else: + assert 0 + relations_all.append(relations) + return np.array(relations_all) + + def generate_phase_pairs(self): + """ + pairs road set to green by the phase, each phase may set 2 roads to green light. + """ + phase_available_roadlinks_all = [] + for I in self.Is: + phase_available_roadlinks_all.append(I.phase_available_roadlinks) + return np.array(phase_available_roadlinks_all) + + def generate_obs(self): + """ + return: numpy array of all the intersections obs + """ + # get all the infos for calc the obs of each intersections + results = [self.world.get_info(fn) for fn in self.fns_obs] + result = results[0] + cur_phases = [I.current_phase for I in self.Is] + ret_all = [] + # only get the vehilce nums, [I_num, phase_num * dim] + all_ret = [] + for I_id, I in enumerate(self.Is): + phase_lane_vehicle_num = [] + phase_onehot = [0 for _ in range(len(I.phases))] + for phase_id in range(len(I.phases)): + available_lanelinks = I.phase_available_lanelinks[phase_id] + for start_end_lanelink_pair in available_lanelinks: + for lane_id in start_end_lanelink_pair: + # append the lane vehicle num for each lane_id in the available_lanelinks, both start and end road. + phase_lane_vehicle_num.append(result[lane_id]) + phase_onehot[cur_phases[I_id]] = 1 + phase_lane_vehicle_num.extend(phase_onehot) + # Note that the len(phase_lane_vehicle_num) that should be equal to the self.obs_dims[I_id] above. + all_ret.append(phase_lane_vehicle_num) + all_ret = np.array(all_ret) + return all_ret + + def generate_reward(self): + """ + getting the reward of each intersections. + """ + pressures = self.world.get_info(self.fns_reward[0]) + rewards = [] + for I in self.world.intersections: + rewards.append(-pressures[I.id]) + return rewards diff --git a/examples/Application/Traffic-Light-Control/obs_reward/presslight_obs_reward.py b/examples/Application/Traffic-Light-Control/obs_reward/presslight_obs_reward.py new file mode 100644 index 000000000..0b393754b --- /dev/null +++ b/examples/Application/Traffic-Light-Control/obs_reward/presslight_obs_reward.py @@ -0,0 +1,105 @@ +# Third party code +# +# The following code are referenced, copied or modified from: +# https://github.com/zhc134/tlc-baselines and https://github.com/gjzheng93/tlc-baseline + +import numpy as np + + +class PressureLightGenerator(object): + """PressureLightGenerator + + Args: + world (object): World used by this Generator. + fns_obs: Functions to get the obs. + fns_reward: Functions to get the rewards. + in_only: Only the incoming roads or the all road. + average: Whether the average nums or the num each load. + """ + + def __init__(self, world, fns_obs, fns_reward, in_only=False, + average=None): + + self.world = world + self.fns_obs = fns_obs + # get all the intersections + self.Is = self.world.intersections + # get lanes of intersections, with the order of the list of self.Is + self.all_intersections_lanes = [] + self.obs_dims = [] + for I in self.Is: + # each intersection's lane_ids is saved in the lanes, and the infos needed for obs can be got from the lane_ids here. + lanes = [] + # road_ids + if in_only: + roads = I.in_roads + else: + roads = I.roads + # get the lane_ids from the road_ids + for road in roads: + from_zero = (road["startIntersection"] == I.id + ) if self.world.RIGHT else ( + road["endIntersection"] == I.id) + lanes.append([ + road["id"] + "_" + str(i) + for i in range(len(road["lanes"]))[::( + 1 if from_zero else -1)] + ]) + # all the lanes of the all the intersections are saved in the self.all_intersections_lanes + self.all_intersections_lanes.append(lanes) + # calculate result dim of obs of each agents + size = sum(len(x) for x in lanes) + if average == "road": + size = len(roads) + elif average == "all": + size = 1 + # In the pressure light len(self.fns_obs) is 1, and the curphase. + self.obs_dims.append(len(self.fns_obs) * size + 1) + # subscribe functions for obs and reward + self.world.subscribe(self.fns_obs) + + self.world.subscribe(fns_reward) + self.fns_reward = fns_reward + self.average = average + + def generate_obs(self): + """ + return: numpy array of all the intersections obs + assert that each lane's dim is same. + """ + # get all the infos for calc the obs of each intersections + results = [self.world.get_info(fn) for fn in self.fns_obs] + + cur_phases = [I.current_phase for I in self.Is] + ret_all = [] + for I_id, lanes in enumerate(self.all_intersections_lanes): + ret = np.array([]) + for i in range(len(self.fns_obs)): + result = results[i] + fn_result = np.array([]) + for road_lanes in lanes: + road_result = [] + for lane_id in road_lanes: + road_result.append(result[lane_id]) + if self.average == "road" or self.average == "all": + road_result = np.mean(road_result) + else: + road_result = np.array(road_result) + fn_result = np.append(fn_result, road_result) + if self.average == "all": + fn_result = np.mean(fn_result) + ret = np.append(ret, fn_result) + # append cur_phase in the last. + ret = np.append(ret, cur_phases[I_id]) + ret_all.append(ret) + return np.array(ret_all) + + def generate_reward(self): + """ + getting the reward of each intersections, using the pressure. + """ + pressures = self.world.get_info(self.fns_reward[0]) + rewards = [] + for I in self.world.intersections: + rewards.append(-pressures[I.id]) + return rewards diff --git a/examples/Application/Traffic-Light-Control/obs_reward/sotl_obs.py b/examples/Application/Traffic-Light-Control/obs_reward/sotl_obs.py new file mode 100644 index 000000000..044b518c2 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/obs_reward/sotl_obs.py @@ -0,0 +1,40 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/zhc134/tlc-baselines and https://github.com/gjzheng93/tlc-baseline + +import numpy as np + + +class SotlGenerator(object): + """ + Generate State or Reward based on statistics of lane vehicles. + Parameters + ---------- + world : World object + fns_obs : list of statistics to get, currently support "lane_count", "lane_waiting_count" , "lane_waiting_time_count", "lane_delay" and "pressure" + fns_reward: default None, for sotl, it don't need the rewards. + """ + + def __init__(self, world, fns_obs='lane_waiting_count', fns_reward=None): + + self.world = world + self.fns_obs = fns_obs + # subscribe functions for obs and reward + self.world.subscribe(self.fns_obs) + self.fns_reward = fns_reward + if self.fns_reward: + self.world.subscribe(fns_reward) + + def generate_obs(self): + """ + return numpy array of all the intersections obs which the sotl agent can infer the actions from. + """ + lane_waiting_count = self.world.get_info(self.fns_obs) + return lane_waiting_count + + def generate_reward(self): + """ + getting the reward of each intersections, defalut None for sotl. + """ + return None diff --git a/examples/Application/Traffic-Light-Control/scenarios/config_hz_1.json b/examples/Application/Traffic-Light-Control/scenarios/config_hz_1.json new file mode 100644 index 000000000..09685b721 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/scenarios/config_hz_1.json @@ -0,0 +1,12 @@ +{ + "interval": 1.0, + "seed": 0, + "dir": "./data/hangzhou_1x1_tms-xy_18041608_1h/", + "roadnetFile": "roadnet.json", + "flowFile": "flow.json", + "rlTrafficLight": true, + "laneChange": false, + "saveReplay":false, + "roadnetLogFile": "replay_roadnet.json", + "replayLogFile": "replay.txt" +} diff --git a/examples/Application/Traffic-Light-Control/test.py b/examples/Application/Traffic-Light-Control/test.py new file mode 100644 index 000000000..e2e9313d4 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/test.py @@ -0,0 +1,361 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import paddle +import parl +from parl.utils import logger + +import numpy as np +from tqdm import tqdm + +from world import World +from environment import CityFlowEnv +from config import config + +from obs_reward.presslight_obs_reward import PressureLightGenerator +from model.presslight_model import PressLightModel + +from obs_reward.presslight_FRAP_obs_reward import PressureLightFRAPGenerator +from model.FRAP_model import PressLightFRAPModel + +from obs_reward.sotl_obs import SotlGenerator +from agent.sotl_agent import SOTLAgent + +from obs_reward.max_pressure_obs import MaxPressureGenerator +from agent.max_pressure_agent import MaxPressureAgent + + +def test_presslight(epsilon_num=1, episode_tag=300, is_replay=False): + """ + test the env. + """ + ######################################## + # creating the world and the env. + ######################################## + logger.info('building the env...') + world = World( + config['config_path_name'], + thread_num=config['thread_num'], + yellow_phase_time=config['yellow_phase_time']) + PLGenerator = PressureLightGenerator(world, config['obs_fns'], + config['reward_fns'], + config['is_only'], config['average']) + obs_dims = PLGenerator.obs_dims + env = CityFlowEnv(world, PLGenerator) + obs = env.reset() + act_dims = env.action_dims + n_agents = env.n_agents + if is_replay: + env.world.eng.set_save_replay(is_replay) + env.world.eng.set_replay_file('replay_presslight.txt') + ######################################## + # creating the agents and + # each agent has it own model. + ######################################## + logger.info( + 'building {} agents, each agent has its model and algorithm...'.format( + n_agents)) + models = [ + PressLightModel(obs_dims[i], act_dims[i], config['algo']) + for i in range(n_agents) + ] + logger.info('successfully creating {} agents...'.format(n_agents)) + ######################################## + # loading the model from the ckpt model. + ######################################## + for model_id, model in enumerate(models): + model_path = os.path.join( + config['save_dir'], 'agentid{}_episode_count{}.ckpt'.format( + model_id, episode_tag)) + logger.info('agent: {}/{} loading model from {}...'.format( + model_id + 1, n_agents, model_path)) + checkpoint = paddle.load(model_path) + model.set_state_dict(checkpoint) + ######################################## + # testing the model with env. + ######################################## + total_avg_travel_time = [] + episode_count = 0 + + with tqdm(total=epsilon_num, desc='[Testing Model]') as pbar: + while episode_count < epsilon_num: + step_count = 0 + while step_count < config['metric_period']: + actions = [] + for agent_id, ob in enumerate(obs): + ob = ob.reshape(1, -1) + ob = paddle.to_tensor(ob, dtype='float32') + action = models[agent_id](ob) + action = action.numpy() + action = np.argmax(action) + actions.append(action) + actions = np.array(actions) + for _ in range(config['action_interval']): + step_count += 1 + next_obs, rewards, dones, _ = env.step(actions) + obs = next_obs + if step_count % 200 == 0: + logger.info('esipode:{}, step_count:{}'.format( + episode_count, step_count)) + episode_count += 1 + avg_travel_time = env.world.eng.get_average_travel_time() + + logger.info('esipode:{}, avg_time:{}'.format( + episode_count, avg_travel_time)) + total_avg_travel_time.append(avg_travel_time) + obs = env.reset(seed=True) + pbar.update(1) + return total_avg_travel_time + + +def test_sotl(epsilon_num=1, is_replay=False): + """ + test the env. + """ + ######################################## + # creating the world and the env. + ######################################## + logger.info('building the env...') + world = World( + config['config_path_name'], + thread_num=config['thread_num'], + yellow_phase_time=config['yellow_phase_time']) + SLGenerator = SotlGenerator(world) + env = CityFlowEnv(world, SLGenerator) + obs = env.reset() + act_dims = env.action_dims + n_agents = env.n_agents + if is_replay: + env.world.eng.set_save_replay() + env.world.eng.set_replay_file('replay_sotl.txt') + ######################################## + # creating the agents. + ######################################## + agent = SOTLAgent(world) + + ######################################## + # testing the agent with env. + ######################################## + total_avg_travel_time = [] + episode_count = 0 + with tqdm(total=epsilon_num, desc='[Testing Model]') as pbar: + while episode_count < epsilon_num: + step_count = 0 + while step_count < config['metric_period']: + actions = agent.predict(obs) + step_count += 1 + next_obs, rewards, dones, _ = env.step(actions) + obs = next_obs + if step_count % 200 == 0: + logger.info('esipode:{}, step_count:{}'.format( + episode_count, step_count)) + episode_count += 1 + avg_travel_time = env.world.eng.get_average_travel_time() + logger.info('esipode:{}, avg_time:{}'.format( + episode_count, avg_travel_time)) + total_avg_travel_time.append(avg_travel_time) + obs = env.reset(seed=True) + pbar.update(1) + + +def test_max_pressure(epsilon_num=1, is_replay=False): + """ + test the env. + """ + ######################################## + # creating the world and the env. + ######################################## + logger.info('building the env...') + logger.info('loading config from {}'.format(config['config_path_name'])) + world = World( + config['config_path_name'], + thread_num=config['thread_num'], + yellow_phase_time=config['yellow_phase_time']) + MPGenerator = MaxPressureGenerator(world) + env = CityFlowEnv(world, MPGenerator) + obs = env.reset() + act_dims = env.action_dims + n_agents = env.n_agents + if is_replay: + env.world.eng.set_save_replay(True) + env.world.eng.set_replay_file('replay_maxpressure.txt') + ######################################## + # creating the agents. + ######################################## + agent = MaxPressureAgent(world) + ######################################## + # testing the agent with env. + ######################################## + total_avg_travel_time = [] + episode_count = 0 + with tqdm(total=epsilon_num, desc='[Testing Model]') as pbar: + while episode_count < epsilon_num: + step_count = 0 + while step_count < config['metric_period']: + actions = agent.predict(obs) + yellow_time = config['yellow_phase_time'] + yellow_time = 0 + for _ in range(config['action_interval'] + yellow_time): + step_count += 1 + next_obs, rewards, dones, _ = env.step(actions) + obs = next_obs + if step_count % 200 == 0: + logger.info('esipode:{}, step_count:{}'.format( + episode_count, step_count)) + episode_count += 1 + avg_travel_time = env.world.eng.get_average_travel_time() + logger.info('esipode:{}, avg_time:{}'.format( + episode_count, avg_travel_time)) + total_avg_travel_time.append(avg_travel_time) + obs = env.reset(seed=True) + pbar.update(1) + return total_avg_travel_time + + +def test_FRAP_light(epsilon_num=1, episode_tag=300, is_replay=False): + ''' + test the env. + ''' + ######################################## + # creating the world and the env. + ######################################## + logger.info('building the env...') + world = World( + config['config_path_name'], + thread_num=config['thread_num'], + yellow_phase_time=config['yellow_phase_time']) + + PLGenerator = PressureLightFRAPGenerator(world, config['obs_fns'], + config['reward_fns']) + relation_constants = PLGenerator.generate_relation() + + obs_dims = PLGenerator.obs_dims + env = CityFlowEnv(world, PLGenerator) + obs = env.reset() + act_dims = env.action_dims + n_agents = env.n_agents + if is_replay: + env.world.eng.set_save_replay(is_replay) + env.world.eng.set_replay_file('replay_presslight.txt') + ######################################## + # creating the agents and + # each agent has it own model. + ######################################## + logger.info( + 'building {} agents, each agent has its model and algorithm...'.format( + n_agents)) + relation_constant = paddle.to_tensor( + relation_constants[0], dtype='float32') + relation_constant = relation_constant.astype('int') + models = [ + PressLightFRAPModel( + obs_dims[i], act_dims[i], constant=relation_constant) + for i in range(n_agents) + ] + logger.info('successfully creating {} agents...'.format(n_agents)) + ######################################## + # loading the model + ######################################## + for model_id, model in enumerate(models): + model_path = os.path.join( + config['save_dir'], 'agentid{}_episode_count{}.ckpt'.format( + model_id, episode_tag)) + logger.info('agent: {}/{} loading model from {}...'.format( + model_id + 1, n_agents, model_path)) + checkpoint = paddle.load(model_path) + model.set_state_dict(checkpoint) + ######################################## + # testing the model with env. + ######################################## + total_avg_travel_time = [] + episode_count = 0 + + with tqdm(total=epsilon_num, desc='[Testing Model]') as pbar: + while episode_count < epsilon_num: + step_count = 0 + while step_count < config['metric_period']: + actions = [] + for agent_id, ob in enumerate(obs): + ob = ob.reshape(1, -1) + ob = paddle.to_tensor(ob, dtype='float32') + action = models[agent_id](ob) + action = action.detach().cpu().numpy() + action = np.argmax(action) + actions.append(action) + actions = np.array(actions) + for _ in range(config['action_interval']): + step_count += 1 + next_obs, rewards, dones, _ = env.step(actions) + obs = next_obs + if step_count % 200 == 0: + logger.info('esipode:{}, step_count:{}'.format( + episode_count, step_count)) + episode_count += 1 + avg_travel_time = env.world.eng.get_average_travel_time() + logger.info('esipode:{}, avg_time:{}'.format( + episode_count, avg_travel_time)) + total_avg_travel_time.append(avg_travel_time) + obs = env.reset(seed=True) + pbar.update(1) + return total_avg_travel_time + + +if __name__ == '__main__': + import argparse + import os + parser = argparse.ArgumentParser() + parser.add_argument( + '--config_path_name', + default='./scenarios/config_hz_1.json', + type=str, + help='config path') + parser.add_argument( + '--is_test_frap', + default=True, + type=bool, + help='test the frap algorithm or others') + parser.add_argument( + '--result_name', default='4_4', type=str, help='result path') + parser.add_argument( + '--save_dir', default='./save_model', type=str, help='config path') + parser.add_argument( + '--episode_tag', default=300, type=int, help='episode_tag') + + args = parser.parse_args() + + config['config_path_name'] = args.config_path_name + config['save_dir'] = args.save_dir + if args.is_test_frap: + results = test_FRAP_light( + is_replay=False, episode_tag=args.episode_tag) + path_name = 'result_frap' + else: + results = test_presslight( + is_replay=False, episode_tag=args.episode_tag) + path_name = 'result_max_pressure' + # result_sotl = test_sotl() + result_max_pressure = test_max_pressure(is_replay=False) + + result_path = path_name + '/{}'.format(args.result_name) + os.makedirs(result_path, exist_ok=True) + with open(os.path.join(result_path, 'avgtime.txt'), 'w') as f: + if args.is_test_frap: + f.writelines('result_FRAP: ') + else: + f.writelines('result_presslight: ') + f.writelines(str(results[0])) + f.writelines('\n') + f.writelines('result_max_pressure: ') + f.writelines(str(result_max_pressure[0])) diff --git a/examples/Application/Traffic-Light-Control/test.sh b/examples/Application/Traffic-Light-Control/test.sh new file mode 100755 index 000000000..315177b90 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/test.sh @@ -0,0 +1,2 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0 python test.py --config_path_name './scenarios/config_hz_1.json' --result_name 'hz_1' --is_test_frap False --save_dir 'save_model/presslight' diff --git a/examples/Application/Traffic-Light-Control/train_FRAP.py b/examples/Application/Traffic-Light-Control/train_FRAP.py new file mode 100644 index 000000000..6017aa289 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/train_FRAP.py @@ -0,0 +1,198 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import paddle +import parl +from parl.utils import logger, ReplayMemory + +import numpy as np +from tqdm import tqdm +from tensorboardX import SummaryWriter +from world import World +from obs_reward.presslight_FRAP_obs_reward import PressureLightFRAPGenerator +from config import config +from environment import CityFlowEnv +from model.FRAP_model import PressLightFRAPModel +from ddqn import DDQN +from agent.agent import Agent + + +def log_metrics(summary, datas, buffer_total_size, is_show=False): + """ + Log metrics + """ + Q_loss, pred_values, target_values, max_v_show_values, train_count, lr, epsilon = datas + metric = { + 'q_loss': Q_loss, + 'pred_values': pred_values, + 'target_values': target_values, + 'max_v_show_values': max_v_show_values, + 'lr': lr, + 'epsilon': epsilon, + 'memory_size': buffer_total_size, + 'train_count': train_count + } + if is_show: + logger.info(metric) + for key in metric: + if key != 'train_count': + summary.add_scalar(key, metric[key], train_count) + + +def main(): + """ + each intersection has each own model. + """ + logger.info('building the env...') + world = World( + config['config_path_name'], + thread_num=config['thread_num'], + yellow_phase_time=config['yellow_phase_time']) + PLGenerator = PressureLightFRAPGenerator(world, config['obs_fns'], + config['reward_fns']) + relation_constants = PLGenerator.generate_relation() + + obs_dims = PLGenerator.obs_dims + env = CityFlowEnv(world, PLGenerator) + obs = env.reset() + episode_count = 0 + step_forward = 0 + #################### + act_dims = env.action_dims + n_agents = env.n_agents + logger.info( + 'creating {} replay_buffers for {} agents, each agent has one replay buffer.' + .format(n_agents, n_agents)) + replay_buffers = [ + ReplayMemory(config['memory_size'], obs_dims[i], 0) + for i in range(n_agents) + ] + #################### + logger.info( + 'building {} agents, each agent has its model and algorithm...'.format( + n_agents)) + relation_constant = paddle.to_tensor( + relation_constants[0], dtype='float32') + relation_constant = relation_constant.astype('int') + models = [ + PressLightFRAPModel( + obs_dims[i], act_dims[i], constant=relation_constant) + for i in range(n_agents) + ] + algorithms = [DDQN(model, config) for model in models] + agents = [Agent(algorithm, config) for algorithm in algorithms] + logger.info('successfully creating {} agents...'.format(n_agents)) + #################### + # tensorboard list + summarys = [ + SummaryWriter(os.path.join(config['train_log_dir'], str(agent_id))) + for agent_id in range(n_agents) + ] + ################### + episodes_rewards = np.zeros(n_agents) + ################### + with tqdm(total=config['episodes'], desc='[Training Model]') as pbar: + while episode_count <= config['episodes']: + step_count = 0 + while step_count < config['metric_period']: + actions = [] + for agent_id, ob in enumerate(obs): + ob = ob.reshape(1, -1) + action = agents[agent_id].sample(ob) + actions.append(action[0]) + actions = np.array(actions) + rewards_list = [] + for _ in range(config['action_interval']): + step_count += 1 + next_obs, rewards, dones, _ = env.step(actions) + rewards_list.append(rewards) + rewards = np.mean( + rewards_list, axis=0) / config['reward_normal_factor'] + # calc the episodes_rewards and will add it to the tensorboard + assert len(episodes_rewards) == len(rewards) + episodes_rewards += rewards + + for agent_id, replay_buffer in enumerate(replay_buffers): + replay_buffers[agent_id].append( + obs[agent_id], actions[agent_id], rewards[agent_id], + next_obs[agent_id], dones[agent_id]) + step_forward += 1 + obs = next_obs + if len(replay_buffers[0]) >= config[ + 'begin_train_mmeory_size'] and step_forward % config[ + 'learn_freq'] == 0: + for agent_id, agent in enumerate(agents): + sample_data = replay_buffers[agent_id].sample_batch( + config['sample_batch_size']) + train_obs, train_actions, train_rewards, train_next_obs, train_terminals = sample_data + + Q_loss, pred_values, target_values, max_v_show_values, train_count, lr, epsilon = \ + agent.learn(train_obs, train_actions, train_terminals, train_rewards, train_next_obs) + datas = [ + Q_loss, pred_values, target_values, + max_v_show_values, train_count, lr, epsilon + ] + # tensorboard + if train_count % config['train_count_log'] == 0: + log_metrics(summarys[agent_id], datas, + step_forward) + if step_count % config['step_count_log'] == 0 and config[ + 'is_show_log']: + logger.info('episode_count: {}, step_count: {}, buffer_size: {}, buffer_size_total_size: {}.'\ + .format(episode_count, step_count, len(replay_buffers[0]), step_forward)) + + episode_count += 1 + avg_travel_time = env.world.eng.get_average_travel_time() + obs = env.reset() + for agent_id, summary in enumerate(summarys): + summary.add_scalar('episodes_reward', + episodes_rewards[agent_id], episode_count) + # the avg travel time is same for all agents. + summary.add_scalar('average_travel_time', avg_travel_time, + episode_count) + logger.info('episode_count: {}, average_travel_time: {}.'.format( + episode_count, avg_travel_time)) + # reset to zeros + episodes_rewards = np.zeros(n_agents) + # save the model + if episode_count % config['save_rate'] == 0: + for agent_id, agent in enumerate(agents): + save_path = "{}/agentid{}_episode_count{}.ckpt".format( + config['save_dir'], agent_id, episode_count) + agent.save(save_path) + pbar.update(1) + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--config_path_name", + default="./scenarios/config_hz_1.json", + type=str, + help='config path') + + parser.add_argument( + "--save_dir", default="./save_model", type=str, help='config path') + + parser.add_argument( + "--is_share_model", default=False, type=bool, help='share_model') + args = parser.parse_args() + + config['config_path_name'] = args.config_path_name + config['save_dir'] = args.save_dir + + main() diff --git a/examples/Application/Traffic-Light-Control/train_FRAP.sh b/examples/Application/Traffic-Light-Control/train_FRAP.sh new file mode 100755 index 000000000..af394f15d --- /dev/null +++ b/examples/Application/Traffic-Light-Control/train_FRAP.sh @@ -0,0 +1,2 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0 python train_FRAP.py --config_path_name './scenarios/config_hz_1.json' --save_dir 'save_model_frap/hz_1' diff --git a/examples/Application/Traffic-Light-Control/train_presslight.py b/examples/Application/Traffic-Light-Control/train_presslight.py new file mode 100644 index 000000000..06a9531f2 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/train_presslight.py @@ -0,0 +1,286 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import paddle +import parl +from parl.utils import logger, ReplayMemory + +import numpy as np +from tqdm import tqdm +from tensorboardX import SummaryWriter +from world import World +from obs_reward.presslight_obs_reward import PressureLightGenerator +from config import config +from environment import CityFlowEnv +from model.presslight_model import PressLightModel +from ddqn import DDQN +from agent.agent import Agent + + +def log_metrics(summary, datas, buffer_total_size, is_show=False): + """ Log metrics + """ + Q_loss, pred_values, target_values, max_v_show_values, train_count, lr, epsilon = datas + metric = { + 'q_loss': Q_loss, + 'pred_values': pred_values, + 'target_values': target_values, + 'max_v_show_values': max_v_show_values, + 'lr': lr, + 'epsilon': epsilon, + 'memory_size': buffer_total_size, + 'train_count': train_count + } + if is_show: + logger.info(metric) + for key in metric: + if key != 'train_count': + summary.add_scalar(key, metric[key], train_count) + + +def main(): + """ + each intersection has each own model. + """ + logger.info('building the env...') + world = World( + config['config_path_name'], + thread_num=config['thread_num'], + yellow_phase_time=config['yellow_phase_time']) + PLGenerator = PressureLightGenerator(world, config['obs_fns'], + config['reward_fns'], + config['is_only'], config['average']) + obs_dims = PLGenerator.obs_dims + env = CityFlowEnv(world, PLGenerator) + obs = env.reset() + episode_count = 0 + step_forward = 0 + #################### + act_dims = env.action_dims + n_agents = env.n_agents + logger.info( + 'creating {} replay_buffers for {} agents, each agent has one replay buffer.' + .format(n_agents, n_agents)) + replay_buffers = [ + ReplayMemory(config['memory_size'], obs_dims[i], 0) + for i in range(n_agents) + ] + #################### + logger.info( + 'building {} agents, each agent has its model and algorithm...'.format( + n_agents)) + models = [ + PressLightModel(obs_dims[i], act_dims[i], config['algo']) + for i in range(n_agents) + ] + algorithms = [DDQN(model, config) for model in models] + agents = [Agent(algorithm, config) for algorithm in algorithms] + logger.info('successfully creating {} agents...'.format(n_agents)) + #################### + # tensorboard list + summarys = [ + SummaryWriter(os.path.join(config['train_log_dir'], str(agent_id))) + for agent_id in range(n_agents) + ] + + ################### + episodes_rewards = np.zeros(n_agents) + ################### + with tqdm(total=config['episodes'], desc='[Training Model]') as pbar: + while episode_count <= config['episodes']: + step_count = 0 + while step_count < config['metric_period']: + actions = [] + for agent_id, ob in enumerate(obs): + ob = ob.reshape(1, -1) + action = agents[agent_id].sample(ob) + actions.append(action[0]) + actions = np.array(actions) + rewards_list = [] + for _ in range(config['action_interval']): + step_count += 1 + next_obs, rewards, dones, _ = env.step(actions) + rewards_list.append(rewards) + rewards = np.mean( + rewards_list, axis=0) / config['reward_normal_factor'] + # calc the episodes_rewards and will add it to the tensorboard + assert len(episodes_rewards) == len(rewards) + episodes_rewards += rewards + for agent_id, replay_buffer in enumerate(replay_buffers): + replay_buffers[agent_id].append( + obs[agent_id], actions[agent_id], rewards[agent_id], + next_obs[agent_id], dones[agent_id]) + step_forward += 1 + obs = next_obs + if len(replay_buffers[0]) >= config[ + 'begin_train_mmeory_size'] and step_forward % config[ + 'learn_freq'] == 0: + for agent_id, agent in enumerate(agents): + sample_data = replay_buffers[agent_id].sample_batch( + config['sample_batch_size']) + train_obs, train_actions, train_rewards, train_next_obs, train_terminals = sample_data + + Q_loss, pred_values, target_values, max_v_show_values, train_count, lr, epsilon = \ + agent.learn(train_obs, train_actions, train_terminals, train_rewards, train_next_obs) + datas = [ + Q_loss, pred_values, target_values, + max_v_show_values, train_count, lr, epsilon + ] + # tensorboard + if train_count % config['train_count_log'] == 0: + log_metrics(summarys[agent_id], datas, + step_forward) + if step_count % config['step_count_log'] == 0 and config[ + 'is_show_log']: + logger.info('episode_count: {}, step_count: {}, buffer_size: {}, buffer_size_total_size: {}.'\ + .format(episode_count, step_count, len(replay_buffers[0]), step_forward)) + + episode_count += 1 + avg_travel_time = env.world.eng.get_average_travel_time() + obs = env.reset() + for agent_id, summary in enumerate(summarys): + summary.add_scalar('episodes_reward', + episodes_rewards[agent_id], episode_count) + # the avg travel time is same for all agents. + summary.add_scalar('average_travel_time', avg_travel_time, + episode_count) + logger.info('episode_count: {}, average_travel_time: {}.'.format( + episode_count, avg_travel_time)) + # reset to zeros + episodes_rewards = np.zeros(n_agents) + # save the model + if episode_count % config['save_rate'] == 0: + for agent_id, agent in enumerate(agents): + save_path = "{}/agentid{}_episode_count{}.ckpt".format( + config['save_dir'], agent_id, episode_count) + agent.save(save_path) + pbar.update(1) + + +def main_all(): + """ + all intersections share one model. + """ + logger.info('building the env...') + world = World( + config['config_path_name'], + thread_num=config['thread_num'], + yellow_phase_time=config['yellow_phase_time']) + PLGenerator = PressureLightGenerator(world, config['obs_fns'], + config['reward_fns'], + config['is_only'], config['average']) + obs_dims = PLGenerator.obs_dims + env = CityFlowEnv(world, PLGenerator) + obs = env.reset() + episode_count = 0 + step_forward = 0 + #################### + act_dims = env.action_dims + n_agents = env.n_agents + replay_buffer = ReplayMemory(config['memory_size'] * n_agents, obs_dims[0], + 0) + ################### + model = PressLightModel(obs_dims[0], act_dims[0], config['algo']) + algorithm = DDQN(model, config) + agent = Agent(algorithm, config) + logger.info('successfully creating the agent...') + ################### + # tensorboard list + ################### + summary = SummaryWriter(os.path.join(config['train_log_dir'], 'same')) + ################### + # train the model + ################### + episodes_rewards = np.zeros(n_agents) + with tqdm(total=config['episodes'], desc='[Training Model]') as pbar: + while episode_count <= config['episodes']: + step_count = 0 + while step_count < config['metric_period']: + actions = agent.sample(obs) + rewards_list = [] + for _ in range(config['action_interval']): + step_count += 1 + next_obs, rewards, dones, _ = env.step(actions) + rewards_list.append(rewards) + rewards = np.mean( + rewards_list, axis=0) / config['reward_normal_factor'] + # calc the episodes_rewards and will add it to the tensorboard + assert len(episodes_rewards) == len(rewards) + episodes_rewards += rewards + for agent_id in range(n_agents): + replay_buffer.append(obs[agent_id], actions[agent_id], + rewards[agent_id], next_obs[agent_id], + dones[agent_id]) + step_forward += 1 + obs = next_obs + if len(replay_buffer) >= config[ + 'begin_train_mmeory_size'] and step_forward % config[ + 'learn_freq'] == 0: + sample_data = replay_buffer.sample_batch( + config['sample_batch_size']) + train_obs, train_actions, train_rewards, train_next_obs, train_terminals = sample_data + Q_loss, pred_values, target_values, max_v_show_values, train_count, lr, epsilon = \ + agent.learn(train_obs, train_actions, train_terminals, train_rewards, train_next_obs) + datas = [ + Q_loss, pred_values, target_values, max_v_show_values, + train_count, lr, epsilon + ] + # tensorboard + if train_count % config['train_count_log'] == 0: + log_metrics(summary, datas, step_forward) + episode_count += 1 + avg_travel_time = env.world.eng.get_average_travel_time() + obs = env.reset() + # just calc the first agent's rewards for show. + summary.add_scalar('episodes_reward', episodes_rewards[0], + episode_count) + # the avg travel time is same for all agents. + summary.add_scalar('average_travel_time', avg_travel_time, + episode_count) + logger.info('episode_count: {}, average_travel_time: {}.'.format( + episode_count, avg_travel_time)) + # reset to zeros + episodes_rewards = np.zeros(n_agents) + # save the model + if episode_count % config['save_rate'] == 0: + save_path = "{}/agentid{}_episode_count{}.ckpt".format( + config['save_dir'], '_same', episode_count) + agent.save(save_path) + pbar.update(1) + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + '--config_path_name', + default='./scenarios/config_hz_1.json', + type=str, + help='config path') + + parser.add_argument( + '--save_dir', default='./save_model', type=str, help='config path') + + parser.add_argument( + '--is_share_model', default=False, type=bool, help='share_model') + args = parser.parse_args() + + config['config_path_name'] = args.config_path_name + config['save_dir'] = args.save_dir + if args.is_share_model: + main_all() + else: + main() diff --git a/examples/Application/Traffic-Light-Control/train_presslight.sh b/examples/Application/Traffic-Light-Control/train_presslight.sh new file mode 100755 index 000000000..b2d98bc93 --- /dev/null +++ b/examples/Application/Traffic-Light-Control/train_presslight.sh @@ -0,0 +1,2 @@ +#!/bin/bash +CUDA_VISIBLE_DEVICES=0 python train_presslight.py --config_path_name './scenarios/config_hz_1.json' --save_dir 'save_model/hz_1' --is_share_model False diff --git a/examples/Application/Traffic-Light-Control/world.py b/examples/Application/Traffic-Light-Control/world.py new file mode 100644 index 000000000..ef7c622eb --- /dev/null +++ b/examples/Application/Traffic-Light-Control/world.py @@ -0,0 +1,380 @@ +# Third party code +# +# The following code are copied or modified from +# https://github.com/zhc134/tlc-baselines and https://github.com/gjzheng93/tlc-baseline + +import json +import os.path as osp +import cityflow +import numpy as np +from math import atan2, pi +import sys + + +def _get_direction(road, out=True): + if out: + x = road["points"][1]["x"] - road["points"][0]["x"] + y = road["points"][1]["y"] - road["points"][0]["y"] + else: + x = road["points"][-2]["x"] - road["points"][-1]["x"] + y = road["points"][-2]["y"] - road["points"][-1]["y"] + tmp = atan2(x, y) + return tmp if tmp >= 0 else (tmp + 2 * pi) + + +class Intersection(object): + def __init__(self, intersection, world, yellow_phase_time=5): + + self.id = intersection["id"] + self.point = [intersection["point"]["x"], intersection["point"]["y"]] + # using the world eng + self.eng = world.eng + # incoming and outgoing roads of each intersection, clock-wise order from North + self.roads = [] + self.outs = [] + self.directions = [] + self.out_roads = None + self.in_roads = None + + # links and phase information of each intersection + self.roadlinks = [] + self.lanelinks_of_roadlink = [] + self.startlanes = [] + self.lanelinks = [] + self.phase_available_roadlinks = [] + self.phase_available_lanelinks = [] + self.phase_available_startlanes = [] + + # define yellow phases, currently the default yellow phase is 0, so make sure the first phase of the roadnet is yellow phase + self.yellow_phase_id = [0] + # the default time of the yellow signal time is 5 seconds, you can change it to the real case. + self.yellow_phase_time = yellow_phase_time + + # parsing links and phases + for roadlink in intersection["roadLinks"]: + self.roadlinks.append((roadlink["startRoad"], roadlink["endRoad"])) + lanelinks = [] + for lanelink in roadlink["laneLinks"]: + startlane = roadlink["startRoad"] + "_" + str( + lanelink["startLaneIndex"]) + self.startlanes.append(startlane) + endlane = roadlink["endRoad"] + "_" + str( + lanelink["endLaneIndex"]) + lanelinks.append((startlane, endlane)) + self.lanelinks.extend(lanelinks) + self.lanelinks_of_roadlink.append(lanelinks) + + self.startlanes = list(set(self.startlanes)) + + phases = intersection["trafficLight"]["lightphases"] + self.phases = [ + i for i in range(len(phases)) if not i in self.yellow_phase_id + ] + for i in self.phases: + phase = phases[i] + self.phase_available_roadlinks.append(phase["availableRoadLinks"]) + phase_available_lanelinks = [] + phase_available_startlanes = [] + for roadlink_id in phase["availableRoadLinks"]: + lanelinks_of_roadlink = self.lanelinks_of_roadlink[roadlink_id] + phase_available_lanelinks.extend(lanelinks_of_roadlink) + for lanelinks in lanelinks_of_roadlink: + phase_available_startlanes.append(lanelinks[0]) + self.phase_available_lanelinks.append(phase_available_lanelinks) + phase_available_startlanes = list(set(phase_available_startlanes)) + self.phase_available_startlanes.append(phase_available_startlanes) + + self.reset() + + def insert_road(self, road, out): + + self.roads.append(road) + self.outs.append(out) + self.directions.append(_get_direction(road, out)) + + def sort_roads(self, RIGHT): + + order = sorted( + range(len(self.roads)), + key= + lambda i: (self.directions[i], self.outs[i] if RIGHT else not self.outs[i]) + ) + self.roads = [self.roads[i] for i in order] + self.directions = [self.directions[i] for i in order] + self.outs = [self.outs[i] for i in order] + self.out_roads = [self.roads[i] for i, x in enumerate(self.outs) if x] + self.in_roads = [ + self.roads[i] for i, x in enumerate(self.outs) if not x + ] + + def _change_phase(self, phase, interval): + self.eng.set_tl_phase(self.id, phase) + self._current_phase = phase + self.current_phase_time = interval + + def step(self, action, interval): + # if current phase is yellow, then continue to finish the yellow phase + # recall self._current_phase means true phase id (including yellows) + # self.current_phase means phase id in self.phases (excluding yellow) + if self._current_phase in self.yellow_phase_id: + if self.current_phase_time >= self.yellow_phase_time: + self._change_phase(self.phases[self.action_before_yellow], + interval) + self.current_phase = self.action_before_yellow + else: + self.current_phase_time += interval + else: + if action == self.current_phase: + self.current_phase_time += interval + else: + if self.yellow_phase_time > 0: + self._change_phase(self.yellow_phase_id[0], interval) + self.action_before_yellow = action + else: + self._change_phase(action, interval) + self.current_phase = action + + def reset(self): + # record phase info + self.current_phase = 0 # phase id in self.phases (excluding yellow) + self._current_phase = self.phases[ + 0] # true phase id (including yellow) + self.eng.set_tl_phase(self.id, self._current_phase) + self.current_phase_time = 0 + self.action_before_yellow = None + + +class World(object): + """ + Create a CityFlow engine and maintain informations about CityFlow world + """ + + def __init__(self, cityflow_config, thread_num, yellow_phase_time=3): + # loading the config and building the world.. + self.eng = cityflow.Engine(cityflow_config, thread_num=thread_num) + with open(cityflow_config) as f: + cityflow_config = json.load(f) + self.roadnet = self._get_roadnet(cityflow_config) + + # vehicles moves on the right side, currently always set to true due to CityFlow's mechanism. + self.RIGHT = True + self.interval = cityflow_config["interval"] + # get all non virtual intersections + self.intersections = [ + i for i in self.roadnet["intersections"] if not i["virtual"] + ] + self.intersection_ids = [i["id"] for i in self.intersections] + # create non-virtual Intersections + print("creating intersections...") + non_virtual_intersections = [ + i for i in self.roadnet["intersections"] if not i["virtual"] + ] + self.intersections = [ + Intersection(i, self, yellow_phase_time) + for i in non_virtual_intersections + ] + self.intersection_ids = [i["id"] for i in non_virtual_intersections] + self.id2intersection = {i.id: i for i in self.intersections} + print("intersections created.") + # id of all roads and lanes + print("parsing roads...") + self.all_roads = [] + self.all_lanes = [] + + for road in self.roadnet["roads"]: + self.all_roads.append(road["id"]) + i = 0 + for _ in road["lanes"]: + self.all_lanes.append(road["id"] + "_" + str(i)) + i += 1 + + iid = road["startIntersection"] + if iid in self.intersection_ids: + self.id2intersection[iid].insert_road(road, True) + iid = road["endIntersection"] + if iid in self.intersection_ids: + self.id2intersection[iid].insert_road(road, False) + + for i in self.intersections: + i.sort_roads(self.RIGHT) + print("roads parsed.") + + # initializing info functions + self.info_functions = { + "vehicles": (lambda: self.eng.get_vehicles(include_waiting=True)), + "lane_count": self.eng.get_lane_vehicle_count, + "lane_waiting_count": self.eng.get_lane_waiting_vehicle_count, + "lane_vehicles": self.eng.get_lane_vehicles, + "time": self.eng.get_current_time, + "vehicle_distance": self.eng.get_vehicle_distance, + "pressure": self.get_pressure, + "lane_waiting_time_count": self.get_lane_waiting_time_count, + "lane_delay": self.get_lane_delay, + "vehicle_trajectory": self.get_vehicle_trajectory, + "history_vehicles": self.get_history_vehicles + } + self.fns = [] + self.info = {} + + self.vehicle_waiting_time = { + } # key: vehicle_id, value: the waiting time of this vehicle since last halt. + self.vehicle_trajectory = { + } # key: vehicle_id, value: [[lane_id_1, enter_time, time_spent_on_lane_1], ... , [lane_id_n, enter_time, time_spent_on_lane_n]] + self.history_vehicles = set() + + print("world built successfully.") + + def get_pressure(self): + vehicles = self.eng.get_lane_vehicle_count() + pressures = {} + for i in self.intersections: + pressure = 0 + in_lanes = [] + for road in i.in_roads: + from_zero = ( + road["startIntersection"] == i.id) if self.RIGHT else ( + road["endIntersection"] == i.id) + for n in range(len(road["lanes"]))[::(1 if from_zero else -1)]: + in_lanes.append(road["id"] + "_" + str(n)) + out_lanes = [] + for road in i.out_roads: + from_zero = ( + road["endIntersection"] == i.id) if self.RIGHT else ( + road["startIntersection"] == i.id) + for n in range(len(road["lanes"]))[::(1 if from_zero else -1)]: + out_lanes.append(road["id"] + "_" + str(n)) + for lane in vehicles.keys(): + if lane in in_lanes: + pressure += vehicles[lane] + if lane in out_lanes: + pressure -= vehicles[lane] + pressures[i.id] = pressure + return pressures + + def get_vehicle_lane(self): + # get the current lane of each vehicle. {vehicle_id: lane_id} + vehicle_lane = {} + lane_vehicles = self.eng.get_lane_vehicles() + for lane in self.all_lanes: + for vehicle in lane_vehicles[lane]: + vehicle_lane[vehicle] = lane + return vehicle_lane + + def get_vehicle_waiting_time(self): + # the waiting time of vehicle since last halt. + vehicles = self.eng.get_vehicles(include_waiting=False) + vehicle_speed = self.eng.get_vehicle_speed() + for vehicle in vehicles: + if vehicle not in self.vehicle_waiting_time.keys(): + self.vehicle_waiting_time[vehicle] = 0 + if vehicle_speed[vehicle] < 0.1: + self.vehicle_waiting_time[vehicle] += 1 + else: + self.vehicle_waiting_time[vehicle] = 0 + return self.vehicle_waiting_time + + def get_lane_waiting_time_count(self): + # the sum of waiting times of vehicles on the lane since their last halt. + lane_waiting_time = {} + lane_vehicles = self.eng.get_lane_vehicles() + vehicle_waiting_time = self.get_vehicle_waiting_time() + for lane in self.all_lanes: + lane_waiting_time[lane] = 0 + for vehicle in lane_vehicles[lane]: + lane_waiting_time[lane] += vehicle_waiting_time[vehicle] + return lane_waiting_time + + def get_lane_delay(self, speed_limit=11.11): + # the delay of each lane: 1 - lane_avg_speed/speed_limit + # set speed limit to 11.11 by default + lane_vehicles = self.eng.get_lane_vehicles() + lane_delay = {} + lanes = self.all_lanes + vehicle_speed = self.eng.get_vehicle_speed() + + for lane in lanes: + vehicles = lane_vehicles[lane] + lane_vehicle_count = len(vehicles) + lane_avg_speed = 0.0 + for vehicle in vehicles: + speed = vehicle_speed[vehicle] + lane_avg_speed += speed + if lane_vehicle_count == 0: + lane_avg_speed = speed_limit + else: + lane_avg_speed /= lane_vehicle_count + lane_delay[lane] = 1 - lane_avg_speed / speed_limit + return lane_delay + + def get_vehicle_trajectory(self): + + # lane_id and time spent on the corresponding lane that each vehicle went through + vehicle_lane = self.get_vehicle_lane() + vehicles = self.eng.get_vehicles(include_waiting=False) + for vehicle in vehicles: + if vehicle not in self.vehicle_trajectory: + self.vehicle_trajectory[vehicle] = [[ + vehicle_lane[vehicle], + int(self.eng.get_current_time()), 0 + ]] + else: + if vehicle not in vehicle_lane.keys(): + continue + if vehicle_lane[vehicle] == self.vehicle_trajectory[vehicle][ + -1][0]: + self.vehicle_trajectory[vehicle][-1][2] += 1 + else: + self.vehicle_trajectory[vehicle].append([ + vehicle_lane[vehicle], + int(self.eng.get_current_time()), 0 + ]) + return self.vehicle_trajectory + + def get_history_vehicles(self): + + self.history_vehicles.update(self.eng.get_vehicles()) + return self.history_vehicles + + def _get_roadnet(self, cityflow_config): + roadnet_file = osp.join(cityflow_config["dir"], + cityflow_config["roadnetFile"]) + with open(roadnet_file) as f: + roadnet = json.load(f) + return roadnet + + def subscribe(self, fns): + if isinstance(fns, str): + fns = [fns] + for fn in fns: + if fn in self.info_functions: + if not fn in self.fns: + self.fns.append(fn) + else: + raise Exception("info function %s not exists" % fn) + + def step(self, actions=None): + if actions is not None: + for i, action in enumerate(actions): + self.intersections[i].step(action, self.interval) + self.eng.next_step() + self._update_infos() + + def reset(self, seed): + self.eng.reset(seed) + for I in self.intersections: + I.reset() + self._update_infos() + + def _update_infos(self): + self.info = {} + for fn in self.fns: + self.info[fn] = self.info_functions[fn]() + + def get_info(self, info): + return self.info[info] + + +if __name__ == "__main__": + # testing the env. + world = World("examples/config.json", thread_num=1) + print(world.intersections[0].phase_available_startlanes)