Skip to content

Commit

Permalink
finalize PR
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Aug 29, 2023
1 parent d2283e2 commit 12bb9df
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 28 deletions.
2 changes: 1 addition & 1 deletion wild_visual_navigation/cfg/experiment_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class AblationDataModuleParams:

@dataclass
class ModelParams:
name: str = "LinearRnvp" # LinearRnvp, SimpleMLP, SimpleGCN, DoubleMLP
name: str = "LinearRnvp" # LinearRnvp, SimpleMLP, SimpleGCN, DoubleMLP
load_ckpt: Optional[str] = None

@dataclass
Expand Down
2 changes: 1 addition & 1 deletion wild_visual_navigation/feature_extractor/dino_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
model_type: str = "vit_small",
patch_size: int = 8,
dim: int = 384,
projection_type: str = None, # nonlinear or None
projection_type: str = None, # nonlinear or None
dropout: bool = False, # True or False
):
self.dim = dim # 90 or 384
Expand Down
8 changes: 4 additions & 4 deletions wild_visual_navigation/learning/model/linear_rnvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def backward(self, x, y=None):
_mx = mx

s, t = self.st(_mx)
s = torch.tanh(s) # Adding an activation function here with non-linearities
s = torch.tanh(s) # Adding an activation function here with non-linearities

u = mx + (1 - self.mask) * (x - t) * torch.exp(-s)

