From 134a6805c91628c536a3cc825bf7a0f0675d1d12 Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Tue, 11 Jul 2023 16:22:47 +0200 Subject: [PATCH 1/2] Added mps support and device auto selection, minor style fixes --- persam.py | 53 +++++++++++++++---------------- persam_f.py | 70 ++++++++++++++++++++--------------------- persam_f_multi_obj.py | 72 +++++++++++++++++++++---------------------- 3 files changed, 96 insertions(+), 99 deletions(-) diff --git a/persam.py b/persam.py index 190ef5d..702ee94 100644 --- a/persam.py +++ b/persam.py @@ -10,21 +10,25 @@ import warnings warnings.filterwarnings('ignore') -from show import * from per_segment_anything import sam_model_registry, SamPredictor +from show import * +# Priority is cuda > mps > cpu +DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else + 'mps' if torch.backends.mps.is_available() else + 'cpu') def get_arguments(): - parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, default='./data') parser.add_argument('--outdir', type=str, default='persam') parser.add_argument('--ckpt', type=str, default='sam_vit_h_4b8939.pth') + parser.add_argument('--device', type=str, default=DEFAULT_DEVICE) parser.add_argument('--ref_idx', type=str, default='00') parser.add_argument('--sam_type', type=str, default='vit_h') - + args = parser.parse_args() return args @@ -40,7 +44,7 @@ def main(): if not os.path.exists('./outputs/'): os.mkdir('./outputs/') - + for obj_name in os.listdir(images_path): if ".DS" not in obj_name: persam(args, obj_name, images_path, masks_path, output_path) @@ -49,7 +53,7 @@ def main(): def persam(args, obj_name, images_path, masks_path, output_path): print("\n------------> Segment " + obj_name) - + # Path preparation ref_image_path = os.path.join(images_path, obj_name, args.ref_idx + '.jpg') ref_mask_path = os.path.join(masks_path, obj_name, args.ref_idx + '.png') @@ -64,21 +68,19 @@ def persam(args, obj_name, images_path, masks_path, output_path): ref_mask = cv2.imread(ref_mask_path) ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) - - print("======> Load SAM" ) + print("======> Load SAM") if args.sam_type == 'vit_h': sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(args.device) elif args.sam_type == 'vit_t': sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' - device = "cuda" if torch.cuda.is_available() else "cpu" - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=args.device) sam.eval() predictor = SamPredictor(sam) - print("======> Obtain Location Prior" ) + print("======> Obtain Location Prior") # Image features encoding ref_mask = predictor.set_image(ref_image, ref_mask) ref_feat = predictor.features.squeeze().permute(1, 2, 0) @@ -92,10 +94,9 @@ def persam(args, obj_name, images_path, masks_path, output_path): target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True) target_embedding = target_embedding.unsqueeze(0) - print('======> Start Testing') for test_idx in tqdm(range(len(os.listdir(test_images_path)))): - + # Load test image test_idx = '%02d' % test_idx test_image_path = test_images_path + '/' + test_idx + '.jpg' @@ -115,9 +116,9 @@ def persam(args, obj_name, images_path, masks_path, output_path): sim = sim.reshape(1, 1, h, w) sim = F.interpolate(sim, scale_factor=4, mode="bilinear") sim = predictor.model.postprocess_masks( - sim, - input_size=predictor.input_size, - original_size=predictor.original_size).squeeze() + sim, + input_size=predictor.input_size, + original_size=predictor.original_size).squeeze() # Positive-negative location prior topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1) @@ -131,8 +132,8 @@ def persam(args, obj_name, images_path, masks_path, output_path): # First-step prediction masks, scores, logits, _ = predictor.predict( - point_coords=topk_xy, - point_labels=topk_label, + point_coords=topk_xy, + point_labels=topk_label, multimask_output=False, attn_sim=attn_sim, # Target-guided Attention target_embedding=target_embedding # Target-semantic Prompting @@ -141,10 +142,10 @@ def persam(args, obj_name, images_path, masks_path, output_path): # Cascaded Post-refinement-1 masks, scores, logits, _ = predictor.predict( - point_coords=topk_xy, - point_labels=topk_label, - mask_input=logits[best_idx: best_idx + 1, :, :], - multimask_output=True) + point_coords=topk_xy, + point_labels=topk_label, + mask_input=logits[best_idx: best_idx + 1, :, :], + multimask_output=True) best_idx = np.argmax(scores) # Cascaded Post-refinement-2 @@ -158,7 +159,7 @@ def persam(args, obj_name, images_path, masks_path, output_path): point_coords=topk_xy, point_labels=topk_label, box=input_box[None, :], - mask_input=logits[best_idx: best_idx + 1, :, :], + mask_input=logits[best_idx: best_idx + 1, :, :], multimask_output=True) best_idx = np.argmax(scores) @@ -189,7 +190,7 @@ def point_selection(mask_sim, topk=1): topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) topk_label = np.array([1] * topk) topk_xy = topk_xy.cpu().numpy() - + # Top-last point selection last_xy = mask_sim.flatten(0).topk(topk, largest=False)[1] last_x = (last_xy // h).unsqueeze(0) @@ -197,9 +198,9 @@ def point_selection(mask_sim, topk=1): last_xy = torch.cat((last_y, last_x), dim=0).permute(1, 0) last_label = np.array([0] * topk) last_xy = last_xy.cpu().numpy() - + return topk_xy, topk_label, last_xy, last_label - + if __name__ == "__main__": main() diff --git a/persam_f.py b/persam_f.py index 443d8fc..a199dee 100644 --- a/persam_f.py +++ b/persam_f.py @@ -11,25 +11,29 @@ import warnings warnings.filterwarnings('ignore') -from show import * from per_segment_anything import sam_model_registry, SamPredictor +from show import * - +# Priority is cuda > mps > cpu +DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else + 'mps' if torch.backends.mps.is_available() else + 'cpu') def get_arguments(): - + parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, default='./data') parser.add_argument('--outdir', type=str, default='persam_f') + parser.add_argument('--device', type=str, default=DEFAULT_DEVICE) parser.add_argument('--ckpt', type=str, default='./sam_vit_h_4b8939.pth') parser.add_argument('--sam_type', type=str, default='vit_h') - parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--train_epoch', type=int, default=1000) parser.add_argument('--log_epoch', type=int, default=200) parser.add_argument('--ref_idx', type=str, default='00') - + args = parser.parse_args() return args @@ -45,16 +49,16 @@ def main(): if not os.path.exists('./outputs/'): os.mkdir('./outputs/') - + for obj_name in os.listdir(images_path): if ".DS" not in obj_name: persam_f(args, obj_name, images_path, masks_path, output_path) def persam_f(args, obj_name, images_path, masks_path, output_path): - + print("\n------------> Segment " + obj_name) - + # Path preparation ref_image_path = os.path.join(images_path, obj_name, args.ref_idx + '.jpg') ref_mask_path = os.path.join(masks_path, obj_name, args.ref_idx + '.png') @@ -70,27 +74,23 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): ref_mask = cv2.imread(ref_mask_path) ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) - gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 - gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda() + gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 + gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to(args.device) - - print("======> Load SAM" ) + print("======> Load SAM") if args.sam_type == 'vit_h': sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(args.device) elif args.sam_type == 'vit_t': sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' - device = "cuda" if torch.cuda.is_available() else "cpu" - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=args.device) sam.eval() - - + for name, param in sam.named_parameters(): param.requires_grad = False predictor = SamPredictor(sam) - - print("======> Obtain Self Location Prior" ) + print("======> Obtain Self Location Prior") # Image features encoding ref_mask = predictor.set_image(ref_image, ref_mask) ref_feat = predictor.features.squeeze().permute(1, 2, 0) @@ -114,19 +114,18 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): sim = sim.reshape(1, 1, h, w) sim = F.interpolate(sim, scale_factor=4, mode="bilinear") sim = predictor.model.postprocess_masks( - sim, - input_size=predictor.input_size, - original_size=predictor.original_size).squeeze() + sim, + input_size=predictor.input_size, + original_size=predictor.original_size).squeeze() # Positive location prior topk_xy, topk_label = point_selection(sim, topk=1) - print('======> Start Training') # Learnable mask weights - mask_weights = Mask_Weights().cuda() + mask_weights = Mask_Weights().to(args.device) mask_weights.train() - + optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch) @@ -158,7 +157,6 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): current_lr = scheduler.get_last_lr()[0] print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item())) - mask_weights.eval() weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) weights_np = weights.detach().cpu().numpy() @@ -186,18 +184,18 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): sim = sim.reshape(1, 1, h, w) sim = F.interpolate(sim, scale_factor=4, mode="bilinear") sim = predictor.model.postprocess_masks( - sim, - input_size=predictor.input_size, - original_size=predictor.original_size).squeeze() + sim, + input_size=predictor.input_size, + original_size=predictor.original_size).squeeze() # Positive location prior topk_xy, topk_label = point_selection(sim, topk=1) # First-step prediction masks, scores, logits, logits_high = predictor.predict( - point_coords=topk_xy, - point_labels=topk_label, - multimask_output=True) + point_coords=topk_xy, + point_labels=topk_label, + multimask_output=True) # Weighted sum three-scale masks logits_high = logits_high * weights.unsqueeze(-1) @@ -236,7 +234,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): mask_input=logits[best_idx: best_idx + 1, :, :], multimask_output=True) best_idx = np.argmax(scores) - + # Save masks plt.figure(figsize=(10, 10)) plt.imshow(test_image) @@ -270,11 +268,11 @@ def point_selection(mask_sim, topk=1): topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) topk_label = np.array([1] * topk) topk_xy = topk_xy.cpu().numpy() - + return topk_xy, topk_label -def calculate_dice_loss(inputs, targets, num_masks = 1): +def calculate_dice_loss(inputs, targets, num_masks=1): """ Compute the DICE loss, similar to generalized IOU for masks Args: @@ -292,7 +290,7 @@ def calculate_dice_loss(inputs, targets, num_masks = 1): return loss.sum() / num_masks -def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2): +def calculate_sigmoid_focal_loss(inputs, targets, num_masks=1, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: diff --git a/persam_f_multi_obj.py b/persam_f_multi_obj.py index 22d8a9c..7c022ad 100644 --- a/persam_f_multi_obj.py +++ b/persam_f_multi_obj.py @@ -11,18 +11,22 @@ import warnings warnings.filterwarnings('ignore') -from show import * from per_segment_anything import sam_model_registry, SamPredictor +from show import * +# Priority is cuda > mps > cpu +DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else + 'mps' if torch.backends.mps.is_available() else + 'cpu') def get_arguments(): - parser = argparse.ArgumentParser() parser.add_argument('--data', type=str, default='./data') parser.add_argument('--outdir', type=str, default='persam_f') parser.add_argument('--ckpt', type=str, default='./sam_vit_h_4b8939.pth') + parser.add_argument('--device', type=str, default=DEFAULT_DEVICE) parser.add_argument('--sam_type', type=str, default='vit_h') parser.add_argument('--lr', type=int, default=1e-3) @@ -30,10 +34,10 @@ def get_arguments(): parser.add_argument('--train_epoch_inside', type=int, default=200) parser.add_argument('--log_epoch', type=int, default=200) parser.add_argument('--training_percentage', type=float, default=0.5) - + parser.add_argument('--max_objects', type=int, default=10) parser.add_argument('--iou_threshold', type=float, default=0.8) - + args = parser.parse_args() return args @@ -47,27 +51,23 @@ def main(): masks_path = args.data + '/Annotations/' output_path = './outputs/' + args.outdir - - if not os.path.exists('./outputs/'): os.mkdir('./outputs/') - + for obj_name in os.listdir(images_path): persam_f(args, obj_name, images_path, masks_path, output_path) def persam_f(args, obj_name, images_path, masks_path, output_path): - print("======> Load SAM" ) + print("======> Load SAM") if args.sam_type == 'vit_h': sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(args.device) elif args.sam_type == 'vit_t': sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' - device = "cuda" if torch.cuda.is_available() else "cpu" - sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) + sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=args.device) sam.eval() - - + for name, param in sam.named_parameters(): param.requires_grad = False predictor = SamPredictor(sam) @@ -76,7 +76,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): for i in tqdm(range(args.train_epoch_outside)): output_path = os.path.join(output_path, obj_name) os.makedirs(output_path, exist_ok=True) - training_size = int(len(os.listdir(os.path.join(images_path, obj_name))) * args.training_percentage) + training_size = int(len(os.listdir(os.path.join(images_path, obj_name))) * args.training_percentage) for ref_idx in range(training_size): # Path preparation ref_image_path = os.path.join(images_path, obj_name, '{:02}.jpg'.format(ref_idx)) @@ -90,9 +90,9 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): ref_mask = cv2.imread(ref_mask_path) ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) - gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 - gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda() - + gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 + gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to(args.device) + # print("======> Obtain Self Location Prior" ) # Image features encoding ref_mask = predictor.set_image(ref_image, ref_mask) @@ -117,19 +117,18 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): sim = sim.reshape(1, 1, h, w) sim = F.interpolate(sim, scale_factor=4, mode="bilinear") sim = predictor.model.postprocess_masks( - sim, - input_size=predictor.input_size, - original_size=predictor.original_size).squeeze() + sim, + input_size=predictor.input_size, + original_size=predictor.original_size).squeeze() # Positive location prior topk_xy, topk_label = point_selection(sim, topk=1) - # print('======> Start Training') # Learnable mask weights - mask_weights = Mask_Weights().cuda() + mask_weights = Mask_Weights().to(args.device) mask_weights.train() - + optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch_inside) @@ -158,7 +157,6 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): # print('Train Epoch: {:} / {:}'.format(train_idx, args.train_epoch_inside)) current_lr = scheduler.get_last_lr()[0] - mask_weights.eval() weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) weights_np = weights.detach().cpu().numpy() @@ -175,7 +173,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB) test_image_original = cv2.imread(test_image_path) test_image_original = cv2.cvtColor(test_image_original, cv2.COLOR_BGR2RGB) - + history_masks = [] plt.figure(figsize=(10, 10)) for i in tqdm(range(args.max_objects)): @@ -192,18 +190,18 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): sim = sim.reshape(1, 1, h, w) sim = F.interpolate(sim, scale_factor=4, mode="bilinear") sim = predictor.model.postprocess_masks( - sim, - input_size=predictor.input_size, - original_size=predictor.original_size).squeeze() + sim, + input_size=predictor.input_size, + original_size=predictor.original_size).squeeze() # Positive location prior topk_xy, topk_label = point_selection(sim, topk=1) # First-step prediction masks, scores, logits, logits_high = predictor.predict( - point_coords=topk_xy, - point_labels=topk_label, - multimask_output=True) + point_coords=topk_xy, + point_labels=topk_label, + multimask_output=True) # Weighted sum three-scale masks logits_high = logits_high * weights.unsqueeze(-1) @@ -243,7 +241,6 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): multimask_output=True) best_idx = np.argmax(scores) - final_mask = masks[best_idx] mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8) mask_colors[final_mask, :] = np.array([[0, 0, 128]]) @@ -261,7 +258,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): show_points(topk_xy, topk_label, plt.gca()) history_masks.append(mask_colors) # Save masks - + plt.imshow(test_image_original) vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}_objects:{len(history_masks)}.jpg') with open(vis_mask_output_path, 'wb') as outfile: @@ -271,7 +268,6 @@ def persam_f(args, obj_name, images_path, masks_path, output_path): cv2.imwrite(mask_output_path, mask_colors) - class Mask_Weights(nn.Module): def __init__(self): super().__init__() @@ -287,11 +283,11 @@ def point_selection(mask_sim, topk=1): topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) topk_label = np.array([1] * topk) topk_xy = topk_xy.cpu().numpy() - + return topk_xy, topk_label -def calculate_dice_loss(inputs, targets, num_masks = 1): +def calculate_dice_loss(inputs, targets, num_masks=1): """ Compute the DICE loss, similar to generalized IOU for masks Args: @@ -309,7 +305,7 @@ def calculate_dice_loss(inputs, targets, num_masks = 1): return loss.sum() / num_masks -def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2): +def calculate_sigmoid_focal_loss(inputs, targets, num_masks=1, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: @@ -336,6 +332,7 @@ def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = return loss.mean(1).sum() / num_masks + def calculate_iou(mask1, mask2): """ Calculate the Intersection over Union (IoU) score between two masks. @@ -358,5 +355,6 @@ def calculate_iou(mask1, mask2): iou = intersection / union return iou + if __name__ == '__main__': main() From 7b156eecd2db9725c6a34e81496abe401f57bbc6 Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Tue, 11 Jul 2023 16:27:00 +0200 Subject: [PATCH 2/2] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index b0d5dcd..bd452d2 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Official implementation of ['Personalize Segment Anything Model with One Shot']( ## News +* MPS (Metal Performance Shader) support added 🔥 Faster performance on apple silicon devices. * Support [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) 🔥 with significant efficiency improvement. Thanks for their wonderful work! * **TODO**: Release the PerSAM-assisted [Dreambooth](https://arxiv.org/pdf/2208.12242.pdf) for better fine-tuning [Stable Diffusion](https://github.com/CompVis/stable-diffusion) 📌. * We release the code of PerSAM and PerSAM-F 🔥. Check our [video](https://www.youtube.com/watch?v=QlunvXpYQXM) here! @@ -83,6 +84,7 @@ For **Multi-Object** segmentation of the same category by PerSAM-F (Great thanks python persam_f_multi_obj.py --sam_type --outdir ``` +Specify device to use with `--device` currently supports `cpu, cuda, mps (apple silicon)`. Will default to `cuda` and `mps` when available. After running, the output masks and visualizations will be stored at `outputs/`. ### Evaluation