Skip to content

Commit

Permalink
Port Faster R-CNN to Keras3 (#2458)
Browse files Browse the repository at this point in the history
* Base structure for faster rcnn till rpn head

* Add export for Faster RNN

* add init file

* initalize faster rcnn at model level

* code fix fo roi align

* Forward Pass code for Faster R-CNN

* Faster RCNN Base code for Keras3(Draft-1)

* Add local batch size

* Add parameters to RPN Head

* Make FPN more customizable with parameters and remove redudant code

* Compute output shape for ROI Generator

* Faster RCNN functional model with required import corrections

* add clip boxes to forward pass

* add prediction decoder and use "yxyx" as default internal bounding box format

* feature pryamid correction

* change ops.divide to ops.divide_no_nan

* use from logits=True for Non Max supression

* include box convertions for both rois and ground truth boxes

* Change number of detections in decoder

* Use categoricalcrossentropy to avoid -1 class error + added get_config for model saving

* add basic test cases + linting

* Add seed generator for sampling in RPN label encoding and ROI sampling layers

* Use only spatial dimension for ops.nn.avg_pool + use ops.convert_to_tensor for list type + linting

* Convert list to tensor using keras ops

* Remove seed number from seed generator

* Remove print and add proper comments

* - Use stddev(0.01) as per paper across RPN and R-CNN Heads
- Maxpool2d as per torch implementation in FPN
-  Add prediction decoder

* - Fixes slice for multi backend
- Slice for tensorflow can use [-1, -1, -1] for shape but not jax and torch, they should have explicit shape

* - Add compute metrics method

* Correct test cases and add missing args

* Fix lint issues

* - Fix lint and remove hard coded params to make it user friendly.

* - Generate ROI's while decoding for predictions
- Liniting + Test Cases

* - Add faster rcnn to build method

* - Test only for Keras3

* - Correct test case
- Add copyright

* - Correct the test cases decorator to skip for Keras2

* - Skip Legacy test cases
- Fix ROI Align ops for torch backend

* - Remove unecessary import in legacy code to fix lint

* - Correct pytest complexity
- Make bounding box test utils use 256,256 image size

* - FIx Image Shape to 512, 512 default which will not break other test cases

* - Lower image sizes for test cases
- Add build method for fpn

* - fix keras to 3.3.3 version

* - Generate api
- Correct YOLOv8 preset test case

* - Lint fix

* - Increase the atol, rtol for YOLOv8 Detector forward pass
  • Loading branch information
sineeli authored Aug 20, 2024
1 parent 3d417ea commit 2f9eb86
Show file tree
Hide file tree
Showing 27 changed files with 2,085 additions and 85 deletions.
1 change: 1 addition & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ jobs:
keras_cv/src/models/classification \
keras_cv/src/models/object_detection/retinanet \
keras_cv/src/models/object_detection/yolo_v8 \
keras_cv/src/models/object_detection/faster_rcnn \
keras_cv/src/models/object_detection_3d \
keras_cv/src/models/segmentation \
--durations 0
Expand Down
5 changes: 5 additions & 0 deletions .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,23 @@ then
pip install -r requirements-tensorflow-cuda.txt --progress-bar off --timeout 1000
pip install keras-nlp-nightly --no-deps
pip install tensorflow-text~=2.16.0
pip install keras~=3.3.3

elif [ "$KERAS_BACKEND" == "jax" ]
then
echo "JAX backend detected."
pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000
pip install keras-nlp-nightly --no-deps
pip install tensorflow-text~=2.16.0
pip install keras~=3.3.3

elif [ "$KERAS_BACKEND" == "torch" ]
then
echo "PyTorch backend detected."
pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000
pip install keras-nlp-nightly --no-deps
pip install tensorflow-text~=2.16.0
pip install keras~=3.3.3
fi

pip install --no-deps -e "." --progress-bar off
Expand All @@ -67,6 +70,7 @@ then
keras_cv/src/models/classification \
keras_cv/src/models/object_detection/retinanet \
keras_cv/src/models/object_detection/yolo_v8 \
keras_cv/src/models/object_detection/faster_rcnn \
keras_cv/src/models/object_detection_3d \
keras_cv/src/models/segmentation \
keras_cv/src/models/feature_extractor/clip \
Expand All @@ -82,6 +86,7 @@ else
keras_cv/src/models/classification \
keras_cv/src/models/object_detection/retinanet \
keras_cv/src/models/object_detection/yolo_v8 \
keras_cv/src/models/object_detection/faster_rcnn \
keras_cv/src/models/object_detection_3d \
keras_cv/src/models/segmentation \
keras_cv/src/models/feature_extractor/clip \
Expand Down
4 changes: 4 additions & 0 deletions keras_cv/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from keras_cv.api.models import classification
from keras_cv.api.models import faster_rcnn
from keras_cv.api.models import feature_extractor
from keras_cv.api.models import object_detection
from keras_cv.api.models import retinanet
Expand Down Expand Up @@ -205,6 +206,9 @@
from keras_cv.src.models.classification.image_classifier import ImageClassifier
from keras_cv.src.models.classification.video_classifier import VideoClassifier
from keras_cv.src.models.feature_extractor.clip.clip_model import CLIP
from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import (
FasterRCNN,
)
from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_backbone import (
YOLOV8Backbone,
Expand Down
11 changes: 11 additions & 0 deletions keras_cv/api/models/faster_rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""DO NOT EDIT.
This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras_cv.src.models.object_detection.faster_rcnn.feature_pyramid import (
FeaturePyramid,
)
from keras_cv.src.models.object_detection.faster_rcnn.rcnn_head import RCNNHead
from keras_cv.src.models.object_detection.faster_rcnn.rpn_head import RPNHead
3 changes: 3 additions & 0 deletions keras_cv/api/models/object_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
since your modifications would be overwritten.
"""

from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import (
FasterRCNN,
)
from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_detector import (
YOLOV8Detector,
Expand Down
2 changes: 1 addition & 1 deletion keras_cv/src/bounding_box/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _clip_boxes(boxes, box_format, image_shape):

if isinstance(image_shape, list) or isinstance(image_shape, tuple):
height, width, _ = image_shape
max_length = [height, width, height, width]
max_length = ops.stack([height, width, height, width], axis=-1)
else:
image_shape = ops.cast(image_shape, dtype=boxes.dtype)
height = image_shape[0]
Expand Down
51 changes: 33 additions & 18 deletions keras_cv/src/layers/object_detection/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x):
features,
[batch_size * num_boxes, output_size * 2, output_size * 2, num_filters],
)
features = ops.nn.average_pool(
features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID"
)
features = ops.nn.average_pool(features, (2, 2), (2, 2), "VALID")
features = ops.reshape(
features, [batch_size, num_boxes, output_size, output_size, num_filters]
)
Expand Down Expand Up @@ -242,6 +240,11 @@ def multilevel_crop_and_resize(
for i in range(len(feature_widths) - 1):
level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i])
batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1]

level_dim_offsets = ops.convert_to_tensor(level_dim_offsets)
feature_widths = ops.convert_to_tensor(feature_widths)
feature_heights = ops.convert_to_tensor(feature_heights)

level_dim_offsets = (
ops.ones_like(level_dim_offsets, dtype="int32") * level_dim_offsets
)
Expand All @@ -259,7 +262,9 @@ def multilevel_crop_and_resize(
# following the FPN paper to divide by 224.
levels = ops.cast(
ops.floor_divide(
ops.log(ops.divide(areas_sqrt, 224.0)),
ops.log(
ops.divide_no_nan(areas_sqrt, ops.convert_to_tensor(224.0))
),
ops.log(2.0),
)
+ 4.0,
Expand Down Expand Up @@ -292,12 +297,18 @@ def multilevel_crop_and_resize(
ops.concatenate(
[
ops.expand_dims(
[[ops.cast(max_feature_height, "float32")]] / level_strides
ops.convert_to_tensor(
[[ops.cast(max_feature_height, "float32")]]
)
/ level_strides
- 1,
axis=-1,
),
ops.expand_dims(
[[ops.cast(max_feature_width, "float32")]] / level_strides
ops.convert_to_tensor(
[[ops.cast(max_feature_width, "float32")]]
)
/ level_strides
- 1,
axis=-1,
),
Expand Down Expand Up @@ -357,7 +368,7 @@ def multilevel_crop_and_resize(
# TODO(tanzhenyu): replace tf.gather with tf.gather_nd and try to get
# similar performance.
features_per_box = ops.reshape(
ops.take(features_r2, indices),
ops.take(features_r2, indices, axis=0),
[
batch_size,
num_boxes,
Expand All @@ -378,7 +389,7 @@ def multilevel_crop_and_resize(
# performance as this is mostly a duplicate of
# https://github.com/tensorflow/models/blob/master/official/legacy/detection/ops/spatial_transform_ops.py#L324
@keras.utils.register_keras_serializable(package="keras_cv")
class _ROIAligner(keras.layers.Layer):
class ROIAligner(keras.layers.Layer):
"""Performs ROIAlign for the second stage processing."""

def __init__(
Expand All @@ -397,13 +408,11 @@ def __init__(
sample_offset: A `float` in [0, 1] of the subpixel sample offset.
**kwargs: Additional keyword arguments passed to Layer.
"""
# assert_tf_keras("keras_cv.layers._ROIAligner")
self._config_dict = {
"bounding_box_format": bounding_box_format,
"crop_size": target_size,
"sample_offset": sample_offset,
}
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.target_size = target_size
self.sample_offset = sample_offset
self.built = True

def call(
self,
Expand All @@ -427,16 +436,22 @@ def call(
"""
boxes = bounding_box.convert_format(
boxes,
source=self._config_dict["bounding_box_format"],
source=self.bounding_box_format,
target="yxyx",
)
roi_features = multilevel_crop_and_resize(
features,
boxes,
output_size=self._config_dict["crop_size"],
sample_offset=self._config_dict["sample_offset"],
output_size=self.target_size,
sample_offset=self.sample_offset,
)
return roi_features

def get_config(self):
return self._config_dict
config = super().get_config()
config["bounding_box_format"] = self.bounding_box_format
config["target_size"] = self.target_size
config["sample_offset"] = self.sample_offset

def compute_output_shape(self, input_shape):
return (None, None, self.target_size, self.target_size, 256)
8 changes: 7 additions & 1 deletion keras_cv/src/layers/object_detection/roi_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class ROIGenerator(keras.layers.Layer):
applying NMS in inference mode. When RPN is run on multiple
feature maps / levels (as in FPN) this number is per
feature map / level.
nms_from_logits: bool. True means input score is logits, False means confidence.
Example:
```python
Expand All @@ -90,6 +91,7 @@ def __init__(
nms_score_threshold_test: float = 0.0,
nms_iou_threshold_test: float = 0.7,
post_nms_topk_test: int = 1000,
nms_from_logits: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -102,6 +104,7 @@ def __init__(
self.nms_score_threshold_test = nms_score_threshold_test
self.nms_iou_threshold_test = nms_iou_threshold_test
self.post_nms_topk_test = post_nms_topk_test
self.nms_from_logits = nms_from_logits
self.built = True

def call(
Expand Down Expand Up @@ -158,7 +161,7 @@ def per_level_gen(boxes, scores):
# TODO(tanzhenyu): consider supporting soft / batched nms for accl
boxes = NonMaxSuppression(
bounding_box_format=self.bounding_box_format,
from_logits=False,
from_logits=self.nms_from_logits,
iou_threshold=nms_iou_threshold,
confidence_threshold=nms_score_threshold,
max_detections=level_post_nms_topk,
Expand Down Expand Up @@ -191,6 +194,9 @@ def per_level_gen(boxes, scores):

return rois, roi_scores

def compute_output_shape(self, input_shape):
return (None, None, 4), (None, None, 1)

def get_config(self):
config = {
"bounding_box_format": self.bounding_box_format,
Expand Down
44 changes: 24 additions & 20 deletions keras_cv/src/layers/object_detection/roi_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@keras.utils.register_keras_serializable(package="keras_cv")
class _ROISampler(keras.layers.Layer):
class ROISampler(keras.layers.Layer):
"""
Sample ROIs for loss related calculation.
Expand All @@ -41,9 +41,10 @@ class _ROISampler(keras.layers.Layer):
if its range is [0, num_classes).
Args:
bounding_box_format: The format of bounding boxes to generate. Refer
roi_bounding_box_format: The format of roi bounding boxes. Refer
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
for more details on supported bounding box formats.
gt_bounding_box_format: The format of ground truth bounding boxes.
roi_matcher: a `BoxMatcher` object that matches proposals with ground
truth boxes. The positive match must be 1 and negative match must be -1.
Such assumption is not being validated here.
Expand All @@ -59,7 +60,8 @@ class _ROISampler(keras.layers.Layer):

def __init__(
self,
bounding_box_format: str,
roi_bounding_box_format: str,
gt_bounding_box_format: str,
roi_matcher: box_matcher.BoxMatcher,
positive_fraction: float = 0.25,
background_class: int = 0,
Expand All @@ -68,12 +70,14 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.roi_bounding_box_format = roi_bounding_box_format
self.gt_bounding_box_format = gt_bounding_box_format
self.roi_matcher = roi_matcher
self.positive_fraction = positive_fraction
self.background_class = background_class
self.num_sampled_rois = num_sampled_rois
self.append_gt_boxes = append_gt_boxes
self.seed_generator = keras.random.SeedGenerator()
self.built = True
# for debugging.
self._positives = keras.metrics.Mean()
Expand All @@ -97,6 +101,12 @@ def call(
sampled_gt_classes: [batch_size, num_sampled_rois, 1]
sampled_class_weights: [batch_size, num_sampled_rois, 1]
"""
rois = bounding_box.convert_format(
rois, source=self.roi_bounding_box_format, target="yxyx"
)
gt_boxes = bounding_box.convert_format(
gt_boxes, source=self.gt_bounding_box_format, target="yxyx"
)
if self.append_gt_boxes:
# num_rois += num_gt
rois = ops.concatenate([rois, gt_boxes], axis=1)
Expand All @@ -110,12 +120,6 @@ def call(
"num_rois must be less than `num_sampled_rois` "
f"({self.num_sampled_rois}), got {num_rois}"
)
rois = bounding_box.convert_format(
rois, source=self.bounding_box_format, target="yxyx"
)
gt_boxes = bounding_box.convert_format(
gt_boxes, source=self.bounding_box_format, target="yxyx"
)
# [batch_size, num_rois, num_gt]
similarity_mat = iou.compute_iou(
rois, gt_boxes, bounding_box_format="yxyx", use_masking=True
Expand Down Expand Up @@ -171,6 +175,7 @@ def call(
negative_matches,
self.num_sampled_rois,
self.positive_fraction,
seed=self.seed_generator,
)
# [batch_size, num_sampled_rois] in the range of [0, num_rois)
sampled_indicators, sampled_indices = ops.top_k(
Expand Down Expand Up @@ -204,16 +209,15 @@ def call(
)

def get_config(self):
config = {
"bounding_box_format": self.bounding_box_format,
"positive_fraction": self.positive_fraction,
"background_class": self.background_class,
"num_sampled_rois": self.num_sampled_rois,
"append_gt_boxes": self.append_gt_boxes,
"roi_matcher": self.roi_matcher.get_config(),
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
config = super().get_config()
config["roi_bounding_box_format"] = self.roi_bounding_box_format
config["gt_bounding_box_format"] = self.gt_bounding_box_format
config["positive_fraction"] = self.positive_fraction
config["background_class"] = self.background_class
config["num_sampled_rois"] = self.num_sampled_rois
config["append_gt_boxes"] = self.append_gt_boxes
config["roi_matcher"] = self.roi_matcher.get_config()
return config

@classmethod
def from_config(cls, config, custom_objects=None):
Expand Down
Loading

0 comments on commit 2f9eb86

Please sign in to comment.