Expand Down Expand Up @@ -222,7 +222,7 @@ def __init__(
super().__init__()

self.register_buffer("prior_mean", torch.zeros(input_dim)) # Normal Gaussian with zero mean
self.register_buffer("prior_var", torch.ones(input_dim)) # Normal Gaussian with unit variance
self.register_buffer("prior_var", torch.ones(input_dim)) # Normal Gaussian with unit variance

if mask_type == "odds":
mask = torch.arange(0, input_dim).float() % 2
Expand Down Expand Up @@ -258,11 +258,11 @@ def __init__(
self.flows = SequentialFlow(*blocks)

def logprob(self, x):
return self.prior.log_prob(x) # Compute log probability of the input at the Gaussian distribution
return self.prior.log_prob(x) # Compute log probability of the input at the Gaussian distribution

@property
def prior(self):
return distributions.Normal(self.prior_mean, self.prior_var) # Normal Gaussian with zero mean and unit variance
return distributions.Normal(self.prior_mean, self.prior_var) # Normal Gaussian with zero mean and unit variance

def forward(self, data: Data):
x = data.x
Expand Down
12 changes: 10 additions & 2 deletions wild_visual_navigation/learning/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,22 @@ def __init__(self, confidence_std_factor, method):
# )

def forward(
self, graph: Optional[Data], res: dict, loss_mean: int = None, loss_std: int = None, train: bool = False, update_generator: bool = True, step: int = 0, log_step: bool = False
self,
graph: Optional[Data],
res: dict,
loss_mean: int = None,
loss_std: int = None,
train: bool = False,
update_generator: bool = True,
step: int = 0,
log_step: bool = False,
):
loss_aux = {}
loss_aux["loss_trav"] = torch.tensor([0.0])
loss_aux["loss_reco"] = torch.tensor([0.0])
loss_aux["confidence"] = torch.tensor([0.0])

losses = -(res["logprob"].sum(1) + res["log_det"]) # Sum over all channels, resulting in h*w output dimensions
losses = -(res["logprob"].sum(1) + res["log_det"]) # Sum over all channels, resulting in h*w output dimensions

# print(torch.mean(losses))
l_clip = losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,13 @@ def load_checkpoint(self, checkpoint_path: str):
self._pause_training = False

@accumulate_time
def make_batch(self, batch_size: int = 8, anomaly_detection: bool = False, n_features: int = 200, vis_training_samples: bool = False):
def make_batch(
self,
batch_size: int = 8,
anomaly_detection: bool = False,
n_features: int = 200,
vis_training_samples: bool = False,
):
"""Samples a batch from the mission_graph
Args:
Expand Down Expand Up @@ -580,7 +586,9 @@ def make_batch(self, batch_size: int = 8, anomaly_detection: bool = False, n_fea

# Visualize supervision mask
if vis_training_samples:
self._last_image_mask_pub.publish(self._bridge.cv2_to_imgmsg(mask.cpu().numpy().astype(np.uint8) * 255, "mono8"))
self._last_image_mask_pub.publish(
self._bridge.cv2_to_imgmsg(mask.cpu().numpy().astype(np.uint8) * 255, "mono8")
)

# Save mask as numpy with opencv
# cv2.imwrite(os.path.join("/home/rschmid/ext", "mask", f"{rospy.get_time()}.png"), mask.cpu().numpy().astype(np.uint8) * 255)
Expand Down Expand Up @@ -622,7 +630,11 @@ def train(self):
return_dict = {"mission_graph_num_valid_node": num_valid_nodes}
if num_valid_nodes > self._min_samples_for_training:
# Prepare new batch
graph = self.make_batch(self._exp_cfg["ablation_data_module"]["batch_size"], anomaly_detection=self._anomaly_detection, vis_training_samples=self._vis_training_samples)
graph = self.make_batch(
self._exp_cfg["ablation_data_module"]["batch_size"],
anomaly_detection=self._anomaly_detection,
vis_training_samples=self._vis_training_samples,
)
if graph is not None:

self._loss_mean = None
Expand All @@ -634,7 +646,15 @@ def train(self):
res = self._model(graph)

log_step = (self._step % 20) == 0
self._loss, loss_aux, trav = self._traversability_loss(graph, res, step=self._step, log_step=log_step, loss_mean=self._loss_mean, loss_std=self._loss_std, train=True)
self._loss, loss_aux, trav = self._traversability_loss(
graph,
res,
step=self._step,
log_step=log_step,
loss_mean=self._loss_mean,
loss_std=self._loss_std,
train=True,
)

self._loss_mean = loss_aux["loss_mean"]
self._loss_std = loss_aux["loss_std"]
Expand Down
5 changes: 3 additions & 2 deletions wild_visual_navigation_ros/launch/robot.launch
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

<!-- Launch node -->
<include file="$(find wild_visual_navigation_ros)/launch/wild_visual_navigation.launch">
<arg name="camera" value="$(arg camera)"/>
<arg name="stack" value="$(arg stack)"/>
<arg name="camera" value="$(arg camera)"/>
<arg name="stack" value="$(arg stack)"/>
<arg name="reload_default_params" value="True"/>
</include>
</launch>
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<arg name="params_file" default="$(find wild_visual_navigation_ros)/config/wild_visual_navigation/default.yaml"/>
<arg name="overlay_images" default="True"/>
<arg name="resize_images" default="True"/>
<arg name="reload_default_params" default="False"/>

<!-- Load parameters -->
<rosparam command="load" file="$(arg params_file)" ns="wvn_learning_node"/>
Expand All @@ -19,10 +20,12 @@
<node name="wvn_learning_node" pkg="wild_visual_navigation_ros" type="wvn_learning_node.py" output="screen">
<param if="$(eval arg('stack') == 'rsl')" name="desired_twist_topic" value="/log/state/desiredRobotTwist"/>
<param if="$(eval arg('stack') == 'anybotics')" name="desired_twist_topic" value="/motion_reference/command_twist"/>
<param name="reload_default_params" value="$(arg reload_default_params)" />
</node>
<node name="wvn_feature_extractor_node" pkg="wild_visual_navigation_ros" type="wvn_feature_extractor_node.py" output="screen">
<param if="$(eval arg('stack') == 'rsl')" name="desired_twist_topic" value="/log/state/desiredRobotTwist"/>
<param if="$(eval arg('stack') == 'anybotics')" name="desired_twist_topic" value="/motion_reference/command_twist"/>
<param name="reload_default_params" value="$(arg reload_default_params)" />
</node>

<include if="$(arg resize_images)" file="$(find wild_visual_navigation_ros)/launch/resize_images.launch"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
if self.clip_to_binary:
out_trav = torch.where(out_trav.squeeze() <= self.traversability_threshold, 0.0, 1.0)


msg = rc.numpy_to_ros_image(out_trav.cpu().numpy(), "passthrough")
msg.header = image_msg.header
msg.width = out_trav.shape[0]
Expand Down Expand Up @@ -381,7 +380,7 @@ def load_model(self):
node_name = "wvn_feature_extractor_node"
rospy.init_node(node_name)

if True:
if rospy.get_param("~reload_default_params", True):
import rospkg

rospack = rospkg.RosPack()
Expand All @@ -390,7 +389,9 @@ def load_model(self):
os.system(
f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/wide_angle_front_compressed.yaml wvn_feature_extractor_node"
)
print(f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/wide_angle_front_compressed.yaml wvn_feature_extractor_node")
print(
f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/wide_angle_front_compressed.yaml wvn_feature_extractor_node"
)

wvn = WvnFeatureExtractor()
rospy.spin()
2 changes: 1 addition & 1 deletion wild_visual_navigation_ros/scripts/wvn_learning_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def visualize_mission_graph(self):
if __name__ == "__main__":
node_name = "wvn_learning_node"
rospy.init_node(node_name)
if True:
if rospy.get_param("~reload_default_params", True):
import rospkg

rospack = rospkg.RosPack()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,18 @@ def ros_tf_to_torch(tf_pose, device="cpu"):


def ros_image_to_torch(ros_img, desired_encoding="rgb8", device="cpu"):
if type(ros_img).__name__ is "_sensor_msgs__Image":
if type(ros_img).__name__ is "_sensor_msgs__Image" or isinstance(ros_img, Image):
np_image = CV_BRIDGE.imgmsg_to_cv2(ros_img, desired_encoding=desired_encoding)

elif type(ros_img).__name__ is "_sensor_msgs__CompressedImage":
elif type(ros_img).__name__ is "_sensor_msgs__CompressedImage" or isinstance(ros_img, CompressedImage):
np_arr = np.fromstring(ros_img.data, np.uint8)
np_image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if "bgr" in ros_img.format:
np_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)

elif isinstance(ros_img, Image):
np_image = CV_BRIDGE.imgmsg_to_cv2(ros_img, desired_encoding=desired_encoding)

elif isinstance(ros_img, CompressedImage):
np_arr = np.fromstring(ros_img.data, np.uint8)
np_image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if "bgr" in ros_img.format:
np_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
else:
raise ValueError("Image message type is not implemented.")

return TO_TENSOR(np_image).to(device)


Expand Down

0 comments on commit 12bb9df

Please sign in to comment.