From 2f9eb86577b1ea76e15953e7d14d6be5fa21121b Mon Sep 17 00:00:00 2001 From: Siva Sravana Kumar Neeli <113718461+sineeli@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:43:15 -0700 Subject: [PATCH] Port Faster R-CNN to Keras3 (#2458) * Base structure for faster rcnn till rpn head * Add export for Faster RNN * add init file * initalize faster rcnn at model level * code fix fo roi align * Forward Pass code for Faster R-CNN * Faster RCNN Base code for Keras3(Draft-1) * Add local batch size * Add parameters to RPN Head * Make FPN more customizable with parameters and remove redudant code * Compute output shape for ROI Generator * Faster RCNN functional model with required import corrections * add clip boxes to forward pass * add prediction decoder and use "yxyx" as default internal bounding box format * feature pryamid correction * change ops.divide to ops.divide_no_nan * use from logits=True for Non Max supression * include box convertions for both rois and ground truth boxes * Change number of detections in decoder * Use categoricalcrossentropy to avoid -1 class error + added get_config for model saving * add basic test cases + linting * Add seed generator for sampling in RPN label encoding and ROI sampling layers * Use only spatial dimension for ops.nn.avg_pool + use ops.convert_to_tensor for list type + linting * Convert list to tensor using keras ops * Remove seed number from seed generator * Remove print and add proper comments * - Use stddev(0.01) as per paper across RPN and R-CNN Heads - Maxpool2d as per torch implementation in FPN - Add prediction decoder * - Fixes slice for multi backend - Slice for tensorflow can use [-1, -1, -1] for shape but not jax and torch, they should have explicit shape * - Add compute metrics method * Correct test cases and add missing args * Fix lint issues * - Fix lint and remove hard coded params to make it user friendly. * - Generate ROI's while decoding for predictions - Liniting + Test Cases * - Add faster rcnn to build method * - Test only for Keras3 * - Correct test case - Add copyright * - Correct the test cases decorator to skip for Keras2 * - Skip Legacy test cases - Fix ROI Align ops for torch backend * - Remove unecessary import in legacy code to fix lint * - Correct pytest complexity - Make bounding box test utils use 256,256 image size * - FIx Image Shape to 512, 512 default which will not break other test cases * - Lower image sizes for test cases - Add build method for fpn * - fix keras to 3.3.3 version * - Generate api - Correct YOLOv8 preset test case * - Lint fix * - Increase the atol, rtol for YOLOv8 Detector forward pass --- .github/workflows/actions.yml | 1 + .kokoro/github/ubuntu/gpu/build.sh | 5 + keras_cv/api/models/__init__.py | 4 + keras_cv/api/models/faster_rcnn/__init__.py | 11 + .../api/models/object_detection/__init__.py | 3 + keras_cv/src/bounding_box/utils.py | 2 +- .../src/layers/object_detection/roi_align.py | 51 +- .../layers/object_detection/roi_generator.py | 8 +- .../layers/object_detection/roi_sampler.py | 44 +- .../object_detection/roi_sampler_test.py | 48 +- .../object_detection/rpn_label_encoder.py | 4 +- .../rpn_label_encoder_test.py | 13 +- .../src/layers/object_detection/sampling.py | 10 +- keras_cv/src/models/__init__.py | 3 + .../faster_rcnn/faster_rcnn.py | 12 +- .../faster_rcnn/faster_rcnn_test.py | 17 +- .../models/object_detection/__test_utils__.py | 6 +- .../object_detection/faster_rcnn/__init__.py | 18 + .../faster_rcnn/faster_rcnn.py | 807 ++++++++++++++++++ .../faster_rcnn/faster_rcnn_test.py | 346 ++++++++ .../faster_rcnn/feature_pyamid_test.py | 146 ++++ .../faster_rcnn/feature_pyramid.py | 251 ++++++ .../object_detection/faster_rcnn/rcnn_head.py | 107 +++ .../faster_rcnn/rcnn_head_test.py | 47 + .../object_detection/faster_rcnn/rpn_head.py | 113 +++ .../faster_rcnn/rpn_head_test.py | 89 ++ .../yolo_v8/yolo_v8_detector_test.py | 4 +- 27 files changed, 2085 insertions(+), 85 deletions(-) create mode 100644 keras_cv/api/models/faster_rcnn/__init__.py create mode 100644 keras_cv/src/models/object_detection/faster_rcnn/__init__.py create mode 100644 keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py create mode 100644 keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn_test.py create mode 100644 keras_cv/src/models/object_detection/faster_rcnn/feature_pyamid_test.py create mode 100644 keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py create mode 100644 keras_cv/src/models/object_detection/faster_rcnn/rcnn_head.py create mode 100644 keras_cv/src/models/object_detection/faster_rcnn/rcnn_head_test.py create mode 100644 keras_cv/src/models/object_detection/faster_rcnn/rpn_head.py create mode 100644 keras_cv/src/models/object_detection/faster_rcnn/rpn_head_test.py diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 6b195ea458..e274eb6a34 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -94,6 +94,7 @@ jobs: keras_cv/src/models/classification \ keras_cv/src/models/object_detection/retinanet \ keras_cv/src/models/object_detection/yolo_v8 \ + keras_cv/src/models/object_detection/faster_rcnn \ keras_cv/src/models/object_detection_3d \ keras_cv/src/models/segmentation \ --durations 0 diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index a19b109f82..879f106558 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -36,6 +36,7 @@ then pip install -r requirements-tensorflow-cuda.txt --progress-bar off --timeout 1000 pip install keras-nlp-nightly --no-deps pip install tensorflow-text~=2.16.0 + pip install keras~=3.3.3 elif [ "$KERAS_BACKEND" == "jax" ] then @@ -43,6 +44,7 @@ then pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000 pip install keras-nlp-nightly --no-deps pip install tensorflow-text~=2.16.0 + pip install keras~=3.3.3 elif [ "$KERAS_BACKEND" == "torch" ] then @@ -50,6 +52,7 @@ then pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000 pip install keras-nlp-nightly --no-deps pip install tensorflow-text~=2.16.0 + pip install keras~=3.3.3 fi pip install --no-deps -e "." --progress-bar off @@ -67,6 +70,7 @@ then keras_cv/src/models/classification \ keras_cv/src/models/object_detection/retinanet \ keras_cv/src/models/object_detection/yolo_v8 \ + keras_cv/src/models/object_detection/faster_rcnn \ keras_cv/src/models/object_detection_3d \ keras_cv/src/models/segmentation \ keras_cv/src/models/feature_extractor/clip \ @@ -82,6 +86,7 @@ else keras_cv/src/models/classification \ keras_cv/src/models/object_detection/retinanet \ keras_cv/src/models/object_detection/yolo_v8 \ + keras_cv/src/models/object_detection/faster_rcnn \ keras_cv/src/models/object_detection_3d \ keras_cv/src/models/segmentation \ keras_cv/src/models/feature_extractor/clip \ diff --git a/keras_cv/api/models/__init__.py b/keras_cv/api/models/__init__.py index 97f9bc577b..54be7764b8 100644 --- a/keras_cv/api/models/__init__.py +++ b/keras_cv/api/models/__init__.py @@ -5,6 +5,7 @@ """ from keras_cv.api.models import classification +from keras_cv.api.models import faster_rcnn from keras_cv.api.models import feature_extractor from keras_cv.api.models import object_detection from keras_cv.api.models import retinanet @@ -205,6 +206,9 @@ from keras_cv.src.models.classification.image_classifier import ImageClassifier from keras_cv.src.models.classification.video_classifier import VideoClassifier from keras_cv.src.models.feature_extractor.clip.clip_model import CLIP +from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( + FasterRCNN, +) from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_backbone import ( YOLOV8Backbone, diff --git a/keras_cv/api/models/faster_rcnn/__init__.py b/keras_cv/api/models/faster_rcnn/__init__.py new file mode 100644 index 0000000000..70d632cdb9 --- /dev/null +++ b/keras_cv/api/models/faster_rcnn/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras_cv.src.models.object_detection.faster_rcnn.feature_pyramid import ( + FeaturePyramid, +) +from keras_cv.src.models.object_detection.faster_rcnn.rcnn_head import RCNNHead +from keras_cv.src.models.object_detection.faster_rcnn.rpn_head import RPNHead diff --git a/keras_cv/api/models/object_detection/__init__.py b/keras_cv/api/models/object_detection/__init__.py index baba2a34be..c49389c0b4 100644 --- a/keras_cv/api/models/object_detection/__init__.py +++ b/keras_cv/api/models/object_detection/__init__.py @@ -4,6 +4,9 @@ since your modifications would be overwritten. """ +from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( + FasterRCNN, +) from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_detector import ( YOLOV8Detector, diff --git a/keras_cv/src/bounding_box/utils.py b/keras_cv/src/bounding_box/utils.py index 21525e2ba8..4f7db46299 100644 --- a/keras_cv/src/bounding_box/utils.py +++ b/keras_cv/src/bounding_box/utils.py @@ -141,7 +141,7 @@ def _clip_boxes(boxes, box_format, image_shape): if isinstance(image_shape, list) or isinstance(image_shape, tuple): height, width, _ = image_shape - max_length = [height, width, height, width] + max_length = ops.stack([height, width, height, width], axis=-1) else: image_shape = ops.cast(image_shape, dtype=boxes.dtype) height = image_shape[0] diff --git a/keras_cv/src/layers/object_detection/roi_align.py b/keras_cv/src/layers/object_detection/roi_align.py index feb6cfcf62..d1cebcc521 100644 --- a/keras_cv/src/layers/object_detection/roi_align.py +++ b/keras_cv/src/layers/object_detection/roi_align.py @@ -69,9 +69,7 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x): features, [batch_size * num_boxes, output_size * 2, output_size * 2, num_filters], ) - features = ops.nn.average_pool( - features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID" - ) + features = ops.nn.average_pool(features, (2, 2), (2, 2), "VALID") features = ops.reshape( features, [batch_size, num_boxes, output_size, output_size, num_filters] ) @@ -242,6 +240,11 @@ def multilevel_crop_and_resize( for i in range(len(feature_widths) - 1): level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i]) batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1] + + level_dim_offsets = ops.convert_to_tensor(level_dim_offsets) + feature_widths = ops.convert_to_tensor(feature_widths) + feature_heights = ops.convert_to_tensor(feature_heights) + level_dim_offsets = ( ops.ones_like(level_dim_offsets, dtype="int32") * level_dim_offsets ) @@ -259,7 +262,9 @@ def multilevel_crop_and_resize( # following the FPN paper to divide by 224. levels = ops.cast( ops.floor_divide( - ops.log(ops.divide(areas_sqrt, 224.0)), + ops.log( + ops.divide_no_nan(areas_sqrt, ops.convert_to_tensor(224.0)) + ), ops.log(2.0), ) + 4.0, @@ -292,12 +297,18 @@ def multilevel_crop_and_resize( ops.concatenate( [ ops.expand_dims( - [[ops.cast(max_feature_height, "float32")]] / level_strides + ops.convert_to_tensor( + [[ops.cast(max_feature_height, "float32")]] + ) + / level_strides - 1, axis=-1, ), ops.expand_dims( - [[ops.cast(max_feature_width, "float32")]] / level_strides + ops.convert_to_tensor( + [[ops.cast(max_feature_width, "float32")]] + ) + / level_strides - 1, axis=-1, ), @@ -357,7 +368,7 @@ def multilevel_crop_and_resize( # TODO(tanzhenyu): replace tf.gather with tf.gather_nd and try to get # similar performance. features_per_box = ops.reshape( - ops.take(features_r2, indices), + ops.take(features_r2, indices, axis=0), [ batch_size, num_boxes, @@ -378,7 +389,7 @@ def multilevel_crop_and_resize( # performance as this is mostly a duplicate of # https://github.com/tensorflow/models/blob/master/official/legacy/detection/ops/spatial_transform_ops.py#L324 @keras.utils.register_keras_serializable(package="keras_cv") -class _ROIAligner(keras.layers.Layer): +class ROIAligner(keras.layers.Layer): """Performs ROIAlign for the second stage processing.""" def __init__( @@ -397,13 +408,11 @@ def __init__( sample_offset: A `float` in [0, 1] of the subpixel sample offset. **kwargs: Additional keyword arguments passed to Layer. """ - # assert_tf_keras("keras_cv.layers._ROIAligner") - self._config_dict = { - "bounding_box_format": bounding_box_format, - "crop_size": target_size, - "sample_offset": sample_offset, - } super().__init__(**kwargs) + self.bounding_box_format = bounding_box_format + self.target_size = target_size + self.sample_offset = sample_offset + self.built = True def call( self, @@ -427,16 +436,22 @@ def call( """ boxes = bounding_box.convert_format( boxes, - source=self._config_dict["bounding_box_format"], + source=self.bounding_box_format, target="yxyx", ) roi_features = multilevel_crop_and_resize( features, boxes, - output_size=self._config_dict["crop_size"], - sample_offset=self._config_dict["sample_offset"], + output_size=self.target_size, + sample_offset=self.sample_offset, ) return roi_features def get_config(self): - return self._config_dict + config = super().get_config() + config["bounding_box_format"] = self.bounding_box_format + config["target_size"] = self.target_size + config["sample_offset"] = self.sample_offset + + def compute_output_shape(self, input_shape): + return (None, None, self.target_size, self.target_size, 256) diff --git a/keras_cv/src/layers/object_detection/roi_generator.py b/keras_cv/src/layers/object_detection/roi_generator.py index fbde4fbcf2..37965acfe6 100644 --- a/keras_cv/src/layers/object_detection/roi_generator.py +++ b/keras_cv/src/layers/object_detection/roi_generator.py @@ -68,6 +68,7 @@ class ROIGenerator(keras.layers.Layer): applying NMS in inference mode. When RPN is run on multiple feature maps / levels (as in FPN) this number is per feature map / level. + nms_from_logits: bool. True means input score is logits, False means confidence. Example: ```python @@ -90,6 +91,7 @@ def __init__( nms_score_threshold_test: float = 0.0, nms_iou_threshold_test: float = 0.7, post_nms_topk_test: int = 1000, + nms_from_logits: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -102,6 +104,7 @@ def __init__( self.nms_score_threshold_test = nms_score_threshold_test self.nms_iou_threshold_test = nms_iou_threshold_test self.post_nms_topk_test = post_nms_topk_test + self.nms_from_logits = nms_from_logits self.built = True def call( @@ -158,7 +161,7 @@ def per_level_gen(boxes, scores): # TODO(tanzhenyu): consider supporting soft / batched nms for accl boxes = NonMaxSuppression( bounding_box_format=self.bounding_box_format, - from_logits=False, + from_logits=self.nms_from_logits, iou_threshold=nms_iou_threshold, confidence_threshold=nms_score_threshold, max_detections=level_post_nms_topk, @@ -191,6 +194,9 @@ def per_level_gen(boxes, scores): return rois, roi_scores + def compute_output_shape(self, input_shape): + return (None, None, 4), (None, None, 1) + def get_config(self): config = { "bounding_box_format": self.bounding_box_format, diff --git a/keras_cv/src/layers/object_detection/roi_sampler.py b/keras_cv/src/layers/object_detection/roi_sampler.py index fe08865587..8b38ccad72 100644 --- a/keras_cv/src/layers/object_detection/roi_sampler.py +++ b/keras_cv/src/layers/object_detection/roi_sampler.py @@ -22,7 +22,7 @@ @keras.utils.register_keras_serializable(package="keras_cv") -class _ROISampler(keras.layers.Layer): +class ROISampler(keras.layers.Layer): """ Sample ROIs for loss related calculation. @@ -41,9 +41,10 @@ class _ROISampler(keras.layers.Layer): if its range is [0, num_classes). Args: - bounding_box_format: The format of bounding boxes to generate. Refer + roi_bounding_box_format: The format of roi bounding boxes. Refer [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) for more details on supported bounding box formats. + gt_bounding_box_format: The format of ground truth bounding boxes. roi_matcher: a `BoxMatcher` object that matches proposals with ground truth boxes. The positive match must be 1 and negative match must be -1. Such assumption is not being validated here. @@ -59,7 +60,8 @@ class _ROISampler(keras.layers.Layer): def __init__( self, - bounding_box_format: str, + roi_bounding_box_format: str, + gt_bounding_box_format: str, roi_matcher: box_matcher.BoxMatcher, positive_fraction: float = 0.25, background_class: int = 0, @@ -68,12 +70,14 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.bounding_box_format = bounding_box_format + self.roi_bounding_box_format = roi_bounding_box_format + self.gt_bounding_box_format = gt_bounding_box_format self.roi_matcher = roi_matcher self.positive_fraction = positive_fraction self.background_class = background_class self.num_sampled_rois = num_sampled_rois self.append_gt_boxes = append_gt_boxes + self.seed_generator = keras.random.SeedGenerator() self.built = True # for debugging. self._positives = keras.metrics.Mean() @@ -97,6 +101,12 @@ def call( sampled_gt_classes: [batch_size, num_sampled_rois, 1] sampled_class_weights: [batch_size, num_sampled_rois, 1] """ + rois = bounding_box.convert_format( + rois, source=self.roi_bounding_box_format, target="yxyx" + ) + gt_boxes = bounding_box.convert_format( + gt_boxes, source=self.gt_bounding_box_format, target="yxyx" + ) if self.append_gt_boxes: # num_rois += num_gt rois = ops.concatenate([rois, gt_boxes], axis=1) @@ -110,12 +120,6 @@ def call( "num_rois must be less than `num_sampled_rois` " f"({self.num_sampled_rois}), got {num_rois}" ) - rois = bounding_box.convert_format( - rois, source=self.bounding_box_format, target="yxyx" - ) - gt_boxes = bounding_box.convert_format( - gt_boxes, source=self.bounding_box_format, target="yxyx" - ) # [batch_size, num_rois, num_gt] similarity_mat = iou.compute_iou( rois, gt_boxes, bounding_box_format="yxyx", use_masking=True @@ -171,6 +175,7 @@ def call( negative_matches, self.num_sampled_rois, self.positive_fraction, + seed=self.seed_generator, ) # [batch_size, num_sampled_rois] in the range of [0, num_rois) sampled_indicators, sampled_indices = ops.top_k( @@ -204,16 +209,15 @@ def call( ) def get_config(self): - config = { - "bounding_box_format": self.bounding_box_format, - "positive_fraction": self.positive_fraction, - "background_class": self.background_class, - "num_sampled_rois": self.num_sampled_rois, - "append_gt_boxes": self.append_gt_boxes, - "roi_matcher": self.roi_matcher.get_config(), - } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) + config = super().get_config() + config["roi_bounding_box_format"] = self.roi_bounding_box_format + config["gt_bounding_box_format"] = self.gt_bounding_box_format + config["positive_fraction"] = self.positive_fraction + config["background_class"] = self.background_class + config["num_sampled_rois"] = self.num_sampled_rois + config["append_gt_boxes"] = self.append_gt_boxes + config["roi_matcher"] = self.roi_matcher.get_config() + return config @classmethod def from_config(cls, config, custom_objects=None): diff --git a/keras_cv/src/layers/object_detection/roi_sampler_test.py b/keras_cv/src/layers/object_detection/roi_sampler_test.py index 95bd90a715..7b5335affd 100644 --- a/keras_cv/src/layers/object_detection/roi_sampler_test.py +++ b/keras_cv/src/layers/object_detection/roi_sampler_test.py @@ -14,18 +14,22 @@ import numpy as np +import pytest from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 from keras_cv.src.layers.object_detection.box_matcher import BoxMatcher -from keras_cv.src.layers.object_detection.roi_sampler import _ROISampler +from keras_cv.src.layers.object_detection.roi_sampler import ROISampler from keras_cv.src.tests.test_case import TestCase class ROISamplerTest(TestCase): + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_roi_sampler(self): box_matcher = BoxMatcher(thresholds=[0.3], match_values=[-1, 1]) - roi_sampler = _ROISampler( - bounding_box_format="xyxy", + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", roi_matcher=box_matcher, positive_fraction=0.5, num_sampled_rois=2, @@ -66,13 +70,15 @@ def test_roi_sampler(self): np.min(ops.convert_to_numpy(sampled_gt_classes)), ) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_roi_sampler_small_threshold(self): self.skipTest( "TODO: resolving flaky test, https://github.com/keras-team/keras-cv/issues/2336" # noqa ) box_matcher = BoxMatcher(thresholds=[0.1], match_values=[-1, 1]) - roi_sampler = _ROISampler( - bounding_box_format="xyxy", + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", roi_matcher=box_matcher, positive_fraction=0.5, num_sampled_rois=2, @@ -121,12 +127,14 @@ def test_roi_sampler_small_threshold(self): ) self.assertAllClose(expected_gt_classes, sampled_gt_classes) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_roi_sampler_large_threshold(self): # the 2nd roi and 2nd gt box has IOU of 0.923, setting # positive_threshold to 0.95 to ignore it. box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) - roi_sampler = _ROISampler( - bounding_box_format="xyxy", + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", roi_matcher=box_matcher, positive_fraction=0.5, num_sampled_rois=2, @@ -160,12 +168,14 @@ def test_roi_sampler_large_threshold(self): self.assertAllClose(expected_gt_boxes, sampled_gt_boxes) self.assertAllClose(expected_gt_classes, sampled_gt_classes) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_roi_sampler_large_threshold_custom_bg_class(self): # the 2nd roi and 2nd gt box has IOU of 0.923, setting # positive_threshold to 0.95 to ignore it. box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) - roi_sampler = _ROISampler( - bounding_box_format="xyxy", + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", roi_matcher=box_matcher, positive_fraction=0.5, background_class=-1, @@ -201,12 +211,14 @@ def test_roi_sampler_large_threshold_custom_bg_class(self): self.assertAllClose(expected_gt_boxes, sampled_gt_boxes) self.assertAllClose(expected_gt_classes, sampled_gt_classes) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_roi_sampler_large_threshold_append_gt_boxes(self): # the 2nd roi and 2nd gt box has IOU of 0.923, setting # positive_threshold to 0.95 to ignore it. box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) - roi_sampler = _ROISampler( - bounding_box_format="xyxy", + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", roi_matcher=box_matcher, positive_fraction=0.5, num_sampled_rois=2, @@ -243,10 +255,12 @@ def test_roi_sampler_large_threshold_append_gt_boxes(self): np.min(ops.convert_to_numpy(sampled_gt_classes)), 0 ) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_roi_sampler_large_num_sampled_rois(self): box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) - roi_sampler = _ROISampler( - bounding_box_format="xyxy", + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", roi_matcher=box_matcher, positive_fraction=0.5, num_sampled_rois=200, @@ -271,15 +285,17 @@ def test_roi_sampler_large_num_sampled_rois(self): with self.assertRaisesRegex(ValueError, "must be less than"): _, _, _ = roi_sampler(rois, gt_boxes, gt_classes) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_serialization(self): box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) - roi_sampler = _ROISampler( - bounding_box_format="xyxy", + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", roi_matcher=box_matcher, positive_fraction=0.5, num_sampled_rois=200, append_gt_boxes=True, ) sampler_config = roi_sampler.get_config() - new_sampler = _ROISampler.from_config(sampler_config) + new_sampler = ROISampler.from_config(sampler_config) self.assertAllEqual(new_sampler.roi_matcher.match_values, [-1, 1]) diff --git a/keras_cv/src/layers/object_detection/rpn_label_encoder.py b/keras_cv/src/layers/object_detection/rpn_label_encoder.py index fa314d9b66..11600166a0 100644 --- a/keras_cv/src/layers/object_detection/rpn_label_encoder.py +++ b/keras_cv/src/layers/object_detection/rpn_label_encoder.py @@ -24,7 +24,7 @@ @keras.utils.register_keras_serializable(package="keras_cv") -class _RpnLabelEncoder(keras.layers.Layer): +class RpnLabelEncoder(keras.layers.Layer): """Transforms the raw labels into training targets for region proposal network (RPN). @@ -84,6 +84,7 @@ def __init__( force_match_for_each_col=False, ) self.box_variance = box_variance + self.seed_generator = keras.random.SeedGenerator() self.built = True self._positives = keras.metrics.Mean(name="percent_boxes_matched") @@ -165,6 +166,7 @@ def call( negative_matches, self.samples_per_image, self.positive_fraction, + seed=self.seed_generator, ) # [num_anchors, 1] or [batch_size, num_anchors, 1] class_sample_weights = ops.cast( diff --git a/keras_cv/src/layers/object_detection/rpn_label_encoder_test.py b/keras_cv/src/layers/object_detection/rpn_label_encoder_test.py index 0de6f1a4e2..ddfbbf198c 100644 --- a/keras_cv/src/layers/object_detection/rpn_label_encoder_test.py +++ b/keras_cv/src/layers/object_detection/rpn_label_encoder_test.py @@ -13,17 +13,20 @@ # limitations under the License. import numpy as np +import pytest from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 from keras_cv.src.layers.object_detection.rpn_label_encoder import ( - _RpnLabelEncoder, + RpnLabelEncoder, ) from keras_cv.src.tests.test_case import TestCase class RpnLabelEncoderTest(TestCase): + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_rpn_label_encoder(self): - rpn_encoder = _RpnLabelEncoder( + rpn_encoder = RpnLabelEncoder( anchor_format="xyxy", ground_truth_box_format="xyxy", positive_threshold=0.7, @@ -68,11 +71,12 @@ def test_rpn_label_encoder(self): self.assertAllClose(np.max(ops.convert_to_numpy(box_weights)), 1.0) self.assertAllClose(np.min(ops.convert_to_numpy(box_weights)), 0.0) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_rpn_label_encoder_multi_level(self): self.skipTest( "TODO: resolving flaky test, https://github.com/keras-team/keras-cv/issues/2336" # noqa ) - rpn_encoder = _RpnLabelEncoder( + rpn_encoder = RpnLabelEncoder( anchor_format="xyxy", ground_truth_box_format="xyxy", positive_threshold=0.7, @@ -97,8 +101,9 @@ def test_rpn_label_encoder_multi_level(self): self.assertAllClose(expected_cls_weights[2], cls_weights[2]) self.assertAllClose(expected_cls_weights[3], cls_weights[3]) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") def test_rpn_label_encoder_batched(self): - rpn_encoder = _RpnLabelEncoder( + rpn_encoder = RpnLabelEncoder( anchor_format="xyxy", ground_truth_box_format="xyxy", positive_threshold=0.7, diff --git a/keras_cv/src/layers/object_detection/sampling.py b/keras_cv/src/layers/object_detection/sampling.py index e756920304..491c5b98d2 100644 --- a/keras_cv/src/layers/object_detection/sampling.py +++ b/keras_cv/src/layers/object_detection/sampling.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional +from typing import Union + from keras_cv.src.backend import ops from keras_cv.src.backend import random @@ -21,6 +24,7 @@ def balanced_sample( negative_matches, num_samples: int, positive_fraction: float, + seed: Optional[Union[random.SeedGenerator, int]] = None, ): """ Sampling ops to balance positive and negative samples, deals with both @@ -51,10 +55,12 @@ def balanced_sample( # maxval=1.) zeros = ops.zeros_like(positive_matches, dtype="float32") ones = ops.ones_like(positive_matches, dtype="float32") - ones_rand = ones + random.uniform(ops.shape(ones), minval=-0.2, maxval=0.2) + ones_rand = ones + random.uniform( + ops.shape(ones), minval=-0.2, maxval=0.2, seed=seed + ) halfs = 0.5 * ops.ones_like(positive_matches, dtype="float32") halfs_rand = halfs + random.uniform( - ops.shape(halfs), minval=-0.2, maxval=0.2 + ops.shape(halfs), minval=-0.2, maxval=0.2, seed=seed ) values = zeros values = ops.where(positive_matches, ones_rand, values) diff --git a/keras_cv/src/models/__init__.py b/keras_cv/src/models/__init__.py index 77513eb8d8..ebe22b7709 100644 --- a/keras_cv/src/models/__init__.py +++ b/keras_cv/src/models/__init__.py @@ -206,6 +206,9 @@ from keras_cv.src.models.classification.image_classifier import ImageClassifier from keras_cv.src.models.classification.video_classifier import VideoClassifier from keras_cv.src.models.feature_extractor.clip import CLIP +from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( + FasterRCNN, +) from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_backbone import ( YOLOV8Backbone, diff --git a/keras_cv/src/models/legacy/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/src/models/legacy/object_detection/faster_rcnn/faster_rcnn.py index df5e2981b7..0f6b36ff9e 100644 --- a/keras_cv/src/models/legacy/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/src/models/legacy/object_detection/faster_rcnn/faster_rcnn.py @@ -25,11 +25,11 @@ AnchorGenerator, ) from keras_cv.src.layers.object_detection.box_matcher import BoxMatcher -from keras_cv.src.layers.object_detection.roi_align import _ROIAligner +from keras_cv.src.layers.object_detection.roi_align import ROIAligner from keras_cv.src.layers.object_detection.roi_generator import ROIGenerator -from keras_cv.src.layers.object_detection.roi_sampler import _ROISampler +from keras_cv.src.layers.object_detection.roi_sampler import ROISampler from keras_cv.src.layers.object_detection.rpn_label_encoder import ( - _RpnLabelEncoder, + RpnLabelEncoder, ) from keras_cv.src.models.object_detection import predict_utils from keras_cv.src.models.object_detection.__internal__ import unpack_input @@ -317,13 +317,13 @@ def __init__( self.box_matcher = BoxMatcher( thresholds=[0.0, 0.5], match_values=[-2, -1, 1] ) - self.roi_sampler = _ROISampler( + self.roi_sampler = ROISampler( bounding_box_format="yxyx", roi_matcher=self.box_matcher, background_class=num_classes, num_sampled_rois=512, ) - self.roi_pooler = _ROIAligner(bounding_box_format="yxyx") + self.roi_pooler = ROIAligner(bounding_box_format="yxyx") self.rcnn_head = rcnn_head or RCNNHead(num_classes) self.backbone = backbone or models.ResNet50Backbone() extractor_levels = ["P2", "P3", "P4", "P5"] @@ -334,7 +334,7 @@ def __init__( self.backbone, extractor_layer_names, extractor_levels ) self.feature_pyramid = FeaturePyramid() - self.rpn_labeler = label_encoder or _RpnLabelEncoder( + self.rpn_labeler = label_encoder or RpnLabelEncoder( anchor_format="yxyx", ground_truth_box_format="yxyx", positive_threshold=0.7, diff --git a/keras_cv/src/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/src/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py index 02c57f5dad..6e651994fa 100644 --- a/keras_cv/src/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py +++ b/keras_cv/src/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py @@ -18,7 +18,6 @@ from tensorflow import keras from tensorflow.keras import optimizers -from keras_cv.src.backend import config as backend_config from keras_cv.src.models import ResNet18V2Backbone from keras_cv.src.models.legacy.object_detection.faster_rcnn.faster_rcnn import ( # noqa: E501 FasterRCNN, @@ -40,10 +39,7 @@ class FasterRCNNTest(TestCase): ((2, 512, 512, 3),), ((2, 128, 128, 3),), ) - @pytest.mark.skipif( - not backend_config.keras_3(), - reason="TODO: Fails in Keras2", - ) + @pytest.mark.skip(reason="moved to stable models") def test_faster_rcnn_infer(self, batch_shape): model = FasterRCNN( num_classes=80, @@ -61,10 +57,7 @@ def test_faster_rcnn_infer(self, batch_shape): ((2, 512, 512, 3),), ((2, 128, 128, 3),), ) - @pytest.mark.skipif( - not backend_config.keras_3(), - reason="TODO: Fails in Keras2", - ) + @pytest.mark.skip(reason="moved to stable models") def test_faster_rcnn_train(self, batch_shape): model = FasterRCNN( num_classes=80, @@ -76,6 +69,7 @@ def test_faster_rcnn_train(self, batch_shape): self.assertAllEqual([2, 1000, 81], outputs[1].shape) self.assertAllEqual([2, 1000, 4], outputs[0].shape) + @pytest.mark.skip(reason="moved to stable models") def test_invalid_compile(self): model = FasterRCNN( num_classes=80, @@ -92,10 +86,7 @@ def test_invalid_compile(self): ) @pytest.mark.large # Fit is slow, so mark these large. - @pytest.mark.skipif( - not backend_config.keras_3(), - reason="TODO: Fails in Keras2", - ) + @pytest.mark.skip(reason="moved to stable models") def test_faster_rcnn_with_dictionary_input_format(self): faster_rcnn = FasterRCNN( num_classes=20, diff --git a/keras_cv/src/models/object_detection/__test_utils__.py b/keras_cv/src/models/object_detection/__test_utils__.py index ad795b9cbd..d0b0bdd0b4 100644 --- a/keras_cv/src/models/object_detection/__test_utils__.py +++ b/keras_cv/src/models/object_detection/__test_utils__.py @@ -19,11 +19,13 @@ def _create_bounding_box_dataset( - bounding_box_format, use_dictionary_box_format=False + bounding_box_format, + image_shape=(512, 512, 3), + use_dictionary_box_format=False, ): # Just about the easiest dataset you can have, all classes are 0, all boxes # are exactly the same. [1, 1, 2, 2] are the coordinates in xyxy. - xs = np.random.normal(size=(1, 512, 512, 3)) + xs = np.random.normal(size=(1,) + image_shape) xs = np.tile(xs, [5, 1, 1, 1]) y_classes = np.zeros((5, 3), "float32") diff --git a/keras_cv/src/models/object_detection/faster_rcnn/__init__.py b/keras_cv/src/models/object_detection/faster_rcnn/__init__.py new file mode 100644 index 0000000000..52dbb4a330 --- /dev/null +++ b/keras_cv/src/models/object_detection/faster_rcnn/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_cv.src.models.object_detection.faster_rcnn.feature_pyramid import ( + FeaturePyramid, +) +from keras_cv.src.models.object_detection.faster_rcnn.rcnn_head import RCNNHead +from keras_cv.src.models.object_detection.faster_rcnn.rpn_head import RPNHead diff --git a/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py new file mode 100644 index 0000000000..a73f40349e --- /dev/null +++ b/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py @@ -0,0 +1,807 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tree + +from keras_cv.src import losses +from keras_cv.src.api_export import keras_cv_export +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.bounding_box import convert_format +from keras_cv.src.bounding_box.converters import _decode_deltas_to_boxes +from keras_cv.src.bounding_box.utils import _clip_boxes +from keras_cv.src.layers.object_detection.anchor_generator import ( + AnchorGenerator, +) +from keras_cv.src.layers.object_detection.box_matcher import BoxMatcher +from keras_cv.src.layers.object_detection.non_max_suppression import ( + NonMaxSuppression, +) +from keras_cv.src.layers.object_detection.roi_align import ROIAligner +from keras_cv.src.layers.object_detection.roi_generator import ROIGenerator +from keras_cv.src.layers.object_detection.roi_sampler import ROISampler +from keras_cv.src.layers.object_detection.rpn_label_encoder import ( + RpnLabelEncoder, +) +from keras_cv.src.models.object_detection.faster_rcnn import FeaturePyramid +from keras_cv.src.models.object_detection.faster_rcnn import RCNNHead +from keras_cv.src.models.object_detection.faster_rcnn import RPNHead +from keras_cv.src.models.task import Task +from keras_cv.src.utils.train import get_feature_extractor + +BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +@keras_cv_export( + [ + "keras_cv.models.FasterRCNN", + "keras_cv.models.object_detection.FasterRCNN", + ] +) +class FasterRCNN(Task): + """A Keras model implementing the Faster R-CNN architecture. + + This model is compatible with Keras 3 only. Implements the Faster R-CNN architecture + for object detection. The constructor requires `num_classes`, `bounding_box_format`, + and a backbone. Optionally, a custom label encoder, and prediction decoder + may be provided. + + Example: + ```python + images = np.ones((1, 512, 512, 3)) + labels = { + "boxes": tf.cast([ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], dtype=tf.float32), + "classes": tf.cast([[1, 1, 1]], dtype=tf.float32), + } + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + + # Evaluate model without box decoding and NMS + model(images) + + # Prediction with box decoding and NMS + model.predict(images) + + # Train model + model.compile( + optimizer=keras.optimizers.SGD(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + model.fit(images, labels, batch_size=1) + ``` + + Args: + backbone: `keras.Model`. If the default `feature_pyramid` is used, + must implement the `pyramid_level_inputs` property with keys "P3", "P4", + and "P5" and layer names as values. A somewhat sensible backbone + to use in many cases is the: + `keras_cv.models.ResNetBackbone.from_preset("resnet50_imagenet")` + num_classes: the number of classes in your dataset excluding the + background class. Classes should be represented by integers in the + range [1, num_classes]. + bounding_box_format: The format of bounding boxes of input dataset. + Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. If + provided, the anchor generator will be passed to both the + `label_encoder` and the `prediction_decoder`. Only to be used when + both `label_encoder` and `prediction_decoder` are both `None`. + Defaults to an anchor generator with the parameterization: + `strides=[2**i for i in range(3, 8)]`, + `scales=[2**x for x in [0, 1 / 3, 2 / 3]]`, + `sizes=[32.0, 64.0, 128.0, 256.0, 512.0]`, + and `aspect_ratios=[0.5, 1.0, 2.0]`. + anchor_scales: (Optional) list of anchor scales for + default anchor generator. + anchor_aspect_ratios: (Optional) list of anchor aspect ratios for + default anchor generator. + feature_pyramid: (Optional) A `keras.layers.Layer` that produces + a list of 4D feature maps (batch dimension included) + when called on the pyramid-level outputs of the `backbone`. + If not provided, the reference implementation from the paper will be used. + fpn_min_level: (Optional) the minimum level of the feature pyramid. + fpn_max_level: (Optional) the maximum level of the feature pyramid. + rpn_head: (Optional) A `keras.Layer` that performs regression and + classification(background or foreground) of the bounding boxes. + If not provided, a simple ConvNet with 3 layers will be used. + rpn_label_encoder_posistive_threshold: (Optional) the float threshold to set an + anchor to positive match to gt box. Values above it are positive matches. + rpn_label_encoder_negative_threshold: (Optional) the float threshold to set an + anchor to negative matchto gt box. Values below it are negative matches. + rpn_label_encoder_samples_per_image: (Optional) for each image, the number of + positive and negative samples to generate. + rpn_label_encoder_positive_fraction: (Optional) the fraction of positive samples to the total samples. + rcnn_head: (Optional) A `keras.Layer` that performs regression and + classification(final prediction) of the bounding boxes. + If not provided, a simple network with 2 dense layers with + box head and regression head will be used. + label_encoder: (Optional) a keras.Layer that accepts an image Tensor, a + bounding box Tensor and a bounding box class Tensor to its `call()` + method, and returns RetinaNet training targets. By default, a + KerasCV standard `RpnLabelEncoder` is created and used. + Results of this object's `call()` method are passed to the `loss` + object for `rpn_box_loss` and `rpn_classification_loss` the `y_true` + argument. + prediction_decoder: (Optional) A `keras.layers.Layer` that is + responsible for transforming RetinaNet predictions into usable + bounding box Tensors. If not provided, a default is provided. The + default `prediction_decoder` layer is a + `keras_cv.layers.MultiClassNonMaxSuppression` layer, which uses + a Non-Max Suppression for box pruning. + num_max_detections: the maximum detections to consider after nms is applied. A + large number may trigger significant memory overhead, defaults to 100. + """ # noqa: E501 + + def __init__( + self, + backbone, + num_classes, + bounding_box_format, + anchor_generator=None, + anchor_scales=[1], + anchor_aspect_ratios=[0.5, 1.0, 2.0], + feature_pyramid=None, + fpn_min_level=2, + fpn_max_level=5, + rpn_head=None, + rpn_filters=256, + rpn_kernel_size=3, + rpn_label_encoder_posistive_threshold=0.7, + rpn_label_encoder_negative_threshold=0.3, + rpn_label_encoder_samples_per_image=256, + rpn_label_encoder_positive_fraction=0.5, + rcnn_head=None, + num_sampled_rois=512, + label_encoder=None, + prediction_decoder=None, + num_max_decoder_detections=100, + *args, + **kwargs, + ): + # Backbone + extractor_levels = [ + f"P{level}" for level in range(fpn_min_level, fpn_max_level + 1) + ] + extractor_layer_names = [ + backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + feature_extractor = get_feature_extractor( + backbone, extractor_layer_names, extractor_levels + ) + + # Feature Pyramid + feature_pyramid = feature_pyramid or FeaturePyramid( + min_level=fpn_min_level, max_level=fpn_max_level + ) + + # Anchors + anchor_generator = ( + anchor_generator + or FasterRCNN.default_anchor_generator( + fpn_min_level, + fpn_max_level + 1, + anchor_scales, + anchor_aspect_ratios, + "yxyx", + ) + ) + + # RPN Head + num_anchors_per_location = len(anchor_scales) * len( + anchor_aspect_ratios + ) + rpn_head = rpn_head or RPNHead( + num_anchors_per_location=num_anchors_per_location, + num_filters=rpn_filters, + kernel_size=rpn_kernel_size, + ) + + # RoI Generator + roi_generator = ROIGenerator( + bounding_box_format="yxyx", + nms_score_threshold_train=float("-inf"), + nms_score_threshold_test=float("-inf"), + nms_from_logits=True, + name="roi_generator", + ) + + # RoI Align + roi_aligner = ROIAligner(bounding_box_format="yxyx", name="roi_align") + + # R-CNN Head + rcnn_head = rcnn_head or RCNNHead(num_classes, name="rcnn_head") + + # Begin construction of forward pass + image_shape = feature_extractor.input_shape[1:] + if None in image_shape: + raise ValueError( + "Found `None` in image_shape, to build anchors `image_shape`" + "is required without any `None`. Make sure to pass " + "`image_shape` to the backbone preset while passing to" + "the Faster R-CNN detector." + ) + + images = keras.layers.Input( + image_shape, + name="images", + ) + + # Forward through backbone + backbone_outputs = feature_extractor(images) + + # Forward through FPN decoder + feature_map = feature_pyramid(backbone_outputs) + + # [P2, P3, P4, P5, P6] -> ([BS, num_anchors, 4], [BS, num_anchors, 1]) + # Pass through RPN Head + rpn_boxes, rpn_scores = rpn_head(feature_map) + + # Reshape and Concatenate all the output boxes of all levels + for lvl in rpn_boxes: + rpn_boxes[lvl] = keras.layers.Reshape(target_shape=(-1, 4))( + rpn_boxes[lvl] + ) + + for lvl in rpn_scores: + rpn_scores[lvl] = keras.layers.Reshape(target_shape=(-1, 1))( + rpn_scores[lvl] + ) + rpn_cls_pred = keras.layers.Concatenate( + axis=1, name="rpn_classification" + )(tree.flatten(rpn_scores)) + rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")( + tree.flatten(rpn_boxes) + ) + + anchors = anchor_generator(image_shape=image_shape) + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + + rois, roi_scores = roi_generator(decoded_rpn_boxes, rpn_scores) + rois = _clip_boxes(rois, "yxyx", image_shape) + + feature_map = roi_aligner(features=feature_map, boxes=rois) + + # Reshape the feature map [BS, H*W*K] + feature_map = keras.layers.Reshape( + target_shape=( + rois.shape[1], + (roi_aligner.target_size**2) * rpn_head.num_filters, + ) + )(feature_map) + + # Pass final feature map to RCNN Head for predictions + box_pred, cls_pred = rcnn_head(feature_map=feature_map) + + box_pred = keras.layers.Concatenate(axis=1, name="box")([box_pred]) + cls_pred = keras.layers.Concatenate(axis=1, name="classification")( + [cls_pred] + ) + + inputs = {"images": images} + outputs = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + } + + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + self.bounding_box_format = bounding_box_format + self.anchor_generator = anchor_generator + self.num_classes = num_classes + self.feature_extractor = feature_extractor + self.backbone = backbone + self.feature_pyramid = feature_pyramid + self.rpn_head = rpn_head + self.label_encoder = label_encoder or RpnLabelEncoder( + anchor_format="yxyx", + ground_truth_box_format=bounding_box_format, + positive_threshold=rpn_label_encoder_posistive_threshold, + negative_threshold=rpn_label_encoder_negative_threshold, + samples_per_image=rpn_label_encoder_samples_per_image, + positive_fraction=rpn_label_encoder_positive_fraction, + box_variance=BOX_VARIANCE, + ) + self.roi_generator = roi_generator + self.box_matcher = BoxMatcher( + thresholds=[0.0, 0.5], match_values=[-2, -1, 1] + ) + self.roi_sampler = ROISampler( + roi_bounding_box_format="yxyx", + gt_bounding_box_format=bounding_box_format, + roi_matcher=self.box_matcher, + num_sampled_rois=num_sampled_rois, + ) + + self.roi_aligner = roi_aligner + self.rcnn_head = rcnn_head + self._prediction_decoder = prediction_decoder or NonMaxSuppression( + bounding_box_format=bounding_box_format, + from_logits=False, + max_detections=num_max_decoder_detections, + ) + self.build(backbone.input_shape) + + def compile( + self, + rpn_box_loss=None, + rpn_classification_loss=None, + box_loss=None, + classification_loss=None, + weight_decay=0.0001, + loss=None, + metrics=None, + **kwargs, + ): + if loss is not None: + raise ValueError( + "`FasterRCNN` does not accept a `loss` to `compile()`. " + "Instead, please pass `box_loss` and `classification_loss`. " + "`loss` will be ignored during training." + ) + if ( + rpn_box_loss is None + or rpn_classification_loss is None + or box_loss is None + or classification_loss is None + ): + raise ValueError( + "`FasterRCNN` expects all of `rpn_box_loss`, " + "`rpn_classification_loss`," + "`box_loss`, and " + "`classification_loss` to be not `None`." + ) + + rpn_box_loss = _parse_box_loss(rpn_box_loss) + rpn_classification_loss = _parse_rpn_classification_loss( + rpn_classification_loss + ) + + if hasattr(rpn_classification_loss, "from_logits"): + if not rpn_classification_loss.from_logits: + raise ValueError( + "FasterRCNN.compile() expects `from_logits` to be True for " + "`rpn_classification_loss`. Got " + "`rpn_classification_loss.from_logits=" + f"{rpn_classification_loss.from_logits}`" + ) + box_loss = _parse_box_loss(box_loss) + classification_loss = _parse_classification_loss(classification_loss) + + if hasattr(classification_loss, "from_logits"): + if not classification_loss.from_logits: + raise ValueError( + "FasterRCNN.compile() expects `from_logits` to be True for " + "`classification_loss`. Got " + "`classification_loss.from_logits=" + f"{classification_loss.from_logits}`" + ) + if hasattr(box_loss, "bounding_box_format"): + if box_loss.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Wrong `bounding_box_format` passed to `box_loss` in " + "`FasterRCNN.compile()`. Got " + "`box_loss.bounding_box_format=" + f"{box_loss.bounding_box_format}`, want " + "`box_loss.bounding_box_format=" + f"{self.bounding_box_format}`" + ) + + self.rpn_box_loss = rpn_box_loss + self.rpn_cls_loss = rpn_classification_loss + self.box_loss = box_loss + self.cls_loss = classification_loss + self.weight_decay = weight_decay + losses = { + "rpn_box": self.rpn_box_loss, + "rpn_classification": self.rpn_cls_loss, + "box": self.box_loss, + "classification": self.cls_loss, + } + self._has_user_metrics = metrics is not None and len(metrics) != 0 + self._user_metrics = metrics + super().compile(loss=losses, **kwargs) + + def compute_loss( + self, x, y, y_pred, sample_weight, training=True, **kwargs + ): + + # 1. Unpack the inputs + images = x + gt_boxes = y["boxes"] + if ops.ndim(y["classes"]) != 2: + raise ValueError( + "Expected 'classes' to be a Tensor of rank 2. " + f"Got y['classes'].shape={ops.shape(y['classes'])}." + ) + + gt_classes = y["classes"] + gt_classes = ops.expand_dims(gt_classes, axis=-1) + + # Generate Anchors and Generate RPN Targets + local_batch = ops.shape(images)[0] + image_shape = ops.shape(images)[1:] + anchors = self.anchor_generator(image_shape=image_shape) + + # Label with the anchors -- exclusive to compute_loss + ( + rpn_box_targets, + rpn_box_weights, + rpn_cls_targets, + rpn_cls_weights, + ) = self.label_encoder( + anchors_dict=ops.concatenate( + tree.flatten(anchors), + axis=0, + ), + gt_boxes=gt_boxes, + gt_classes=gt_classes, + ) + + # Computing the weights + rpn_box_weights /= ( + self.label_encoder.samples_per_image * local_batch * 0.25 + ) + rpn_cls_weights /= self.label_encoder.samples_per_image * local_batch + + # Call Backbone, FPN and RPN Head + backbone_outputs = self.feature_extractor(images) + feature_map = self.feature_pyramid(backbone_outputs) + rpn_boxes, rpn_scores = self.rpn_head(feature_map) + + for lvl in rpn_boxes: + rpn_boxes[lvl] = keras.layers.Reshape(target_shape=(-1, 4))( + rpn_boxes[lvl] + ) + + for lvl in rpn_scores: + rpn_scores[lvl] = keras.layers.Reshape(target_shape=(-1, 1))( + rpn_scores[lvl] + ) + + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_cls_pred = keras.layers.Concatenate( + axis=1, name="rpn_classification" + )(tree.flatten(rpn_scores)) + rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")( + tree.flatten(rpn_boxes) + ) + + # Generate RoI's and RoI Sampling + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + + rois, _ = self.roi_generator( + decoded_rpn_boxes, rpn_scores, training=training + ) + rois = _clip_boxes(rois, "yxyx", image_shape) + + # Stop gradient from flowing into the ROI + # -- exclusive to compute_loss + rois = ops.stop_gradient(rois) + + # Sample the ROIS -- exclusive to compute_loss + ( + rois, + box_targets, + box_weights, + cls_targets, + cls_weights, + ) = self.roi_sampler(rois, gt_boxes, gt_classes) + + cls_targets = ops.squeeze(cls_targets, axis=-1) + cls_weights = ops.squeeze(cls_weights, axis=-1) + + # Box and class weights -- exclusive to compute loss + box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 + cls_weights /= self.roi_sampler.num_sampled_rois * local_batch + cls_targets = ops.one_hot(cls_targets, num_classes=self.num_classes + 1) + + # Call RoI Aligner and RCNN Head + feature_map = self.roi_aligner(features=feature_map, boxes=rois) + + # [BS, H*W*K] + feature_map = ops.reshape( + feature_map, + newshape=ops.shape(rois)[:2] + (-1,), + ) + + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + box_pred, cls_pred = self.rcnn_head(feature_map=feature_map) + + y_true = { + "rpn_box": rpn_box_targets, + "rpn_classification": rpn_cls_targets, + "box": box_targets, + "classification": cls_targets, + } + y_pred = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + } + weights = { + "rpn_box": rpn_box_weights, + "rpn_classification": rpn_cls_weights, + "box": box_weights, + "classification": cls_weights, + } + + return super().compute_loss( + x=x, y=y_true, y_pred=y_pred, sample_weight=weights, **kwargs + ) + + def train_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().train_step(*args, (x, y)) + + def test_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().test_step(*args, (x, y)) + + def predict_step(self, *args): + outputs = super().predict_step(*args) + if type(outputs) is tuple: + return self.decode_predictions(outputs[0], args[-1]), outputs[1] + else: + return self.decode_predictions(outputs, args[-1]) + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + if prediction_decoder.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Expected `prediction_decoder` and FasterRCNN to " + "use the same `bounding_box_format`, but got " + "`prediction_decoder.bounding_box_format=" + f"{prediction_decoder.bounding_box_format}`, and " + "`self.bounding_box_format=" + f"{self.bounding_box_format}`." + ) + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + self.make_train_function(force=True) + self.make_test_function(force=True) + + def decode_predictions(self, predictions, images): + image_shape = ops.shape(images)[1:] + anchors = self.anchor_generator(image_shape=image_shape) + rpn_boxes, rpn_scores = ( + predictions["rpn_box"], + predictions["rpn_classification"], + ) + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=ops.concatenate( + tree.flatten(anchors), + axis=0, + ), + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + + rois, _ = self.roi_generator( + decoded_rpn_boxes, rpn_scores, training=False + ) + rois = _clip_boxes(rois, "yxyx", image_shape) + box_pred, cls_pred = predictions["box"], predictions["classification"] + + # box_pred is on "center_yxhw" format, convert to target format. + box_pred = _decode_deltas_to_boxes( + anchors=rois, + boxes_delta=box_pred, + anchor_format=self.roi_aligner.bounding_box_format, + box_format=self.bounding_box_format, + variance=BOX_VARIANCE, + image_shape=image_shape, + ) + + box_pred = convert_format( + box_pred, + source=self.bounding_box_format, + target=self.prediction_decoder.bounding_box_format, + image_shape=image_shape, + ) + cls_pred = ops.softmax(cls_pred) + cls_pred = ops.slice( + cls_pred, + start_indices=[0, 0, 1], + shape=[cls_pred.shape[0], cls_pred.shape[1], cls_pred.shape[2] - 1], + ) + + y_pred = self.prediction_decoder( + box_pred, cls_pred, image_shape=image_shape + ) + + y_pred["classes"] = ops.where( + y_pred["classes"] == -1, -1, y_pred["classes"] + 1 + ) + + y_pred["boxes"] = convert_format( + y_pred["boxes"], + source=self.prediction_decoder.bounding_box_format, + target=self.bounding_box_format, + image_shape=image_shape, + ) + return y_pred + + def compute_metrics(self, x, y, y_pred, sample_weight): + metrics = {} + metrics.update(super().compute_metrics(x, {}, {}, sample_weight={})) + + if not self._has_user_metrics: + return metrics + + y_pred = self.decode_predictions(y_pred, x) + + for metric in self._user_metrics: + metric.update_state(y, y_pred, sample_weight=sample_weight) + + for metric in self._user_metrics: + result = metric.result() + if isinstance(result, dict): + metrics.update(result) + else: + metrics[metric.name] = result + return metrics + + @staticmethod + def default_anchor_generator( + min_level, max_level, scales, aspect_ratios, bounding_box_format + ): + strides = {f"P{i}": 2**i for i in range(min_level, max_level + 1)} + sizes = {f"P{i}": 2 ** (3 + i) for i in range(min_level, max_level + 1)} + return AnchorGenerator( + bounding_box_format=bounding_box_format, + sizes=sizes, + aspect_ratios=aspect_ratios, + scales=scales, + strides=strides, + clip_boxes=True, + name="anchor_generator", + ) + + def get_config(self): + return { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "backbone": keras.saving.serialize_keras_object(self.backbone), + "label_encoder": keras.saving.serialize_keras_object( + self.label_encoder + ), + "rpn_head": keras.saving.serialize_keras_object(self.rpn_head), + "prediction_decoder": self._prediction_decoder, + "rcnn_head": self.rcnn_head, + } + + @classmethod + def from_config(cls, config): + if "rpn_head" in config and isinstance(config["rpn_head"], dict): + config["rpn_head"] = keras.layers.deserialize(config["rpn_head"]) + if "label_encoder" in config and isinstance( + config["label_encoder"], dict + ): + config["label_encoder"] = keras.layers.deserialize( + config["label_encoder"] + ) + if "prediction_decoder" in config and isinstance( + config["prediction_decoder"], dict + ): + config["prediction_decoder"] = keras.layers.deserialize( + config["prediction_decoder"] + ) + if "rcnn_head" in config and isinstance(config["rcnn_head"], dict): + config["rcnn_head"] = keras.layers.deserialize(config["rcnn_head"]) + + return super().from_config(config) + + +def _parse_box_loss(loss): + # support arbitrary callables + if not isinstance(loss, str): + return loss + + # case insensitive comparison + if loss.lower() == "smoothl1": + return losses.SmoothL1Loss(l1_cutoff=1.0, reduction="sum") + if loss.lower() == "huber": + return keras.losses.Huber(reduction="sum") + + raise ValueError( + "Expected `box_loss` to be either a Keras Loss, " + f"callable, or the string 'SmoothL1'. Got loss={loss}." + ) + + +def _parse_rpn_classification_loss(loss): + # support arbitrary callables + if not isinstance(loss, str): + return loss + + if loss.lower() == "binarycrossentropy": + return keras.losses.BinaryCrossentropy( + reduction="sum", from_logits=True + ) + + raise ValueError( + f"Expected `rpn_classification_loss` to be either BinaryCrossentropy" + f" loss callable, or the string 'BinaryCrossentropy'. Got loss={loss}." + ) + + +def _parse_classification_loss(loss): + # support arbitrary callables + if not isinstance(loss, str): + return loss + + # case insensitive comparison + if loss.lower() == "focal": + return losses.FocalLoss(reduction="sum", from_logits=True) + if loss.lower() == "categoricalcrossentropy": + return keras.losses.CategoricalCrossentropy( + reduction="sum", from_logits=True + ) + + raise ValueError( + f"Expected `classification_loss` to be either a Keras Loss, " + f"callable, or the string 'Focal', CategoricalCrossentropy'. " + f"Got loss={loss}." + ) + + +def unpack_input(data): + if type(data) is dict: + return data["images"], data["bounding_boxes"] + else: + return data diff --git a/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn_test.py new file mode 100644 index 0000000000..f3116d9247 --- /dev/null +++ b/keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn_test.py @@ -0,0 +1,346 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import tensorflow as tf + +import keras_cv +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 +from keras_cv.src.models.object_detection.__test_utils__ import ( + _create_bounding_box_dataset, +) +from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( + FasterRCNN, +) +from keras_cv.src.tests.test_case import TestCase + + +class FasterRCNNTest(TestCase): + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_faster_rcnn_construction(self): + faster_rcnn = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + @pytest.mark.extra_large() + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_faster_rcnn_call(self): + faster_rcnn = FasterRCNN( + num_classes=3, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + images = np.random.uniform(size=(1, 32, 32, 3)) + _ = faster_rcnn(images) + _ = faster_rcnn.predict(images) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_wrong_logits(self): + faster_rcnn = FasterRCNN( + num_classes=80, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + + with self.assertRaisesRegex( + ValueError, + "from_logits", + ): + faster_rcnn.compile( + optimizer=keras.optimizers.SGD(learning_rate=0.25), + box_loss=keras_cv.losses.SmoothL1Loss( + l1_cutoff=1.0, reduction="none" + ), + classification_loss=keras_cv.losses.FocalLoss( + from_logits=False, reduction="none" + ), + rpn_box_loss=keras_cv.losses.SmoothL1Loss( + l1_cutoff=1.0, reduction="none" + ), + rpn_classification_loss=keras_cv.losses.FocalLoss( + from_logits=False, reduction="none" + ), + ) + + @pytest.mark.large() + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_weights_contained_in_trainable_variables(self): + bounding_box_format = "xyxy" + faster_rcnn = FasterRCNN( + num_classes=80, + bounding_box_format=bounding_box_format, + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + faster_rcnn.backbone.trainable = False + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + xs, ys = _create_bounding_box_dataset( + bounding_box_format, image_shape=(32, 32, 3) + ) + + # call once + _ = faster_rcnn(xs) + self.assertEqual(len(faster_rcnn.trainable_variables), 30) + + @pytest.mark.extra_large # Fit is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_no_nans(self): + faster_rcnn = FasterRCNN( + num_classes=5, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + # only a -1 box + xs = np.ones((1, 32, 32, 3), "float32") + ys = { + "classes": np.array([[-1]], "float32"), + "boxes": np.array([[[0, 0, 0, 0]]], "float32"), + } + ds = tf.data.Dataset.from_tensor_slices((xs, ys)) + ds = ds.repeat(1) + ds = ds.batch(1, drop_remainder=True) + faster_rcnn.fit(ds, epochs=1) + + weights = faster_rcnn.get_weights() + for weight in weights: + self.assertFalse(ops.any(ops.isnan(weight))) + + @pytest.mark.extra_large # Fit is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_weights_change(self): + faster_rcnn = FasterRCNN( + num_classes=3, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(128, 128, 3) + ), + ) + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + ds = _create_bounding_box_dataset( + "xyxy", image_shape=(128, 128, 3), use_dictionary_box_format=True + ) + + # call once + _ = faster_rcnn(ops.ones((1, 128, 128, 3))) + original_fpn_weights = faster_rcnn.feature_pyramid.get_weights() + original_rpn_head_weights = faster_rcnn.rpn_head.get_weights() + original_rcnn_head_weights = faster_rcnn.rcnn_head.get_weights() + + faster_rcnn.fit(ds, epochs=1) + fpn_after_fit = faster_rcnn.feature_pyramid.get_weights() + rpn_head_after_fit_weights = faster_rcnn.rpn_head.get_weights() + rcnn_head_after_fit_weights = faster_rcnn.rcnn_head.get_weights() + + for w1, w2 in zip( + original_rcnn_head_weights, + rcnn_head_after_fit_weights, + ): + self.assertNotAllClose(w1, w2) + + for w1, w2 in zip( + original_rpn_head_weights, rpn_head_after_fit_weights + ): + self.assertNotAllClose(w1, w2) + + for w1, w2 in zip(original_fpn_weights, fpn_after_fit): + self.assertNotAllClose(w1, w2) + + @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_saved_model(self): + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + ) + input_batch = ops.ones(shape=(1, 32, 32, 3)) + model_output = model(input_batch) + save_path = os.path.join(self.get_temp_dir(), "faster_rcnn.keras") + model.save(save_path) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, keras_cv.models.FasterRCNN) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose( + tf.nest.map_structure(ops.convert_to_numpy, model_output), + tf.nest.map_structure(ops.convert_to_numpy, restored_output), + ) + + @pytest.mark.large + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_faster_rcnn_infer(self): + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(128, 128, 3) + ), + ) + images = ops.ones((1, 128, 128, 3)) + outputs = model(images, training=False) + # 1000 proposals in inference + self.assertAllEqual([1, 1000, 81], outputs["classification"].shape) + self.assertAllEqual([1, 1000, 4], outputs["box"].shape) + + @pytest.mark.large + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_faster_rcnn_train(self): + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(128, 128, 3) + ), + ) + images = ops.ones((1, 128, 128, 3)) + outputs = model(images, training=True) + self.assertAllEqual([1, 1000, 81], outputs["classification"].shape) + self.assertAllEqual([1, 1000, 4], outputs["box"].shape) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_invalid_compile(self): + model = FasterRCNN( + num_classes=80, + bounding_box_format="yxyx", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + with self.assertRaisesRegex(ValueError, "expects"): + model.compile(rpn_box_loss="binary_crossentropy") + with self.assertRaisesRegex(ValueError, "from_logits"): + model.compile( + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss=keras.losses.BinaryCrossentropy( + from_logits=False + ), + ) + + @pytest.mark.extra_large # Fit is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_faster_rcnn_with_dictionary_input_format(self): + faster_rcnn = FasterRCNN( + num_classes=3, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + + images, boxes = _create_bounding_box_dataset( + "xywh", image_shape=(32, 32, 3) + ) + dataset = tf.data.Dataset.from_tensor_slices( + {"images": images, "bounding_boxes": boxes} + ).batch(1, drop_remainder=True) + + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + faster_rcnn.fit(dataset, epochs=1) + + @pytest.mark.extra_large # Fit is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_fit_with_no_valid_gt_bbox(self): + bounding_box_format = "xywh" + faster_rcnn = FasterRCNN( + num_classes=2, + bounding_box_format=bounding_box_format, + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + xs, ys = _create_bounding_box_dataset( + bounding_box_format, image_shape=(32, 32, 3) + ) + xs = ops.convert_to_tensor(xs) + # Make all bounding_boxes invalid and filter out them + ys["classes"] = -ops.ones_like(ys["classes"]) + + faster_rcnn.fit(x=xs, y=ys, epochs=1, batch_size=1) + + +# TODO: add presets test cases once model training is done. diff --git a/keras_cv/src/models/object_detection/faster_rcnn/feature_pyamid_test.py b/keras_cv/src/models/object_detection/faster_rcnn/feature_pyamid_test.py new file mode 100644 index 0000000000..7292a1837d --- /dev/null +++ b/keras_cv/src/models/object_detection/faster_rcnn/feature_pyamid_test.py @@ -0,0 +1,146 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_cv.src.backend import keras +from keras_cv.src.backend.config import keras_3 +from keras_cv.src.models.object_detection.faster_rcnn import FeaturePyramid +from keras_cv.src.tests.test_case import TestCase + + +class FeaturePyramidTest(TestCase): + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_return_type_dict(self): + layer = FeaturePyramid(min_level=2, max_level=5) + c2 = np.ones([2, 32, 32, 3]) + c3 = np.ones([2, 16, 16, 3]) + c4 = np.ones([2, 8, 8, 3]) + c5 = np.ones([2, 4, 4, 3]) + + inputs = {"P2": c2, "P3": c3, "P4": c4, "P5": c5} + output = layer(inputs) + self.assertTrue(isinstance(output, dict)) + self.assertEquals(sorted(output.keys()), ["P2", "P3", "P4", "P5", "P6"]) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_result_shapes(self): + layer = FeaturePyramid(min_level=2, max_level=5) + c2 = np.ones([2, 32, 32, 3]) + c3 = np.ones([2, 16, 16, 3]) + c4 = np.ones([2, 8, 8, 3]) + c5 = np.ones([2, 4, 4, 3]) + + inputs = {"P2": c2, "P3": c3, "P4": c4, "P5": c5} + output = layer(inputs) + for level in inputs.keys(): + self.assertEquals(output[level].shape[1], inputs[level].shape[1]) + self.assertEquals(output[level].shape[2], inputs[level].shape[2]) + self.assertEquals(output[level].shape[3], layer.num_channels) + + # Test with different resolution and channel size + c2 = np.ones([2, 64, 128, 4]) + c3 = np.ones([2, 32, 64, 8]) + c4 = np.ones([2, 16, 32, 16]) + c5 = np.ones([2, 8, 16, 32]) + + inputs = {"P2": c2, "P3": c3, "P4": c4, "P5": c5} + layer = FeaturePyramid(min_level=2, max_level=5) + output = layer(inputs) + for level in inputs.keys(): + self.assertEquals(output[level].shape[1], inputs[level].shape[1]) + self.assertEquals(output[level].shape[2], inputs[level].shape[2]) + self.assertEquals(output[level].shape[3], layer.num_channels) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_with_keras_input_tensor(self): + # This mimic the model building with Backbone network + layer = FeaturePyramid(min_level=2, max_level=5) + c2 = keras.layers.Input([32, 32, 3]) + c3 = keras.layers.Input([16, 16, 3]) + c4 = keras.layers.Input([8, 8, 3]) + c5 = keras.layers.Input([4, 4, 3]) + + inputs = {"P2": c2, "P3": c3, "P4": c4, "P5": c5} + output = layer(inputs) + for level in inputs.keys(): + self.assertEquals(output[level].shape[1], inputs[level].shape[1]) + self.assertEquals(output[level].shape[2], inputs[level].shape[2]) + self.assertEquals(output[level].shape[3], layer.num_channels) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_invalid_lateral_layers(self): + lateral_layers = [keras.layers.Conv2D(256, 1)] * 3 + with self.assertRaisesRegexp( + ValueError, "Expect lateral_layers to be a dict" + ): + _ = FeaturePyramid( + min_level=2, max_level=5, lateral_layers=lateral_layers + ) + lateral_layers = { + "P2": keras.layers.Conv2D(256, 1), + "P3": keras.layers.Conv2D(256, 1), + "P4": keras.layers.Conv2D(256, 1), + } + with self.assertRaisesRegexp( + ValueError, "with keys as .* ['P2', 'P3', 'P4', 'P5']" + ): + _ = FeaturePyramid( + min_level=2, max_level=5, lateral_layers=lateral_layers + ) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_invalid_output_layers(self): + output_layers = [keras.layers.Conv2D(256, 3)] * 3 + with self.assertRaisesRegexp( + ValueError, "Expect output_layers to be a dict" + ): + _ = FeaturePyramid( + min_level=2, max_level=5, output_layers=output_layers + ) + output_layers = { + "P2": keras.layers.Conv2D(256, 3), + "P3": keras.layers.Conv2D(256, 3), + "P4": keras.layers.Conv2D(256, 3), + } + with self.assertRaisesRegexp( + ValueError, "with keys as .* ['P2', 'P3', 'P4', 'P5']" + ): + _ = FeaturePyramid( + min_level=2, max_level=5, output_layers=output_layers + ) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_invalid_input_features(self): + layer = FeaturePyramid(min_level=2, max_level=5) + + c2 = np.ones([2, 32, 32, 3]) + c3 = np.ones([2, 16, 16, 3]) + c4 = np.ones([2, 8, 8, 3]) + c5 = np.ones([2, 4, 4, 3]) + inputs = {"P2": c2, "P3": c3, "P4": c4, "P5": c5} + # Build required for Keas 3 + _ = layer(inputs) + list_input = [c2, c3, c4, c5] + with self.assertRaisesRegexp( + ValueError, "expects input features to be a dict" + ): + layer(list_input) + + dict_input_with_missing_feature = {"P2": c2, "P3": c3, "P4": c4} + with self.assertRaisesRegexp( + ValueError, "Expect feature keys.*['P2', 'P3', 'P4', 'P5']" + ): + layer(dict_input_with_missing_feature) diff --git a/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py b/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py new file mode 100644 index 0000000000..4e7bd6b884 --- /dev/null +++ b/keras_cv/src/models/object_detection/faster_rcnn/feature_pyramid.py @@ -0,0 +1,251 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +from keras_cv.src.api_export import keras_cv_export +from keras_cv.src.backend import keras + + +@keras_cv_export( + "keras_cv.models.faster_rcnn.FeaturePyramid", + package="keras_cv.models.faster_rcnn", +) +class FeaturePyramid(keras.layers.Layer): + """Implements a Feature Pyramid Network. + + This implements the paper: + Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan, + and Serge Belongie. Feature Pyramid Networks for Object Detection. + (https://arxiv.org/pdf/1612.03144) + + Feature Pyramid Networks (FPNs) are basic components that are added to an + existing feature extractor (CNN) to combine features at different scales. + For the basic FPN, the inputs are features `Ci` from different levels of a + CNN, which is usually the last block for each level, where the feature is + scaled from the image by a factor of `1/2^i`. + + There is an output associated with each level in the basic FPN. The output + Pi at level `i` (corresponding to Ci) is given by performing a merge + operation on the outputs of: + + 1) a lateral operation on Ci (usually a conv2D layer with kernel = 1 and + strides = 1) + 2) a top-down upsampling operation from Pi+1 (except for the top most level) + + The final output of each level will also have a conv2D operation + (typically with kernel = 3 and strides = 1). + + The inputs to the layer should be a dict with int keys should match the + pyramid_levels, e.g. for `pyramid_levels` = [2,3,4,5], the expected input + dict should be `{2:c2, 3:c3, 4:c4, 5:c5}`. + + The output of the layer will have same structures as the inputs, a dict with + int keys and value for each of the level. + + Args: + min_level: a python int for the lowest level of the pyramid for + feature extraction. + max_level: a python int for the highest level of the pyramid for + feature extraction. + num_channels: an integer representing the number of channels for the FPN + operations, defaults to 256. + lateral_layers: a python dict with int keys that matches to each of the + pyramid level. The values of the dict should be `keras.Layer`, which + will be called with feature activation outputs from backbone at each + level. Defaults to None, and a `keras.Conv2D` layer with kernel 1x1 + will be created for each pyramid level. + output_layers: a python dict with int keys that matches to each of the + pyramid level. The values of the dict should be `keras.Layer`, which + will be called with feature inputs and merged result from upstream + levels. Defaults to None, and a `keras.Conv2D` layer with kernel 3x3 + will be created for each pyramid level. + + Example: + ```python + images = np.ones((1, 512, 512, 3)) + extractor_levels= ["P2", "P3", "P4", "P5"] + backbone = keras_cv.models.ResNetV2Backbone.from_preset( + "resnet50_v2_imagenet", include_rescaling=True + ) + + extractor_layer_names = [ + backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + + feature_extractor = get_feature_extractor( + backbone, + extractor_layer_names, + extractor_levels + ) + feature_pyramid = FeaturePyramid(min_level=2, max_level=5) + + backbone_outputs = feature_extractor(images) + feature_map = feature_pyramid(backbone_outputs) + ``` + """ + + def __init__( + self, + min_level, + max_level, + num_channels=256, + lateral_layers=None, + output_layers=None, + **kwargs, + ): + super().__init__(**kwargs) + self.min_level = min_level + self.max_level = max_level + self.pyramid_levels = [ + f"P{level}" for level in range(min_level, max_level + 1) + ] + self.num_channels = num_channels + + # required for successful serialization + self.lateral_layers_passed = lateral_layers + self.output_layers_passed = output_layers + + if not lateral_layers: + # populate self.lateral_ops with default FPN Conv2D 1X1 layers + self.lateral_layers = {} + for i in self.pyramid_levels: + self.lateral_layers[i] = keras.layers.Conv2D( + self.num_channels, + kernel_size=1, + strides=1, + padding="same", + name=f"lateral_P{i}", + ) + else: + self._validate_user_layers(lateral_layers, "lateral_layers") + self.lateral_layers = lateral_layers + + # Output conv2d layers. + if not output_layers: + self.output_layers = {} + for i in self.pyramid_levels: + self.output_layers[i] = keras.layers.Conv2D( + self.num_channels, + kernel_size=3, + strides=1, + padding="same", + name=f"output_P{i}", + ) + else: + self._validate_user_layers(output_layers, "output_layers") + self.output_layers = output_layers + # Applies a max_pool2d (not actual max_pool2d, we just subsample) on + # top of the last feature map + # Use max pooling to simulate stride 2 subsampling + self.max_pool = keras.layers.MaxPool2D( + pool_size=(1, 1), strides=2, padding="same" + ) + + # the same upsampling layer is used for all levels + self.top_down_op = keras.layers.UpSampling2D(size=2) + # the same merge layer is used for all levels + self.merge_op = keras.layers.Add() + + def _validate_user_layers(self, user_input, param_name): + if ( + not isinstance(user_input, dict) + or sorted(user_input.keys()) != self.pyramid_levels + ): + raise ValueError( + f"Expect {param_name} to be a dict with keys as " + f"{self.pyramid_levels}, got {user_input}" + ) + + def call(self, features): + # Note that this assertion might not be true for all the subclasses. It + # is possible to have FPN that has high levels than the height of + # backbone outputs. + if ( + not isinstance(features, dict) + or sorted(features.keys()) != self.pyramid_levels + ): + raise ValueError( + "FeaturePyramid expects input features to be a dict with int " + "keys that match the values provided in pyramid_levels. " + f"Expect feature keys: {self.pyramid_levels}, got: {features}" + ) + return self.build_feature_pyramid(features) + + def build_feature_pyramid(self, input_features): + # To illustrate the connection/topology, the basic flow for a FPN with + # level 3, 4, 5 is like below: + # output_l6 + # ^ + # | + # max_pool_2d + # ^ + # | + # input_l5 -> conv2d_1x1_l5 ----V---> conv2d_3x3_l5 -> output_l5 + # V + # upsample2d + # V + # input_l4 -> conv2d_1x1_l4 -> Add -> conv2d_3x3_l4 -> output_l4 + # V + # upsample2d + # V + # input_l3 -> conv2d_1x1_l3 -> Add -> conv2d_3x3_l3 -> output_l3 + # V + # upsample2d + # V + # input_l2 -> conv2d_1x1_l2 -> Add -> conv2d_3x3_l2 -> output_l2 + output_features = {} + reversed_levels = list(sorted(input_features.keys(), reverse=True)) + + for i in range(self.max_level, self.min_level - 1, -1): + level = f"P{i}" + output = self.lateral_layers[level](input_features[level]) + if i < self.max_level: + # for the top most output, it doesn't need to merge with any + # upper stream outputs + upstream_output = self.top_down_op(output_features[f"P{i + 1}"]) + output = self.merge_op([output, upstream_output]) + output_features[level] = output + + # Post apply the output layers so that we don't leak them to the down + # stream level + for level in reversed_levels: + output_features[level] = self.output_layers[level]( + output_features[level] + ) + output_features[f"P{self.max_level + 1}"] = self.max_pool( + output_features[f"P{self.max_level}"] + ) + output_features = OrderedDict(sorted(output_features.items())) + return output_features + + def get_config(self): + config = super().get_config() + config["min_level"] = self.min_level + config["max_level"] = self.max_level + config["num_channels"] = self.num_channels + config["lateral_layers"] = self.lateral_layers + config["output_layers"] = self.output_layers + return config + + def build(self, input_shape): + for level in self.pyramid_levels: + self.lateral_layers[level].build( + (None, None, None, input_shape[level][-1]) + ) + + for level in self.pyramid_levels: + self.output_layers[level].build((None, None, None, 256)) + self.max_pool.build((None, None, None, 256)) + self.built = True diff --git a/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head.py b/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head.py new file mode 100644 index 0000000000..5d77928e2d --- /dev/null +++ b/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head.py @@ -0,0 +1,107 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.src.api_export import keras_cv_export +from keras_cv.src.backend import keras + + +@keras_cv_export( + "keras_cv.models.faster_rcnn.RCNNHead", + package="keras_cv.models.faster_rcnn", +) +class RCNNHead(keras.layers.Layer): + """A Keras layer implementing the R-CNN Head. + + Args: + num_classes: The number of object classes to be detected. + conv_dims: (Optional) a list of integers specifying the number of + filters for each convolutional layer. Defaults to []. + fc_dims: (Optional) a list of integers specifying the number of + units for each fully-connected layer. Defaults to [1024, 1024]. + """ + + def __init__( + self, + num_classes, + conv_dims=[], + fc_dims=[1024, 1024], + **kwargs, + ): + super().__init__(**kwargs) + self.num_classes = num_classes + self.conv_dims = conv_dims + self.fc_dims = fc_dims + self.convs = [] + for conv_dim in conv_dims: + layer = keras.layers.Conv2D( + filters=conv_dim, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + ) + self.convs.append(layer) + self.fcs = [] + for fc_dim in fc_dims: + layer = keras.layers.Dense( + units=fc_dim, + activation="relu", + ) + self.fcs.append(layer) + self.box_pred = keras.layers.Dense( + units=4, + kernel_initializer=keras.initializers.RandomNormal(stddev=0.01), + ) + self.cls_score = keras.layers.Dense( + units=num_classes + 1, + kernel_initializer=keras.initializers.RandomNormal(stddev=0.01), + ) + + def call(self, feature_map, training=False): + x = feature_map + for conv in self.convs: + x = conv(x, training=training) + for fc in self.fcs: + x = fc(x, training=training) + rcnn_boxes = self.box_pred(x, training=training) + rcnn_scores = self.cls_score(x, training=training) + return rcnn_boxes, rcnn_scores + + def build(self, input_shape): + intermediate_shape = input_shape + if self.conv_dims: + for idx in range(len(self.convs)): + self.convs[idx].build(intermediate_shape) + intermediate_shape = tuple(intermediate_shape[:-1]) + ( + self.conv_dims[idx], + ) + + for idx in range(len(self.fc_dims)): + self.fcs[idx].build(intermediate_shape) + intermediate_shape = tuple(intermediate_shape[:-1]) + ( + self.fc_dims[idx], + ) + + self.box_pred.build(intermediate_shape) + self.cls_score.build(intermediate_shape) + + self.built = True + + def get_config(self): + config = super().get_config() + config["num_classes"] = self.num_classes + config["conv_dims"] = self.conv_dims + config["fc_dims"] = self.fc_dims + + return config diff --git a/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head_test.py b/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head_test.py new file mode 100644 index 0000000000..7607359ef8 --- /dev/null +++ b/keras_cv/src/models/object_detection/faster_rcnn/rcnn_head_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from absl.testing import parameterized + +from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 +from keras_cv.src.models.object_detection.faster_rcnn import RCNNHead +from keras_cv.src.tests.test_case import TestCase + + +class RCNNHeadTest(TestCase): + @parameterized.parameters( + (2, 256, 20, 7, 256), + (1, 512, 80, 14, 512), + ) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_rcnn_head_output_shapes( + self, + batch_size, + num_rois, + num_classes, + roi_align_target_size, + num_filters, + ): + layer = RCNNHead(num_classes) + + feature_map_size = (roi_align_target_size**2) * num_filters + inputs = ops.ones(shape=(batch_size, num_rois, feature_map_size)) + outputs = layer(inputs) + + self.assertEqual((batch_size, num_rois, 4), outputs[0].shape) + self.assertEqual( + (batch_size, num_rois, num_classes + 1), outputs[1].shape + ) diff --git a/keras_cv/src/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/src/models/object_detection/faster_rcnn/rpn_head.py new file mode 100644 index 0000000000..10880297be --- /dev/null +++ b/keras_cv/src/models/object_detection/faster_rcnn/rpn_head.py @@ -0,0 +1,113 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.src.api_export import keras_cv_export +from keras_cv.src.backend import keras + + +@keras_cv_export( + "keras_cv.models.faster_rcnn.RPNHead", + package="keras_cv.models.faster_rcnn", +) +class RPNHead(keras.layers.Layer): + """A Keras layer implementing the RPN architecture. + + Region Proposal Networks (RPN) was first suggested in + [FasterRCNN](https://arxiv.org/abs/1506.01497). + This is an end to end trainable layer which proposes regions + for a detector (RCNN). + + Args: + num_achors_per_location: (Optional) the number of anchors per location, + defaults to 3. + num_filters: (Optional) number convolution filters + kernel_size: (Optional) kernel size of the convolution filters. + """ + + def __init__( + self, + num_anchors_per_location=3, + num_filters=256, + kernel_size=3, + **kwargs, + ): + super().__init__(**kwargs) + self.num_anchors = num_anchors_per_location + self.num_filters = num_filters + self.kernel_size = kernel_size + + self.conv = keras.layers.Conv2D( + filters=num_filters, + kernel_size=kernel_size, + strides=1, + padding="same", + activation="relu", + kernel_initializer=keras.initializers.RandomNormal(stddev=0.01), + ) + self.objectness_logits = keras.layers.Conv2D( + filters=self.num_anchors * 1, + kernel_size=1, + strides=1, + padding="valid", + kernel_initializer=keras.initializers.RandomNormal(stddev=0.01), + ) + self.anchor_deltas = keras.layers.Conv2D( + filters=self.num_anchors * 4, + kernel_size=1, + strides=1, + padding="valid", + kernel_initializer=keras.initializers.RandomNormal(stddev=0.01), + ) + + def call(self, feature_map, training=False): + def call_single_level(f_map): + # [BS, H, W, C] + t = self.conv(f_map, training=training) + # [BS, H, W, K] + rpn_scores = self.objectness_logits(t, training=training) + # [BS, H, W, K * 4] + rpn_boxes = self.anchor_deltas(t, training=training) + return rpn_boxes, rpn_scores + + if not isinstance(feature_map, (dict, list, tuple)): + return call_single_level(feature_map) + elif isinstance(feature_map, (list, tuple)): + rpn_boxes = [] + rpn_scores = [] + for f_map in feature_map: + rpn_box, rpn_score = call_single_level(f_map) + rpn_boxes.append(rpn_box) + rpn_scores.append(rpn_score) + return rpn_boxes, rpn_scores + else: + rpn_boxes = {} + rpn_scores = {} + for lvl, f_map in feature_map.items(): + rpn_box, rpn_score = call_single_level(f_map) + rpn_boxes[lvl] = rpn_box + rpn_scores[lvl] = rpn_score + return rpn_boxes, rpn_scores + + def get_config(self): + config = super().get_config() + config["num_anchors_per_location"] = self.num_anchors + config["num_filters"] = self.num_filters + config["kernel_size"] = self.kernel_size + return config + + def build(self, input_shape): + self.conv.build((None, None, None, self.num_filters)) + self.objectness_logits.build((None, None, None, self.num_filters)) + self.anchor_deltas.build((None, None, None, self.num_filters)) + self.built = True diff --git a/keras_cv/src/models/object_detection/faster_rcnn/rpn_head_test.py b/keras_cv/src/models/object_detection/faster_rcnn/rpn_head_test.py new file mode 100644 index 0000000000..56a11af706 --- /dev/null +++ b/keras_cv/src/models/object_detection/faster_rcnn/rpn_head_test.py @@ -0,0 +1,89 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from absl.testing import parameterized + +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 +from keras_cv.src.models.object_detection.faster_rcnn import RPNHead +from keras_cv.src.tests.test_case import TestCase + + +class RCNNHeadTest(TestCase): + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_return_type_dict( + self, + ): + layer = RPNHead() + c2 = ops.ones([2, 64, 64, 256]) + c3 = ops.ones([2, 32, 32, 256]) + c4 = ops.ones([2, 16, 16, 256]) + c5 = ops.ones([2, 8, 8, 256]) + c6 = ops.ones([2, 4, 4, 256]) + + inputs = {"P2": c2, "P3": c3, "P4": c4, "P5": c5, "P6": c6} + rpn_boxes, rpn_scores = layer(inputs) + self.assertTrue(isinstance(rpn_boxes, dict)) + self.assertTrue(isinstance(rpn_scores, dict)) + self.assertEquals( + sorted(rpn_boxes.keys()), ["P2", "P3", "P4", "P5", "P6"] + ) + self.assertEquals( + sorted(rpn_scores.keys()), ["P2", "P3", "P4", "P5", "P6"] + ) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_return_type_list(self): + layer = RPNHead() + c2 = ops.ones([2, 64, 64, 256]) + c3 = ops.ones([2, 32, 32, 256]) + c4 = ops.ones([2, 16, 16, 256]) + c5 = ops.ones([2, 8, 8, 256]) + c6 = ops.ones([2, 4, 4, 256]) + + inputs = [c2, c3, c4, c5, c6] + rpn_boxes, rpn_scores = layer(inputs) + self.assertTrue(isinstance(rpn_boxes, list)) + self.assertTrue(isinstance(rpn_scores, list)) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + @parameterized.parameters( + (3,), + (9,), + ) + def test_with_keras_input_tensor_and_num_anchors(self, num_anchors): + layer = RPNHead(num_anchors_per_location=num_anchors) + c2 = keras.layers.Input([64, 64, 256]) + c3 = keras.layers.Input([32, 32, 256]) + c4 = keras.layers.Input([16, 16, 256]) + c5 = keras.layers.Input([8, 8, 256]) + c6 = keras.layers.Input([4, 4, 256]) + + inputs = {"P2": c2, "P3": c3, "P4": c4, "P5": c5, "P6": c6} + rpn_boxes, rpn_scores = layer(inputs) + for level in inputs.keys(): + self.assertEquals(rpn_boxes[level].shape[1], inputs[level].shape[1]) + self.assertEquals(rpn_boxes[level].shape[2], inputs[level].shape[2]) + self.assertEquals(rpn_boxes[level].shape[3], layer.num_anchors * 4) + + for level in inputs.keys(): + self.assertEquals( + rpn_scores[level].shape[1], inputs[level].shape[1] + ) + self.assertEquals( + rpn_scores[level].shape[2], inputs[level].shape[2] + ) + self.assertEquals(rpn_scores[level].shape[3], layer.num_anchors * 1) diff --git a/keras_cv/src/models/object_detection/yolo_v8/yolo_v8_detector_test.py b/keras_cv/src/models/object_detection/yolo_v8/yolo_v8_detector_test.py index 70ba79e92b..8d0faf6c58 100644 --- a/keras_cv/src/models/object_detection/yolo_v8/yolo_v8_detector_test.py +++ b/keras_cv/src/models/object_detection/yolo_v8/yolo_v8_detector_test.py @@ -246,7 +246,9 @@ def test_preset_with_forward_pass(self): self.assertAllClose( ops.convert_to_numpy(encoded_predictions["boxes"][0, 0:5, 0]), - [-0.8303556, 0.75213313, 1.809204, 1.6576759, 1.4134747], + [-0.830356, 0.752131, 1.809205, 1.657676, 1.413475], + atol=1e-5, + rtol=1e-5, ) self.assertAllClose( ops.convert_to_numpy(encoded_predictions["classes"][0, 0:5, 0]),