Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine tuning SAM, 'The input_points must be a 3D tensor. Of shape batch_size, nb_boxes, 4.', ' got torch.Size([2, 4]).' #93

Open
leemorton opened this issue Nov 21, 2024 · 0 comments

Comments

@leemorton
Copy link

Hi,

I am trying to fine tune SAM on custom images and masks but am struggling and am hoping someone can point me in the right direction to resolving it.

I have been referencing 331_fine_tune_SAM_mito.ipynb

I cannot get the training to work as I get this message at the forward pass step:
'The input_points must be a 3D tensor. Of shape batch_size, nb_boxes, 4.', ' got torch.Size([2, 4]).'

I think the input_boxes is wrong somehow?

image

The images I am using are colour PNG images rather than the tiff images in the reference code and are showing with 3 channels here....
image

My SamDataset code is:

class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["label"])
    
    # get bounding box prompt
    # prompt = get_bounding_box(ground_truth_mask)
    prompt = item["bounding_box"]

    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

and this is where I run into trouble...
image

@leemorton leemorton changed the title Fine tuning SAM, 'The input_points must be a 3D tensor. Of shape batch_size, nb_boxes, 4.', ' got torch.Size([2, 4]).' #34862 Fine tuning SAM, 'The input_points must be a 3D tensor. Of shape batch_size, nb_boxes, 4.', ' got torch.Size([2, 4]).' Nov 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant