diff --git a/src/cudafeat/Makefile b/src/cudafeat/Makefile index 92e266d96d3..900a458813d 100644 --- a/src/cudafeat/Makefile +++ b/src/cudafeat/Makefile @@ -8,10 +8,12 @@ ifeq ($(CUDA), true) TESTFILES = ifeq ($(CUDA), true) - OBJFILES += feature-window-cuda.o feature-spectral-cuda.o feature-online-cmvn-cuda.o \ - online-ivector-feature-cuda-kernels.o online-ivector-feature-cuda.o \ - online-cuda-feature-pipeline.o feature-online-batched-cmvn-cuda.o \ - feature-online-batched-cmvn-cuda-kernels.o + OBJFILES += feature-window-cuda.o feature-spectral-cuda.o \ + feature-online-cmvn-cuda.o feature-online-batched-spectral-cuda.o \ + feature-spectral-batched-kernels.o \ + online-ivector-feature-cuda-kernels.o online-ivector-feature-cuda.o \ + online-cuda-feature-pipeline.o feature-online-batched-cmvn-cuda.o \ + feature-online-batched-cmvn-cuda-kernels.o endif LIBNAME = kaldi-cudafeat diff --git a/src/cudafeat/feature-online-batched-spectral-cuda.cc b/src/cudafeat/feature-online-batched-spectral-cuda.cc new file mode 100644 index 00000000000..b8b23f87d56 --- /dev/null +++ b/src/cudafeat/feature-online-batched-spectral-cuda.cc @@ -0,0 +1,256 @@ +// cudafeature/feature-online-batched-spectral-cuda.cc +// +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens, Levi Barnes +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cudafeat/feature-online-batched-spectral-cuda.h" +#include "cudafeat/feature-spectral-batched-kernels.h" + +namespace kaldi { + +CudaOnlineBatchedSpectralFeatures::CudaOnlineBatchedSpectralFeatures( + const CudaSpectralFeatureOptions &opts, int32_t max_chunk_frames, + int32_t num_channels, int32_t max_lanes) + : MfccComputer(opts.mfcc_opts), + cu_lifter_coeffs_(lifter_coeffs_), + cu_dct_matrix_(dct_matrix_), + window_function_(opts.mfcc_opts.frame_opts), + max_chunk_frames_(max_chunk_frames), + num_channels_(num_channels), + max_lanes_(max_lanes) { + KALDI_ASSERT(max_chunk_frames > 0); + const MelBanks *mel_banks = GetMelBanks(1.0); + const std::vector>> &bins = + mel_banks->GetBins(); + int size = bins.size(); + bin_size_ = size; + std::vector offsets(size), sizes(size); + std::vector vecs(size); + cu_vecs_ = new CuVector[size]; + for (int i = 0; i < bins.size(); i++) { + cu_vecs_[i].Resize(bins[i].second.Dim(), kUndefined); + cu_vecs_[i].CopyFromVec(bins[i].second); + vecs[i] = cu_vecs_[i].Data(); + sizes[i] = cu_vecs_[i].Dim(); + offsets[i] = bins[i].first; + } + offsets_ = static_cast( + CuDevice::Instantiate().Malloc(size * sizeof(int32))); + sizes_ = static_cast( + CuDevice::Instantiate().Malloc(size * sizeof(int32))); + vecs_ = static_cast( + CuDevice::Instantiate().Malloc(size * sizeof(float *))); + + CU_SAFE_CALL(cudaMemcpyAsync(vecs_, &vecs[0], size * sizeof(float *), + cudaMemcpyHostToDevice, cudaStreamPerThread)); + CU_SAFE_CALL(cudaMemcpyAsync(offsets_, &offsets[0], size * sizeof(int32), + cudaMemcpyHostToDevice, cudaStreamPerThread)); + CU_SAFE_CALL(cudaMemcpyAsync(sizes_, &sizes[0], size * sizeof(int32), + cudaMemcpyHostToDevice, cudaStreamPerThread)); + CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread)); + + const FrameExtractionOptions frame_opts = opts.mfcc_opts.frame_opts; + frame_length_ = frame_opts.WindowSize(); + padded_length_ = frame_opts.PaddedWindowSize(); + fft_length_ = padded_length_ / 2; // + 1; + fft_batch_size_ = 800; + + // place holders to get strides for cufft. these will be resized correctly + // later. The +2 for cufft/fftw requirements of an extra element at the end. + // Turning off stride because cufft seems buggy with a stride + int32_t fft_num_frames = + max_chunk_frames + + (fft_batch_size_ - max_chunk_frames_ % fft_batch_size_); + cu_windows_.Resize(fft_num_frames * max_lanes_, padded_length_, kUndefined, + kStrideEqualNumCols); + //+1 matches cufft/fftw requirements + tmp_window_.Resize(fft_num_frames * max_lanes_, padded_length_ + 2, + kUndefined, kStrideEqualNumCols); + + // Pre-allocated memory for power spectra + power_spectrum_.Resize(max_chunk_frames_ * max_lanes_, padded_length_ / 2 + 1, + kUndefined); + raw_log_energies_.Resize(max_lanes_, max_chunk_frames_, kUndefined); + cu_mel_energies_.Resize(max_chunk_frames_ * max_lanes_, bin_size_, + kUndefined); + int32_t max_stash_size = + 2 * (frame_opts.WindowSize() / 2 + frame_opts.WindowShift()); + stash_.Resize(num_channels_, max_stash_size); + + stride_ = cu_windows_.Stride(); + tmp_stride_ = tmp_window_.Stride(); + + cufftPlanMany(&plan_, 1, &padded_length_, NULL, 1, stride_, NULL, 1, + tmp_stride_ / 2, CUFFT_R2C, fft_batch_size_); + cufftSetStream(plan_, cudaStreamPerThread); + cumfcc_opts_ = opts; +} + +// ExtractWindow extracts a windowed frame of waveform with a power-of-two, +// padded size. It does mean subtraction, pre-emphasis and dithering as +// requested. +void CudaOnlineBatchedSpectralFeatures::ExtractWindowsBatched( + const LaneDesc *lanes, int32_t num_lanes, + const CuMatrixBase &wave) { + CU_SAFE_CALL(cudaGetLastError()); + const FrameExtractionOptions &opts = GetFrameOptions(); + cuda_extract_window( + lanes, num_lanes, max_chunk_frames_, opts.WindowShift(), + opts.WindowSize(), opts.PaddedWindowSize(), opts.snip_edges, wave.Data(), + wave.Stride(), cu_windows_.Data(), opts.WindowSize(), + cu_windows_.Stride(), stash_.Data(), stash_.NumCols(), stash_.Stride()); +} + +void CudaOnlineBatchedSpectralFeatures::ProcessWindowsBatched( + const LaneDesc *lanes, int32_t num_lanes, + const FrameExtractionOptions &opts, + CuMatrixBase *log_energy_pre_window) { + int fft_num_frames = cu_windows_.NumRows(); + KALDI_ASSERT(fft_num_frames % fft_batch_size_ == 0); + + cuda_process_window( + lanes, num_lanes, max_chunk_frames_, frame_length_, opts.dither, + std::numeric_limits::epsilon(), opts.remove_dc_offset, + opts.preemph_coeff, NeedRawLogEnergy(), log_energy_pre_window->Data(), + log_energy_pre_window->Stride(), window_function_.cu_window.Data(), + tmp_window_.Data(), tmp_window_.Stride(), cu_windows_.Data(), + cu_windows_.Stride()); + + CU_SAFE_CALL(cudaGetLastError()); +} + +void CudaOnlineBatchedSpectralFeatures::UpdateStashBatched( + const LaneDesc *lanes, int32_t num_lanes, + const CuMatrixBase &wave) { + KALDI_ASSERT(stash_.NumCols() < 1024); + + cuda_update_stash(lanes, num_lanes, wave.Data(), wave.Stride(), stash_.Data(), + stash_.NumCols(), stash_.Stride()); +} + +void CudaOnlineBatchedSpectralFeatures::ComputeFinalFeaturesBatched( + const LaneDesc *lanes, int32_t num_lanes, BaseFloat vtln_wrap, + CuMatrix *cu_signal_log_energy, + CuMatrix *cu_features) { + MfccOptions mfcc_opts = cumfcc_opts_.mfcc_opts; + Vector tmp; + KALDI_ASSERT(mfcc_opts.htk_compat == false); + + if (num_lanes == 0) return; + + if (mfcc_opts.use_energy && !mfcc_opts.raw_energy) { + cuda_dot_log(max_chunk_frames_, num_lanes, cu_windows_.NumCols(), + cu_windows_.Data(), cu_windows_.Stride(), + cu_signal_log_energy->Data(), cu_signal_log_energy->Stride()); + CU_SAFE_CALL(cudaGetLastError()); + } + + // make sure a reallocation hasn't changed these + KALDI_ASSERT(cu_windows_.Stride() == stride_); + KALDI_ASSERT(tmp_window_.Stride() == tmp_stride_); + + // Perform FFTs in batches of fft_size. This reduces memory requirements + for (int idx = 0; idx < max_chunk_frames_ * num_lanes; + idx += fft_batch_size_) { + CUFFT_SAFE_CALL(cufftExecR2C( + plan_, cu_windows_.Data() + cu_windows_.Stride() * idx, + (cufftComplex *)(tmp_window_.Data() + tmp_window_.Stride() * idx))); + } + + // Compute Power spectrum + cuda_power_spectrum(max_chunk_frames_, num_lanes, padded_length_, + tmp_window_.Data(), tmp_window_.Stride(), + power_spectrum_.Data(), power_spectrum_.Stride(), + cumfcc_opts_.use_power); + CU_SAFE_CALL(cudaGetLastError()); + + // mel banks + int num_bins = bin_size_; + cuda_mel_banks_compute(lanes, num_lanes, max_chunk_frames_, num_bins, + std::numeric_limits::epsilon(), offsets_, + sizes_, vecs_, power_spectrum_.Data(), + power_spectrum_.Stride(), cu_mel_energies_.Data(), + cu_mel_energies_.Stride(), cumfcc_opts_.use_log_fbank); + CU_SAFE_CALL(cudaGetLastError()); + + // dct transform + if (cumfcc_opts_.use_dct) { + if (cu_features->NumRows() > cu_mel_energies_.NumRows()) { + CuSubMatrix cu_feats_sub(*cu_features, 0, + cu_mel_energies_.NumRows(), 0, + cu_features->NumCols()); + cu_feats_sub.AddMatMat(1.0, cu_mel_energies_, kNoTrans, cu_dct_matrix_, + kTrans, 0.0); + } else { + cu_features->AddMatMat(1.0, cu_mel_energies_, kNoTrans, cu_dct_matrix_, + kTrans, 0.0); + } + cuda_apply_lifter_and_floor_energy( + lanes, num_lanes, max_chunk_frames_, cu_features->NumCols(), + mfcc_opts.cepstral_lifter, mfcc_opts.use_energy, mfcc_opts.energy_floor, + cu_signal_log_energy->Data(), cu_signal_log_energy->Stride(), + cu_lifter_coeffs_.Data(), cu_features->Data(), cu_features->Stride()); + + } else { + cudaMemcpyAsync(cu_features->Data(), cu_mel_energies_.Data(), + sizeof(BaseFloat) * max_chunk_frames_ * num_lanes * + cu_features->Stride(), + cudaMemcpyDeviceToDevice, cudaStreamPerThread); + } + CU_SAFE_CALL(cudaGetLastError()); +} + +void CudaOnlineBatchedSpectralFeatures::ComputeFeaturesBatched( + const LaneDesc *lanes, int32_t n_lanes, + const CuMatrixBase &cu_wave_in, BaseFloat sample_freq, + BaseFloat vtln_warp, CuMatrix *cu_feats_out) { + // Note: cu_features is actually a rank 3 tensor. + // channels x frames x features + // it is currently represented as a matrix with n_channels*n_frames rows and + // n_features cols + const FrameExtractionOptions &frame_opts = GetFrameOptions(); + + if (frame_opts.dither != 0.0f) { + // Calling cu-rand directly + // CuRand class works on CuMatrixBase which must + // assume that the matrix is part of a larger matrix + // Doing this directly avoids unecessary memory copies + CURAND_SAFE_CALL( + curandGenerateNormal(GetCurandHandle(), tmp_window_.Data(), + tmp_window_.NumRows() * tmp_window_.Stride(), + 0.0 /*mean*/, 1.0 /*stddev*/)); + } + + // Extract Windows + ExtractWindowsBatched(lanes, n_lanes, cu_wave_in); + + UpdateStashBatched(lanes, n_lanes, cu_wave_in); + + // Process Windows + ProcessWindowsBatched(lanes, n_lanes, frame_opts, &raw_log_energies_); + + // Compute Features + ComputeFinalFeaturesBatched(lanes, n_lanes, 1.0, &raw_log_energies_, + cu_feats_out); +} + +CudaOnlineBatchedSpectralFeatures::~CudaOnlineBatchedSpectralFeatures() { + delete[] cu_vecs_; + CuDevice::Instantiate().Free(vecs_); + CuDevice::Instantiate().Free(offsets_); + CuDevice::Instantiate().Free(sizes_); + cufftDestroy(plan_); +} +} // namespace kaldi diff --git a/src/cudafeat/feature-online-batched-spectral-cuda.h b/src/cudafeat/feature-online-batched-spectral-cuda.h new file mode 100644 index 00000000000..e4549c7177c --- /dev/null +++ b/src/cudafeat/feature-online-batched-spectral-cuda.h @@ -0,0 +1,113 @@ +// cudafeat/feature-batched-spectral-cuda.h +// +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens, Levi Barnes +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_CUDAFEAT_FEATURE_BATCHED_SPECTRAL_CUDA_H_ +#define KALDI_CUDAFEAT_FEATURE_BATCHED_SPECTRAL_CUDA_H_ + +#if HAVE_CUDA == 1 +#include +#endif + +#include "cudafeat/feature-spectral-cuda.h" +#include "cudafeat/feature-window-cuda.h" +#include "cudafeat/lane-desc.h" +#include "cudamatrix/cu-matrix.h" +#include "cudamatrix/cu-vector.h" +#include "feat/feature-fbank.h" +#include "feat/feature-mfcc.h" + +namespace kaldi { +// This class implements MFCC and Fbank computation in CUDA. +// It handles batched input. +// It takes input from device memory and outputs to +// device memory. It also does no synchronization. +class CudaOnlineBatchedSpectralFeatures : public MfccComputer { + public: + void ComputeFeatures(const CuVectorBase &cu_wave, + BaseFloat sample_freq, BaseFloat vtln_warp, + CuMatrix *cu_features) { + // Non-batched processing not allowed from + // CudaOnlineBatchedSpectralFeatures + KALDI_ASSERT(false); + } + + void ComputeFeaturesBatched(const LaneDesc *lanes, int32_t n_lanes, + const CuMatrixBase &cu_wave_in, + BaseFloat sample_freq, BaseFloat vtln_warp, + CuMatrix *cu_feats_out); + + CudaOnlineBatchedSpectralFeatures(const CudaSpectralFeatureOptions &opts, + int32_t max_chunk_frames, + int32_t num_channels, int32_t max_lanes); + ~CudaOnlineBatchedSpectralFeatures(); + CudaSpectralFeatureOptions cumfcc_opts_; + int32 Dim() + // The dimension of the output is different for MFCC and Fbank. + // This returns the appropriate value depending on the feature + // extraction algorithm + { + if (cumfcc_opts_.feature_type == MFCC) return MfccComputer::Dim(); + // If we're running fbank, we need to set the dimension right + else + return cumfcc_opts_.mfcc_opts.mel_opts.num_bins + + (cumfcc_opts_.mfcc_opts.use_energy ? 1 : 0); + } + + private: + + void ExtractWindowsBatched(const LaneDesc *lanes, int32_t num_lanes, + const CuMatrixBase &wave); + + void UpdateStashBatched(const LaneDesc *lanes, int32_t num_lanes, + const CuMatrixBase &wave); + + void ProcessWindowsBatched(const LaneDesc *lanes, int32_t num_lanes, + const FrameExtractionOptions &opts, + CuMatrixBase *log_energy_pre_window); + + void ComputeFinalFeaturesBatched(const LaneDesc *lanes, int32_t num_lanes, + BaseFloat vtln_wrap, + CuMatrix *cu_signal_log_energy, + CuMatrix *cu_features); + + CuVector cu_lifter_coeffs_; + CuMatrix cu_windows_; + CuMatrix tmp_window_, cu_mel_energies_; + CuMatrix cu_dct_matrix_; + CuMatrix stash_; + CuMatrix power_spectrum_; + CuMatrix raw_log_energies_; + + int frame_length_, padded_length_, fft_length_, fft_batch_size_; + cufftHandle plan_; + CudaFeatureWindowFunction window_function_; + + int bin_size_; + int32 *offsets_, *sizes_; + CuVector *cu_vecs_; + float **vecs_; + + // for sanity checking cufft + int32_t stride_, tmp_stride_; + + int32_t max_chunk_frames_; + int32_t num_channels_; + int32_t max_lanes_; +}; +} // namespace kaldi + +#endif diff --git a/src/cudafeat/feature-spectral-batched-kernels.cu b/src/cudafeat/feature-spectral-batched-kernels.cu new file mode 100644 index 00000000000..488948f36d9 --- /dev/null +++ b/src/cudafeat/feature-spectral-batched-kernels.cu @@ -0,0 +1,553 @@ +// cudafeature/feature-spectral-batched_kernels.cu +// +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens, Levi Barnes +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 +#include +#include +#endif + +#include "cudafeat/feature-spectral-batched-kernels.h" +#include "cudafeat/feature-spectral-cuda.h" +#include "cudafeat/lane-desc.h" +#include "cudamatrix/cu-rand.h" + +namespace kaldi { + +// Mimics the functionality of mel_banks_compute_kernel +// (found in feature-spectral-cuda.cu). The 3rd +// dimension (z) of the block grid gives the hardware +// "lane". lanes tells us which channel is in this lane, +// what current frame and sample are processed in this +// batch, etc. +__global__ void batched_mel_banks_compute_kernel( + const LaneDesc *lanes, int32_t n_lanes, int32_t max_chunk_frames, + float energy_floor, int32 *offsets, int32 *sizes, float **vecs, + const float *feats, int32_t ldf, float *mels, int32_t ldm, bool use_log) { + // Specialize WarpReduce for type float + typedef cub::WarpReduce WarpReduce; + // Allocate WarpReduce shared memory for 8 warps + __shared__ typename WarpReduce::TempStorage temp_storage[8]; + + // warp will work together to compute sum + int tid = threadIdx.x; + int wid = threadIdx.y; + // blocks in the x dimension take different bins + int bin = blockIdx.x; + // frame is a combination of blocks in the y dimension and threads in the y + // dimension + int frame = blockIdx.y * blockDim.y + threadIdx.y; + int lane = blockIdx.z; + + LaneDesc desc = lanes[lane]; + int num_frames = desc.num_chunk_frames; + + // TODO get offsets, sizes, and vecs from laneInfo? + int offset = offsets[bin]; + int size = sizes[bin]; + const float *v = vecs[bin]; + const float *w = feats + frame * ldf + lane * max_chunk_frames * ldf + offset; + + // perfom local sum + float sum = 0; + if (frame < num_frames) { // exclude frames beyond the end + for (int idx = tid; idx < size; idx += 32) { + sum += v[idx] * w[idx]; + } + } + + // Sum in cub + sum = WarpReduce(temp_storage[wid]).Sum(sum); + if (tid == 0 && frame < num_frames) { + if (use_log) { + // avoid log of zero + if (sum < energy_floor) sum = energy_floor; + float val = logf(sum); + mels[lane * max_chunk_frames * ldm + frame * ldm + bin] = val; + } else { + mels[lane * max_chunk_frames * ldm + frame * ldm + bin] = sum; + } + } +} +// Mimics the functionality of apply_lifter_and_floor_energy +// (found in feature-spectral-cuda.cu) for a chunk of data +// from several audio channels. The 2nd dimension +// (y) of the block grid gives the hardware "lane". +// The lanes array tells us which channel is in this lane, +// what current frame and sample are processed in this +// batch, etc. +__global__ void batched_apply_lifter_and_floor_energy_kernel( + const LaneDesc *lanes, int32_t n_lanes, int32_t max_chunk_frames, + int num_cols, float cepstral_lifter, bool use_energy, float energy_floor, + float *log_energy, int32_t ldl, float *lifter_coeffs, float *features, + int32_t ldf) { + int thread_id = threadIdx.x; + int frame = blockIdx.x; + int lane = blockIdx.y; + + LaneDesc desc = lanes[lane]; + if (frame > desc.num_chunk_frames) return; + + float *feats = features + frame * ldf + lane * max_chunk_frames * ldf; + + // apply lifter coefficients + if (cepstral_lifter != 0.0f) { + for (int c = thread_id; c < num_cols; c += CU1DBLOCK) { + float lift = lifter_coeffs[c]; + float f = feats[c]; + feats[c] = f * lift; + } + } + + // Thread 0 for each frame will apply energy + if (use_energy && thread_id == 0) { + float energy = log_energy[frame + ldl * lane]; + float log_energy_floor = log(energy_floor); + + if (energy_floor > 0.0f && energy < log_energy_floor) { + energy = log_energy_floor; + } + feats[0] = energy; + } +} +// Mimics the functionality of process_window_kernel +// (found in feature-spectral-cuda.cu) for a chunk of data +// from several audio channels. The 2nd dimension +// (y) of the block grid gives the hardware "lane". +// The lanes array tells us which channel is in this lane, +// what current frame and sample are processed in this +// batch, etc. +__global__ void batched_process_window_kernel( + const LaneDesc *lanes, int32_t n_lanes, int32_t max_chunk_frames, + int frame_length, float dither, float energy_floor, bool remove_dc_offset, + float preemph_coeff, bool need_raw_log_energy, float *log_energy_pre_window, + int32_t lde, const float *windowing, float *tmp_windows, int32_t ldt, + float *windows, int32_t ldw) { + // Specialize WarpReduce for type float + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int thread_id = threadIdx.x; + int row = blockIdx.x; + int lane = blockIdx.y; + + LaneDesc desc = lanes[lane]; + if (row >= desc.num_chunk_frames) return; + + float *tmp_window = tmp_windows + row * ldt + lane * max_chunk_frames * ldt; + float *window = windows + row * ldw + lane * max_chunk_frames * ldw; + + __shared__ float ssum; + + float sum = 0; + float wdot = 0; + + for (int idx = thread_id; idx < frame_length; idx += CU1DBLOCK) { + // tmp_window contains optional dither. Apply that on read. + float wval = window[idx]; + if (dither != 0.0f) { + wval += tmp_window[idx] * dither; + } + // compute local sum for removing dc offset + sum += wval; + // compute dot product for log energy + wdot += wval * wval; + + float windowing_mul = 1; + if (remove_dc_offset == false && preemph_coeff == 0.0f) { + // we are done here so set windowing multiplication on write. + windowing_mul = windowing[idx]; + } + + // write dithered output + window[idx] = wval * windowing_mul; + } + __syncthreads(); + if (remove_dc_offset) { + // we will recompute this below + wdot = 0.0f; + // use cub to reduce + sum = BlockReduce(temp_storage).Sum(sum); + + // broadcast sum to entire block + if (thread_id == 0) ssum = sum; + __syncthreads(); + + sum = -ssum / frame_length; + for (int idx = thread_id; idx < frame_length; idx += CU1DBLOCK) { + float windowing_mul = 1; + float *out = window; + if (preemph_coeff == 0.0f) { + // we are done here so apply windowing + windowing_mul = windowing[idx]; + } else { + // write to temp window as we will copy back into window + // when doing pre-emphasis + out = tmp_window; + } + // updated window value + float wval = window[idx] + sum; + + // compute new dot product with dc offset removed + wdot += wval * wval; + + // write output + out[idx] = wval * windowing_mul; + } + } + __syncthreads(); + + // if pointer is not NULL we will set energy to either + // the computed energy or 0 depending on need_raw_log_energy + if (log_energy_pre_window != NULL) { + float energy = 0.0f; + + if (need_raw_log_energy) { + // must sync to use retemp_storage + if (remove_dc_offset) __syncthreads(); + // use cub to reduce + wdot = BlockReduce(temp_storage).Sum(wdot); + + energy = max(wdot, energy_floor); + } + + if (thread_id == 0) { + log_energy_pre_window[row + lane * lde] = log(energy); + } + } + + // TODO this could be more efficient using shared memory instead of + // tmp_window. + if (preemph_coeff != 0.0f) { + // wait for tmp_window to be computed + __threadfence(); + __syncthreads(); + // starting thread idx at 0 to keep writes aligned. + // unaligned reads are less painful then unaligned writes + for (int idx = thread_id; idx < frame_length; idx += CU1DBLOCK) { + float wval = tmp_window[idx]; + float prev_window = wval; + if (idx > 0) { + prev_window = tmp_window[idx - 1]; + } + // use __fmul_rn to match CPU + // window[idx] = (wval - preemph_coeff*prev_window) * windowing[idx]; + window[idx] = + (wval - __fmul_rn(preemph_coeff, prev_window)) * windowing[idx]; + } + } +} + +__host__ __device__ inline int32 FirstSampleOfFrame(int32 frame, + int32 frame_shift, + int32 window_size, + bool snip_edges) { + if (snip_edges) { + return frame * frame_shift; + } else { + int32 midpoint_of_frame = frame_shift * frame + frame_shift / 2, + beginning_of_frame = midpoint_of_frame - window_size / 2; + return beginning_of_frame; + } +} + +// Mimics the functionality of extract_window_kernel +// (found in feature-spectral-cuda.cu) for a chunk of data +// from several audio channels. The 2nd dimension +// (y) of the block grid gives the hardware "lane". +// The lanes array tells us which channel is in this lane, +// what current frame and sample are processed in this +// batch, etc. +// Extra samples not processed in this chunk are moved to +// "stash" where they'll be pre-pended to the next chunk +// from this channel +__global__ void batched_extract_window_kernel( + const LaneDesc *lanes, int32_t num_lanes, int32 frame_shift, + int32 frame_length, int32 frame_length_padded, bool snip_edges, + const BaseFloat __restrict__ *wave, int32_t ldw, + BaseFloat *__restrict__ windows, int32_t window_size, int32_t wlda, + BaseFloat *stash, int32_t ssize, int32_t lds) { + // local frame number + int32_t fidx = blockIdx.x; + int32_t tidx = threadIdx.x; + int32_t lane = blockIdx.y; + + const LaneDesc desc = lanes[lane]; + ChannelId channel = desc.channel; + // This is the current sample that is pointed to by wave + int32_t current_sample = desc.current_sample; + // current frame we are computing in global space + int32_t current_frame = desc.current_frame; + + // global frame number computed by this block + int32_t global_frame = current_frame + fidx; + + int32_t num_chunk_samples = desc.num_chunk_samples; + + if (fidx > desc.num_chunk_frames) return; + + // offset input/output by channels or lanes + stash = stash + channel * lds; + wave = wave + lane * ldw; + BaseFloat *window = windows + fidx * wlda + gridDim.x * lane * wlda; + + // This is the first sample needed to compute this frame + int32_t start_sample = + FirstSampleOfFrame(global_frame, frame_shift, window_size, snip_edges); + + // Sample offset is how much we have to offset our index + // into the input wave. + int32_t wave_start = start_sample - current_sample; + + // wave_start and wave_end are start and end indexes into 'wave', for the + // piece of wave that we're trying to extract. + int32_t wave_end = wave_start + frame_length; + + // wave_start will be negative on successive chunks as we need + // to grab context from stash. + if ((current_frame > 0 || wave_start >= 0) && wave_end <= num_chunk_samples) { + // the normal case-- no edge effects to consider. + for (int i = tidx; i < frame_length; i += blockDim.x) { + int32_t widx = wave_start + i; + BaseFloat val; + if (widx >= 0) { + val = wave[widx]; + } else { + // widx is negative. Add it to the stash size + // to get the correct index into the stash + int32_t sidx = ssize + widx; + val = stash[sidx]; + } + window[i] = val; + } + } else { + // Deal with any end effects by reflection, if needed. This code will only + // be reached for about two frames per utterance, so we don't concern + // ourselves excessively with efficiency. + for (int s = tidx; s < frame_length; s += blockDim.x) { + int32 s_in_wave = wave_start + s; + while (s_in_wave < 0 || s_in_wave >= num_chunk_samples) { + // reflect around the beginning or end of the wave. + // e.g. -1 -> 0, -2 -> 1. + // dim -> dim - 1, dim + 1 -> dim - 2. + // the code supports repeated reflections, although this + // would only be needed in pathological cases. + if (s_in_wave < 0) + s_in_wave = -s_in_wave - 1; + else + s_in_wave = 2 * num_chunk_samples - 1 - s_in_wave; + } + window[s] = wave[s_in_wave]; + } + } + + if (frame_length_padded > frame_length) { + for (int i = frame_length + tidx; i < frame_length_padded; + i += blockDim.x) { + window[i] = 0.0f; + } + } +} +// For each frame +// compute logf(dot(signal_frame, signal_frame)) +// This is the batched version. The y-dimension of the grid +// give the lane number +__global__ void batched_dot_log_kernel(int32_t max_chunk_frames, + int32_t frame_length, + float *signal_frame, int32_t lds, + float *signal_log_energy, int32_t lde) { + // Specialize WarpReduce for type float + typedef cub::BlockReduce BlockReduce; + // Allocate WarpReduce shared memory for 8 warps + __shared__ typename BlockReduce::TempStorage temp_storage; + + int32_t frame = blockIdx.x; + int32_t tid = threadIdx.x; + int32_t lane = blockIdx.y; + + float *in = signal_frame + frame * lds + max_chunk_frames * lane * lds; + float sum = 0; + + // preform local dot product + for (int32_t i = tid; i < frame_length; i += blockDim.x) { + float val = in[i]; + sum += val * val; + } + + // reduce using cub + sum = BlockReduce(temp_storage).Sum(sum); + + if (threadIdx.x == 0) { + signal_log_energy[frame + lane * lde] = logf(sum); + } +} + +__global__ void batched_update_stash_kernel(const LaneDesc *lanes, + int32_t num_lanes, + const BaseFloat *wave, int32_t ldw, + BaseFloat *stash, int32_t num_stash, + int32_t lds) { + int32_t lane = blockIdx.x; + LaneDesc desc = lanes[lane]; + int32_t channel = desc.channel; + int32_t num_chunk_samples = desc.num_chunk_samples; + + // offset memory by lane or channel + wave = wave + lane * ldw; + stash = stash + channel * lds; + + int32_t sample_offset = num_chunk_samples - num_stash; + for (int i = threadIdx.x; i < num_stash; i += blockDim.x) { + int32_t idx = sample_offset + i; + + float val; + if (idx < 0) { + // data must come from old stash + val = stash[idx + num_stash]; + } else { + // data comes from new wave + val = wave[idx]; + } + + __syncthreads(); + + stash[i] = val; + } +} +// Each threadblock computes a different row of the matrix. +// Threads in the same block compute the row collaboratively. +// This kernel must be called out of place (A_in!=A_out). +__global__ void power_spectrum_kernel(int row_length, const float *A_in, int32_t ldi, + float *A_out, int32_t ldo, + bool use_power) { + int thread_id = threadIdx.x; + int block_id = blockIdx.x; + const float *Ar = A_in + block_id * ldi; + float *Aw = A_out + block_id * ldo; + + int half_length = row_length / 2; + for (int idx = thread_id; idx < half_length; idx += CU1DBLOCK) { + // ignore special case + if (idx == 0) continue; + + float2 val = reinterpret_cast(Ar)[idx]; + float ret = val.x * val.x + val.y * val.y; + if (use_power) { + Aw[idx] = ret; + } else { + Aw[idx] = sqrtf(ret); + } + } + + // handle special case + if (threadIdx.x == 0) { + float real = Ar[0]; + // cufft puts this at the end, this is different than kaldi does with its + // own + // internal implementation + float im = Ar[row_length]; + + if (use_power) { + Aw[0] = real * real; + Aw[half_length] = im * im; + } else { + Aw[0] = fabs(real); + Aw[half_length] = fabs(im); + } + } +} + + +void cuda_power_spectrum(int32_t max_chunk_frames, int32_t num_lanes, + int row_length, const float *A_in, int32_t ldi, + float *A_out, int32_t ldo, bool use_power) { + power_spectrum_kernel<<>>( + row_length, A_in, ldi, A_out, ldo, use_power); +} + +void cuda_mel_banks_compute(const LaneDesc *lanes, int32_t num_lanes, + int32_t max_chunk_frames, int32_t num_bins, + float energy_floor, int32 *offsets, int32 *sizes, + float **vecs, const float *feats, int32_t ldf, + float *mels, int32_t ldm, bool use_log) { + dim3 Bl(32, 8); + dim3 Gr(num_bins, (max_chunk_frames + Bl.y - 1) / Bl.y, num_lanes); + batched_mel_banks_compute_kernel<<>>( + lanes, num_lanes, max_chunk_frames, energy_floor, offsets, sizes, vecs, + feats, ldf, mels, ldm, use_log); +} + +void cuda_apply_lifter_and_floor_energy(const LaneDesc *lanes, + int32_t num_lanes, + int32_t max_chunk_frames, int num_cols, + float cepstral_lifter, bool use_energy, + float energy_floor, float *log_energy, + int32_t ldl, float *lifter_coeffs, + float *features, int32_t ldf) { + dim3 Gr(max_chunk_frames, num_lanes); + batched_apply_lifter_and_floor_energy_kernel<<>>( + lanes, num_lanes, max_chunk_frames, num_cols, cepstral_lifter, use_energy, + energy_floor, log_energy, ldl, lifter_coeffs, features, ldf); +} + +void cuda_process_window(const LaneDesc *lanes, int32_t num_lanes, + int32_t max_chunk_frames, int frame_length, + float dither, float energy_floor, + bool remove_dc_offset, float preemph_coeff, + bool need_raw_log_energy, float *log_energy_pre_window, + int32_t lde, const float *windowing, + float *tmp_windows, int32_t ldt, float *windows, + int32_t ldw) { + dim3 Gr(max_chunk_frames, num_lanes); + int Bl = CU1DBLOCK; + batched_process_window_kernel<<>>( + lanes, num_lanes, max_chunk_frames, frame_length, dither, energy_floor, + remove_dc_offset, preemph_coeff, need_raw_log_energy, + log_energy_pre_window, lde, windowing, tmp_windows, ldt, windows, ldw); +} + +void cuda_extract_window(const LaneDesc *lanes, int32_t num_lanes, + int32_t max_chunk_frames, int32 frame_shift, + int32 frame_length, int32 frame_length_padded, + bool snip_edges, const float *wave, int32_t ldw, + float *windows, int32_t window_size, int32_t wlda, + BaseFloat *stash, int32_t ssize, int32_t lds) { + dim3 Gr(max_chunk_frames, num_lanes); + int Bl = CU1DBLOCK; + batched_extract_window_kernel<<>>( + lanes, num_lanes, frame_shift, frame_length, frame_length_padded, + snip_edges, wave, ldw, windows, window_size, wlda, stash, ssize, lds); +} + +void cuda_dot_log(int32_t max_chunk_frames, int32_t num_lanes, + int32_t frame_length, float *signal_frame, int32_t lds, + float *signal_log_energy, int32_t lde) { + dim3 Gr(max_chunk_frames, num_lanes); + batched_dot_log_kernel<<>>(max_chunk_frames, frame_length, + signal_frame, lds, + + signal_log_energy, lde); +} + +void cuda_update_stash(const LaneDesc *lanes, int32_t num_lanes, + const BaseFloat *wave, int32_t ldw, BaseFloat *stash, + int32_t num_stash, int32_t lds) { + int Gr = num_lanes; + int Bl = 1024; + batched_update_stash_kernel<<>>(lanes, num_lanes, wave, ldw, stash, + num_stash, lds); +} +} // namespace kaldi diff --git a/src/cudafeat/feature-spectral-batched-kernels.h b/src/cudafeat/feature-spectral-batched-kernels.h new file mode 100644 index 00000000000..e25f86c23c2 --- /dev/null +++ b/src/cudafeat/feature-spectral-batched-kernels.h @@ -0,0 +1,69 @@ +// cudafeature/feature-spectral-batched-kernels.h +// +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens, Levi Barnes +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_CUDAFEAT_FEATURE_SPECTRAL_BATCHED_KERNELS_H_ +#define KALDI_CUDAFEAT_FEATURE_SPECTRAL_BATCHED_KERNELS_H_ + +#include "cudafeat/lane-desc.h" + +namespace kaldi { + +void cuda_power_spectrum(int32_t max_chunk_frames, int32_t num_lanes, + int row_length, const float *A_in, int32_t ldi, + float *A_out, int32_t ldo, bool use_power); + +void cuda_mel_banks_compute(const LaneDesc *lanes, int32_t n_lanes, + int32_t max_chunk_frames, int32_t num_bins, + float energy_floor, int32_t *offsets, + int32_t *sizes, float **vecs, const float *feats, + int32_t ldf, float *mels, int32_t ldm, + bool use_log); + +void cuda_apply_lifter_and_floor_energy(const LaneDesc *lanes, + int32_t num_lanes, + int32_t max_chunk_frames, int num_cols, + float cepstral_lifter, bool use_energy, + float energy_floor, float *log_energy, + int32_t ldl, float *lifter_coeffs, + float *features, int32_t ldf); + +void cuda_process_window(const LaneDesc *lanes, int32_t num_lanes, + int32_t max_chunk_frames, int frame_length, + float dither, float energy_floor, + bool remove_dc_offset, float preemph_coeff, + bool need_raw_log_energy, float *log_energy_pre_window, + int32_t lde, const float *windowing, + float *tmp_windows, int32_t ldt, float *windows, + int32_t ldw); + +void cuda_extract_window(const LaneDesc *lanes, int32_t num_lanes, + int32_t max_chunk_frames, int32_t frame_shift, + int32_t frame_length, int32_t frame_length_padded, + bool snip_edges, const float *wave, int32_t ldw, + float *windows, int32_t window_size, int32_t wlda, + float *stash, int32_t ssize, int32_t lds); + +void cuda_dot_log(int32_t max_chunk_frames, int32_t num_lanes, + int32_t frame_length, float *signal_frame, int32_t lds, + float *signal_log_energy, int32_t lde); + +void cuda_update_stash(const LaneDesc *lanes, int32_t num_lanes, + const float *wave, int32_t ldw, float *stash, + int32_t num_stash, int32_t lds); + +} // namespace kaldi +#endif diff --git a/src/cudafeatbin/Makefile b/src/cudafeatbin/Makefile index 150b41f087e..7ed9abf18fc 100644 --- a/src/cudafeatbin/Makefile +++ b/src/cudafeatbin/Makefile @@ -10,7 +10,7 @@ BINFILES = ifeq ($(CUDA), true) BINFILES += compute-mfcc-feats-cuda apply-cmvn-online-cuda compute-online-feats-cuda compute-fbank-feats-cuda \ - apply-batched-cmvn-online-cuda + apply-batched-cmvn-online-cuda compute-mfcc-online-batched-cuda endif OBJFILES = diff --git a/src/cudafeatbin/compute-mfcc-online-batched-cuda.cc b/src/cudafeatbin/compute-mfcc-online-batched-cuda.cc new file mode 100644 index 00000000000..dbe86aa3d68 --- /dev/null +++ b/src/cudafeatbin/compute-mfcc-online-batched-cuda.cc @@ -0,0 +1,377 @@ +// cudafeat/compute-mfcc-online-batched-cuda.cc +// +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens, Levi Barnes +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "base/kaldi-common.h" +#include "cudafeat/feature-online-batched-spectral-cuda.h" +#include "cudamatrix/cu-matrix.h" +#include "cudamatrix/cu-vector.h" +#include "feat/feature-window.h" +#include "feat/wave-reader.h" +#include "util/common-utils.h" + +using namespace kaldi; + +// This class stores data for input and output for this binary. +// We will read/write slices of this input/output in an online +// fasion. +struct UtteranceDataHandle { + std::string utt; + WaveData wave_data_in; + Matrix feats_out; + Vector ivector_out; + int32_t num_samples; + int32_t current_sample; + int32_t num_frames; + int32_t current_frame; + + UtteranceDataHandle(const std::string &utt, WaveData &wave_data, + const FrameExtractionOptions &opts, int32_t feat_dim) + : utt(utt) { + current_sample = 0; + current_frame = 0; + num_samples = wave_data.Data().NumCols(); + + wave_data_in = wave_data; + + num_frames = NumFrames(num_samples, opts, true); + feats_out.Resize(num_frames, feat_dim); + } +}; + +int main(int argc, char *argv[]) { + try { + typedef kaldi::int32 int32; + using namespace kaldi; + const char *usage = + "Compute online mfcc features.\n\n" + "This binary processes the audio in chunks of samples. " + "In addition, the computation is batched and done in CUDA. " + "This binary is not intended to demonstrate how to achieve " + "maximum performance. Instead it is intended to demonstrate " + "how to use the class CudaOnlineBatchedSpectralFeatures and provide " + "a mechanism to test this class independently.\n\n" + "Usage: ./compute-mfcc-batched-cuda --batch-size=50 " + " " + " \n"; + + int32_t num_channels = 50; + int32_t num_lanes = 10; + int32_t max_chunk_length_samples = 10000; + BaseFloat sample_freq = -1; + BaseFloat vtln_warp = 1.0; + + ParseOptions po(usage); + MfccOptions feature_opts; + feature_opts.Register(&po); + + po.Register("num-channels", &num_channels, + "The number of" + " channels used for compute"); + po.Register("batch-size", &num_lanes, + "The number of chunks from" + " audio cuts processed in a single batch"); + po.Register("chunk-length", &max_chunk_length_samples, + "The length of a chunk" + " of audio in terms of samples."); + + CuDevice::RegisterDeviceOptions(&po); + RegisterCuAllocatorOptions(&po); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + KALDI_ASSERT(num_channels >= num_lanes); + + g_cuda_allocator.SetOptions(g_allocator_options); + CuDevice::Instantiate().SelectGpuId("yes"); + CuDevice::Instantiate().AllowMultithreading(); + + LaneDesc *d_lanes = (LaneDesc *)CuDevice::Instantiate().Malloc( + sizeof(LaneDesc) * num_lanes); + + std::string wav_rspecifier = po.GetArg(1), + feature_wspecifier = po.GetArg(2); + + SequentialTableReader reader(wav_rspecifier); + BaseFloatMatrixWriter feature_writer; + + if (!feature_writer.Open(feature_wspecifier)) { + KALDI_ERR << "Could not initialize feature_writer with wspecifier " + << feature_wspecifier; + exit(1); + } + + std::vector free_channels; + + // list of audio handles to be processed + std::vector data_handles; + // maps currently active channels to their handle index + std::map channel_to_handle_idx; + // Index of next unprocessed audio file + int32_t not_done_idx = 0; + int32_t num_done = 0, tot_t = 0; + + int32_t feat_dim = feature_opts.num_ceps; + + // compute the maximum chunk length in frames + const FrameExtractionOptions &frame_opts = feature_opts.frame_opts; + + // span between consective features in output + int32_t shift = frame_opts.WindowShift(); + int32_t max_chunk_frames = (max_chunk_length_samples + shift - 1) / shift; + + int32_t ldf = max_chunk_frames; + + CudaOnlineBatchedSpectralFeatures mfcc(feature_opts, max_chunk_frames, + num_channels, num_lanes); + + CuMatrix d_batch_wav_in(num_lanes, max_chunk_length_samples, + kUndefined, kStrideEqualNumCols); + CuMatrix d_batch_feats_out(num_lanes * ldf, feat_dim, kUndefined, + kStrideEqualNumCols); + + // host matrices for staging data in pinned memory before copying down + Matrix h_batch_wav_in(num_lanes, max_chunk_length_samples, + kUndefined, kStrideEqualNumCols); + Matrix h_batch_feats_out(num_lanes * ldf, feat_dim, kUndefined, + kStrideEqualNumCols); + + size_t wave_in_size = + num_lanes * max_chunk_length_samples * sizeof(BaseFloat); + size_t feats_out_size = num_lanes * ldf * feat_dim * sizeof(BaseFloat); + ; + + cudaHostRegister(h_batch_wav_in.Data(), wave_in_size, 0); + cudaHostRegister(h_batch_feats_out.Data(), feats_out_size, 0); + + CU_SAFE_CALL(cudaGetLastError()); + + std::vector num_frames_computed(num_lanes); + + std::vector lanes; + + for (int32_t i = 0; i < num_channels; i++) { + free_channels.push_back(i); + } + + sample_freq = frame_opts.samp_freq; + + double duration = 0.0; + // preload data for batching + for (; !reader.Done(); reader.Next()) { + std::string utt = reader.Key(); + WaveData &wave_data = reader.Value(); + duration += wave_data.Duration(); + data_handles.emplace_back(utt, wave_data, frame_opts, feat_dim); + } + + // Timing just compute, we don't want to include + // disc I/O in this timer. + Timer timer; + // A single pass through this loop will fill the + // batch with new work if any is available. + // Then process a single iteration of batched cmvn. + // At exit each process handle should have valid data + // in feats_out. + while (true) { + // This loop will fill the batch by pulling from the + // data_handles vector for new work + while (lanes.size() < num_lanes && not_done_idx < data_handles.size()) { + UtteranceDataHandle &handle = data_handles[not_done_idx]; + int32_t num_samples = handle.num_samples; + num_samples = std::min(max_chunk_length_samples, num_samples); + + // grab a free channel + int32_t channel = free_channels.back(); + free_channels.pop_back(); + + LaneDesc desc; + desc.channel = channel; + desc.current_sample = 0; + desc.num_chunk_samples = num_samples; + desc.first = true; + desc.last = num_samples == handle.num_samples; + desc.current_frame = 0; + desc.num_chunk_frames = NumFrames(num_samples, frame_opts, desc.last); + lanes.push_back(desc); + + channel_to_handle_idx[channel] = not_done_idx; + not_done_idx++; + } + + // No work in lanes, this means corpus is finished + if (lanes.size() == 0) break; + + cudaMemcpyAsync(d_lanes, &lanes[0], sizeof(LaneDesc) * lanes.size(), + cudaMemcpyHostToDevice, cudaStreamPerThread); + + // This loop copies a slice from each active audio cut + // down to the device for processing + for (int32_t lane = 0; lane < lanes.size(); lane++) { + LaneDesc &desc = lanes[lane]; + int32_t channel = desc.channel; + UtteranceDataHandle &handle = + data_handles[channel_to_handle_idx[channel]]; + + int32_t current_sample = handle.current_sample; + int32_t num_samples = desc.num_chunk_samples; + + // Create a subvector for this slice of data + SubVector p_wave( + h_batch_wav_in.Row(lane).Range(0, num_samples)); + + SubVector h_wave(handle.wave_data_in.Data().Row(0).Range( + current_sample, num_samples)); + + // Copy slice into pinned memory + p_wave.CopyFromVec(h_wave); + } + + // use a memcpy here to avoid a possible 2D memcpy which is very slow + cudaMemcpyAsync(d_batch_wav_in.Data(), h_batch_wav_in.Data(), + wave_in_size, cudaMemcpyHostToDevice, + cudaStreamPerThread); + CU_SAFE_CALL(cudaGetLastError()); + + // process batch + mfcc.ComputeFeaturesBatched(d_lanes, lanes.size(), d_batch_wav_in, + sample_freq, vtln_warp, &d_batch_feats_out); + + // copy feats to host + cudaMemcpyAsync(h_batch_feats_out.Data(), d_batch_feats_out.Data(), + feats_out_size, cudaMemcpyDeviceToHost, + cudaStreamPerThread); + CU_SAFE_CALL(cudaGetLastError()); + + // wait for copy to host to complete before copying to final + // location. For additional optimization you should double buffer + // h_batch_* arrays so that the GPU isn't idle while the CPU + // is copying data into final destination. We don't envision + // people using this binary directly and thus won't do that + // here to keep the API example more concise. + cudaStreamSynchronize(cudaStreamPerThread); + + // At this time the batch is computed. We now need to copy each slice + // into the appropriate output buffer + for (int lane = 0; lane < lanes.size(); lane++) { + LaneDesc &desc = lanes[lane]; + ChannelId channel = desc.channel; + + int32_t current_frame = desc.current_frame; + int32_t num_chunk_frames = desc.num_chunk_frames; + if (num_chunk_frames == 0) continue; + + UtteranceDataHandle &handle = + data_handles[channel_to_handle_idx[channel]]; + + // Copy slice back up + CuSubMatrix A(d_batch_feats_out.Range( + lane * max_chunk_frames, num_chunk_frames, 0, feat_dim)); + SubMatrix B(handle.feats_out.Range( + current_frame, num_chunk_frames, 0, feat_dim)); + + B.CopyFromMat(A); + } // end copy to host loop + + // For each lane check if compute is done. + // If completed, remove from channel list and + // free the channel. + for (int32_t lane = 0; lane < lanes.size();) { + LaneDesc &desc = lanes[lane]; + ChannelId channel = desc.channel; + UtteranceDataHandle &handle = + data_handles[channel_to_handle_idx[channel]]; + + int32_t &chunk_samples = desc.num_chunk_samples; + // advance by samples processed in last chunk + handle.current_sample += chunk_samples; + + desc.current_sample += desc.num_chunk_samples; + desc.num_chunk_samples = std::min( + max_chunk_length_samples, handle.num_samples - desc.current_sample); + desc.current_frame = NumFrames(desc.current_sample, frame_opts, false); + int32_t num_samples = desc.current_sample + desc.num_chunk_samples; + int32_t num_frames = NumFrames(num_samples, frame_opts, desc.last); + desc.num_chunk_frames = + std::min(max_chunk_frames, num_frames - desc.current_frame); + // read if we said last chunk was last + bool finished = desc.last; + + // compute next batch of samples + int32_t num_remaining_samples = + handle.num_samples - handle.current_sample; + chunk_samples = + std::min(max_chunk_length_samples, num_remaining_samples); + + int32_t num_total_samples = handle.current_sample + chunk_samples; + + desc.last = num_total_samples == handle.num_samples; + desc.first = false; + + if (finished) { + // free this channel + free_channels.push_back(channel); + // Move last lane to this lane + lanes[lane] = lanes.back(); + lanes.pop_back(); + + num_done++; + } else { + lane++; + } + } // end check if done loop + } // end while(true) + double total_time = timer.Elapsed(); + + // output all utterances. In an efficeint implementation + // this would be done on demand in a threaded manner. This + // binary is purely for checking correctness and demonstrating + // usage and thus this type of optimization is not done. + for (int i = 0; i < data_handles.size(); i++) { + UtteranceDataHandle &handle = data_handles[i]; + + tot_t += handle.feats_out.NumRows(); + feature_writer.Write(handle.utt, handle.feats_out); + } + + KALDI_LOG << "Computed Online Features for " << num_done << " files, and " + << tot_t << " frames."; + + KALDI_LOG << "Total Audio: " << duration + << " seconds, Total Time: " << total_time + << " seconds, RTFX: " << duration / total_time; + + cudaHostUnregister(h_batch_wav_in.Data()); + cudaHostUnregister(h_batch_feats_out.Data()); + + cudaDeviceSynchronize(); + cudaProfilerStop(); + + return 0; + } catch (const std::exception &e) { + std::cerr << e.what(); + return -1; + } +}