Standardize sample dict keys #985
Replies: 6 comments 12 replies
-
I'm okay with standardizing on |
Beta Was this translation helpful? Give feedback.
-
@isaaccorley correctly pointed out that augs = AugmentationSequential(...)
augs(sample["image"], sample["mask"], data_keys=["input", "mask"]) so there's actually no reason to force ourselves to match our keys to Kornia's. I think we should:
Does that sound reasonable? I agree that "input" isn't a good variable name, it's a builtin function in Python and it's too generic. |
Beta Was this translation helpful? Give feedback.
-
We can actually standardize the dtypes at the same time. I propose: class Sample(TypedDict, total=False):
input: FloatTensor
label: LongTensor
mask: LongTensor
bbox: LongTensor
bbox_xyxy: LongTensor
bbox_xywh: LongTensor
keypoints: LongTensor Kornia still requires everything to be float, but torchmetrics and torchvision/timm seem to require all targets to be long. |
Beta Was this translation helpful? Give feedback.
-
While we're at it, we should also standardize all transforms/data augmentation: SampleTransform = Callable[[Sample], Sample]
BatchTransform = Callable[[Batch], Batch] |
Beta Was this translation helpful? Give feedback.
-
In #1997 we note that some datasets contain both regression and classification labels, but these can't both be returned since they both use the same "label" key. We should consider using a different key for each type of task such that both classification and regression can be supported for the same dataset. |
Beta Was this translation helpful? Give feedback.
-
We can make subclasses for these: https://stackoverflow.com/a/71814659/5828163 |
Beta Was this translation helpful? Give feedback.
-
Summary
In TorchGeo, each
Dataset
returns a sample dictionary with key/value pairs for each object. I propose we standardize the names of these keys and the types and dimensions of their corresponding values as follows:I also propose we standardize batch dictionaries (mini-batches of samples) using the exact same definition, but with an additional batch dimension in front.
Rationale
Our current keys are not uniform (bbox and boxes, label and labels):
This makes it difficult to create dataset-independent trainers that handle multiple datamodules.
We would also like to be able to type check each of these so we don't have a mix of ints/floats and Tensors (e.g., for label).
TorchGeo relies heavily on Kornia for data augmentation. Kornia provides an AugmentationSequential container for composing and applying transforms. Currently, we create our own wrapper around AugmentationSequential that maps our keys to the keys expected by Kornia, but a better solution would be to standardize on the same keys that Kornia uses.
Implementation
Newer versions of Python introduce the TypedDict type. I propose we use this (and add a dependency on typing-extensions for older versions of Python). The code would look like:
Each dataset
__getitem__
would return an object of typeSample
, while collation functions would take inList[Sample]
and returnBatch
. This would replace our currentDict[str, Tensor]
type hints and be much more strict.There are a few important caveats to note here:
total=False
, all keys are optional.labels
andboxes
keys again.Batch
will have its own separate implementation. For now, the difference is only semantic.Alternatives
The keys allowed by Kornia are defined in the kornia.constants.DataKey enum. Unfortunately, I couldn't find any documentation or discussion on this enum. Kornia does not standardize "label", so we could alternatively call this "target".
Kornia isn't the only library for data augmentation. There is also Albumentations, which seems to standardize on:
Torchvision does not use sample dicts, so they don't have a standard of their own. Torchvision uses PIL, while Albumentations uses OpenCV, neither of which support MSI or GPU, so I think we're better off matching the Kornia standard.
The final alternative is to convince Kornia to change their key names to match ours. I think this is unlikely, but @isaaccorley is more optimistic.
Additional information
We will likely need many additional keys for things that don't fall into this list. How strict should we be about this? It would be really nice if PyTorch's default collate function worked. We currently roll our own solely because we need to include things like rasterio.crs.CRS which isn't a type that the default function knows how to handle.
How should we handle predictions? We currently prefix with
prediction_
.How should we handle multi-label problems? The dimensions of these would be different. Should we use a different key?
Should we also standardize dtypes? All Kornia transforms require {float16, float32, float64} as input, while all torchmetrics metrics require int? I can't actually find evidence of the latter, but we have several datamodules which assert this. Depending on if labels are for classification or regression, the type can and should change. We could either have separate keys for these, or ignore dtype for these.
Beta Was this translation helpful? Give feedback.
All reactions