Skip to content

Commit

Permalink
refactor: 🐛 fix unrenamed bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sigureling committed Apr 3, 2024
1 parent 11994b0 commit 6d79011
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 206 deletions.
126 changes: 126 additions & 0 deletions CAPI/python/PyAPI/Space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from collections import deque, namedtuple
from typing import Optional, Union, List, Sequence
import numpy as np
import random
import PyAPI.structures as THUAI7
from PyAPI.State import State

Transition = namedtuple("Transition", ["state", "action", "reward", "next_state"])


class SingleObservation:
def __init__(self, state: State):
if isinstance(state.self, THUAI7.Ship):
self.selfShipInfo = np.array(self.__ObsShip(state.self))
self.otherShipInfo = np.array(
[
self.__ObsShip(ship)
for ship in state.ships
if ship.playerID != state.self.playerID
]
)
self.bulletsInfo = np.array(
[self.__ObsBullet(bullet) for bullet in state.bullets]
)
self.homeInfo = np.array(
[[*key, *value] for key, value in state.mapInfo.homeState.items()]
)
elif isinstance(state.self, THUAI7.Team):
self.shipInfo = np.array([self.__ObsShip(ship) for ship in state.ships])
self.enemyShipInfo = np.array(
[self.__ObsShip(ship) for ship in state.enemyShips]
)
self.factoryInfo = np.array(
[[*key, *value] for key, value in state.mapInfo.factoryState.items()]
)
self.communityInfo = np.array(
[[*key, *value] for key, value in state.mapInfo.communityState.items()]
)
self.fortInfo = np.array(
[[*key, *value] for key, value in state.mapInfo.communityState.items()]
)
self.bridgeInfo = np.array(
[[*key, value] for key, value in state.mapInfo.bridgeState.items()]
)
self.garbageInfo = np.array(
[[*key, value] for key, value in state.mapInfo.garbageState.items()]
)
self.gameInfo = np.array(self.__ObsGame(state.gameInfo))

def __ObsShip(ship: THUAI7.Ship):
return [
ship.x,
ship.y,
ship.speed,
ship.facingDirection,
ship.viewRange,
ship.hp,
ship.armor,
ship.shield,
ship.shipState.value,
ship.shipType.value,
ship.producerType.value,
ship.constructorType.value,
ship.armorType.value,
ship.shieldType.value,
ship.weaponType.value,
]

def __ObsBullet(bullet: THUAI7.Bullet):
return [
bullet.x,
bullet.y,
bullet.facingDirection,
bullet.speed,
bullet.damage,
bullet.bombRange,
bullet.explodeRange,
]

def __ObsGame(gameInfo: THUAI7.GameInfo):
return [
[gameInfo.redHomeHp, gameInfo.redEnergy, gameInfo.redScore],
[gameInfo.blueHomeHp, gameInfo.blueEnergy, gameInfo.blueScore],
]


class ObservatonSpace:
def __init__(self, state: List[State]):
assert len(state) == 5
self.teamObs = SingleObservation(state[0])
self.ship_1_Obs = SingleObservation(state[1])
self.ship_2_Obs = SingleObservation(state[2])
self.ship_3_Obs = SingleObservation(state[3])
self.ship_4_Obs = SingleObservation(state[4])


class ShipAction:
def __init__(self, action: int, angle: Optional[float] = None):
assert action in range(0, 16), "ship action out of range"
self.action = action
self.attackAngle = angle


class ActionSpace:
def __init__(self, teamAction: Sequence[int], shipsAction: Sequence[ShipAction]):
assert (
len(teamAction) == 0
and (teamAction[0] in range(0, 14))
and (teamAction[1] in range(0, 19))
)
assert len(shipsAction) == 4
self.teamAction = teamAction
self.shipsAction = shipsAction
# move
# attack angle
# recover
# produce
# rebuild constructiontype
# constrct constructiontype
# wait
# endallaction
# wait
# endall
# install
# recycle
# buildship
101 changes: 9 additions & 92 deletions CAPI/python/PyAPI/gym.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,16 @@
from collections import deque, namedtuple
import random
import PyAPI.structures as THUAI7
from PyAPI.State import State
from PyAPI.utils import AssistFunction
from typing import Union, Final, cast, List
from PyAPI.constants import Constants
from logic import Logic
from PyAPI.Space import ActionSpace, ObservatonSpace, Transition
import time

Transition = namedtuple("Transition", ["state", "action", "reward", "next_state"])


class ObservationSpace:
def __init__(self, state: State):
if isinstance(state.self, THUAI7.Sweeper):
self.selfSweeperInfo = self.__ObsSweeper(state.self)
self.otherSweeperInfo = [
self.__ObsSweeper(sweeper)
for sweeper in state.sweepers
if sweeper.playerID != state.self.playerID
]
self.bulletsInfo = [self.__ObsBullet(bullet) for bullet in state.bullets]
self.homeInfo = [
[*key, *value] for key, value in state.mapInfo.homeState.items()
]
elif isinstance(state.self, THUAI7.Team):
self.sweeperInfo = [
self.__ObsSweeper(sweeper) for sweeper in state.sweepers
]
self.enemySweeperInfo = [
self.__ObsSweeper(sweeper) for sweeper in state.enemySweepers
]
self.recycleBankInfo = [
[*key, *value] for key, value in state.mapInfo.recycleBankState.items()
]
self.chargeStationInfo = [
[*key, *value] for key, value in state.mapInfo.chargeStationState.items()
]
self.signalTowerInfo = [
[*key, *value] for key, value in state.mapInfo.chargeStationState.items()
]
self.bridgeInfo = [
[*key, value] for key, value in state.mapInfo.bridgeState.items()
]
self.garbageInfo = [
[*key, value] for key, value in state.mapInfo.garbageState.items()
]
self.gameInfo = self.__ObsGame(state.gameInfo)

def __ObsSweeper(sweeper: THUAI7.Sweeper):
return [
sweeper.x,
sweeper.y,
sweeper.speed,
sweeper.facingDirection,
sweeper.viewRange,
sweeper.hp,
sweeper.armor,
sweeper.shield,
sweeper.sweeperState.value,
sweeper.sweeperType.value,
sweeper.producerType.value,
sweeper.constructorType.value,
sweeper.armorType.value,
sweeper.shieldType.value,
sweeper.weaponType.value,
]

def __ObsBullet(bullet: THUAI7.Bullet):
return [
bullet.x,
bullet.y,
bullet.facingDirection,
bullet.speed,
bullet.damage,
bullet.bombRange,
bullet.explodeRange,
]

def __ObsGame(gameInfo: THUAI7.GameInfo):
return [
[gameInfo.redHomeHp, gameInfo.redEnergy, gameInfo.redScore],
[gameInfo.blueHomeHp, gameInfo.blueEnergy, gameInfo.blueScore],
]


class ActionSpace:
class gym:
def __init__(self):
pass


class Memory:
def __init__(self, max_len):
self.memory = deque([], maxlen=max_len)
self.memory = []
self.logic1 = Logic(playerID=0, teamID=0)

def push(self, *args):
self.memory.append(Transition(*args))

def sample(self, batch_size):
return random.sample(self.memory, batch_size)

def __len__(self):
return len(self.memory)
Loading

0 comments on commit 6d79011

Please sign in to comment.