From e778eca667b34bdc22b83c7218acfadf7ca26ca6 Mon Sep 17 00:00:00 2001 From: liueo Date: Tue, 26 May 2020 06:33:52 +0000 Subject: [PATCH 1/7] matrix net --- gluoncv/data/transforms/presets/matrix_net.py | 194 +++++++++++ gluoncv/model_zoo/matrix_net/matrix_net.py | 329 ++++++++++++++++++ .../model_zoo/matrix_net/target_generator.py | 197 +++++++++++ gluoncv/model_zoo/model_zoo.py | 2 + gluoncv/nn/coder.py | 69 ++++ .../detection/matrix_net/train_matrix_net.py | 307 ++++++++++++++++ 6 files changed, 1098 insertions(+) create mode 100644 gluoncv/data/transforms/presets/matrix_net.py create mode 100644 gluoncv/model_zoo/matrix_net/matrix_net.py create mode 100644 gluoncv/model_zoo/matrix_net/target_generator.py create mode 100644 scripts/detection/matrix_net/train_matrix_net.py diff --git a/gluoncv/data/transforms/presets/matrix_net.py b/gluoncv/data/transforms/presets/matrix_net.py new file mode 100644 index 0000000000..6bafdf4cba --- /dev/null +++ b/gluoncv/data/transforms/presets/matrix_net.py @@ -0,0 +1,194 @@ +"""Transforms described in https://arxiv.org/abs/1904.07850 and https://arxiv.org/abs/2001.03194.""" +# pylint: disable=too-many-function-args +from __future__ import absolute_import +import numpy as np +import mxnet as mx +from .. import bbox as tbbox +from .. import image as timage +from .. import experimental +from ....utils.filesystem import try_import_cv2 + +__all__ = ['MatrixNetDefaultTrainTransform', 'MatrixNetDefaultValTransform', + 'get_post_transform'] + +class MatrixNetDefaultTrainTransform(object): + """Default MatrixNet training transform which includes tons of image augmentations. + + Parameters + ---------- + width : int + Image width. + height : int + Image height. + num_class : int + Number of categories + scale_factor : int, default is 4 + The downsampling scale factor between input image and output heatmap + mean : array-like of size 3 + Mean pixel values to be subtracted from image tensor. Default is [0.485, 0.456, 0.406]. + std : array-like of size 3 + Standard deviation to be divided from image. Default is [0.229, 0.224, 0.225]. + """ + def __init__(self, width, height, num_class, layers_range, scale_factor=4, mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), **kwargs): + self._kwargs = kwargs + self._width = width + self._height = height + self._num_class = num_class + self._layers_range = layers_range + self._scale_factor = scale_factor + self._mean = np.array(mean, dtype=np.float32).reshape(1, 1, 3) + self._std = np.array(std, dtype=np.float32).reshape(1, 1, 3) + self._data_rng = np.random.RandomState(123) + self._eig_val = np.array([0.2141788, 0.01817699, 0.00341571], + dtype=np.float32) + self._eig_vec = np.array([ + [-0.58752847, -0.69563484, 0.41340352], + [-0.5832747, 0.00994535, -0.81221408], + [-0.56089297, 0.71832671, 0.41158938] + ], dtype=np.float32) + + from ....model_zoo.matrix_net.target_generator import MatrixNetTargetGenerator + self._target_generator = MatrixNetTargetGenerator( + num_class, width, height, self._layers_range) + + def __call__(self, src, label): + """Apply transform to training image/label.""" + # random color jittering + img = src + bbox = label + + # random horizontal flip + h, w, _ = img.shape + img, flips = timage.random_flip(img, px=0.5) + bbox = tbbox.flip(bbox, (w, h), flip_x=flips[0]) + + cv2 = try_import_cv2() + input_h, input_w = self._height, self._width + s = max(h, w) * 1.0 + c = np.array([w / 2., h / 2.], dtype=np.float32) + sf = 0.4 + w_border = _get_border(128, img.shape[1]) + h_border = _get_border(128, img.shape[0]) + c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border) + c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border) + s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf) + trans_input = tbbox.get_affine_transform(c, s, 0, [input_w, input_h]) + inp = cv2.warpAffine(img.asnumpy(), trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) + + trans_output = tbbox.get_affine_transform(c, s, 0, [input_w, input_h]) + for i in range(bbox.shape[0]): + bbox[i, :2] = tbbox.affine_transform(bbox[i, :2], trans_output) + bbox[i, 2:4] = tbbox.affine_transform(bbox[i, 2:4], trans_output) + bbox[:, :2] = np.clip(bbox[:, :2], 0, input_w - 1) + bbox[:, 2:4] = np.clip(bbox[:, 2:4], 0, input_h - 1) + img = inp + + # to tensor + img = img.astype(np.float32) / 255. + experimental.image.np_random_color_distort(img, data_rng=self._data_rng) + img = (img - self._mean) / self._std + img = img.transpose(2, 0, 1).astype(np.float32) + img = mx.nd.array(img) + + # generate training target so cpu workers can help reduce the workload on gpu + gt_bboxes = bbox[:, :4] + gt_ids = bbox[:, 4:5] + heatmaps, wh_targets, wh_masks, center_regs, center_reg_masks = self._target_generator( + gt_bboxes, gt_ids) + results = [] + results.append(img) + for heatmap in heatmaps: + results.append(heatmap) + for wh_target in wh_targets: + results.append(wh_target) + for wh_mask in wh_masks: + results.append(wh_mask) + for center_reg in center_regs: + results.append(center_reg) + for center_reg_mask in center_reg_masks: + results.append(center_reg_mask) + return tuple(results) + + +class MatrixNetDefaultValTransform(object): + """Default SSD validation transform. + + Parameters + ---------- + width : int + Image width. + height : int + Image height. + mean : array-like of size 3 + Mean pixel values to be subtracted from image tensor. Default is [0.485, 0.456, 0.406]. + std : array-like of size 3 + Standard deviation to be divided from image. Default is [0.229, 0.224, 0.225]. + + """ + def __init__(self, width, height, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + self._width = width + self._height = height + self._mean = np.array(mean, dtype=np.float32).reshape(1, 1, 3) + self._std = np.array(std, dtype=np.float32).reshape(1, 1, 3) + + def __call__(self, src, label): + """Apply transform to validation image/label.""" + # resize + img, bbox = src.asnumpy(), label + cv2 = try_import_cv2() + input_h, input_w = self._height, self._width + h, w, _ = src.shape + s = max(h, w) * 1.0 + c = np.array([w / 2., h / 2.], dtype=np.float32) + trans_input = tbbox.get_affine_transform(c, s, 0, [input_w, input_h]) + inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) + output_w = input_w + output_h = input_h + trans_output = tbbox.get_affine_transform(c, s, 0, [output_w, output_h]) + for i in range(bbox.shape[0]): + bbox[i, :2] = tbbox.affine_transform(bbox[i, :2], trans_output) + bbox[i, 2:4] = tbbox.affine_transform(bbox[i, 2:4], trans_output) + bbox[:, :2] = np.clip(bbox[:, :2], 0, output_w - 1) + bbox[:, 2:4] = np.clip(bbox[:, 2:4], 0, output_h - 1) + img = inp + + # to tensor + img = img.astype(np.float32) / 255. + img = (img - self._mean) / self._std + img = img.transpose(2, 0, 1).astype(np.float32) + img = mx.nd.array(img) + return img, bbox.astype(img.dtype) + +def get_post_transform(orig_w, orig_h, out_w, out_h): + """Get the post prediction affine transforms. This will be used to adjust the prediction results + according to original coco image resolutions. + + Parameters + ---------- + orig_w : int + Original width of the image. + orig_h : int + Original height of the image. + out_w : int + Width of the output image after prediction. + out_h : int + Height of the output image after prediction. + + Returns + ------- + numpy.ndarray + Affine transform matrix 3x2. + + """ + s = max(orig_w, orig_h) * 1.0 + c = np.array([orig_w / 2., orig_h / 2.], dtype=np.float32) + trans_output = tbbox.get_affine_transform(c, s, 0, [out_w, out_h], inv=True) + return trans_output + +def _get_border(border, size): + """Get the border size of the image""" + i = 1 + while size - border // i <= border // i: + i *= 2 + return border // i diff --git a/gluoncv/model_zoo/matrix_net/matrix_net.py b/gluoncv/model_zoo/matrix_net/matrix_net.py new file mode 100644 index 0000000000..5cad55bfa6 --- /dev/null +++ b/gluoncv/model_zoo/matrix_net/matrix_net.py @@ -0,0 +1,329 @@ +"""MatrixNet object detector: https://arxiv.org/abs/2001.03194""" +from __future__ import absolute_import + +import os +import warnings +from collections import OrderedDict + +import mxnet as mx +from mxnet.gluon import nn +from mxnet import autograd +from ...nn.coder import MatrixNetDecoder + +__all__ = ['MatrixNet', 'get_matrix_net','matrix_net_resnet101_v1d_coco'] + +class MatrixNet(nn.HybridBlock): + """https://arxiv.org/abs/2001.03194 + + Parameters + ---------- + base_network : mxnet.gluon.nn.HybridBlock + The base feature extraction network. + heads : OrderedDict + OrderedDict with specifications for each head. + For example: OrderedDict([ + ('heatmap', {'num_output': len(classes), 'bias': -2.19}), + ('wh', {'num_output': 2}), + ('reg', {'num_output': 2}) + ]) + classes : list of str + Category names. + head_conv_channel : int, default is 0 + If > 0, will use an extra conv layer before each of the real heads. + scale : float, default is 4.0 + The downsampling ratio of the entire network. + topk : int, default is 100 + Number of outputs . + flip_test : bool + Whether apply flip test in inference (training mode not affected). + nms_thresh : float, default is 0. + Non-maximum suppression threshold. You can specify < 0 or > 1 to disable NMS. + By default nms is disabled. + nms_topk : int, default is 400 + Apply NMS to top k detection results, use -1 to disable so that every Detection + result is used in NMS. + post_nms : int, default is 100 + Only return top `post_nms` detection results, the rest is discarded. The number is + based on COCO dataset which has maximum 100 objects per image. You can adjust this + number if expecting more objects. You can use -1 to return all detections. + + """ + def __init__(self, base_network, heads, classes, layers_range, + head_conv_channel=0, base_layer_scale=8.0, topk=100, flip_test=False, + nms_thresh=0.5, nms_topk=300, post_nms=100, **kwargs): + if 'norm_layer' in kwargs: + kwargs.pop('norm_layer') + if 'norm_kwargs' in kwargs: + kwargs.pop('norm_kwargs') + super(MatrixNet, self).__init__(**kwargs) + assert isinstance(heads, OrderedDict), \ + "Expecting heads to be a OrderedDict per head, given {}" \ + .format(type(heads)) + self.classes = classes + self.topk = topk + self.nms_thresh = nms_thresh + self.nms_topk = nms_topk + post_nms = min(post_nms, topk) + self.post_nms = post_nms + self.base_layer_scale = base_layer_scale + self.layers_range = layers_range + self.flip_test = flip_test + with self.name_scope(): + self.base_network = base_network + self.heatmap_nms = nn.MaxPool2D(pool_size=3, strides=1, padding=1) + weight_initializer = mx.init.Normal(0.01) + self.pyramid_transformation_7 = nn.Conv2D( + 256, kernel_size=3, padding=1, strides=2, use_bias=True, + weight_initializer=weight_initializer, bias_initializer='zeros') + self.downsample_transformation_12 = nn.Conv2D( + 256, kernel_size=3, padding=1, strides=(1,2), use_bias=True, + weight_initializer=weight_initializer, bias_initializer='zeros') + self.downsample_transformation_21 = nn.Conv2D( + 256, kernel_size=3, padding=1, strides=(2,1), use_bias=True, + weight_initializer=weight_initializer, bias_initializer='zeros') + self.decoder = MatrixNetDecoder(topk=topk, base_layer_scale=base_layer_scale) + self.heads = nn.HybridSequential('heads') + for name, values in heads.items(): + head = nn.HybridSequential(name) + num_output = values['num_output'] + bias = values.get('bias', 0.0) + weight_initializer = mx.init.Normal(0.001) if bias == 0 else mx.init.Xavier() + if head_conv_channel > 0: + head.add(nn.Conv2D( + head_conv_channel, kernel_size=3, padding=1, use_bias=True, + weight_initializer=weight_initializer, bias_initializer='zeros')) + head.add(nn.Activation('relu')) + head.add(nn.Conv2D(num_output, kernel_size=1, strides=1, padding=0, use_bias=True, + weight_initializer=weight_initializer, + bias_initializer=mx.init.Constant(bias))) + + self.heads.add(head) + + @property + def num_classes(self): + """Return number of foreground classes. + + Returns + ------- + int + Number of foreground classes + + """ + return len(self.classes) + + def set_nms(self, nms_thresh=0, nms_topk=400, post_nms=100): + """Set non-maximum suppression parameters. + + Parameters + ---------- + nms_thresh : float, default is 0. + Non-maximum suppression threshold. You can specify < 0 or > 1 to disable NMS. + By default NMS is disabled. + nms_topk : int, default is 400 + Apply NMS to top k detection results, use -1 to disable so that every Detection + result is used in NMS. + post_nms : int, default is 100 + Only return top `post_nms` detection results, the rest is discarded. The number is + based on COCO dataset which has maximum 100 objects per image. You can adjust this + number if expecting more objects. You can use -1 to return all detections. + + Returns + ------- + None + + """ + self._clear_cached_op() + self.nms_thresh = nms_thresh + self.nms_topk = nms_topk + post_nms = min(post_nms, self.nms_topk) + self.post_nms = post_nms + + def reset_class(self, classes, reuse_weights=None): + """Reset class categories and class predictors. + + Parameters + ---------- + classes : iterable of str + The new categories. ['apple', 'orange'] for example. + reuse_weights : dict + A {new_integer : old_integer} or mapping dict or {new_name : old_name} mapping dict, + or a list of [name0, name1,...] if class names don't change. + This allows the new predictor to reuse the + previously trained weights specified. + + Example + ------- + >>> net = gluoncv.model_zoo.get_model('center_net_resnet50_v1b_voc', pretrained=True) + >>> # use direct name to name mapping to reuse weights + >>> net.reset_class(classes=['person'], reuse_weights={'person':'person'}) + >>> # or use interger mapping, person is the 14th category in VOC + >>> net.reset_class(classes=['person'], reuse_weights={0:14}) + >>> # you can even mix them + >>> net.reset_class(classes=['person'], reuse_weights={'person':14}) + >>> # or use a list of string if class name don't change + >>> net.reset_class(classes=['person'], reuse_weights=['person']) + + """ + raise NotImplementedError("Not yet implemented, please wait for future updates.") + + def hybrid_forward(self, F, x): + # pylint: disable=arguments-differ + """Hybrid forward of matrixnet""" + feature_2, feature_3, feature_4, feature_5, feature_6 = self.base_network(x) + _dict = {} + _dict[11] = feature_3 + _dict[22] = feature_4 + _dict[33] = feature_5 + _dict[44] = feature_6 + _dict[55] = self.pyramid_transformation_7(_dict[44]) + _dict[12] = self.downsample_transformation_21(_dict[11]) + _dict[13] = self.downsample_transformation_21(_dict[12]) + _dict[23] = self.downsample_transformation_21(_dict[22]) + _dict[24] = self.downsample_transformation_21(_dict[23]) + _dict[34] = self.downsample_transformation_21(_dict[33]) + _dict[35] = self.downsample_transformation_21(_dict[34]) + _dict[45] = self.downsample_transformation_21(_dict[44]) + _dict[21] = self.downsample_transformation_12(_dict[11]) + _dict[31] = self.downsample_transformation_12(_dict[21]) + _dict[32] = self.downsample_transformation_12(_dict[22]) + _dict[42] = self.downsample_transformation_12(_dict[32]) + _dict[43] = self.downsample_transformation_12(_dict[33]) + _dict[53] = self.downsample_transformation_12(_dict[43]) + _dict[54] = self.downsample_transformation_12(_dict[44]) + ys = [ _dict[i] for i in sorted(_dict)] + heatmaps = [self.heads[0](y) for y in ys] + wh_preds = [self.heads[1](y) for y in ys] + center_regrs = [self.heads[2](y) for y in ys] + heatmaps = [F.sigmoid(heatmap) for heatmap in heatmaps] + if autograd.is_training(): + heatmaps = [F.clip(heatmap, 1e-4, 1 - 1e-4) for heatmap in heatmaps] + return heatmaps, wh_preds, center_regrs + if self.flip_test: + feature_2_flip, feature_3_flip, feature_4_flip, feature_5_flip, feature_6_flip = self.base_network(x.flip(axis=3)) + _dict_flip = {} + _dict_flip[11] = feature_3_flip + _dict_flip[22] = feature_4_flip + _dict_flip[33] = feature_5_flip + _dict_flip[44] = feature_6_flip + _dict_flip[55] = self.pyramid_transformation_7(_dict_flip[44]) + _dict_flip[12] = self.downsample_transformation_21(_dict_flip[11]) + _dict_flip[13] = self.downsample_transformation_21(_dict_flip[12]) + _dict_flip[23] = self.downsample_transformation_21(_dict_flip[22]) + _dict_flip[24] = self.downsample_transformation_21(_dict_flip[23]) + _dict_flip[34] = self.downsample_transformation_21(_dict_flip[33]) + _dict_flip[35] = self.downsample_transformation_21(_dict_flip[34]) + _dict_flip[45] = self.downsample_transformation_21(_dict_flip[44]) + _dict_flip[21] = self.downsample_transformation_12(_dict_flip[11]) + _dict_flip[31] = self.downsample_transformation_12(_dict_flip[21]) + _dict_flip[32] = self.downsample_transformation_12(_dict_flip[22]) + _dict_flip[42] = self.downsample_transformation_12(_dict_flip[32]) + _dict_flip[43] = self.downsample_transformation_12(_dict_flip[33]) + _dict_flip[53] = self.downsample_transformation_12(_dict_flip[43]) + _dict_flip[54] = self.downsample_transformation_12(_dict_flip[44]) + ys_flip = [ _dict_flip[i] for i in sorted(_dict_flip)] + heatmaps_flip = [self.heads[0](y) for y in ys_flip] + wh_preds_flip = [self.heads[1](y) for y in ys_flip] + center_regrs_flip = [self.heads[2](y) for y in ys_flip] + heatmaps_flip = [F.sigmoid(heatmap) for heatmap in heatmaps_flip] + for i in range(len(heatmaps)): + heatmaps[i] = (heatmaps[i] + heatmaps_flip[i].flip(axis=3)) * 0.5 + wh_preds[i] = (wh_preds[i] + wh_preds_flip[i].flip(axis=3)) * 0.5 + + + keeps = [F.broadcast_equal(self.heatmap_nms(heatmap), heatmap) for heatmap in heatmaps] + results = self.decoder(keeps, heatmaps, wh_preds, center_regrs) + if self.nms_thresh > 0 and self.nms_thresh < 1: + results = F.contrib.box_nms( + results, overlap_thresh=self.nms_thresh, topk=self.nms_topk, valid_thresh=0.01, + id_index=0, score_index=1, coord_start=2, force_suppress=False) + if self.post_nms > 0: + results = results.slice_axis(axis=1, begin=0, end=self.post_nms) + ids = F.slice_axis(results, axis=2, begin=0, end=1) + scores = F.slice_axis(results, axis=2, begin=1, end=2) + bboxes = F.slice_axis(results, axis=2, begin=2, end=6) + return ids, scores, bboxes + + +def get_matrix_net(name, dataset, pretrained=False, ctx=mx.cpu(), + root=os.path.join('~', '.mxnet', 'models'), **kwargs): + """Get a matrix net instance. + + Parameters + ---------- + name : str or None + Model name, if `None` is used, you must specify `features` to be a `HybridBlock`. + dataset : str + Name of dataset. This is used to identify model name because models trained on + different datasets are going to be very different. + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + ctx : mxnet.Context + Context such as mx.cpu(), mx.gpu(0). + root : str + Model weights storing path. + + Returns + ------- + HybridBlock + A MatrixNet detection network. + + """ + # pylint: disable=unused-variable + net = MatrixNet(**kwargs) + if pretrained: + from ..model_store import get_model_file + full_name = '_'.join(('matrix_net', name, dataset)) + net.load_parameters(get_model_file(full_name, tag=pretrained, root=root), ctx=ctx) + else: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + net.initialize() + for v in net.collect_params().values(): + try: + v.reset_ctx(ctx) + except ValueError: + pass + return net + +def matrix_net_resnet101_v1d_coco(pretrained=False, pretrained_base=True, **kwargs): + """Matrix net with resnet101_v1d base network on coco dataset. + + Parameters + ---------- + classes : iterable of str + Names of custom foreground classes. `len(classes)` is the number of foreground classes. + pretrained_base : bool or str, optional, default is True + Load pretrained base network, the extra layers are randomized. + + Returns + ------- + HybridBlock + A CenterNet detection network. + + """ + from ....model_zoo.resnetv1b import resnet101_v1d + from ....data import COCODetection + classes = COCODetection.CLASSES + pretrained_base = False if pretrained else pretrained_base + base_network = resnet101_v1d(pretrained=pretrained_base, dilated=False, + use_global_stats=True, **kwargs) + features = FPNFeatureExpander( + network=base_network, + outputs=['layers1_relu8_fwd', 'layers2_relu11_fwd', 'layers3_relu68_fwd', + 'layers4_relu8_fwd'], num_filters=[256, 256, 256, 256], use_1x1=True, + use_upsample=True, use_elewadd=True, use_p6=True, no_bias=False, pretrained=pretrained_base) + heads = OrderedDict([ + ('heatmap', {'num_output': len(classes), 'bias': -2.19}), # use bias = -log((1 - 0.1) / 0.1) + ('wh', {'num_output': 2}), + ('reg', {'num_output': 2}) + ]) + layers_range = [[[0,48,0,48],[48,96,0,48],[96,192,0,48], -1, -1], + [[0,48,48,96],[48,96,48,96],[96,192,48,96],[192,384,0,96], -1], + [[0,48,96,192],[48,96,96,192],[96,192,96,192],[192,384,96,192],[384,2000,96,192]], + [-1, [0,96,192,384],[96,192,192,384],[192,384,192,384],[384,2000,192,384]], + [-1, -1, [0,192,384,2000],[192,384,384,2000],[384,2000,384,2000]]] + return get_matrix_net('resnet101_v1d', 'coco', base_network=features, heads=heads, layers_range = layers_range + head_conv_channel=64, pretrained=pretrained, classes=classes, + base_layer_scale=8.0, topk=100, **kwargs) + diff --git a/gluoncv/model_zoo/matrix_net/target_generator.py b/gluoncv/model_zoo/matrix_net/target_generator.py new file mode 100644 index 0000000000..be29fa4103 --- /dev/null +++ b/gluoncv/model_zoo/matrix_net/target_generator.py @@ -0,0 +1,197 @@ +"""MatrixNet training target generator.""" +from __future__ import absolute_import + +import numpy as np + +from mxnet import nd +from mxnet import gluon + +def layer_map_using_ranges(width, height, layer_ranges, fpn_flag=0): + layers = [] + + for i, layer_range in enumerate(layer_ranges): + if fpn_flag ==0: + if (width >= 0.8 * layer_range[2]) and (width <= 1.3 * layer_range[3]) and (height >= 0.8 * layer_range[0]) and (height <= 1.3 * layer_range[1]): + layers.append(i) + else: + max_dim = max(height, width) + if max_dim <= 1.3*layer_range[1] and max_dim >= 0.8* layer_range[0]: + layers.append(i) + if len(layers) > 0: + return layers + else: + return [len(layer_ranges) - 1] + + +class MatrixNetTargetGenerator(gluon.Block): + """Target generator for CenterNet. + + Parameters + ---------- + num_class : int + Number of categories. + output_width : int + Width of the network output. + output_height : int + Height of the network output. + + """ + def __init__(self, num_class, input_width, input_height, layers_range): + super(MatrixNetTargetGenerator, self).__init__() + self._num_class = num_class + self._input_width = int(input_width) + self._input_height = int(input_height) + self._layers_range = layers_range + + def forward(self, gt_boxes, gt_ids): + """Target generation""" + # pylint: disable=arguments-differ + _dict={} + output_sizes=[] + # indexing layer map + for i,l in enumerate(self._layers_range): + for j,e in enumerate(l): + if e !=-1: + output_sizes.append([self._input_height//(8*2**(j)), self._input_width//(8*2**(i))]) + _dict[(i+1)*10+(j+1)]=e + + self._layers_range=[_dict[i] for i in sorted(_dict)] + fpn_flag = set(_dict.keys()) == set([11,22,33,44,55]) + + heatmaps = [np.zeros((self._num_class, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] + wh_targets = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] + wh_masks = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] + center_regs = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] + center_reg_masks = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] + for bbox, cid in zip(gt_boxes, gt_ids): + for olayer_idx in layer_map_using_ranges(bbox[2] - bbox[0], bbox[3] - bbox[1], self._layers_range, fpn_flag): + cid = int(cid) + width_ratio = output_sizes[olayer_idx][1] / self._input_width + height_ratio = output_sizes[olayer_idx][0] / self._input_height + xtl, ytl = bbox[0], bbox[1] + xbr, ybr = bbox[2], bbox[3] + fxtl = (xtl * width_ratio) + fytl = (ytl * height_ratio) + fxbr = (xbr * width_ratio) + fybr = (ybr * height_ratio) + + box_h, box_w = fybr - fytl, fxbr - fxtl + if box_h > 0 and box_w > 0: + radius = _gaussian_radius((np.ceil(box_h), np.ceil(box_w))) + radius = max(0, int(radius)) + center = np.array( + [(fxtl + fxbr) / 2 , (fytl + fybr) / 2 ], + dtype=np.float32) + center_int = center.astype(np.int32) + center_x, center_y = center_int + assert center_x < output_sizes[olayer_idx][1], \ + 'center_x: {} > output_width: {}'.format(center_x, output_sizes[olayer_idx][1]) + assert center_y < output_sizes[olayer_idx][0], \ + 'center_y: {} > output_height: {}'.format(center_y, output_sizes[olayer_idx][0]) + _draw_umich_gaussian(heatmaps[olayer_idx][cid], center_int, radius) + wh_targets[olayer_idx][0, center_y, center_x] = box_w + wh_targets[olayer_idx][1, center_y, center_x] = box_h + wh_masks[olayer_idx][:, center_y, center_x] = 1.0 + center_regs[olayer_idx][:, center_y, center_x] = center - center_int + center_reg_masks[olayer_idx][:, center_y, center_x] = 1.0 + heatmaps = [nd.array(heatmap) for heatmap in heatmaps] + wh_targets = [nd.array(wh_target) for wh_target in wh_targets] + wh_masks = [nd.array(wh_mask) for wh_mask in wh_masks] + center_regs = [nd.array(center_reg) for center_reg in center_regs] + center_reg_masks = [nd.array(center_reg_mask) for center_reg_mask in center_reg_masks] + return heatmaps, wh_targets, wh_masks, center_regs, center_reg_masks + + +def _gaussian_radius(det_size, min_overlap=0.7): + """Calculate gaussian radius for foreground objects. + + Parameters + ---------- + det_size : tuple of int + Object size (h, w). + min_overlap : float + Minimal overlap between objects. + + Returns + ------- + float + Gaussian radius. + + """ + height, width = det_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) + r1 = (b1 + sq1) / 2 + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) + r2 = (b2 + sq2) / 2 + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) + r3 = (b3 + sq3) / 2 + return min(r1, r2, r3) + +def _gaussian_2d(shape, sigma=1): + """Generate 2d gaussian. + + Parameters + ---------- + shape : tuple of int + The shape of the gaussian. + sigma : float + Sigma for gaussian. + + Returns + ------- + float + 2D gaussian kernel. + + """ + m, n = [(ss - 1.) / 2. for ss in shape] + y, x = np.ogrid[-m:m+1, -n:n+1] + + h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + return h + +def _draw_umich_gaussian(heatmap, center, radius, k=1): + """Draw a 2D gaussian heatmap. + + Parameters + ---------- + heatmap : numpy.ndarray + Heatmap to be write inplace. + center : tuple of int + Center of object (h, w). + radius : type + The radius of gaussian. + + Returns + ------- + numpy.ndarray + Drawn gaussian heatmap. + + """ + diameter = 2 * radius + 1 + gaussian = _gaussian_2d((diameter, diameter), sigma=diameter / 6) + + x, y = int(center[0]), int(center[1]) + + height, width = heatmap.shape[0:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) + return heatmap diff --git a/gluoncv/model_zoo/model_zoo.py b/gluoncv/model_zoo/model_zoo.py index bc785614ef..d26f0440f1 100644 --- a/gluoncv/model_zoo/model_zoo.py +++ b/gluoncv/model_zoo/model_zoo.py @@ -42,6 +42,7 @@ from .icnet import * from .fastscnn import * from .danet import * +from .matrix_net import * __all__ = ['get_model', 'get_model_list'] @@ -365,6 +366,7 @@ 'hrnet_w18_small_v2_s' : hrnet_w18_small_v2_s, 'hrnet_w48_s' : hrnet_w48_s, 'siamrpn_alexnet_v2_otb15': siamrpn_alexnet_v2_otb15, + 'matrix_net_resnet101_v1d_coco': matrix_net_resnet101_v1d_coco, } diff --git a/gluoncv/nn/coder.py b/gluoncv/nn/coder.py index 3199b1ec7e..4278ba61b3 100644 --- a/gluoncv/nn/coder.py +++ b/gluoncv/nn/coder.py @@ -498,3 +498,72 @@ def hybrid_forward(self, F, x, wh, reg): results = [topk_xs - half_w, topk_ys - half_h, topk_xs + half_w, topk_ys + half_h] results = F.concat(*[tmp.expand_dims(-1) for tmp in results], dim=-1) return topk_classes, scores, results * self._scale + +class MatrixNetDecoder(gluon.HybridBlock): + """Decorder for matrixnet. + + Parameters + ---------- + topk : int + Only keep `topk` results. + scale : float, default is 4.0 + Downsampling scale for the network. + + """ + def __init__(self, topk=100, base_layer_scale=8.0): + super(MatrixNetDecoder, self).__init__() + self._topk = topk + self._base_layer_scale = base_layer_scale + + def hybrid_forward(self, F, keeps, xs, whs, regs): + """Forward of decoder""" + _, _, out_h0, out_w0 = xs[0].shape_array().split(num_outputs=4, axis=0) + results = [] + for i in range(len(xs)): + x = keeps[i] * xs[i] + wh = whs[i] + reg = regs[i] + _, _, out_h, out_w = x.shape_array().split(num_outputs=4, axis=0) + height_scale = out_h0.asscalar() / out_h.asscalar() + width_scale = out_w0.asscalar() / out_w.asscalar() + scores, indices = x.reshape((0, -1)).topk(k=self._topk, ret_typ='both') + indices = F.cast(indices, 'int64') + topk_classes = F.cast(F.broadcast_div(indices, (out_h * out_w)), 'float32') + topk_indices = F.broadcast_mod(indices, (out_h * out_w)) + topk_ys = F.broadcast_div(topk_indices, out_w) + topk_xs = F.broadcast_mod(topk_indices, out_w) + center = reg.transpose((0, 2, 3, 1)).reshape((0, -1, 2)) + wh = wh.transpose((0, 2, 3, 1)).reshape((0, -1, 2)) + batch_indices = F.cast(F.arange(256).slice_like( + center, axes=(0)).expand_dims(-1).tile(reps=(1, self._topk)), 'int64') + reg_xs_indices = F.zeros_like(batch_indices, dtype='int64') + reg_ys_indices = F.ones_like(batch_indices, dtype='int64') + reg_xs = F.concat(batch_indices, topk_indices, reg_xs_indices, dim=0).reshape((3, -1)) + reg_ys = F.concat(batch_indices, topk_indices, reg_ys_indices, dim=0).reshape((3, -1)) + xs = F.cast(F.gather_nd(center, reg_xs).reshape((-1, self._topk)), 'float32') + ys = F.cast(F.gather_nd(center, reg_ys).reshape((-1, self._topk)), 'float32') + topk_xs = F.cast(topk_xs, 'float32') + xs + topk_ys = F.cast(topk_ys, 'float32') + ys + w = F.cast(F.gather_nd(wh, reg_xs).reshape((-1, self._topk)), 'float32') + h = F.cast(F.gather_nd(wh, reg_ys).reshape((-1, self._topk)), 'float32') + half_w = w / 2 + half_h = h / 2 + result = [topk_xs - half_w, topk_ys - half_h, topk_xs + half_w, topk_ys + half_h] + result = F.concat(*[tmp.expand_dims(-1) for tmp in result], dim=-1) + result[:,:,0:4:2] *= width_scale + result[:,:,1:4:2] *= height_scale + result = F.concat(*[topk_classes.expand_dims(-1), scores.expand_dims(-1), result],dim=-1) + results.append(result) + results = F.concat(*results, dim=1) + results[:,:,2:6] *= self._base_layer_scale + + ''' + batch_num = len(results) + batch_indices = F.cast(F.arange(256).slice_like( + results, axes=(0)).expand_dims(-1).tile(reps=(1, 300*6)), 'int64') + topk_indices = F.cast(results[:,:,1].topk(k=300).expand_dims(-1).tile(reps=(1, 1, 6)), 'int64').reshape((0, -1)) + val_indices = F.cast(nd.arange(6).expand_dims(0).tile(reps=(300 * batch_num, 1)), 'int64').reshape((batch_num,-1)) + inds = F.concat(batch_indices, topk_indices, val_indices, dim=0).reshape((3, -1)) + results = F.cast(F.gather_nd(results, inds).reshape((batch_num, 300, -1)), 'float32') + ''' + return results diff --git a/scripts/detection/matrix_net/train_matrix_net.py b/scripts/detection/matrix_net/train_matrix_net.py new file mode 100644 index 0000000000..c6b6fc3df1 --- /dev/null +++ b/scripts/detection/matrix_net/train_matrix_net.py @@ -0,0 +1,307 @@ +"""Train MatrixNet""" +import argparse +import os +import logging +import warnings +import time +import numpy as np +import mxnet as mx +from mxnet import nd +from mxnet import gluon +from mxnet import autograd +import gluoncv as gcv +gcv.utils.check_version('0.6.0') +from gluoncv import data as gdata +from gluoncv import utils as gutils +from gluoncv.model_zoo import get_model +from gluoncv.data.batchify import Tuple, Stack, Pad +from gluoncv.data.transforms.presets.matrix_net import MatrixNetDefaultTrainTransform +from gluoncv.data.transforms.presets.matrix_net import MatrixNetDefaultValTransform, get_post_transform + +from gluoncv.utils.metrics.voc_detection import VOC07MApMetric +from gluoncv.utils.metrics.coco_detection import COCODetectionMetric +from gluoncv.utils.metrics.accuracy import Accuracy +from gluoncv.utils import LRScheduler, LRSequential + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train MatrixNet networks.') + parser.add_argument('--network', type=str, default='resnet18_v1b', + help="Base network name which serves as feature extraction base.") + parser.add_argument('--data-shape', type=int, default=512, + help="Input data shape, use 300, 512.") + parser.add_argument('--batch-size', type=int, default=32, + help='Training mini-batch size') + parser.add_argument('--dataset', type=str, default='voc', + help='Training dataset. Now support voc.') + parser.add_argument('--dataset-root', type=str, default='~/.mxnet/datasets/', + help='Path of the directory where the dataset is located.') + parser.add_argument('--num-workers', '-j', dest='num_workers', type=int, + default=4, help='Number of data workers, you can use larger ' + 'number to accelerate data loading, if you CPU and GPUs are powerful.') + parser.add_argument('--gpus', type=str, default='0', + help='Training with GPUs, you can specify 1,3 for example.') + parser.add_argument('--epochs', type=int, default=140, + help='Training epochs.') + parser.add_argument('--resume', type=str, default='', + help='Resume from previously saved parameters if not None. ' + 'For example, you can resume from ./ssd_xxx_0123.params') + parser.add_argument('--start-epoch', type=int, default=0, + help='Starting epoch for resuming, default is 0 for new training.' + 'You can specify it to 100 for example to start from 100 epoch.') + parser.add_argument('--lr', type=float, default=1.25e-4, + help='Learning rate, default is 0.000125') + parser.add_argument('--lr-decay', type=float, default=0.1, + help='decay rate of learning rate. default is 0.1.') + parser.add_argument('--lr-decay-epoch', type=str, default='90,120', + help='epochs at which learning rate decays. default is 90,120.') + parser.add_argument('--lr-mode', type=str, default='step', + help='learning rate scheduler mode. options are step, poly and cosine.') + parser.add_argument('--warmup-lr', type=float, default=0.0, + help='starting warmup learning rate. default is 0.0.') + parser.add_argument('--warmup-epochs', type=int, default=0, + help='number of warmup epochs.') + parser.add_argument('--momentum', type=float, default=0.9, + help='SGD momentum, default is 0.9') + parser.add_argument('--wd', type=float, default=0.0001, + help='Weight decay, default is 1e-4') + parser.add_argument('--log-interval', type=int, default=100, + help='Logging mini-batch interval. Default is 100.') + parser.add_argument('--num-samples', type=int, default=-1, + help='Training images. Use -1 to automatically get the number.') + parser.add_argument('--save-prefix', type=str, default='', + help='Saving parameter prefix') + parser.add_argument('--save-interval', type=int, default=10, + help='Saving parameters epoch interval, best model will always be saved.') + parser.add_argument('--val-interval', type=int, default=1, + help='Epoch interval for validation, increase the number will reduce the ' + 'training time if validation is slow.') + parser.add_argument('--seed', type=int, default=233, + help='Random seed to be fixed.') + parser.add_argument('--wh-weight', type=float, default=0.1, + help='Loss weight for width/height') + parser.add_argument('--center-reg-weight', type=float, default=1.0, + help='Center regression loss weight') + parser.add_argument('--flip-validation', action='store_true', + help='flip data augmentation in validation.') + + args = parser.parse_args() + return args + +def get_dataset(dataset, args): + if dataset.lower() == 'voc': + train_dataset = gdata.VOCDetection( + splits=[(2007, 'trainval'), (2012, 'trainval')]) + val_dataset = gdata.VOCDetection( + splits=[(2007, 'test')]) + val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes) + elif dataset.lower() == 'coco': + train_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_train2017') + val_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_val2017', skip_empty=False) + val_metric = COCODetectionMetric( + val_dataset, args.save_prefix + '_eval', cleanup=True, + data_shape=(args.data_shape, args.data_shape), post_affine=get_post_transform) + # coco validation is slow, consider increase the validation interval + if args.val_interval == 1: + args.val_interval = 10 + else: + raise NotImplementedError('Dataset: {} not implemented.'.format(dataset)) + if args.num_samples < 0: + args.num_samples = len(train_dataset) + return train_dataset, val_dataset, val_metric + +def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size, num_workers, ctx): + """Get dataloader.""" + width, height = data_shape, data_shape + num_class = len(train_dataset.classes) + batchify_fn = Tuple([Stack() for _ in range(96)]) # stack image, cls_targets, box_targets + train_loader = gluon.data.DataLoader( + train_dataset.transform(MatrixNetDefaultTrainTransform( + width, height, num_class=num_class, layers_range=net.layers_range)), + batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers) + val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1)) + val_loader = gluon.data.DataLoader( + val_dataset.transform(MatrixNetDefaultValTransform(width, height)), + batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=num_workers) + return train_loader, val_loader + +def save_params(net, best_map, current_map, epoch, save_interval, prefix): + current_map = float(current_map) + if current_map > best_map[0]: + best_map[0] = current_map + net.save_parameters('{:s}_best.params'.format(prefix, epoch, current_map)) + with open(prefix+'_best_map.log', 'a') as f: + f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map)) + if save_interval and epoch % save_interval == 0: + net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map)) + +def validate(net, val_data, ctx, eval_metric, flip_test=False): + """Test on validation dataset.""" + eval_metric.reset() + net.flip_test = flip_test + mx.nd.waitall() + net.hybridize() + for batch in val_data: + data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False) + label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False) + det_bboxes = [] + det_ids = [] + det_scores = [] + gt_bboxes = [] + gt_ids = [] + gt_difficults = [] + for x, y in zip(data, label): + # get prediction results + ids, scores, bboxes = net(x) + det_ids.append(ids) + det_scores.append(scores) + # clip to image size + det_bboxes.append(bboxes.clip(0, batch[0].shape[2])) + # split ground truths + gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5)) + gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4)) + gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None) + + # update metric + eval_metric.update(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults) + return eval_metric.get() + +def train(net, train_data, val_data, eval_metric, ctx, args): + """Training pipeline""" + net.collect_params().reset_ctx(ctx) + # lr decay policy + lr_decay = float(args.lr_decay) + lr_steps = sorted([int(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) + lr_decay_epoch = [e - args.warmup_epochs for e in lr_steps] + num_batches = args.num_samples // args.batch_size + lr_scheduler = LRSequential([ + LRScheduler('linear', base_lr=0, target_lr=args.lr, + nepochs=args.warmup_epochs, iters_per_epoch=num_batches), + LRScheduler(args.lr_mode, base_lr=args.lr, + nepochs=args.epochs - args.warmup_epochs, + iters_per_epoch=num_batches, + step_epoch=lr_decay_epoch, + step_factor=args.lr_decay, power=2), + ]) + + for k, v in net.collect_params('.*bias').items(): + v.wd_mult = 0.0 + trainer = gluon.Trainer( + net.collect_params(), 'adam', + {'learning_rate': args.lr, 'wd': args.wd, + 'lr_scheduler': lr_scheduler}) + + heatmap_loss = gcv.loss.HeatmapFocalLoss(from_logits=True) + wh_loss = gcv.loss.MaskedL1Loss(weight=args.wh_weight) + center_reg_loss = gcv.loss.MaskedL1Loss(weight=args.center_reg_weight) + heatmap_loss_metric = mx.metric.Loss('HeatmapFocal') + wh_metric = mx.metric.Loss('WHL1') + center_reg_metric = mx.metric.Loss('CenterRegL1') + + # set up logger + logging.basicConfig() + logger = logging.getLogger() + logger.setLevel(logging.INFO) + log_file_path = args.save_prefix + '_train.log' + log_dir = os.path.dirname(log_file_path) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir) + fh = logging.FileHandler(log_file_path) + logger.addHandler(fh) + logger.info(args) + logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) + best_map = [0] + + for epoch in range(args.start_epoch, args.epochs): + wh_metric.reset() + center_reg_metric.reset() + tic = time.time() + btic = time.time() + net.hybridize() + + for i, batch in enumerate(train_data): + split_data = [gluon.utils.split_and_load(batch[ind], ctx_list=ctx, batch_axis=0) for ind in range(96)] + inter_num = len(split_data[0]) + batch_size = args.batch_size + with autograd.record(): + sum_losses = [] + mid_layers_num = 19 + heatmap_losses = [0 for _ in range(inter_num)] + wh_losses = [0 for _ in range(inter_num)] + center_reg_losses = [0 for _ in range(inter_num)] + wh_preds = [[] for _ in range(mid_layers_num)] + center_reg_preds = [[] for _ in range(mid_layers_num)] + for ind in range(inter_num): + + heatmap_pred, wh_pred, center_reg_pred = net(split_data[0][ind]) + for ii in range(mid_layers_num): + wh_preds[ii].append(wh_pred[ii]) + center_reg_preds[ii].append(center_reg_pred[ii]) + wh_losses[ind] += wh_loss(wh_pred[ii], split_data[1+mid_layers_num+ii][ind],\ + split_data[1+mid_layers_num*2+ii][ind]) + center_reg_losses[ind] += center_reg_loss(center_reg_pred[ii], split_data[1+mid_layers_num*3+ii][ind],\ + split_data[1+mid_layers_num*4+ii][ind]) + heatmap_losses[ind] += heatmap_loss(heatmap_pred[ii], split_data[1+ii][ind]) + + sum_losses = [heatmap_losses[ii]+wh_losses[ii]+center_reg_losses[ii] for ii in range(inter_num)] + autograd.backward(sum_losses) + trainer.step(len(sum_losses)) # step with # gpus + + heatmap_loss_metric.update(0, heatmap_losses) + wh_metric.update(0, wh_losses) + center_reg_metric.update(0, center_reg_losses) + if args.log_interval and not (i + 1) % args.log_interval: + name2, loss2 = wh_metric.get() + name3, loss3 = center_reg_metric.get() + name4, loss4 = heatmap_loss_metric.get() + logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, LR={}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format( + epoch, i, batch_size/(time.time()-btic), trainer.learning_rate, name2, loss2, name3, loss3, name4, loss4)) + btic = time.time() + + name2, loss2 = wh_metric.get() + name3, loss3 = center_reg_metric.get() + name4, loss4 = heatmap_loss_metric.get() + logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format( + epoch, (time.time()-tic), name2, loss2, name3, loss3, name4, loss4)) + if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0) or (epoch == args.epochs - 1): + # consider reduce the frequency of validation to save time + map_name, mean_ap = validate(net, val_data, ctx, eval_metric, flip_test=args.flip_validation) + val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) + logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg)) + current_map = float(mean_ap[-1]) + else: + current_map = 0. + save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix) + +if __name__ == '__main__': + args = parse_args() + + # fix seed for mxnet, numpy and python builtin random generator. + gutils.random.seed(args.seed) + + # training contexts + ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] + ctx = ctx if ctx else [mx.cpu()] + + # network + net_name = '_'.join(('matrix_net', args.network, args.dataset)) + args.save_prefix += net_name + net = get_model(net_name, pretrained_base=True, norm_layer=gluon.nn.BatchNorm) + if args.resume.strip(): + net.load_parameters(args.resume.strip()) + else: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + net.initialize() + # needed for net to be first gpu when using AMP + net.collect_params().reset_ctx(ctx[0]) + + # training data + train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args) + batch_size = args.batch_size + train_data, val_data = get_dataloader( + net, train_dataset, val_dataset, args.data_shape, batch_size, args.num_workers, ctx[0]) + + + # training + train(net, train_data, val_data, eval_metric, ctx, args) From a376d0cacde04081a8572ea7ae1d992e9aaacd5e Mon Sep 17 00:00:00 2001 From: liueo Date: Tue, 26 May 2020 10:08:25 +0000 Subject: [PATCH 2/7] matrix net --- gluoncv/data/transforms/presets/__init__.py | 1 + gluoncv/model_zoo/matrix_net/__init__.py | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 gluoncv/model_zoo/matrix_net/__init__.py diff --git a/gluoncv/data/transforms/presets/__init__.py b/gluoncv/data/transforms/presets/__init__.py index c3ff9943f5..1a1ad668ca 100644 --- a/gluoncv/data/transforms/presets/__init__.py +++ b/gluoncv/data/transforms/presets/__init__.py @@ -6,3 +6,4 @@ from . import imagenet from . import simple_pose from . import segmentation +from . import matrix_net diff --git a/gluoncv/model_zoo/matrix_net/__init__.py b/gluoncv/model_zoo/matrix_net/__init__.py new file mode 100644 index 0000000000..551b081eba --- /dev/null +++ b/gluoncv/model_zoo/matrix_net/__init__.py @@ -0,0 +1,5 @@ +"""MatrixNet""" +# pylint: disable=wildcard-import +from __future__ import absolute_import + +from .matrix_net import * From 06a687924c1e83407f0e9f4ab24b8cc094b63e69 Mon Sep 17 00:00:00 2001 From: liueo Date: Tue, 26 May 2020 10:18:24 +0000 Subject: [PATCH 3/7] matrix_net --- gluoncv/model_zoo/matrix_net/matrix_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluoncv/model_zoo/matrix_net/matrix_net.py b/gluoncv/model_zoo/matrix_net/matrix_net.py index 5cad55bfa6..3295935be0 100644 --- a/gluoncv/model_zoo/matrix_net/matrix_net.py +++ b/gluoncv/model_zoo/matrix_net/matrix_net.py @@ -323,7 +323,7 @@ def matrix_net_resnet101_v1d_coco(pretrained=False, pretrained_base=True, **kwar [[0,48,96,192],[48,96,96,192],[96,192,96,192],[192,384,96,192],[384,2000,96,192]], [-1, [0,96,192,384],[96,192,192,384],[192,384,192,384],[384,2000,192,384]], [-1, -1, [0,192,384,2000],[192,384,384,2000],[384,2000,384,2000]]] - return get_matrix_net('resnet101_v1d', 'coco', base_network=features, heads=heads, layers_range = layers_range + return get_matrix_net('resnet101_v1d', 'coco', base_network=features, heads=heads, layers_range = layers_range, head_conv_channel=64, pretrained=pretrained, classes=classes, base_layer_scale=8.0, topk=100, **kwargs) From d67621ddb72780b253e65170322a0bc71b7d55db Mon Sep 17 00:00:00 2001 From: liueo Date: Tue, 26 May 2020 10:31:08 +0000 Subject: [PATCH 4/7] matrix net --- gluoncv/model_zoo/matrix_net/matrix_net.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gluoncv/model_zoo/matrix_net/matrix_net.py b/gluoncv/model_zoo/matrix_net/matrix_net.py index 3295935be0..ad2afa96b6 100644 --- a/gluoncv/model_zoo/matrix_net/matrix_net.py +++ b/gluoncv/model_zoo/matrix_net/matrix_net.py @@ -302,8 +302,8 @@ def matrix_net_resnet101_v1d_coco(pretrained=False, pretrained_base=True, **kwar A CenterNet detection network. """ - from ....model_zoo.resnetv1b import resnet101_v1d - from ....data import COCODetection + from ...model_zoo.resnetv1b import resnet101_v1d + from ...data import COCODetection classes = COCODetection.CLASSES pretrained_base = False if pretrained else pretrained_base base_network = resnet101_v1d(pretrained=pretrained_base, dilated=False, From 1208890dbe58df1c69af330651d8f198f235ee72 Mon Sep 17 00:00:00 2001 From: liueo Date: Tue, 26 May 2020 10:36:45 +0000 Subject: [PATCH 5/7] matrix net --- gluoncv/model_zoo/matrix_net/matrix_net.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gluoncv/model_zoo/matrix_net/matrix_net.py b/gluoncv/model_zoo/matrix_net/matrix_net.py index ad2afa96b6..15927201ce 100644 --- a/gluoncv/model_zoo/matrix_net/matrix_net.py +++ b/gluoncv/model_zoo/matrix_net/matrix_net.py @@ -9,6 +9,7 @@ from mxnet.gluon import nn from mxnet import autograd from ...nn.coder import MatrixNetDecoder +from ...nn.feature import FPNFeatureExpander __all__ = ['MatrixNet', 'get_matrix_net','matrix_net_resnet101_v1d_coco'] From bf4ac7543e4b3f057ae3231eac7d4ed2cbafb1df Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 1 Jun 2020 17:56:34 +0000 Subject: [PATCH 6/7] update --- gluoncv/data/transforms/presets/matrix_net.py | 2 + gluoncv/model_zoo/matrix_net/matrix_net.py | 79 ++++++++++--------- .../model_zoo/matrix_net/target_generator.py | 77 ++++++++++++------ gluoncv/nn/coder.py | 42 +++++----- .../detection/matrix_net/train_matrix_net.py | 34 +++++--- 5 files changed, 137 insertions(+), 97 deletions(-) diff --git a/gluoncv/data/transforms/presets/matrix_net.py b/gluoncv/data/transforms/presets/matrix_net.py index 6bafdf4cba..3364818c75 100644 --- a/gluoncv/data/transforms/presets/matrix_net.py +++ b/gluoncv/data/transforms/presets/matrix_net.py @@ -22,6 +22,8 @@ class MatrixNetDefaultTrainTransform(object): Image height. num_class : int Number of categories + layers_range : list of list of number(list of number) + Represents the same meaning as that of MatrixNet scale_factor : int, default is 4 The downsampling scale factor between input image and output heatmap mean : array-like of size 3 diff --git a/gluoncv/model_zoo/matrix_net/matrix_net.py b/gluoncv/model_zoo/matrix_net/matrix_net.py index 15927201ce..443efa035d 100644 --- a/gluoncv/model_zoo/matrix_net/matrix_net.py +++ b/gluoncv/model_zoo/matrix_net/matrix_net.py @@ -1,4 +1,5 @@ -"""MatrixNet object detector: https://arxiv.org/abs/2001.03194""" +"""MatrixNet object detector: using matrix layers to extract features and heads of CenterNet to predict + matrix layers: https://arxiv.org/abs/2001.03194 and its code on github""" from __future__ import absolute_import import os @@ -18,8 +19,9 @@ class MatrixNet(nn.HybridBlock): Parameters ---------- - base_network : mxnet.gluon.nn.HybridBlock + base_network : mxnet.gluon.nn.SymbolBlock The base feature extraction network. + Currently just using pre-defined resnet101_v1d_fpn heads : OrderedDict OrderedDict with specifications for each head. For example: OrderedDict([ @@ -29,20 +31,25 @@ class MatrixNet(nn.HybridBlock): ]) classes : list of str Category names. + layers_range : list of list of number(list of number) + Denotes the size of objects assigned to each matrix layer + layers_range is a 5 * 5 matrix, where each element is -1 or a list of 4 numbers + -1 denotes this layer is pruned, a list of 4 numbers is min height, max height, + min width, max width of the objects respectively. head_conv_channel : int, default is 0 If > 0, will use an extra conv layer before each of the real heads. - scale : float, default is 4.0 - The downsampling ratio of the entire network. + base_layer_scale : float, default is 4.0 + The downsampling ratio of the first (top-left) layer in the matrix. topk : int, default is 100 Number of outputs . flip_test : bool Whether apply flip test in inference (training mode not affected). - nms_thresh : float, default is 0. + nms_thresh : float, default is 0.5. Non-maximum suppression threshold. You can specify < 0 or > 1 to disable NMS. - By default nms is disabled. - nms_topk : int, default is 400 + nms_topk : int, default is 300 Apply NMS to top k detection results, use -1 to disable so that every Detection result is used in NMS. + Choose the default value according to the code of matrixnets. post_nms : int, default is 100 Only return top `post_nms` detection results, the rest is discarded. The number is based on COCO dataset which has maximum 100 objects per image. You can adjust this @@ -50,7 +57,7 @@ class MatrixNet(nn.HybridBlock): """ def __init__(self, base_network, heads, classes, layers_range, - head_conv_channel=0, base_layer_scale=8.0, topk=100, flip_test=False, + head_conv_channel=0, base_layer_scale=4.0, topk=100, flip_test=False, nms_thresh=0.5, nms_topk=300, post_nms=100, **kwargs): if 'norm_layer' in kwargs: kwargs.pop('norm_layer') @@ -73,9 +80,7 @@ def __init__(self, base_network, heads, classes, layers_range, self.base_network = base_network self.heatmap_nms = nn.MaxPool2D(pool_size=3, strides=1, padding=1) weight_initializer = mx.init.Normal(0.01) - self.pyramid_transformation_7 = nn.Conv2D( - 256, kernel_size=3, padding=1, strides=2, use_bias=True, - weight_initializer=weight_initializer, bias_initializer='zeros') + # the following two layers are used to generate the off-diagonal layers' features from diagonal layers' features self.downsample_transformation_12 = nn.Conv2D( 256, kernel_size=3, padding=1, strides=(1,2), use_bias=True, weight_initializer=weight_initializer, bias_initializer='zeros') @@ -83,6 +88,7 @@ def __init__(self, base_network, heads, classes, layers_range, 256, kernel_size=3, padding=1, strides=(2,1), use_bias=True, weight_initializer=weight_initializer, bias_initializer='zeros') self.decoder = MatrixNetDecoder(topk=topk, base_layer_scale=base_layer_scale) + # using heads of CenterNet( Objects as Point ) self.heads = nn.HybridSequential('heads') for name, values in heads.items(): head = nn.HybridSequential(name) @@ -152,31 +158,22 @@ def reset_class(self, classes, reuse_weights=None): This allows the new predictor to reuse the previously trained weights specified. - Example - ------- - >>> net = gluoncv.model_zoo.get_model('center_net_resnet50_v1b_voc', pretrained=True) - >>> # use direct name to name mapping to reuse weights - >>> net.reset_class(classes=['person'], reuse_weights={'person':'person'}) - >>> # or use interger mapping, person is the 14th category in VOC - >>> net.reset_class(classes=['person'], reuse_weights={0:14}) - >>> # you can even mix them - >>> net.reset_class(classes=['person'], reuse_weights={'person':14}) - >>> # or use a list of string if class name don't change - >>> net.reset_class(classes=['person'], reuse_weights=['person']) - """ raise NotImplementedError("Not yet implemented, please wait for future updates.") def hybrid_forward(self, F, x): # pylint: disable=arguments-differ """Hybrid forward of matrixnet""" + # following lines computes the features of 19 matrix layers + # 5 diagonal features are FPN outputs, others are computed from the diagonal features + # this part can be impoved by modifying the code of FPNFeatureExpander feature_2, feature_3, feature_4, feature_5, feature_6 = self.base_network(x) _dict = {} - _dict[11] = feature_3 - _dict[22] = feature_4 - _dict[33] = feature_5 - _dict[44] = feature_6 - _dict[55] = self.pyramid_transformation_7(_dict[44]) + _dict[11] = feature_2 + _dict[22] = feature_3 + _dict[33] = feature_4 + _dict[44] = feature_5 + _dict[55] = feature_6 _dict[12] = self.downsample_transformation_21(_dict[11]) _dict[13] = self.downsample_transformation_21(_dict[12]) _dict[23] = self.downsample_transformation_21(_dict[22]) @@ -191,6 +188,8 @@ def hybrid_forward(self, F, x): _dict[43] = self.downsample_transformation_12(_dict[33]) _dict[53] = self.downsample_transformation_12(_dict[43]) _dict[54] = self.downsample_transformation_12(_dict[44]) + + # run the shared heads on the 19 features ys = [ _dict[i] for i in sorted(_dict)] heatmaps = [self.heads[0](y) for y in ys] wh_preds = [self.heads[1](y) for y in ys] @@ -199,14 +198,16 @@ def hybrid_forward(self, F, x): if autograd.is_training(): heatmaps = [F.clip(heatmap, 1e-4, 1 - 1e-4) for heatmap in heatmaps] return heatmaps, wh_preds, center_regrs + print('whether flip_test: {}'.format(self.flip_test)) if self.flip_test: + # some duplicate code, can be optimized by modifying the code of FPNFeatureExpander. feature_2_flip, feature_3_flip, feature_4_flip, feature_5_flip, feature_6_flip = self.base_network(x.flip(axis=3)) _dict_flip = {} - _dict_flip[11] = feature_3_flip - _dict_flip[22] = feature_4_flip - _dict_flip[33] = feature_5_flip - _dict_flip[44] = feature_6_flip - _dict_flip[55] = self.pyramid_transformation_7(_dict_flip[44]) + _dict_flip[11] = feature_2_flip + _dict_flip[22] = feature_3_flip + _dict_flip[33] = feature_4_flip + _dict_flip[44] = feature_5_flip + _dict_flip[55] = feature_6_flip _dict_flip[12] = self.downsample_transformation_21(_dict_flip[11]) _dict_flip[13] = self.downsample_transformation_21(_dict_flip[12]) _dict_flip[23] = self.downsample_transformation_21(_dict_flip[22]) @@ -233,6 +234,7 @@ def hybrid_forward(self, F, x): keeps = [F.broadcast_equal(self.heatmap_nms(heatmap), heatmap) for heatmap in heatmaps] results = self.decoder(keeps, heatmaps, wh_preds, center_regrs) + #since the 19 matrix layers may generate duplicate results, add soft-nms for post-processing if self.nms_thresh > 0 and self.nms_thresh < 1: results = F.contrib.box_nms( results, overlap_thresh=self.nms_thresh, topk=self.nms_topk, valid_thresh=0.01, @@ -288,7 +290,7 @@ def get_matrix_net(name, dataset, pretrained=False, ctx=mx.cpu(), return net def matrix_net_resnet101_v1d_coco(pretrained=False, pretrained_base=True, **kwargs): - """Matrix net with resnet101_v1d base network on coco dataset. + """MatrixNet with resnet101_v1d base network on coco dataset. Parameters ---------- @@ -300,7 +302,7 @@ def matrix_net_resnet101_v1d_coco(pretrained=False, pretrained_base=True, **kwar Returns ------- HybridBlock - A CenterNet detection network. + A MatrixNet detection network. """ from ...model_zoo.resnetv1b import resnet101_v1d @@ -319,6 +321,11 @@ def matrix_net_resnet101_v1d_coco(pretrained=False, pretrained_base=True, **kwar ('wh', {'num_output': 2}), ('reg', {'num_output': 2}) ]) + # according to the reference code of the paper(https://arxiv.org/abs/2001.03194), there can be up to 25 matrix layers + # layers_range is the configuration containing 5*5 elements. + # Each element can be -1 (meaning this layer is cut, so this position is empty, no need to generate features + # Or the element can be a list of 4 numbers, standing for min_height, max_height, min_width, max_width of the objects assigned + # as this layers' traing target layers_range = [[[0,48,0,48],[48,96,0,48],[96,192,0,48], -1, -1], [[0,48,48,96],[48,96,48,96],[96,192,48,96],[192,384,0,96], -1], [[0,48,96,192],[48,96,96,192],[96,192,96,192],[192,384,96,192],[384,2000,96,192]], @@ -326,5 +333,5 @@ def matrix_net_resnet101_v1d_coco(pretrained=False, pretrained_base=True, **kwar [-1, -1, [0,192,384,2000],[192,384,384,2000],[384,2000,384,2000]]] return get_matrix_net('resnet101_v1d', 'coco', base_network=features, heads=heads, layers_range = layers_range, head_conv_channel=64, pretrained=pretrained, classes=classes, - base_layer_scale=8.0, topk=100, **kwargs) + base_layer_scale=4.0, topk=100, **kwargs) diff --git a/gluoncv/model_zoo/matrix_net/target_generator.py b/gluoncv/model_zoo/matrix_net/target_generator.py index be29fa4103..e794814304 100644 --- a/gluoncv/model_zoo/matrix_net/target_generator.py +++ b/gluoncv/model_zoo/matrix_net/target_generator.py @@ -7,10 +7,29 @@ from mxnet import gluon def layer_map_using_ranges(width, height, layer_ranges, fpn_flag=0): - layers = [] - + """Map each object to some of 19 matrix layers according to the object's height and width + + Parameters + ---------- + width : int + Width of the object. + height : int + Height of the object + layer_ranges : list of list of list of number + The range of object size correspond to each matrix layer. + fpn_flag : bool + Whether the matrix layers only contain diagonal layers + + Returns + ------- + list of number + index of layer(s) to which the object is mapped + """ + layers = [] for i, layer_range in enumerate(layer_ranges): if fpn_flag ==0: + if type(layer_range) != list: + print('index: {} is of type {}, value {}'.format(i, type(layer_range), layer_range)) if (width >= 0.8 * layer_range[2]) and (width <= 1.3 * layer_range[3]) and (height >= 0.8 * layer_range[0]) and (height <= 1.3 * layer_range[1]): layers.append(i) else: @@ -30,10 +49,12 @@ class MatrixNetTargetGenerator(gluon.Block): ---------- num_class : int Number of categories. - output_width : int - Width of the network output. - output_height : int - Height of the network output. + input_width : int + Width of the network input. + input_height : int + Height of the networl input. + layers_range : list of list of number(list of number) + Represents the same meaning as that of MatrixNet """ def __init__(self, num_class, input_width, input_height, layers_range): @@ -42,32 +63,37 @@ def __init__(self, num_class, input_width, input_height, layers_range): self._input_width = int(input_width) self._input_height = int(input_height) self._layers_range = layers_range - - def forward(self, gt_boxes, gt_ids): - """Target generation""" - # pylint: disable=arguments-differ + #output_sizes is a list containing feature maps' height and width of matrix layers + #_dict is used to clear -1 in layers_range _dict={} output_sizes=[] # indexing layer map - for i,l in enumerate(self._layers_range): + for i,l in enumerate(layers_range): for j,e in enumerate(l): if e !=-1: - output_sizes.append([self._input_height//(8*2**(j)), self._input_width//(8*2**(i))]) + output_sizes.append([self._input_height//(4*2**(j)), self._input_width//(4*2**(i))]) _dict[(i+1)*10+(j+1)]=e self._layers_range=[_dict[i] for i in sorted(_dict)] - fpn_flag = set(_dict.keys()) == set([11,22,33,44,55]) + self._fpn_flag = set(_dict.keys()) == set([11,22,33,44,55]) + self._output_sizes = output_sizes - heatmaps = [np.zeros((self._num_class, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] - wh_targets = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] - wh_masks = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] - center_regs = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] - center_reg_masks = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in output_sizes] + + def forward(self, gt_boxes, gt_ids): + """Target generation""" + # pylint: disable=arguments-differ + # the following five variables are all lists containing #(matrix layers) np.array + heatmaps = [np.zeros((self._num_class, output_size[0], output_size[1]), dtype=np.float32) for output_size in self._output_sizes] + wh_targets = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in self._output_sizes] + wh_masks = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in self._output_sizes] + center_regs = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in self._output_sizes] + center_reg_masks = [np.zeros((2, output_size[0], output_size[1]), dtype=np.float32) for output_size in self._output_sizes] for bbox, cid in zip(gt_boxes, gt_ids): - for olayer_idx in layer_map_using_ranges(bbox[2] - bbox[0], bbox[3] - bbox[1], self._layers_range, fpn_flag): + for olayer_idx in layer_map_using_ranges(bbox[2] - bbox[0], bbox[3] - bbox[1], self._layers_range, self._fpn_flag): cid = int(cid) - width_ratio = output_sizes[olayer_idx][1] / self._input_width - height_ratio = output_sizes[olayer_idx][0] / self._input_height + # the following two ratios are used to adjust the coordinates of bounding box according to the size of each layer's feature map + width_ratio = self._output_sizes[olayer_idx][1] / self._input_width + height_ratio = self._output_sizes[olayer_idx][0] / self._input_height xtl, ytl = bbox[0], bbox[1] xbr, ybr = bbox[2], bbox[3] fxtl = (xtl * width_ratio) @@ -84,16 +110,17 @@ def forward(self, gt_boxes, gt_ids): dtype=np.float32) center_int = center.astype(np.int32) center_x, center_y = center_int - assert center_x < output_sizes[olayer_idx][1], \ - 'center_x: {} > output_width: {}'.format(center_x, output_sizes[olayer_idx][1]) - assert center_y < output_sizes[olayer_idx][0], \ - 'center_y: {} > output_height: {}'.format(center_y, output_sizes[olayer_idx][0]) + assert center_x < self._output_sizes[olayer_idx][1], \ + 'center_x: {} > output_width: {}'.format(center_x, self._output_sizes[olayer_idx][1]) + assert center_y < self._output_sizes[olayer_idx][0], \ + 'center_y: {} > output_height: {}'.format(center_y, self._output_sizes[olayer_idx][0]) _draw_umich_gaussian(heatmaps[olayer_idx][cid], center_int, radius) wh_targets[olayer_idx][0, center_y, center_x] = box_w wh_targets[olayer_idx][1, center_y, center_x] = box_h wh_masks[olayer_idx][:, center_y, center_x] = 1.0 center_regs[olayer_idx][:, center_y, center_x] = center - center_int center_reg_masks[olayer_idx][:, center_y, center_x] = 1.0 + heatmaps = [nd.array(heatmap) for heatmap in heatmaps] wh_targets = [nd.array(wh_target) for wh_target in wh_targets] wh_masks = [nd.array(wh_mask) for wh_mask in wh_masks] diff --git a/gluoncv/nn/coder.py b/gluoncv/nn/coder.py index 4278ba61b3..1ec22be4fc 100644 --- a/gluoncv/nn/coder.py +++ b/gluoncv/nn/coder.py @@ -506,26 +506,27 @@ class MatrixNetDecoder(gluon.HybridBlock): ---------- topk : int Only keep `topk` results. - scale : float, default is 4.0 - Downsampling scale for the network. + base_layer_scale : float, default is 4.0 + The downsampling ratio of the first (top-left) layer in the matrix. """ - def __init__(self, topk=100, base_layer_scale=8.0): + def __init__(self, topk=100, base_layer_scale=4.0): super(MatrixNetDecoder, self).__init__() self._topk = topk self._base_layer_scale = base_layer_scale - def hybrid_forward(self, F, keeps, xs, whs, regs): + def hybrid_forward(self, F, keeps, in_xs, whs, regs): """Forward of decoder""" - _, _, out_h0, out_w0 = xs[0].shape_array().split(num_outputs=4, axis=0) + #keeps, in_xs, whs, regs are all lists containing #(matrix layer) NDArrays + _, _, out_h0, out_w0 = in_xs[0].shape_array().split(num_outputs=4, axis=0) results = [] - for i in range(len(xs)): - x = keeps[i] * xs[i] + for i in range(len(in_xs)): + x = keeps[i] * in_xs[i] wh = whs[i] reg = regs[i] _, _, out_h, out_w = x.shape_array().split(num_outputs=4, axis=0) - height_scale = out_h0.asscalar() / out_h.asscalar() - width_scale = out_w0.asscalar() / out_w.asscalar() + height_scale = out_h0 / out_h + width_scale = out_w0 / out_w scores, indices = x.reshape((0, -1)).topk(k=self._topk, ret_typ='both') indices = F.cast(indices, 'int64') topk_classes = F.cast(F.broadcast_div(indices, (out_h * out_w)), 'float32') @@ -548,22 +549,17 @@ def hybrid_forward(self, F, keeps, xs, whs, regs): h = F.cast(F.gather_nd(wh, reg_ys).reshape((-1, self._topk)), 'float32') half_w = w / 2 half_h = h / 2 - result = [topk_xs - half_w, topk_ys - half_h, topk_xs + half_w, topk_ys + half_h] + result = [] + # adjust to the size of first (top-left) layer firstly + result.append(F.broadcast_mul((topk_xs - half_w), F.cast(width_scale, 'float32'))) + result.append(F.broadcast_mul((topk_ys - half_h), F.cast(height_scale, 'float32'))) + result.append(F.broadcast_mul((topk_xs + half_w), F.cast(width_scale, 'float32'))) + result.append(F.broadcast_mul((topk_ys + half_h), F.cast(height_scale, 'float32'))) result = F.concat(*[tmp.expand_dims(-1) for tmp in result], dim=-1) - result[:,:,0:4:2] *= width_scale - result[:,:,1:4:2] *= height_scale + # adjust to the size of original input + result = result * self._base_layer_scale result = F.concat(*[topk_classes.expand_dims(-1), scores.expand_dims(-1), result],dim=-1) results.append(result) results = F.concat(*results, dim=1) - results[:,:,2:6] *= self._base_layer_scale - ''' - batch_num = len(results) - batch_indices = F.cast(F.arange(256).slice_like( - results, axes=(0)).expand_dims(-1).tile(reps=(1, 300*6)), 'int64') - topk_indices = F.cast(results[:,:,1].topk(k=300).expand_dims(-1).tile(reps=(1, 1, 6)), 'int64').reshape((0, -1)) - val_indices = F.cast(nd.arange(6).expand_dims(0).tile(reps=(300 * batch_num, 1)), 'int64').reshape((batch_num,-1)) - inds = F.concat(batch_indices, topk_indices, val_indices, dim=0).reshape((3, -1)) - results = F.cast(F.gather_nd(results, inds).reshape((batch_num, 300, -1)), 'float32') - ''' - return results + return results \ No newline at end of file diff --git a/scripts/detection/matrix_net/train_matrix_net.py b/scripts/detection/matrix_net/train_matrix_net.py index c6b6fc3df1..7d59d06dde 100644 --- a/scripts/detection/matrix_net/train_matrix_net.py +++ b/scripts/detection/matrix_net/train_matrix_net.py @@ -114,7 +114,8 @@ def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size, num_ """Get dataloader.""" width, height = data_shape, data_shape num_class = len(train_dataset.classes) - batchify_fn = Tuple([Stack() for _ in range(96)]) # stack image, cls_targets, box_targets + # 96 = 1(image) + 19 (number of matrix layers) * 5 (heatmap, wh_target, wh_mask, offset_target, offset_mask) + batchify_fn = Tuple([Stack() for _ in range(96)]) train_loader = gluon.data.DataLoader( train_dataset.transform(MatrixNetDefaultTrainTransform( width, height, num_class=num_class, layers_range=net.layers_range)), @@ -221,28 +222,35 @@ def train(net, train_data, val_data, eval_metric, ctx, args): for i, batch in enumerate(train_data): split_data = [gluon.utils.split_and_load(batch[ind], ctx_list=ctx, batch_axis=0) for ind in range(96)] - inter_num = len(split_data[0]) + inter_num = len(split_data[0]) # number of gpus batch_size = args.batch_size with autograd.record(): sum_losses = [] - mid_layers_num = 19 - heatmap_losses = [0 for _ in range(inter_num)] - wh_losses = [0 for _ in range(inter_num)] - center_reg_losses = [0 for _ in range(inter_num)] + mid_layers_num = 19 # number of matrix layers + heatmap_losses = [] + wh_losses = [] + center_reg_losses = [] wh_preds = [[] for _ in range(mid_layers_num)] center_reg_preds = [[] for _ in range(mid_layers_num)] for ind in range(inter_num): - + # just sum the loss of 19 matrix layers respectively, so heatmap_losses, wh_losses, center_reg_losses are all lists containing #gpu items heatmap_pred, wh_pred, center_reg_pred = net(split_data[0][ind]) for ii in range(mid_layers_num): wh_preds[ii].append(wh_pred[ii]) center_reg_preds[ii].append(center_reg_pred[ii]) - wh_losses[ind] += wh_loss(wh_pred[ii], split_data[1+mid_layers_num+ii][ind],\ - split_data[1+mid_layers_num*2+ii][ind]) - center_reg_losses[ind] += center_reg_loss(center_reg_pred[ii], split_data[1+mid_layers_num*3+ii][ind],\ - split_data[1+mid_layers_num*4+ii][ind]) - heatmap_losses[ind] += heatmap_loss(heatmap_pred[ii], split_data[1+ii][ind]) - + if ii == 0: + wh_losses.append(wh_loss(wh_pred[ii], split_data[1+mid_layers_num+ii][ind], + split_data[1+mid_layers_num*2+ii][ind])) + center_reg_losses.append(center_reg_loss(center_reg_pred[ii], split_data[1+mid_layers_num*3+ii][ind], + split_data[1+mid_layers_num*4+ii][ind])) + heatmap_losses.append(heatmap_loss(heatmap_pred[ii], split_data[1+ii][ind])) + else: + wh_losses[-1] = wh_losses[-1] + wh_loss(wh_pred[ii], split_data[1+mid_layers_num+ii][ind], + split_data[1+mid_layers_num*2+ii][ind]) + center_reg_losses[-1] = center_reg_losses[-1] + center_reg_loss(center_reg_pred[ii], + split_data[1+mid_layers_num*3+ii][ind], split_data[1+mid_layers_num*4+ii][ind]) + heatmap_losses[-1] = heatmap_losses[-1] + heatmap_loss(heatmap_pred[ii], split_data[1+ii][ind]) + sum_losses = [heatmap_losses[ii]+wh_losses[ii]+center_reg_losses[ii] for ii in range(inter_num)] autograd.backward(sum_losses) trainer.step(len(sum_losses)) # step with # gpus From c73ae4f51525b70e7ccd74b66db38f1c7f551b2c Mon Sep 17 00:00:00 2001 From: liueo Date: Mon, 1 Jun 2020 18:00:33 +0000 Subject: [PATCH 7/7] update --- gluoncv/data/transforms/presets/matrix_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluoncv/data/transforms/presets/matrix_net.py b/gluoncv/data/transforms/presets/matrix_net.py index 3364818c75..f9435bc7f6 100644 --- a/gluoncv/data/transforms/presets/matrix_net.py +++ b/gluoncv/data/transforms/presets/matrix_net.py @@ -114,7 +114,7 @@ def __call__(self, src, label): class MatrixNetDefaultValTransform(object): - """Default SSD validation transform. + """Default MatrixNet validation transform. Parameters ----------