diff --git a/game_wrappers/ai_sys.py b/game_wrappers/ai_sys.py index ffd7cd4..3bdf909 100644 --- a/game_wrappers/ai_sys.py +++ b/game_wrappers/ai_sys.py @@ -12,11 +12,19 @@ def __init__(self, args, env, logger): self.args = args self.use_model = True self.p1_model = None - if args.load_p1_model is '': - self.use_model = False - else: - self.p1_model = init_model(None, args.load_p1_model, args.alg, args, env, logger) + model_path = '' + try: + model_path = args.load_p1_model if args.load_p1_model != '' else args.model_1 + except AttributeError: + try: + model_path = args.model_1 + except AttributeError: + self.use_model = False + print('No model attribute found') + + if model_path and self.use_model: + self.p1_model = init_model(None, model_path, args.alg, args, env, logger) def predict(self, state, info, deterministic): if self.use_model: diff --git a/model_vs_game.py b/model_vs_game.py index 1d39dff..b47b021 100644 --- a/model_vs_game.py +++ b/model_vs_game.py @@ -51,7 +51,10 @@ def __init__(self, args, logger, need_display=True): self.ai_sys = games.wrappers.ai_sys(args, self.p1_env, logger) if args.model_1 != '' or args.model_2 != '': models = [args.model_1, args.model_2] - self.ai_sys.SetModels(models) + try: + self.ai_sys.SetModels(models) + except AttributeError: + print("SetModels method not found in ai_sys") self.need_display = need_display self.args = args @@ -70,7 +73,10 @@ def play(self, continuous=True, need_reset=True): self.display_env.action_probabilities = [] for i in range(4): - self.display_env.set_ai_sys_info(self.ai_sys) + try: + self.display_env.set_ai_sys_info(self.ai_sys) + except AttributeError: + pass state, reward, done, info = self.display_env.step(p1_actions) total_rewards += reward