Skip to content

Commit

Permalink
merge develop
Browse files Browse the repository at this point in the history
  • Loading branch information
cpnota committed Dec 7, 2023
2 parents 8e88660 + abeece2 commit ca1ba1d
Show file tree
Hide file tree
Showing 21 changed files with 287 additions and 389 deletions.
232 changes: 126 additions & 106 deletions all/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ class State(dict):
The torch device on which component tensors are stored.
"""

def __init__(self, x, device='cpu', **kwargs):
def __init__(self, x, device="cpu", **kwargs):
if not isinstance(x, dict):
x = {'observation': x}
x = {"observation": x}
for k, v in kwargs.items():
x[k] = v
if 'observation' not in x:
raise Exception('State must contain an observation')
if 'reward' not in x:
x['reward'] = 0.
if 'done' not in x:
x['done'] = False
if 'mask' not in x:
x['mask'] = 1. - x['done']
if "observation" not in x:
raise Exception("State must contain an observation")
if "reward" not in x:
x["reward"] = 0.0
if "done" not in x:
x["done"] = False
if "mask" not in x:
x["mask"] = 1.0 - x["done"]
super().__init__(x)
self._shape = ()
self.device = device
Expand All @@ -71,17 +71,33 @@ def array(cls, list_of_states):
v = list_of_states[0][key]
try:
if isinstance(v, list) and len(v) > 0 and torch.is_tensor(v[0]):
x[key] = torch.stack([torch.stack(state[key]) for state in list_of_states])
x[key] = torch.stack(
[torch.stack(state[key]) for state in list_of_states]
)
elif torch.is_tensor(v):
x[key] = torch.stack([state[key] for state in list_of_states])
else:
x[key] = torch.tensor([state[key] for state in list_of_states], device=device)
x[key] = torch.tensor(
[state[key] for state in list_of_states], device=device
)
except KeyError:
warnings.warn('KeyError while creating StateArray for key "{}", omitting.'.format(key))
warnings.warn(
'KeyError while creating StateArray for key "{}", omitting.'.format(
key
)
)
except ValueError:
warnings.warn('ValueError while creating StateArray for key "{}", omitting.'.format(key))
warnings.warn(
'ValueError while creating StateArray for key "{}", omitting.'.format(
key
)
)
except TypeError:
warnings.warn('TypeError while creating StateArray for key "{}", omitting.'.format(key))
warnings.warn(
'TypeError while creating StateArray for key "{}", omitting.'.format(
key
)
)

return StateArray(x, shape, device=device)

Expand All @@ -100,7 +116,9 @@ def apply(self, model, *keys):
Returns:
The output of the model.
"""
return self.apply_mask(self.as_output(model(*[self.as_input(key) for key in keys])))
return self.apply_mask(
self.as_output(model(*[self.as_input(key) for key in keys]))
)

