Skip to content

Commit

Permalink
training conv
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Feb 2, 2024
1 parent 6fde115 commit 6d34a93
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const
}

template <typename T_Perf>
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results) {
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32) {
perf_results.resize(1);
perf_results[0].algo = AlgoSearch<T_Perf>::DEFAULT_ALGO;
if (args.params.data_type == CUDNN_DATA_HALF) {
perf_results[0].mathType = CUDNN_TENSOR_OP_MATH;
} else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) {
perf_results[0].mathType = CUDNN_FMA_MATH;
} else {
perf_results[0].mathType = CUDNN_DEFAULT_MATH;
}
Expand All @@ -256,7 +258,7 @@ Status AlgoIterator<T_Perf>::TryAll(const CUDAExecutionProvider* provider, const

std::vector<T_Perf> perf_results;
ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault
? OnlyDefaultAlgorithm(args_, perf_results)
? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32())
: AlgoSearch<T_Perf>::FindAlgorithms(args_, provider, allocator, perf_results));
for (auto& algo_perf : perf_results) {
if (f(algo_perf) == Status::OK()) {
Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/training_ops/cuda/nn/conv_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AlgoIterator {
Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
std::function<Status(const T_Perf& perf)> f);

static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results);
static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32);

private:
const ConvArgs& args_;
Expand Down

0 comments on commit 6d34a93

Please sign in to comment.