-
-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
11994b0
commit 6d79011
Showing
4 changed files
with
249 additions
and
206 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from collections import deque, namedtuple | ||
from typing import Optional, Union, List, Sequence | ||
import numpy as np | ||
import random | ||
import PyAPI.structures as THUAI7 | ||
from PyAPI.State import State | ||
|
||
Transition = namedtuple("Transition", ["state", "action", "reward", "next_state"]) | ||
|
||
|
||
class SingleObservation: | ||
def __init__(self, state: State): | ||
if isinstance(state.self, THUAI7.Ship): | ||
self.selfShipInfo = np.array(self.__ObsShip(state.self)) | ||
self.otherShipInfo = np.array( | ||
[ | ||
self.__ObsShip(ship) | ||
for ship in state.ships | ||
if ship.playerID != state.self.playerID | ||
] | ||
) | ||
self.bulletsInfo = np.array( | ||
[self.__ObsBullet(bullet) for bullet in state.bullets] | ||
) | ||
self.homeInfo = np.array( | ||
[[*key, *value] for key, value in state.mapInfo.homeState.items()] | ||
) | ||
elif isinstance(state.self, THUAI7.Team): | ||
self.shipInfo = np.array([self.__ObsShip(ship) for ship in state.ships]) | ||
self.enemyShipInfo = np.array( | ||
[self.__ObsShip(ship) for ship in state.enemyShips] | ||
) | ||
self.factoryInfo = np.array( | ||
[[*key, *value] for key, value in state.mapInfo.factoryState.items()] | ||
) | ||
self.communityInfo = np.array( | ||
[[*key, *value] for key, value in state.mapInfo.communityState.items()] | ||
) | ||
self.fortInfo = np.array( | ||
[[*key, *value] for key, value in state.mapInfo.communityState.items()] | ||
) | ||
self.bridgeInfo = np.array( | ||
[[*key, value] for key, value in state.mapInfo.bridgeState.items()] | ||
) | ||
self.garbageInfo = np.array( | ||
[[*key, value] for key, value in state.mapInfo.garbageState.items()] | ||
) | ||
self.gameInfo = np.array(self.__ObsGame(state.gameInfo)) | ||
|
||
def __ObsShip(ship: THUAI7.Ship): | ||
return [ | ||
ship.x, | ||
ship.y, | ||
ship.speed, | ||
ship.facingDirection, | ||
ship.viewRange, | ||
ship.hp, | ||
ship.armor, | ||
ship.shield, | ||
ship.shipState.value, | ||
ship.shipType.value, | ||
ship.producerType.value, | ||
ship.constructorType.value, | ||
ship.armorType.value, | ||
ship.shieldType.value, | ||
ship.weaponType.value, | ||
] | ||
|
||
def __ObsBullet(bullet: THUAI7.Bullet): | ||
return [ | ||
bullet.x, | ||
bullet.y, | ||
bullet.facingDirection, | ||
bullet.speed, | ||
bullet.damage, | ||
bullet.bombRange, | ||
bullet.explodeRange, | ||
] | ||
|
||
def __ObsGame(gameInfo: THUAI7.GameInfo): | ||
return [ | ||
[gameInfo.redHomeHp, gameInfo.redEnergy, gameInfo.redScore], | ||
[gameInfo.blueHomeHp, gameInfo.blueEnergy, gameInfo.blueScore], | ||
] | ||
|
||
|
||
class ObservatonSpace: | ||
def __init__(self, state: List[State]): | ||
assert len(state) == 5 | ||
self.teamObs = SingleObservation(state[0]) | ||
self.ship_1_Obs = SingleObservation(state[1]) | ||
self.ship_2_Obs = SingleObservation(state[2]) | ||
self.ship_3_Obs = SingleObservation(state[3]) | ||
self.ship_4_Obs = SingleObservation(state[4]) | ||
|
||
|
||
class ShipAction: | ||
def __init__(self, action: int, angle: Optional[float] = None): | ||
assert action in range(0, 16), "ship action out of range" | ||
self.action = action | ||
self.attackAngle = angle | ||
|
||
|
||
class ActionSpace: | ||
def __init__(self, teamAction: Sequence[int], shipsAction: Sequence[ShipAction]): | ||
assert ( | ||
len(teamAction) == 0 | ||
and (teamAction[0] in range(0, 14)) | ||
and (teamAction[1] in range(0, 19)) | ||
) | ||
assert len(shipsAction) == 4 | ||
self.teamAction = teamAction | ||
self.shipsAction = shipsAction | ||
# move | ||
# attack angle | ||
# recover | ||
# produce | ||
# rebuild constructiontype | ||
# constrct constructiontype | ||
# wait | ||
# endallaction | ||
# wait | ||
# endall | ||
# install | ||
# recycle | ||
# buildship |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,99 +1,16 @@ | ||
from collections import deque, namedtuple | ||
import random | ||
import PyAPI.structures as THUAI7 | ||
from PyAPI.State import State | ||
from PyAPI.utils import AssistFunction | ||
from typing import Union, Final, cast, List | ||
from PyAPI.constants import Constants | ||
from logic import Logic | ||
from PyAPI.Space import ActionSpace, ObservatonSpace, Transition | ||
import time | ||
|
||
Transition = namedtuple("Transition", ["state", "action", "reward", "next_state"]) | ||
|
||
|
||
class ObservationSpace: | ||
def __init__(self, state: State): | ||
if isinstance(state.self, THUAI7.Sweeper): | ||
self.selfSweeperInfo = self.__ObsSweeper(state.self) | ||
self.otherSweeperInfo = [ | ||
self.__ObsSweeper(sweeper) | ||
for sweeper in state.sweepers | ||
if sweeper.playerID != state.self.playerID | ||
] | ||
self.bulletsInfo = [self.__ObsBullet(bullet) for bullet in state.bullets] | ||
self.homeInfo = [ | ||
[*key, *value] for key, value in state.mapInfo.homeState.items() | ||
] | ||
elif isinstance(state.self, THUAI7.Team): | ||
self.sweeperInfo = [ | ||
self.__ObsSweeper(sweeper) for sweeper in state.sweepers | ||
] | ||
self.enemySweeperInfo = [ | ||
self.__ObsSweeper(sweeper) for sweeper in state.enemySweepers | ||
] | ||
self.recycleBankInfo = [ | ||
[*key, *value] for key, value in state.mapInfo.recycleBankState.items() | ||
] | ||
self.chargeStationInfo = [ | ||
[*key, *value] for key, value in state.mapInfo.chargeStationState.items() | ||
] | ||
self.signalTowerInfo = [ | ||
[*key, *value] for key, value in state.mapInfo.chargeStationState.items() | ||
] | ||
self.bridgeInfo = [ | ||
[*key, value] for key, value in state.mapInfo.bridgeState.items() | ||
] | ||
self.garbageInfo = [ | ||
[*key, value] for key, value in state.mapInfo.garbageState.items() | ||
] | ||
self.gameInfo = self.__ObsGame(state.gameInfo) | ||
|
||
def __ObsSweeper(sweeper: THUAI7.Sweeper): | ||
return [ | ||
sweeper.x, | ||
sweeper.y, | ||
sweeper.speed, | ||
sweeper.facingDirection, | ||
sweeper.viewRange, | ||
sweeper.hp, | ||
sweeper.armor, | ||
sweeper.shield, | ||
sweeper.sweeperState.value, | ||
sweeper.sweeperType.value, | ||
sweeper.producerType.value, | ||
sweeper.constructorType.value, | ||
sweeper.armorType.value, | ||
sweeper.shieldType.value, | ||
sweeper.weaponType.value, | ||
] | ||
|
||
def __ObsBullet(bullet: THUAI7.Bullet): | ||
return [ | ||
bullet.x, | ||
bullet.y, | ||
bullet.facingDirection, | ||
bullet.speed, | ||
bullet.damage, | ||
bullet.bombRange, | ||
bullet.explodeRange, | ||
] | ||
|
||
def __ObsGame(gameInfo: THUAI7.GameInfo): | ||
return [ | ||
[gameInfo.redHomeHp, gameInfo.redEnergy, gameInfo.redScore], | ||
[gameInfo.blueHomeHp, gameInfo.blueEnergy, gameInfo.blueScore], | ||
] | ||
|
||
|
||
class ActionSpace: | ||
class gym: | ||
def __init__(self): | ||
pass | ||
|
||
|
||
class Memory: | ||
def __init__(self, max_len): | ||
self.memory = deque([], maxlen=max_len) | ||
self.memory = [] | ||
self.logic1 = Logic(playerID=0, teamID=0) | ||
|
||
def push(self, *args): | ||
self.memory.append(Transition(*args)) | ||
|
||
def sample(self, batch_size): | ||
return random.sample(self.memory, batch_size) | ||
|
||
def __len__(self): | ||
return len(self.memory) |
Oops, something went wrong.