Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Roialign fix and half_pixel mode support #3482

Open
wants to merge 75 commits into
base: develop
Choose a base branch
from
Open

Conversation

bpickrel
Copy link
Contributor

@bpickrel bpickrel commented Sep 26, 2024

Fix bugs in the implementation of ROIAlign operation which were found when attempting to run it with the half_pixel coordinate conversion mode, to include more thorough tests. Some bugs are mode-specific and some are not.

The ROIAlign operation was first proposed in a paper at https://arxiv.org/abs/1703.06870v3 which introduced the Mask R-CNN model. It was a variant of the ROIPool operation which was found to give significantly better accuracy. In the implementations in Torch, Onnxruntime, and Migraphx, ROIPool and ROIAlign are implemented in the same op. with different choices for the mode attribute, with output_half_pixel for ROIPool and half_pixel for ROIAlign; thus, there is no ROIAlign op without fixing the half_pixel mode.

Note, by the way, that these same coordinate conversion modes are also attributes of the Resize op.

MIGraphX uses the Onnxruntime implementation of ROIAlign as its functional specification and should give identical results.

This change is prerequisite for torch-migraphx PR #143 but does not close it.

bpickrel and others added 30 commits July 17, 2024 22:23
…code. Tests need to be completed, including updating generated onnx test files.
…_half_pixel_verify_test for first roi but fails for second
@@ -41,7 +41,7 @@ TEST_CASE(roialign_test)
{{"coordinate_transformation_mode", "output_half_pixel"},
{"spatial_scale", 2.0f},
{"output_height", 5},
{"output_width", 5},
{"output_width", 3},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This particular change was a big deal, btw, since the old code appeared to work fine until I gave it an output_height and output_width that were not the same.

migraphx::shape srois{migraphx::shape::float_type, {2, 4}};
std::vector<float> rois_data = {1.1, 0.73, 1.7, 1.13, 1.1, 0.73, 2.6, 1.13};
migraphx::shape sbi{migraphx::shape::int64_type, {2}}; // batch_index
std::vector<int64_t> bi_data = {0, 1};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for good measure, can you add a tests with the following cases:

  • Repeated batch indices
  • Missing batch indices (ie. not all batch items are computed on)
  • Number of ROIs != batch_size

You can probably just create one test case to get all these. Make the input batch_size 3 and the batch_indices something like {1,2,2,1} (and hence the rois shape will be {4,4})
Would be good to have a gpu verify test for this same case too just to be sure gpu impl matches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added such a case in ref. tests, along with other updates. Some of the new cases fail and I'm now debugging those.

Copy link
Contributor

@spolifroni-amd spolifroni-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A whitespace in the table needs to be removed.

@@ -697,7 +697,7 @@ Operator Support Matrix
| | | | functions are |
| | | | not enabled |
+--------------------------+-----------+-----------------+------------------------------+
| RoiAlign | ✅ | FP8, FP16, | |
| RoiAlign | ✅ | FP8, FP16, | |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra whitespace at the end of this row is causing the table to be improperly formatted and not appear on the doc page.

Suggested change
| RoiAlign | ✅ | FP8, FP16, | |
| RoiAlign | ✅ | FP8, FP16, | |

…variety of options including pooling mode, transformation type, spatial scale, multiple input channels, non-symmetrical output shape,

and roi index list with skips and duplicates.  Changed roialign_half_pixel_verify_test to match one of the new ref test cases.  Cases using max pooling do not pass test.
@@ -84,114 +84,164 @@ TEST_CASE(roialign_out_of_bound_test)
}
}