def as_input(self, key):
"""
Expand Down Expand Up @@ -158,7 +176,7 @@ def update(self, key, value):
return self.__class__(x, device=self.device)

@classmethod
def from_gym(cls, gym_output, device='cpu', dtype=np.float32):
def from_gym(cls, gym_output, device="cpu", dtype=np.float32):
"""
Constructs a State object given the return value of an OpenAI gym reset()/step(action) call.
Expand All @@ -170,27 +188,35 @@ def from_gym(cls, gym_output, device='cpu', dtype=np.float32):
Returns:
A State object.
"""
if not isinstance(gym_output, tuple) and (len(gym_output) == 2 or len(gym_output) == 5):
raise TypeError(f"gym_output should be a tuple, either (observation, info) or (observation, reward, terminated, truncated, info). Recieved {gym_output}.")

# extract info from timestep
if len(gym_output) == 5:
if isinstance(gym_output, tuple) and len(gym_output) == 5:
# gymanisum step()
observation, reward, terminated, truncated, info = gym_output
if len(gym_output) == 2:
elif isinstance(gym_output, tuple) and len(gym_output) == 4:
# legacy gym step()
observation, reward, done, info = gym_output
terminated = done
truncated = False
elif isinstance(gym_output, tuple) and len(gym_output) == 2:
# gymnasium reset()
observation, info = gym_output
reward = 0.
reward = 0.0
terminated = False
truncated = False
else:
# legacy gym reset()
observation = gym_output
reward = 0.0
terminated = False
truncated = False
info = {}
x = {
'observation': torch.from_numpy(
np.array(
observation,
dtype=dtype
),
).to(device),
'reward': float(reward),
'done': terminated or truncated,
'mask': 1. - terminated
"observation": torch.from_numpy(
np.array(observation, dtype=dtype),
).to(device),
"reward": float(reward),
"done": terminated or truncated,
"mask": 1.0 - terminated,
}
info = info if info else {}
for key in info:
Expand All @@ -211,22 +237,22 @@ def to(self, device):
@property
def observation(self):
"""A tensor containing the current observation."""
return self['observation']
return self["observation"]

@property
def reward(self):
"""A float representing the reward for the previous state/action pair."""
return self['reward']
return self["reward"]

@property
def done(self):
"""A boolean that is true if the state is a terminal state, and false otherwise."""
return self['done']
return self["done"]

@property
def mask(self):
"""A float that is 1. if the state is non-terminal, or 0. otherwise."""
return self['mask']
return self["mask"]

@property
def shape(self):
Expand All @@ -239,47 +265,49 @@ def __len__(self):

class StateArray(State):
"""
An n-dimensional array of environment State objects.
An n-dimensional array of environment State objects.
Internally, all components of the states are represented as n-dimensional tensors.
This allows for batch-style processing and easy manipulation of states.
Usually, a StateArray should be constructed using the State.array() function.
Internally, all components of the states are represented as n-dimensional tensors.
This allows for batch-style processing and easy manipulation of states.
Usually, a StateArray should be constructed using the State.array() function.
Args:
x (dict):
A dictionary containing all state information.
Each value should be a tensor in which the first n-dimensions
match the shape of the StateArray.
The following keys are standard:
Args:
x (dict):
A dictionary containing all state information.
Each value should be a tensor in which the first n-dimensions
match the shape of the StateArray.
The following keys are standard:
observation (torch.tensor) (required):
A tensor representing the observations for each state
observation (torch.tensor) (required):
A tensor representing the observations for each state
reward (torch.FloatTensor) (optional):
A tensor representing rewards for the previous state/action pairs
reward (torch.FloatTensor) (optional):
A tensor representing rewards for the previous state/action pairs
done (torch.BoolTensors) (optional):
A tensor representing whether each state is terminal
done (torch.BoolTensors) (optional):
A tensor representing whether each state is terminal
mask (torch.FloatTensor) (optional):
A tensor representing the mask for each state.
device (string):
The torch device on which component tensors are stored.
mask (torch.FloatTensor) (optional):
A tensor representing the mask for each state.
device (string):
The torch device on which component tensors are stored.
"""

def __init__(self, x, shape, device='cpu', **kwargs):
def __init__(self, x, shape, device="cpu", **kwargs):
if not isinstance(x, dict):
x = {'observation': x}
x = {"observation": x}
for k, v in kwargs.items():
x[k] = v
if 'observation' not in x:
raise Exception('StateArray must contain an observation')
if 'reward' not in x:
x['reward'] = torch.zeros(shape, device=device)
if 'done' not in x:
x['done'] = torch.tensor([False] * np.prod(shape), device=device).view(shape)
if 'mask' not in x:
x['mask'] = 1. - x['done'].float()
if "observation" not in x:
raise Exception("StateArray must contain an observation")
if "reward" not in x:
x["reward"] = torch.zeros(shape, device=device)
if "done" not in x:
x["done"] = torch.tensor([False] * np.prod(shape), device=device).view(
shape
)
if "mask" not in x:
x["mask"] = 1.0 - x["done"].float()
super().__init__(x, device=device)
self._shape = shape

Expand All @@ -305,7 +333,9 @@ def update(self, key, value):

def as_input(self, key):
value = self[key]
return value.view((np.prod(self.shape), *value.shape[len(self.shape):])).float()
return value.view(
(np.prod(self.shape), *value.shape[len(self.shape):])
).float()

def as_output(self, tensor):
return tensor.view((*self.shape, *tensor.shape[1:]))
Expand Down Expand Up @@ -343,31 +373,33 @@ def view(self, shape):

@property
def observation(self):
return self['observation']
return self["observation"]

@property
def reward(self):
return self['reward']
return self["reward"]

@property
def done(self):
return self['done']
return self["done"]

@property
def mask(self):
return self['mask']
return self["mask"]

def __getitem__(self, key):
if isinstance(key, slice) or isinstance(key, int):
shape = self['mask'][key].shape
shape = self["mask"][key].shape
if len(shape) == 0:
return State({k: v[key] for (k, v) in self.items()}, device=self.device)
return StateArray({k: v[key] for (k, v) in self.items()}, shape, device=self.device)
return StateArray(
{k: v[key] for (k, v) in self.items()}, shape, device=self.device
)
if torch.is_tensor(key):
# some things may get lost
d = {}
shape = self['mask'][key].shape
for (k, v) in self.items():
shape = self["mask"][key].shape
for k, v in self.items():
try:
d[k] = v[key]
except KeyError:
Expand All @@ -389,7 +421,7 @@ def __len__(self):

@classmethod
def cat(cls, state_array_list, axis=0):
'''Concatenates along batch dimention'''
"""Concatenates along batch dimention"""
if len(state_array_list) == 0:
raise ValueError("cat accepts a non-zero size list of StateArrays")

Expand All @@ -400,13 +432,15 @@ def cat(cls, state_array_list, axis=0):
new_shape = tuple(new_shape)
keys = list(state_array_list[0].keys())
for key in keys:
d[key] = torch.cat([state_array[key] for state_array in state_array_list], axis=axis)
d[key] = torch.cat(
[state_array[key] for state_array in state_array_list], axis=axis
)
return StateArray(d, new_shape, device=state_array_list[0].device)

def batch_execute(self, minibatch_size, fn):
'''
"""
execute in batches to reduce memory consumption
'''
"""
data = self
batch_size = self.shape[0]
results = []
Expand All @@ -424,17 +458,17 @@ def batch_execute(self, minibatch_size, fn):


class MultiagentState(State):
def __init__(self, x, device='cpu', **kwargs):
if 'agent' not in x:
raise Exception('MultiagentState must contain an agent ID')
def __init__(self, x, device="cpu", **kwargs):
if "agent" not in x:
raise Exception("MultiagentState must contain an agent ID")
super().__init__(x, device=device, **kwargs)

@property
def agent(self):
return self['agent']
return self["agent"]

@classmethod
def from_zoo(cls, agent, state, device='cpu', dtype=np.float32):
def from_zoo(cls, agent, state, device="cpu", dtype=np.float32):
"""
Constructs a State object given the return value of an OpenAI gym reset()/step(action) call.
Expand All @@ -446,29 +480,15 @@ def from_zoo(cls, agent, state, device='cpu', dtype=np.float32):
Returns:
A State object.
"""
if not isinstance(state, tuple):
return MultiagentState({
'agent': agent,
'observation': torch.from_numpy(
np.array(
state,
dtype=dtype
),
).to(device)
}, device=device)

observation, reward, done, info = state
observation = torch.from_numpy(
np.array(
observation,
dtype=dtype
),
).to(device)
observation, reward, terminated, truncated, info = state
x = {
'agent': agent,
'observation': observation,
'reward': float(reward),
'done': done,
"agent": agent,
"observation": torch.from_numpy(
np.array(observation, dtype=dtype),
).to(device),
"reward": float(reward),
"done": terminated or truncated,
"mask": 1.0 - terminated,
}
info = info if info else {}
for key in info:
Expand Down
Loading

0 comments on commit ca1ba1d

Please sign in to comment.