From 8a4c6076e31e0410d724770f7559943112a240d8 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Tue, 5 Oct 2021 00:59:37 +0200 Subject: [PATCH] Add Enum observation space unflattening (from one hot encoding) --- gym_csgo/spaces/enum.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gym_csgo/spaces/enum.py b/gym_csgo/spaces/enum.py index 382e122..18e9ca1 100644 --- a/gym_csgo/spaces/enum.py +++ b/gym_csgo/spaces/enum.py @@ -41,3 +41,9 @@ def flatten_enum(space, x): onehot[space.values.tolist().index(x)] = 1 # Return one hot encoded flat enum space return onehot + +# Reverse the flatten_enum operation +@gym.spaces.unflatten.register(Enum) +def unflatten_enum(space, x): + # Lookup enum value at encoded position + return space.values[np.nonzero(x)[0][0]]