auto create_program(const std::string& trans_mode = "half_pixel",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notes on ref tests: added cases with all 4 combinations of trans_mode and pooling_mode and split them apart into separate named cases. The modified create_program() has a reshaped input with multiple channels, multiple layers, an ROI list that doesn't match 1-1 with the layers, non-unity scale and sampling ratio, one negative data value (any float value is legal) and an ROI that goes out of bounds (also legal). Also, the output height and width are no longer equal, which masked errors in the original implementation.

Need to debug the max pooling cases!

@bpickrel
Copy link
Contributor Author

The licensing check fail now occurring is for a file not related to this PR:

Error: The licenses for the following 1 file(s) either... do not match the year of commit, have a different copyright format or have not been synced from the latest roialign_fix branch:
['src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp']

@bpickrel
Copy link
Contributor Author

Looks fine, just a few small things. I haven't been able to fully wrap my head around all the math in the ref and gpu impl, the index changes look reasonable. Do we have a way to directly test against ORT (without maunally extracting gold outputs)? If so, I think it would be worthwhile to add a few more tests comparing with ORT

I think it would be possible to add a test following the model of the existing tests in test/py/. With luck it wouldn't be very much extra work, half a day or so. @pfultz2 what do you think? The rationale for adding an op test here is that the ROIAlign op is defined in terms of the Onnxruntime implementation so it makes sense to have a specialized test with ORT as the reference.

Note my recent comment that I learned the ORT implementation of the max pooling option is buggy and can't be used for a test reference until the fix is released. I don't know whether max pooling is widely used with this op or not.

@bpickrel
Copy link
Contributor Author

Looks fine, just a few small things. I haven't been able to fully wrap my head around all the math in the ref and gpu impl, the index changes look reasonable.

Do you want me to go over it with you? I can explain the intent of nearly everything but the indexing is still very difficult to unravel.

# XXXXX 0x562d956ec8f0 (0x562d956ec8f0 + 0 * 2 + channel 0) * 4 * 3
# XXXXX 0x562d956ec920 (0x562d956ec8f0 + 0 * 2 + channel 1) * 4 * 3
res = sess.run(['y'], {'x': data, 'rois': roi_data, 'batch_ind': index_data})
assert np.allclose(mgx_result, res, rtol=1e-05, atol=1e-08, equal_nan=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tolerances are the Numpy defaults

@bpickrel bpickrel marked this pull request as draft October 29, 2024 15:49
@bpickrel
Copy link
Contributor Author

Requesting re-review after a recent change: Added a Python test test_roialign.py to check MigraphX output directly vs. onnxruntime, and found that MigraphX results were internally consistent but output the right values in a transposed shape. Fixing this caused changes to internal computations, but I updated both the ref. and GPU implementations to emit a corrected shape.

Repeat of an earlier comment: we can't do a similar check vs. onnxruntime for "max" pooling mode because the ORT implementation of max pooling in ROIAlign has a known bug.

@bpickrel bpickrel marked this pull request as ready for review November 12, 2024 16:44
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
400bd0
Rate old
c51bea
Diff Compare
torchvision-resnet50 64 3,258.94 3,257.81 0.03%
torchvision-resnet50_fp16 64 6,988.19 6,987.81 0.01%
torchvision-densenet121 32 2,431.87 2,434.57 -0.11%
torchvision-densenet121_fp16 32 4,099.62 4,065.61 0.84%
torchvision-inceptionv3 32 1,636.68 1,637.17 -0.03%
torchvision-inceptionv3_fp16 32 2,761.86 2,759.26 0.09%
cadene-inceptionv4 16 775.66 776.31 -0.08%
cadene-resnext64x4 16 808.05 811.75 -0.46%
slim-mobilenet 64 7,525.92 7,533.16 -0.10%
slim-nasnetalarge 64 211.28 211.39 -0.05%
slim-resnet50v2 64 3,497.50 3,504.83 -0.21%
bert-mrpc-onnx 8 1,147.54 1,146.47 0.09%
bert-mrpc-tf 1 464.53 473.89 -1.98%
pytorch-examples-wlang-gru 1 413.15 425.31 -2.86%
pytorch-examples-wlang-lstm 1 389.69 408.68 -4.65% 🔴
torchvision-resnet50_1 1 806.71 771.75 4.53% 🔆
cadene-dpn92_1 1 399.87 399.01 0.22%
cadene-resnext101_1 1 382.86 383.85 -0.26%
onnx-taau-downsample 1 343.04 343.09 -0.02%
dlrm-criteoterabyte 1 33.33 33.31 0.05%
dlrm-criteoterabyte_fp16 1 52.71 52.71 0.01%
agentmodel 1 7,901.33 8,235.67 -4.06% 🔴
unet_fp16 2 58.79 58.90 -0.19%
resnet50v1_fp16 1 948.60 940.89 0.82%
resnet50v1_int8 1 1,002.37 1,025.93 -2.30%
bert_base_cased_fp16 64 1,171.54 1,170.88 0.06%
bert_large_uncased_fp16 32 363.60 363.69 -0.02%
bert_large_fp16 1 200.49 200.14 0.18%
distilgpt2_fp16 16 2,202.57 2,200.77 0.08%
yolov5s 1 543.48 535.15 1.56%
tinyllama 1 43.46 43.41 0.10%
vicuna-fastchat 1 175.77 178.09 -1.30%
whisper-tiny-encoder 1 417.88 418.18 -0.07%
whisper-tiny-decoder 1 427.73 427.58 0.03%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 12, 2024

You should capture the onnxruntime results and just create a ref test.



if __name__ == "__main__":
test_roialign()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is not even used in our test suite. It should just be removed and a ref test should be used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants