diff --git a/simpler_env/main_inference.py b/simpler_env/main_inference.py index 7ebe20c2..699c00af 100644 --- a/simpler_env/main_inference.py +++ b/simpler_env/main_inference.py @@ -7,7 +7,7 @@ from simpler_env.evaluation.maniskill2_evaluator import maniskill2_evaluator from simpler_env.policies.octo.octo_server_model import OctoServerInference from simpler_env.policies.rt1.rt1_model import RT1Inference -from simpler_env.policies.openvla.openvla_model import OpenVALInference +from simpler_env.policies.openvla.openvla_model import OpenVLAInference try: from simpler_env.policies.octo.octo_model import OctoInference @@ -56,7 +56,7 @@ ) elif args.policy_model == "openvla": assert args.ckpt_path is not None - model = OpenVALInference( + model = OpenVLAInference( saved_model_path=args.ckpt_path, policy_setup=args.policy_setup, action_scale=args.action_scale, diff --git a/simpler_env/policies/openvla/openvla_model.py b/simpler_env/policies/openvla/openvla_model.py index 66bdf0d4..b21f1378 100644 --- a/simpler_env/policies/openvla/openvla_model.py +++ b/simpler_env/policies/openvla/openvla_model.py @@ -11,7 +11,7 @@ import cv2 as cv -class OpenVALInference: +class OpenVLAInference: def __init__( self, saved_model_path: str = "openvla/openvla-7b", @@ -144,7 +144,7 @@ def step( relative_gripper_action = self.previous_gripper_action - current_gripper_action self.previous_gripper_action = current_gripper_action - if np.abs(relative_gripper_action) > 0.5 and self.sticky_action_is_on is False: + if np.abs(relative_gripper_action) > 0.5 and (not self.sticky_action_is_on): self.sticky_action_is_on = True self.sticky_gripper_action = relative_gripper_action