From 038db92ad6d0112719e5b17b3236fbd0481097a6 Mon Sep 17 00:00:00 2001 From: Nick Jong Date: Sat, 7 Mar 2020 12:45:35 -0800 Subject: [PATCH] Refactor Object Detection inference to use new Model Trainer type (#3034) --- src/ml/neural_net/CMakeLists.txt | 1 + src/ml/neural_net/mps_compute_context.hpp | 3 +- src/ml/neural_net/mps_compute_context.mm | 27 ++- src/ml/neural_net/mps_od_backend.hpp | 63 +++++++ src/ml/neural_net/mps_od_backend.mm | 86 +++++++++ .../object_detection/object_detector.cpp | 171 ++++-------------- .../object_detection/object_detector.hpp | 3 + .../od_darknet_yolo_model_trainer.cpp | 139 +++++++++++++- .../od_darknet_yolo_model_trainer.hpp | 61 ++++++- .../object_detection/od_model_trainer.cpp | 17 +- .../object_detection/od_model_trainer.hpp | 60 ++++++ .../object_detection/test_object_detector.cxx | 144 ++++++++++++--- 12 files changed, 584 insertions(+), 191 deletions(-) create mode 100644 src/ml/neural_net/mps_od_backend.hpp create mode 100644 src/ml/neural_net/mps_od_backend.mm diff --git a/src/ml/neural_net/CMakeLists.txt b/src/ml/neural_net/CMakeLists.txt index 5378e4610f..6445dc93a0 100644 --- a/src/ml/neural_net/CMakeLists.txt +++ b/src/ml/neural_net/CMakeLists.txt @@ -31,6 +31,7 @@ if(APPLE AND HAS_MPS AND NOT TC_BUILD_IOS) mps_weight.mm mps_device_manager.m mps_descriptor_utils.m + mps_od_backend.mm style_transfer/mps_style_transfer.m style_transfer/mps_style_transfer_backend.mm style_transfer/mps_style_transfer_utils.m diff --git a/src/ml/neural_net/mps_compute_context.hpp b/src/ml/neural_net/mps_compute_context.hpp index 0df8c811e6..0bc76c139d 100644 --- a/src/ml/neural_net/mps_compute_context.hpp +++ b/src/ml/neural_net/mps_compute_context.hpp @@ -62,8 +62,7 @@ class mps_compute_context: public compute_context { std::function rng); private: - - std::unique_ptr command_queue_; + std::shared_ptr command_queue_; }; } // namespace neural_net diff --git a/src/ml/neural_net/mps_compute_context.mm b/src/ml/neural_net/mps_compute_context.mm index 07414aeb92..4169e4861a 100644 --- a/src/ml/neural_net/mps_compute_context.mm +++ b/src/ml/neural_net/mps_compute_context.mm @@ -10,11 +10,13 @@ #include #include +#include +#include +#include + #include #include -#include -#import namespace turi { namespace neural_net { @@ -125,18 +127,23 @@ float_array_map multiply_mps_od_loss_multiplier(float_array_map config, std::unique_ptr mps_compute_context::create_object_detector( int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out, const float_array_map& config, const float_array_map& weights) { - float_array_map updated_config; + mps_od_backend::parameters params; + params.command_queue = command_queue_; + params.n = n; + params.c_in = c_in; + params.h_in = h_in; + params.w_in = w_in; + params.c_out = c_out; + params.h_out = h_out; + params.w_out = w_out; + params.weights = weights; + std::vector update_keys = { "learning_rate", "od_scale_class", "od_scale_no_object", "od_scale_object", "od_scale_wh", "od_scale_xy", "gradient_clipping"}; - updated_config = multiply_mps_od_loss_multiplier(config, update_keys); - std::unique_ptr result( - new mps_graph_cnn_module(*command_queue_)); - - result->init(/* network_id */ kODGraphNet, n, c_in, h_in, w_in, c_out, h_out, - w_out, updated_config, weights); + params.config = multiply_mps_od_loss_multiplier(config, update_keys); - return result; + return std::unique_ptr(new mps_od_backend(std::move(params))); } std::unique_ptr mps_compute_context::create_activity_classifier( diff --git a/src/ml/neural_net/mps_od_backend.hpp b/src/ml/neural_net/mps_od_backend.hpp new file mode 100644 index 0000000000..76241b6331 --- /dev/null +++ b/src/ml/neural_net/mps_od_backend.hpp @@ -0,0 +1,63 @@ +/* Copyright © 2020 Apple Inc. All rights reserved. + * + * Use of this source code is governed by a BSD-3-clause license that can + * be found in the LICENSE.txt file or at + * https://opensource.org/licenses/BSD-3-Clause + */ + +#ifndef MPS_OD_BACKEND_HPP_ +#define MPS_OD_BACKEND_HPP_ + +#include +#include + +namespace turi { +namespace neural_net { + +/** + * Model backend for object detection that uses a separate mps_graph_cnnmodule + * for training and for inference, since mps_graph_cnnmodule doesn't currently + * support doing both. + */ +class mps_od_backend : public model_backend { + public: + struct parameters { + std::shared_ptr command_queue; + int n; + int c_in; + int h_in; + int w_in; + int c_out; + int h_out; + int w_out; + float_array_map config; + float_array_map weights; + }; + + mps_od_backend(parameters params); + + // Training + void set_learning_rate(float lr) override; + float_array_map train(const float_array_map& inputs) override; + + // Inference + float_array_map predict(const float_array_map& inputs) const override; + + float_array_map export_weights() const override; + + private: + void ensure_training_module(); + void ensure_prediction_module() const; + + parameters params_; + + std::unique_ptr training_module_; + + // Cleared whenever the training module is updated. + mutable std::unique_ptr prediction_module_; +}; + +} // namespace neural_net +} // namespace turi + +#endif // MPS_OD_BACKEND_HPP_ diff --git a/src/ml/neural_net/mps_od_backend.mm b/src/ml/neural_net/mps_od_backend.mm new file mode 100644 index 0000000000..931a8c14a5 --- /dev/null +++ b/src/ml/neural_net/mps_od_backend.mm @@ -0,0 +1,86 @@ +/* Copyright © 2020 Apple Inc. All rights reserved. + * + * Use of this source code is governed by a BSD-3-clause license that can + * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + */ + +#include + +namespace turi { +namespace neural_net { + +void mps_od_backend::ensure_training_module() { + if (training_module_) return; + + training_module_.reset(new mps_graph_cnn_module(*params_.command_queue)); + training_module_->init(/* network_id */ kODGraphNet, params_.n, params_.c_in, params_.h_in, + params_.w_in, params_.c_out, params_.h_out, params_.w_out, params_.config, + params_.weights); + + // Clear params_.weights to free up memory, since they are now superceded by + // whatever the training module contains. + params_.weights.clear(); +} + +void mps_od_backend::ensure_prediction_module() const { + if (prediction_module_) return; + + // Adjust configuration for prediction. + float_array_map config = params_.config; + config["mode"] = shared_float_array::wrap(2.0f); + config["od_include_loss"] = shared_float_array::wrap(0.0f); + + // Take weights from training module if present, else from original weights. + float_array_map weights; + if (training_module_) { + weights = training_module_->export_weights(); + } else { + weights = params_.weights; + } + + prediction_module_.reset(new mps_graph_cnn_module(*params_.command_queue)); + prediction_module_->init(/* network_id */ kODGraphNet, params_.n, params_.c_in, params_.h_in, + params_.w_in, params_.c_out, params_.h_out, params_.w_out, config, + weights); +} + +mps_od_backend::mps_od_backend(parameters params) : params_(std::move(params)) { + // Immediate instantiate at least one module, since at present we can't + // guarantee that the weights will remain valid after we return. + // TODO: Remove this eager construction once we stop putting weak pointers in + // float_array_map. + if (params_.config.at("mode").data()[0] == 0.f) { + ensure_training_module(); + } else { + ensure_prediction_module(); + } +} + +void mps_od_backend::set_learning_rate(float lr) { + ensure_training_module(); + training_module_->set_learning_rate(lr); +} + +float_array_map mps_od_backend::train(const float_array_map& inputs) { + // Invalidate prediction_module, since its weights will be stale. + prediction_module_.reset(); + + ensure_training_module(); + return training_module_->train(inputs); +} + +float_array_map mps_od_backend::predict(const float_array_map& inputs) const { + ensure_prediction_module(); + return prediction_module_->predict(inputs); +} + +float_array_map mps_od_backend::export_weights() const { + if (training_module_) { + return training_module_->export_weights(); + } else { + return params_.weights; + } +} + +} // namespace neural_net +} // namespace turi diff --git a/src/toolkits/object_detection/object_detector.cpp b/src/toolkits/object_detection/object_detector.cpp index 3b0a8c6711..73d1c599c2 100644 --- a/src/toolkits/object_detection/object_detector.cpp +++ b/src/toolkits/object_detection/object_detector.cpp @@ -57,6 +57,7 @@ using turi::coreml::MLModelWrapper; using turi::neural_net::compute_context; using turi::neural_net::deferred_float_array; using turi::neural_net::float_array_map; +using turi::neural_net::FuturesStream; using turi::neural_net::image_annotation; using turi::neural_net::image_augmenter; using turi::neural_net::labeled_image; @@ -79,16 +80,11 @@ constexpr int DEFAULT_BATCH_SIZE = 32; // Empircally, we need 4GB to support batch size 32. constexpr size_t MEMORY_REQUIRED_FOR_DEFAULT_BATCH_SIZE = 4294967296; -// We assume RGB input. -constexpr int NUM_INPUT_CHANNELS = 3; - // The spatial reduction depends on the input size of the pre-trained model // (relative to the grid size). // TODO: When we support alternative base models, we will have to generalize. constexpr int SPATIAL_REDUCTION = 32; -constexpr float BASE_LEARNING_RATE = 0.001f; - constexpr float DEFAULT_NON_MAXIMUM_SUPPRESSION_THRESHOLD = 0.45f; constexpr float DEFAULT_CONFIDENCE_THRESHOLD_PREDICT = 0.25f; @@ -110,35 +106,6 @@ const std::vector>& anchor_boxes() { return *default_boxes; }; -// These are the fixed values that the Python implementation currently passes -// into TCMPS. -// TODO: These should be exposed in a way that facilitates experimentation. -// TODO: A struct instead of a map would be nice, too. - -float_array_map get_base_config() { - float_array_map config; - config["learning_rate"] = - shared_float_array::wrap(BASE_LEARNING_RATE); - config["gradient_clipping"] = shared_float_array::wrap(0.025f); - // TODO: Have MPS path use these parameters, instead - // of the values hardcoded in the MPS code. - config["od_rescore"] = shared_float_array::wrap(1.0f); - config["lmb_noobj"] = shared_float_array::wrap(5.0); - config["lmb_obj"] = shared_float_array::wrap(100.0); - config["lmb_coord_xy"] = shared_float_array::wrap(10.0); - config["lmb_coord_wh"] = shared_float_array::wrap(10.0); - config["lmb_class"] = shared_float_array::wrap(2.0); - return config; -} - -float_array_map get_prediction_config() { - float_array_map config = get_base_config(); - config["mode"] = shared_float_array::wrap(2.0f); - config["od_include_loss"] = shared_float_array::wrap(0.0f); - config["od_include_network"] = shared_float_array::wrap(1.0f); - return config; -} - flex_int estimate_max_iterations(flex_int num_instances, flex_int batch_size) { // Scale with square root of number of labeled instances. @@ -666,6 +633,12 @@ gl_sframe object_detector::convert_types_to_sframe( return sframe_data; } +std::unique_ptr object_detector::create_inference_trainer( + const Checkpoint& checkpoint, + std::unique_ptr context) const { + return checkpoint.CreateModelTrainer(context.get()); +} + void object_detector::perform_predict( gl_sframe data, std::function&, @@ -673,12 +646,8 @@ void object_detector::perform_predict( const std::pair&)> consumer, float confidence_threshold, float iou_threshold) { - std::string image_column_name = read_state("feature"); - std::string annotations_column_name = read_state("annotations"); flex_list class_labels = read_state("classes"); int batch_size = read_state("batch_size"); - int grid_height = read_state("grid_height"); - int grid_width = read_state("grid_width"); // return if the data is empty if (data.size() == 0) return; @@ -694,105 +663,41 @@ void object_detector::perform_predict( log_and_throw("No neural network compute context provided"); } - // Instantiate the data augmenter. Don't enable any of the actual - // augmentations, just resize the input images to the desired shape. - image_augmenter::options augmenter_opts; - augmenter_opts.batch_size = batch_size; - augmenter_opts.output_height = grid_height * SPATIAL_REDUCTION; - augmenter_opts.output_width = grid_width * SPATIAL_REDUCTION; - std::unique_ptr augmenter = - ctx->create_image_augmenter(augmenter_opts); - - // Instantiate the NN backend. - // For each anchor box, we have 4 bbox coords + 1 conf + one-hot class labels - int num_outputs_per_anchor = 5 + static_cast(class_labels.size()); - int num_output_channels = static_cast(num_outputs_per_anchor * anchor_boxes().size()); - - float_array_map pred_config = get_prediction_config(); - pred_config["num_iterations"] = - shared_float_array::wrap(get_max_iterations()); - pred_config["num_classes"] = - shared_float_array::wrap(get_num_classes()); - - std::unique_ptr model = ctx->create_object_detector( - /* n */ read_state("batch_size"), - /* c_in */ NUM_INPUT_CHANNELS, - /* h_in */ grid_height * SPATIAL_REDUCTION, - /* w_in */ grid_width * SPATIAL_REDUCTION, - /* c_out */ num_output_channels, - /* h_out */ grid_height, - /* w_out */ grid_width, - /* config */ pred_config, - /* weights */ strip_fwd(checkpoint_->weights())); - - // To support double buffering, use a queue of pending inference results. - std::queue pending_batches; - - // Helper function to process results until the queue reaches a given size. - auto pop_until_size = [&](size_t remaining) { - while (pending_batches.size() > remaining) { - - // Pop one batch from the queue. - inference_batch batch = pending_batches.front(); - - pending_batches.pop(); - for (size_t i = 0; i < batch.annotations_batch.size(); ++i) { - // For this row (corresponding to one image), extract the prediction. - shared_float_array raw_prediction = batch.image_batch[i]; - - // Translate the raw output into predicted labels and bounding boxes. - std::vector predicted_annotations = - convert_yolo_to_annotations(raw_prediction, anchor_boxes(), - confidence_threshold); - // Remove overlapping predictions. - predicted_annotations = apply_non_maximum_suppression( - std::move(predicted_annotations), iou_threshold); - - consumer(predicted_annotations, batch.annotations_batch[i], - batch.image_dimensions_batch[i]); + // Construct a pipeline generating inference results. + std::unique_ptr model_trainer = + create_inference_trainer(read_checkpoint(), std::move(ctx)); + std::shared_ptr> inference_futures = + model_trainer + ->AsInferenceBatchPublisher(std::move(data_iter), batch_size, + confidence_threshold, iou_threshold) + ->AsFutures(); + + // Consume the results, ensuring that we have the next batch in progress in + // the background while we consume the previous batch. + std::future> pending_batch = + inference_futures->Next(); + while (pending_batch.valid()) { + // Start the next batch before we handle the pending batch. + std::future> next_batch = + inference_futures->Next(); + + // Wait for the pending batch to be complete. + std::unique_ptr encoded_batch = pending_batch.get(); + if (encoded_batch) { + // We have more raw results. Decode them. + InferenceOutputBatch batch = model_trainer->DecodeOutputBatch( + *encoded_batch, confidence_threshold, iou_threshold); + + // Consume the results. + for (size_t i = 0; i < batch.annotations.size(); ++i) { + consumer(batch.predictions[i], batch.annotations[i], + batch.image_sizes[i]); } - } - }; - // Iterate through the data once. - std::vector input_batch = data_iter->next_batch(batch_size); - - while (!input_batch.empty()) { - // Wait until we have just one asynchronous batch outstanding. The work - // below should be concurrent with the neural net inference for that batch. - pop_until_size(1); - - inference_batch result_batch; - - // Instead of giving the ground truth data to the image augmenter and the - // neural net, instead save them for later, pairing them with the future - // predictions. - result_batch.annotations_batch.resize(input_batch.size()); - result_batch.image_dimensions_batch.resize(input_batch.size()); - for (size_t i = 0; i < input_batch.size(); ++i) { - result_batch.annotations_batch[i] = std::move(input_batch[i].annotations); - result_batch.image_dimensions_batch[i] = std::make_pair( - input_batch[i].image.m_height, input_batch[i].image.m_width); - input_batch[i].annotations.clear(); + // Continue iterating. + pending_batch = std::move(next_batch); } - - // Use the image augmenter to format the images into float arrays, and - // submit them to the neural net. - image_augmenter::result prepared_input_batch = - augmenter->prepare_images(std::move(input_batch)); - - std::map prediction_results = - model->predict({{"input", prepared_input_batch.image_batch}}); - - result_batch.image_batch = prediction_results.at("output"); - - // Add the pending result to our queue and move on to the next input batch. - pending_batches.push(std::move(result_batch)); - input_batch = data_iter->next_batch(batch_size); } - - // Process all remaining batches. - pop_until_size(0); } // TODO: Should accept model_backend as an optional argument to avoid diff --git a/src/toolkits/object_detection/object_detector.hpp b/src/toolkits/object_detection/object_detector.hpp index 24ce4b652e..5617dc34c4 100644 --- a/src/toolkits/object_detection/object_detector.hpp +++ b/src/toolkits/object_detection/object_detector.hpp @@ -193,6 +193,9 @@ class EXPORT object_detector: public ml_model_base { const Config& config, const std::string& pretrained_model_path, int random_seed, std::unique_ptr context) const; + virtual std::unique_ptr create_inference_trainer( + const Checkpoint& checkpoint, + std::unique_ptr context) const; // Establishes training pipelines from the backend. void connect_trainer(std::unique_ptr trainer, diff --git a/src/toolkits/object_detection/od_darknet_yolo_model_trainer.cpp b/src/toolkits/object_detection/od_darknet_yolo_model_trainer.cpp index debaae70b3..39dc26a94e 100644 --- a/src/toolkits/object_detection/od_darknet_yolo_model_trainer.cpp +++ b/src/toolkits/object_detection/od_darknet_yolo_model_trainer.cpp @@ -7,6 +7,7 @@ #include +#include #include #include @@ -206,7 +207,7 @@ std::unique_ptr InitializeDarknetYOLO( } // namespace -image_augmenter::options DarknetYOLOTrainingAugmentationOptions( +image_augmenter::options DarknetYOLOInferenceAugmentationOptions( int batch_size, int output_height, int output_width) { image_augmenter::options opts; @@ -214,6 +215,13 @@ image_augmenter::options DarknetYOLOTrainingAugmentationOptions( opts.batch_size = static_cast(batch_size); opts.output_height = static_cast(output_height * SPATIAL_REDUCTION); opts.output_width = static_cast(output_width * SPATIAL_REDUCTION); + return opts; +} + +image_augmenter::options DarknetYOLOTrainingAugmentationOptions( + int batch_size, int output_height, int output_width) { + image_augmenter::options opts = DarknetYOLOInferenceAugmentationOptions( + batch_size, output_height, output_width); // Apply random crops. opts.crop_prob = 0.9f; @@ -252,6 +260,7 @@ EncodedInputBatch EncodeDarknetYOLO(InputBatch input_batch, result.iteration_id = input_batch.iteration_id; result.images = std::move(input_batch.images); result.annotations = std::move(input_batch.annotations); + result.image_sizes = std::move(input_batch.image_sizes); // Allocate a float buffer of sufficient size. // TODO: Recycle these allocations. @@ -282,7 +291,33 @@ EncodedInputBatch EncodeDarknetYOLO(InputBatch input_batch, return result; } -TrainingOutputBatch DarknetYOLOTrainer::Invoke(EncodedInputBatch input_batch) { +InferenceOutputBatch DecodeDarknetYOLOInference(EncodedBatch batch, + float confidence_threshold, + float iou_threshold) { + InferenceOutputBatch result; + result.iteration_id = batch.iteration_id; + + result.predictions.resize(batch.image_sizes.size()); + for (size_t i = 0; i < result.predictions.size(); ++i) { + // For this row (corresponding to one image), extract the prediction. + shared_float_array raw_prediction = batch.encoded_data.at("output")[i]; + + // Translate the raw output into predicted labels and bounding boxes. + result.predictions[i] = convert_yolo_to_annotations( + raw_prediction, GetAnchorBoxes(), confidence_threshold); + + // Remove overlapping predictions. + result.predictions[i] = apply_non_maximum_suppression( + std::move(result.predictions[i]), iou_threshold); + } + + result.annotations = std::move(batch.annotations); + result.image_sizes = std::move(batch.image_sizes); + return result; +} + +TrainingOutputBatch DarknetYOLOBackendTrainingWrapper::Invoke( + EncodedInputBatch input_batch) { ApplyLearningRateSchedule(input_batch.iteration_id); auto results = impl_->train( @@ -294,7 +329,8 @@ TrainingOutputBatch DarknetYOLOTrainer::Invoke(EncodedInputBatch input_batch) { return output_batch; } -void DarknetYOLOTrainer::ApplyLearningRateSchedule(int iteration_id) { +void DarknetYOLOBackendTrainingWrapper::ApplyLearningRateSchedule( + int iteration_id) { // Leave the learning rate unchanged for the first half of the expected number // of iterations. if (iteration_id == 1 + max_iterations_ / 2) { @@ -310,6 +346,16 @@ void DarknetYOLOTrainer::ApplyLearningRateSchedule(int iteration_id) { } } +EncodedBatch DarknetYOLOBackendInferenceWrapper::Invoke( + EncodedInputBatch input_batch) { + EncodedBatch output_batch; + output_batch.iteration_id = input_batch.iteration_id; + output_batch.encoded_data = impl_->predict({{"input", input_batch.images}}); + output_batch.annotations = std::move(input_batch.annotations); + output_batch.image_sizes = std::move(input_batch.image_sizes); + return output_batch; +} + std::unique_ptr DarknetYOLOCheckpointer::Next() { // Copy the weights out from the backend. float_array_map backend_weights = impl_->export_weights(); @@ -374,11 +420,7 @@ float_array_map DarknetYOLOCheckpoint::internal_weights() const { DarknetYOLOModelTrainer::DarknetYOLOModelTrainer( const DarknetYOLOCheckpoint& checkpoint, neural_net::compute_context* context) - : ModelTrainer(context->create_image_augmenter( - DarknetYOLOTrainingAugmentationOptions( - checkpoint.config().batch_size, checkpoint.config().output_height, - checkpoint.config().output_width))), - config_(checkpoint.config()), + : config_(checkpoint.config()), backend_(context->create_object_detector( /* n */ config_.batch_size, /* c_in */ 3, // RGB input @@ -388,7 +430,82 @@ DarknetYOLOModelTrainer::DarknetYOLOModelTrainer( /* h_out */ config_.output_height, /* w_out */ config_.output_width, /* config */ checkpoint.internal_config(), - /* weights */ checkpoint.internal_weights())) {} + /* weights */ checkpoint.internal_weights())), + training_augmenter_( + std::make_shared(context->create_image_augmenter( + DarknetYOLOTrainingAugmentationOptions( + checkpoint.config().batch_size, + checkpoint.config().output_height, + checkpoint.config().output_width)))), + inference_augmenter_( + std::make_shared(context->create_image_augmenter( + DarknetYOLOInferenceAugmentationOptions( + checkpoint.config().batch_size, + checkpoint.config().output_height, + checkpoint.config().output_width)))) {} + +std::shared_ptr> +DarknetYOLOModelTrainer::AsTrainingBatchPublisher( + std::unique_ptr training_data, size_t batch_size, + int offset) { + // Wrap the data_iterator to incorporate into a Combine pipeline. + auto iterator = std::make_shared(std::move(training_data), + batch_size, offset); + + // Define a lambda that applies EncodeDarknetYOLO to the raw annotations. + Config config = config_; + auto encoder = [config](InputBatch input_batch) { + return EncodeDarknetYOLO( + std::move(input_batch), config.output_height, config.output_width, + static_cast(GetAnchorBoxes().size()), config.num_classes); + }; + + // Wrap the model_backend. + auto trainer = std::make_shared( + backend_, BASE_LEARNING_RATE, config_.max_iterations); + + // Construct the training pipeline. + return iterator->AsPublisher() + ->Map(training_augmenter_) + ->Map(encoder) + ->Map(trainer); +} + +std::shared_ptr> +DarknetYOLOModelTrainer::AsInferenceBatchPublisher( + std::unique_ptr test_data, size_t batch_size, + float confidence_threshold, float iou_threshold) { + // Wrap the data_iterator to incorporate into a Combine pipeline. + auto iterator = std::make_shared(std::move(test_data), + batch_size, /* offset */ 0); + + // No labels to encode. Just pass the annotations through for potential + // evaluation. + auto trivial_encoder = [](InputBatch input_batch) { + EncodedInputBatch result; + result.iteration_id = input_batch.iteration_id; + result.images = std::move(input_batch.images); + result.annotations = std::move(input_batch.annotations); + result.image_sizes = std::move(input_batch.image_sizes); + return result; + }; + + // Wrap the model_backend. + auto predicter = + std::make_shared(backend_); + + // Construct the inference pipeline. + return iterator->AsPublisher() + ->Map(inference_augmenter_) + ->Map(trivial_encoder) + ->Map(predicter); +} + +InferenceOutputBatch DarknetYOLOModelTrainer::DecodeOutputBatch( + EncodedBatch batch, float confidence_threshold, float iou_threshold) { + return DecodeDarknetYOLOInference(std::move(batch), confidence_threshold, + iou_threshold); +} std::shared_ptr>> DarknetYOLOModelTrainer::AsCheckpointPublisher() { @@ -397,6 +514,8 @@ DarknetYOLOModelTrainer::AsCheckpointPublisher() { return checkpointer->AsPublisher(); } +// TODO: Remove this method. It is only called by the base class implementation +// of AsTrainingBatchPublisher we overrode above. std::shared_ptr> DarknetYOLOModelTrainer::AsTrainingBatchPublisher( std::shared_ptr> augmented_data) { @@ -410,7 +529,7 @@ DarknetYOLOModelTrainer::AsTrainingBatchPublisher( }; // Wrap the model_backend. - auto trainer = std::make_shared( + auto trainer = std::make_shared( backend_, BASE_LEARNING_RATE, config_.max_iterations); // Append the encoding function and the model backend to the pipeline. diff --git a/src/toolkits/object_detection/od_darknet_yolo_model_trainer.hpp b/src/toolkits/object_detection/od_darknet_yolo_model_trainer.hpp index cc124075d9..ebefe48ff2 100644 --- a/src/toolkits/object_detection/od_darknet_yolo_model_trainer.hpp +++ b/src/toolkits/object_detection/od_darknet_yolo_model_trainer.hpp @@ -23,7 +23,17 @@ namespace turi { namespace object_detection { -/** Configures an image_augmenter given darknet-yolo network parameters. */ +/** + * Configures an image_augmenter for inference given darknet-yolo network + * parameters. + */ +neural_net::image_augmenter::options DarknetYOLOInferenceAugmentationOptions( + int batch_size, int output_height, int output_width); + +/** + * Configures an image_augmenter for training given darknet-yolo network + * parameters. + */ neural_net::image_augmenter::options DarknetYOLOTrainingAugmentationOptions( int batch_size, int output_height, int output_width); @@ -35,6 +45,13 @@ EncodedInputBatch EncodeDarknetYOLO(InputBatch input_batch, size_t output_height, size_t output_width, size_t num_anchors, size_t num_classes); +/** + * Decodes the raw inference output into structured predictions. + */ +InferenceOutputBatch DecodeDarknetYOLOInference(EncodedBatch batch, + float confidence_threshold, + float iou_threshold); + /** * Wrapper that integrates a darknet-yolo model_backend into a training * pipeline. @@ -42,13 +59,14 @@ EncodedInputBatch EncodeDarknetYOLO(InputBatch input_batch, * \todo Once model_backend exposes support for explicit asynchronous * invocations, this class won't be able to simply use the Transform base class. */ -class DarknetYOLOTrainer +class DarknetYOLOBackendTrainingWrapper : public neural_net::Transform { public: // Uses base_learning_rate and max_iterations to determine the learning-rate // schedule. - DarknetYOLOTrainer(std::shared_ptr impl, - float base_learning_rate, int max_iterations) + DarknetYOLOBackendTrainingWrapper( + std::shared_ptr impl, float base_learning_rate, + int max_iterations) : impl_(std::move(impl)), base_learning_rate_(base_learning_rate), max_iterations_(max_iterations) {} @@ -63,6 +81,26 @@ class DarknetYOLOTrainer int max_iterations_ = 0; }; +/** + * Wrapper that integrates a darknet-yolo model_backend into an inference + * pipeline. + * + * \todo Once model_backend exposes support for explicit asynchronous + * invocations, this class won't be able to simply use the Transform base class. + */ +class DarknetYOLOBackendInferenceWrapper + : public neural_net::Transform { + public: + DarknetYOLOBackendInferenceWrapper( + std::shared_ptr impl) + : impl_(std::move(impl)) {} + + EncodedBatch Invoke(EncodedInputBatch input_batch) override; + + private: + std::shared_ptr impl_; +}; + /** * Wrapper for a darknet-yolo model_backend that publishes checkpoints. */ @@ -130,6 +168,19 @@ class DarknetYOLOModelTrainer : public ModelTrainer { DarknetYOLOModelTrainer(const DarknetYOLOCheckpoint& checkpoint, neural_net::compute_context* context); + std::shared_ptr> + AsTrainingBatchPublisher(std::unique_ptr training_data, + size_t batch_size, int offset) override; + + std::shared_ptr> + AsInferenceBatchPublisher(std::unique_ptr test_data, + size_t batch_size, float confidence_threshold, + float iou_threshold) override; + + InferenceOutputBatch DecodeOutputBatch(EncodedBatch batch, + float confidence_threshold, + float iou_threshold) override; + std::shared_ptr>> AsCheckpointPublisher() override; @@ -141,6 +192,8 @@ class DarknetYOLOModelTrainer : public ModelTrainer { private: Config config_; std::shared_ptr backend_; + std::shared_ptr training_augmenter_; + std::shared_ptr inference_augmenter_; }; } // namespace object_detection diff --git a/src/toolkits/object_detection/od_model_trainer.cpp b/src/toolkits/object_detection/od_model_trainer.cpp index d0bb8d8a78..05bd473c7e 100644 --- a/src/toolkits/object_detection/od_model_trainer.cpp +++ b/src/toolkits/object_detection/od_model_trainer.cpp @@ -11,6 +11,7 @@ namespace turi { namespace object_detection { using neural_net::image_augmenter; +using neural_net::labeled_image; using neural_net::Publisher; DataBatch DataIterator::Next() { @@ -21,13 +22,23 @@ DataBatch DataIterator::Next() { } InputBatch DataAugmenter::Invoke(DataBatch data_batch) { - image_augmenter::result result = - impl_->prepare_images(std::move(data_batch.examples)); - InputBatch batch; batch.iteration_id = data_batch.iteration_id; + + // Extract the image sizes from data_batch.examples before we move the + // examples into the augmenter. + batch.image_sizes.resize(data_batch.examples.size()); + auto extract_size = [](const labeled_image &example) { + return std::make_pair(example.image.m_height, example.image.m_width); + }; + std::transform(data_batch.examples.begin(), data_batch.examples.end(), + batch.image_sizes.begin(), extract_size); + + image_augmenter::result result = + impl_->prepare_images(std::move(data_batch.examples)); batch.images = std::move(result.image_batch); batch.annotations = std::move(result.annotations_batch); + return batch; } diff --git a/src/toolkits/object_detection/od_model_trainer.hpp b/src/toolkits/object_detection/od_model_trainer.hpp index 3915706f3a..b96d8a0cb7 100644 --- a/src/toolkits/object_detection/od_model_trainer.hpp +++ b/src/toolkits/object_detection/od_model_trainer.hpp @@ -44,20 +44,32 @@ struct InputBatch { /** The raw annotations from the DataBatch. */ std::vector> annotations; + + /** + * The original height and width of each image, used to scale bounding-box + * predictions. + */ + std::vector> image_sizes; }; /** Represents one batch of data, in a possibly model-specific format. */ struct EncodedInputBatch { int iteration_id = 0; + + // TODO: Migrate to neural_net::float_array_map neural_net::shared_float_array images; neural_net::shared_float_array labels; // The raw annotations are preserved to support evaluation, comparing raw // annotations against model predictions. std::vector> annotations; + + // The original image sizes are preserved to support prediction. + std::vector> image_sizes; }; /** Represents the raw output of an object-detection model. */ +// TODO: Adopt EncodedBatch instead. struct TrainingOutputBatch { int iteration_id = 0; neural_net::shared_float_array loss; @@ -69,6 +81,29 @@ struct TrainingProgress { float smoothed_loss = 0.f; }; +/** + * Represents the immediate (model-specific) input or output of a model backend, + * using the float_array_map representation. + */ +struct EncodedBatch { + int iteration_id = 0; + + neural_net::float_array_map encoded_data; + + std::vector> annotations; + std::vector> image_sizes; +}; + +/** Represents one batch of inference results, in a generic format. */ +struct InferenceOutputBatch { + int iteration_id = 0; + + std::vector> predictions; + + std::vector> annotations; + std::vector> image_sizes; +}; + /** Ostensibly model-agnostic parameters for object detection. */ struct Config { /** @@ -192,6 +227,8 @@ class ProgressUpdater */ class ModelTrainer { public: + ModelTrainer() : ModelTrainer(nullptr) {} + // TODO: This class should be responsible for producing the augmenter itself. ModelTrainer(std::unique_ptr augmenter); @@ -206,6 +243,27 @@ class ModelTrainer { AsTrainingBatchPublisher(std::unique_ptr training_data, size_t batch_size, int offset); + /** + * Given a data iterator, return a publisher of inference model outputs. + * + * \todo Publish InferenceOutputBatch instead of EncodedBatch. + */ + virtual std::shared_ptr> + AsInferenceBatchPublisher(std::unique_ptr test_data, + size_t batch_size, float confidence_threshold, + float iou_threshold) = 0; + + /** + * Convert the raw output of the inference batch publisher into structured + * predictions. + * + * \todo This conversion should be incorporated into the inference pipeline + * once the backends support proper asynchronous complete handlers. + */ + virtual InferenceOutputBatch DecodeOutputBatch(EncodedBatch batch, + float confidence_threshold, + float iou_threshold) = 0; + /** Returns a publisher that can be used to request checkpoints. */ virtual std::shared_ptr>> AsCheckpointPublisher() = 0; @@ -213,6 +271,8 @@ class ModelTrainer { protected: // Used by subclasses to produce the model-specific portions of the overall // training pipeline. + // TODO: Remove this method. Just let subclasses define the entire training + // pipeline. virtual std::shared_ptr> AsTrainingBatchPublisher( std::shared_ptr> augmented_data) = 0; diff --git a/test/unity/toolkits/object_detection/test_object_detector.cxx b/test/unity/toolkits/object_detection/test_object_detector.cxx index 1edc445aa6..86ab9c5a9a 100644 --- a/test/unity/toolkits/object_detection/test_object_detector.cxx +++ b/test/unity/toolkits/object_detection/test_object_detector.cxx @@ -94,6 +94,30 @@ class mock_data_iterator: public data_iterator { size_t num_instances_ = 0; }; +// Subclass of DarknetYOLOModelTrainer that mocks out the decoding of +// inference batches. +class TestDarknetYOLOModelTrainer : public DarknetYOLOModelTrainer { + public: + using decode_output_batch_call = std::function; + + TestDarknetYOLOModelTrainer(const DarknetYOLOCheckpoint& checkpoint, + neural_net::compute_context* context) + : DarknetYOLOModelTrainer(checkpoint, context) {} + + InferenceOutputBatch DecodeOutputBatch(EncodedBatch batch, + float confidence_threshold, + float iou_threshold) override { + TS_ASSERT(!decode_output_batch_calls_.empty()); + decode_output_batch_call expected_call = + std::move(decode_output_batch_calls_.front()); + decode_output_batch_calls_.pop_front(); + return expected_call(std::move(batch), confidence_threshold, iou_threshold); + } + + mutable std::deque decode_output_batch_calls_; +}; + // Subclass of object_detector that mocks out the methods that inject the // object_detector dependencies. class test_object_detector: public object_detector { @@ -109,6 +133,11 @@ class test_object_detector: public object_detector { const Config& config, const std::string& pretrained_model_path, int random_seed, std::unique_ptr context)>; + using create_inference_trainer_call = + std::function( + const Checkpoint& checkpoint, + std::unique_ptr context)>; + using perform_evaluation_call = std::function; @@ -161,6 +190,16 @@ class test_object_detector: public object_detector { std::move(context)); } + std::unique_ptr create_inference_trainer( + const Checkpoint& checkpoint, + std::unique_ptr context) const override { + TS_ASSERT(!create_inference_trainer_calls_.empty()); + create_inference_trainer_call expected_call = + std::move(create_inference_trainer_calls_.front()); + create_inference_trainer_calls_.pop_front(); + return expected_call(checkpoint, std::move(context)); + } + variant_type perform_evaluation(gl_sframe data, std::string metric, std::string output_type, float confidence_threshold, @@ -192,6 +231,8 @@ class test_object_detector: public object_detector { mutable std::deque create_iterator_calls_; mutable std::deque create_compute_context_calls_; mutable std::deque create_trainer_calls_; + mutable std::deque + create_inference_trainer_calls_; mutable std::deque perform_evaluation_calls_; mutable std::deque convert_yolo_to_annotations_calls_; @@ -318,6 +359,12 @@ BOOST_AUTO_TEST_CASE(test_object_detector_iterate_training) { [&](const image_augmenter::options& opts) { return std::move(mock_augmenter); }); + mock_context->create_augmenter_calls_.push_back( + [&](const image_augmenter::options& opts) { + // The ModelTrainer will instantiate an augmenter for inference, but + // we won't use it in this test. + return std::unique_ptr(new mock_image_augmenter); + }); mock_context->create_object_detector_calls_.push_back( [&](int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out, const float_array_map& config, @@ -395,6 +442,12 @@ BOOST_AUTO_TEST_CASE(test_object_detector_init_training) { return std::move(mock_augmenter); }; mock_context->create_augmenter_calls_.push_back(create_augmenter_impl); + mock_context->create_augmenter_calls_.push_back( + [&](const image_augmenter::options& opts) { + // The ModelTrainer will instantiate an augmenter for inference, but + // we won't use it in this test. + return std::unique_ptr(new mock_image_augmenter); + }); // We'll provide this path for the "mlmodel_path" option. When the // object_detector attempts to initialize weights from that path, just return @@ -521,6 +574,12 @@ BOOST_AUTO_TEST_CASE(test_object_detector_finalize_training) { [&](const image_augmenter::options& opts) { return std::move(mock_augmenter); }); + mock_context->create_augmenter_calls_.push_back( + [&](const image_augmenter::options& opts) { + // The ModelTrainer will instantiate an augmenter for inference, but + // we won't use it in this test. + return std::unique_ptr(new mock_image_augmenter); + }); mock_context->create_object_detector_calls_.push_back( [&](int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out, const float_array_map& config, @@ -698,6 +757,12 @@ BOOST_AUTO_TEST_CASE(test_object_detector_auto_split) { return std::move(mock_augmenter); }; mock_context->create_augmenter_calls_.push_back(create_augmenter_impl); + mock_context->create_augmenter_calls_.push_back( + [&](const image_augmenter::options& opts) { + // The ModelTrainer will instantiate an augmenter for inference, but + // we won't use it in this test. + return std::unique_ptr(new mock_image_augmenter); + }); // We'll provide this path for the "mlmodel_path" option. When the // object_detector attempts to initialize weights from that path, just return @@ -919,6 +984,12 @@ BOOST_AUTO_TEST_CASE(test_object_detector_predict) { return std::move(mock_augmenter); }; mock_context->create_augmenter_calls_.push_back(create_augmenter_impl); + mock_context->create_augmenter_calls_.push_back( + [&](const image_augmenter::options& opts) { + // The ModelTrainer will instantiate an augmenter for inference, but + // we won't use it in this test. + return std::unique_ptr(new mock_image_augmenter); + }); // We'll provide this path for the "mlmodel_path" option. When the // object_detector attempts to initialize weights from that path, just return @@ -984,6 +1055,12 @@ BOOST_AUTO_TEST_CASE(test_object_detector_predict) { model.create_compute_context_calls_.push_back(create_compute_context_impl); mock_augmenter.reset(new mock_image_augmenter); + mock_context->create_augmenter_calls_.push_back( + [&](const image_augmenter::options& opts) { + // The ModelTrainer will instantiate an augmenter for training, but + // we won't use it in this test. + return std::unique_ptr(new mock_image_augmenter); + }); mock_context->create_augmenter_calls_.push_back(create_augmenter_impl); mock_nn_model.reset(new mock_model_backend); @@ -1104,22 +1181,15 @@ BOOST_AUTO_TEST_CASE(test_object_detector_predict) { mock_augmenter->prepare_images_calls_.push_back(prepare_images_impl); mock_nn_model->predict_calls_.push_back(predict_impl); - auto empty_next_batch_impl = [=](size_t batch_size) { - std::vector result(0); - return result; - }; - // Send empty batch to match perform_predict() implementation - mock_iterator->next_batch_calls_.push_back(empty_next_batch_impl); - - auto convert_yolo_impl = - [](const neural_net::float_array& yolo_map, - const std::vector>& anchor_boxes, - float min_confidence) { - ASSERT_EQ(yolo_map.dim(), 3); + auto decode_output_batch_impl = + [](EncodedBatch batch, float confidence_threshold, float iou_threshold) { + const float_array& yolo_map = batch.encoded_data["output"]; + ASSERT_EQ(yolo_map.dim(), 4); const size_t* const shape = yolo_map.shape(); - const size_t output_height = shape[0]; - const size_t output_width = shape[1]; - const size_t num_channels = shape[2]; + const size_t batch_size = shape[0]; + const size_t output_height = shape[1]; + const size_t output_width = shape[2]; + const size_t num_channels = shape[3]; size_t num_anchor_boxes = 2; static constexpr size_t NUM_CLASSES = 2; @@ -1131,24 +1201,40 @@ BOOST_AUTO_TEST_CASE(test_object_detector_predict) { ASSERT_EQ(output_height, OUTPUT_GRID_SIZE); ASSERT_EQ(output_width, OUTPUT_GRID_SIZE); - std::vector result; - for (size_t j = 0; j < num_prediction_instances; ++j) { - // The actual contents of the image and the annotations are irrelevant - // for the purposes of this test. But encode the batch index and row - // index into the bounding box so that we can verify this data is - // passed into the image augmenter. - image_annotation annotation; - annotation.bounding_box.x = 0; - annotation.bounding_box.y = j; - result.push_back(annotation); + InferenceOutputBatch result; + result.predictions.resize(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_prediction_instances; ++j) { + // The actual contents of the image and the annotations are + // irrelevant for the purposes of this test. But encode the batch + // index and row index into the bounding box so that we can verify + // this data is passed into the image augmenter. + image_annotation annotation; + annotation.bounding_box.x = 0; + annotation.bounding_box.y = j; + result.predictions[i].push_back(annotation); + } + result.image_sizes.emplace_back(416, 416); } return result; }; + model.create_inference_trainer_calls_.emplace_back( + [=](const Checkpoint& checkpoint, + std::unique_ptr context) { + std::unique_ptr result; + DarknetYOLOCheckpoint darknet_yolo_checkpoint(checkpoint.config(), + checkpoint.weights()); + result.reset(new TestDarknetYOLOModelTrainer(darknet_yolo_checkpoint, + context.get())); + + // Two calls for two batches + for (size_t i = 0; i < test_batch_size; ++i) { + result->decode_output_batch_calls_.push_back( + decode_output_batch_impl); + } + return result; + }); - // Two calls for two batches - for (size_t i = 0; i < 2 * test_batch_size; ++i) { - model.convert_yolo_to_annotations_calls_.push_back(convert_yolo_impl); - } std::map opts{{"confidence_threshold",0.25}, {"iou_threshold",0.45}}; variant_type result_variant = model.predict(data, opts); gl_sarray result = variant_get_value(result_variant);