-
Notifications
You must be signed in to change notification settings - Fork 0
/
simu_rl.py
38 lines (34 loc) · 1.06 KB
/
simu_rl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
from stable_baselines3.common.env_checker import check_env
from tqdm import tqdm
from agents.rl_simple import RLSimple
from associations.simple import SimpleAssociation
from channels.simple import SimpleChannel
from mobilities.simple import SimpleMobility
from sixg_radio_mgmt import CommunicationEnv
from traffics.simple import SimpleTraffic
seed = 10
comm_env = CommunicationEnv(
SimpleChannel,
SimpleTraffic,
SimpleMobility,
SimpleAssociation,
"simple",
obs_space=RLSimple.get_obs_space,
action_space=RLSimple.get_action_space,
)
rl_agent = RLSimple(comm_env, 2, 2, np.array([2, 2]), seed=seed)
comm_env.set_agent_functions(
rl_agent.obs_space_format,
rl_agent.action_format,
rl_agent.calculate_reward,
)
check_env(comm_env)
total_number_steps = 10000
rl_agent.train(total_number_steps)
obs = comm_env.reset(seed=seed)[0]
for step_number in tqdm(np.arange(total_number_steps)):
sched_decision = rl_agent.step(obs)
obs, _, end_ep, _, _ = comm_env.step(sched_decision)
if end_ep:
comm_env.reset()