forked from CSharpYDS/edge-computing-Q-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeep_q_learning_v1.py
61 lines (52 loc) · 2.36 KB
/
deep_q_learning_v1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from brain import *
from env import *
from schedule_policy import *
def deepQLearning(job_sequence):
agent = Agent()
S = Servers()
history, ret_history = [], []
cost_min, cost_min1 = 1000000, 1000000
i, wrong = 0, 0
for episode in range(300):
S = Servers()
cost, cost1, period_reward = 0, 0, 0
history = []
i, idx = i+1, 0
state, state_next = State(), State()
for time in range(1, 100000):
if idx >= len(job_sequence) and S.done(): break
job_in_time = []
while True:
if idx < len(job_sequence) and job_sequence[idx].depart_time == time:
job_in_time.append(job_sequence[idx])
idx+=1
else: break
for job in job_in_time:
# 1. state
_,_, state_server = state.updateState(S, time, job)
tensor_state = torch.from_numpy(np.array([[state.count(), job.job_id]])).type(torch.FloatTensor).to(device)
# 2.1 获取action
action = agent.get_action(tensor_state, episode)
server_id = action.cpu().item()
# 2.2 执行action
job.arrive_time = job.arriveTime(server_id)
S.servers[server_id].server.append(job)
# 4. state_next
_, _, state_server_next = state_next.updateState(S, time+1, None)
tensor_state_next = torch.from_numpy(np.array([[state_next.count(), -1]])).type(torch.FloatTensor).to(device)
# 3. reward
reward = torch.from_numpy(np.array([-state_next.getCost()])).type(torch.FloatTensor).to(device)
# 5. Q Learning
# print(tensor_state, action, tensor_state_next, reward)
agent.memorize(tensor_state, action, tensor_state_next, reward)
agent.update_q_function()
S, cost,cost1, history = SJFPolicy(S, time, cost, cost1,history) # 直接更新cost
if(episode % 2 == 0):
agent.update_target_q_function()
cost_min = min(cost_min, cost)
cost_min1 = min(cost_min1, cost1)
if judge(history) == False or len(history) != len(job_sequence):
wrong += 1
torch.save(agent.brain.main_q_network, PATH)
# print("wrong DQN", wrong)
return cost_min, cost_min1, ret_history