diff --git a/.github/PERFORMANCE.md b/.github/PERFORMANCE.md index b29677c..5df6643 100644 --- a/.github/PERFORMANCE.md +++ b/.github/PERFORMANCE.md @@ -54,3 +54,10 @@ sys 3m28.465s ``` More than 2x faster for 4 threads. This is inspired by the parallelism strategy used in . + +V3 is a faster algorithm and the mt variant (with 4 threads) runs in 2.5 min: +``` +real 2m35.737s +user 10m28.019s +sys 2m42.292s +``` diff --git a/.github/SDR_scores.md b/.github/SDR_scores.md index 9ea0449..f4f80d4 100644 --- a/.github/SDR_scores.md +++ b/.github/SDR_scores.md @@ -13,10 +13,10 @@ other ==> SDR: 7.421 SIR: 11.289 ISR: 14.241 SAR: 8.179 ``` CPP inference (this codebase): ``` -vocals ==> SDR: 8.339 SIR: 18.276 ISR: 15.836 SAR: 8.346 -drums ==> SDR: 10.058 SIR: 18.596 ISR: 17.019 SAR: 10.810 -bass ==> SDR: 3.919 SIR: 12.436 ISR: 6.931 SAR: 3.182 -other ==> SDR: 7.421 SIR: 11.286 ISR: 14.252 SAR: 8.183 +vocals ==> SDR: 8.370 SIR: 18.188 ISR: 15.924 SAR: 8.475 +drums ==> SDR: 10.002 SIR: 18.571 ISR: 17.027 SAR: 10.645 +bass ==> SDR: 4.021 SIR: 12.407 ISR: 7.031 SAR: 3.223 +other ==> SDR: 7.469 SIR: 11.367 ISR: 14.186 SAR: 8.182 ``` *n.b.* for the above results, the random shift in the beginning of the song was fixed to 1337 in both PyTorch and C++. @@ -33,10 +33,10 @@ other ==> SDR: 0.168 SIR: 11.449 ISR: 0.411 SAR: -2.720 ``` CPP inference (this codebase): ``` -vocals ==> SDR: 8.395 SIR: 18.699 ISR: 16.076 SAR: 8.576 -drums ==> SDR: 9.927 SIR: 17.921 ISR: 17.518 SAR: 10.635 -bass ==> SDR: 4.519 SIR: 10.458 ISR: 8.606 SAR: 4.370 -other ==> SDR: 0.164 SIR: 11.443 ISR: 0.409 SAR: -2.713 +vocals ==> SDR: 8.395 SIR: 18.581 ISR: 16.101 SAR: 8.579 +drums ==> SDR: 9.922 SIR: 18.013 ISR: 17.477 SAR: 10.669 +bass ==> SDR: 4.523 SIR: 10.482 ISR: 8.567 SAR: 4.336 +other ==> SDR: 0.167 SIR: 11.145 ISR: 0.448 SAR: -1.238 ``` *n.b.* the "other" score will be artificially low because of the extra guitar + piano separation where there are no stems to compare to @@ -54,10 +54,36 @@ other ==> SDR: 7.384 SIR: 12.812 ISR: 12.977 SAR: 7.798 ``` CPP inference (this codebase, `demucs_ft.cpp`) ``` -vocals ==> SDR: 8.594 SIR: 19.045 ISR: 16.313 SAR: 8.617 -drums ==> SDR: 10.463 SIR: 19.782 ISR: 17.144 SAR: 11.132 -bass ==> SDR: 4.584 SIR: 9.359 ISR: 9.068 SAR: 4.885 -other ==> SDR: 7.426 SIR: 12.793 ISR: 12.975 SAR: 7.830 +vocals ==> SDR: 8.679 SIR: 18.861 ISR: 16.611 SAR: 8.664 +drums ==> SDR: 10.480 SIR: 19.898 ISR: 17.125 SAR: 11.053 +bass ==> SDR: 4.590 SIR: 9.516 ISR: 9.102 SAR: 4.935 +other ==> SDR: 7.370 SIR: 12.853 ISR: 12.926 SAR: 7.805 +``` + +### Performance of v3 (hdemucs_mmi) model + +Track 'Zeno - Signs' from MUSDB18-HQ test set + +PyTorch inference (using v3-mmi default segment length + LSTM max length of 200): +``` +vocals ==> SDR: 8.328 SIR: 18.943 ISR: 16.097 SAR: 8.563 +drums ==> SDR: 9.284 SIR: 18.123 ISR: 16.230 SAR: 10.125 +bass ==> SDR: 3.612 SIR: 10.313 ISR: 6.958 SAR: 3.077 +other ==> SDR: 7.122 SIR: 11.391 ISR: 14.363 SAR: 7.910 +``` +PyTorch inference (using v4 7.8s segment length + LSTM max length of 336): +``` +vocals ==> SDR: 8.304 SIR: 18.916 ISR: 16.087 SAR: 8.557 +drums ==> SDR: 9.279 SIR: 18.149 ISR: 16.203 SAR: 10.109 +bass ==> SDR: 3.601 SIR: 10.350 ISR: 6.971 SAR: 3.076 +other ==> SDR: 7.123 SIR: 11.373 ISR: 14.373 SAR: 7.907 +``` +CPP inference (this codebase, `demucs_v3.cpp`): +``` +vocals ==> SDR: 8.332 SIR: 18.889 ISR: 16.083 SAR: 8.557 +drums ==> SDR: 9.285 SIR: 18.242 ISR: 16.194 SAR: 10.140 +bass ==> SDR: 3.668 SIR: 10.040 ISR: 7.056 SAR: 3.210 +other ==> SDR: 7.130 SIR: 11.440 ISR: 14.257 SAR: 7.860 ``` ### Performance of multi-threaded inference diff --git a/CMakeLists.txt b/CMakeLists.txt index d9a0726..6184f79 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -104,6 +104,16 @@ target_include_directories(demucs_ft_mt.cpp.main PRIVATE vendor/libnyquist/inclu target_include_directories(demucs_ft_mt.cpp.main PRIVATE cli-apps) target_link_libraries(demucs_ft_mt.cpp.main demucs.cpp.lib libnyquist) +add_executable(demucs_v3.cpp.main "cli-apps/demucs_v3.cpp") +target_include_directories(demucs_v3.cpp.main PRIVATE vendor/libnyquist/include) +target_include_directories(demucs_v3.cpp.main PRIVATE cli-apps) +target_link_libraries(demucs_v3.cpp.main demucs.cpp.lib libnyquist) + +add_executable(demucs_v3_mt.cpp.main "cli-apps/demucs_v3_mt.cpp") +target_include_directories(demucs_v3_mt.cpp.main PRIVATE vendor/libnyquist/include) +target_include_directories(demucs_v3_mt.cpp.main PRIVATE cli-apps) +target_link_libraries(demucs_v3_mt.cpp.main demucs.cpp.lib libnyquist) + file(GLOB SOURCES_TO_LINT "src/*.cpp" "src/*.hpp" "cli-apps/*.cpp" "cli-apps/*.hpp") # add target to run standard lints and formatters diff --git a/README.md b/README.md index e35a937..2ed0f40 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # demucs.cpp -C++17 library that implements the inference of the [Demucs v4 hybrid transformer model](https://github.com/facebookresearch/demucs), a PyTorch neural network for music demixing. +C++17 library that implements inference for the [Demucs v4 hybrid transformer](https://github.com/facebookresearch/demucs) and [Demucs v3 hybrid](https://github.com/facebookresearch/demucs/tree/v3) models, which are high-performance PyTorch neural networks for music source separation. -It uses only the standard library and the header-only library [Eigen](https://eigen.tuxfamily.org/index.php?title=Main_Page) as dependencies, making it suitable to compile and run on many platforms. It was designed for low-memory environments by sacrificing the speed of the Torch implementation. +It uses only the standard library (C++17) and the header-only library [Eigen](https://eigen.tuxfamily.org/index.php?title=Main_Page) as dependencies, making it suitable to compile and run on many platforms. It was designed for low-memory environments by sacrificing the speed of the Torch implementation. Demucs.cpp powers my websites (, ) and now my new Android app [Music Demixer](https://play.google.com/store/apps/details?id=com.freemusicdemixer.pro) to bring Demucs to your pocket! @@ -12,9 +12,11 @@ See my other project [umx.cpp](https://github.com/sevagh/umx.cpp) for a similar ### Library design -It uses [libnyquist](https://github.com/ddiakopoulos/libnyquist) to load audio files, the [ggml](https://github.com/ggerganov/ggml) file format to serialize the PyTorch weights of `htdemucs`, `htdemucs_6s`, and `htdemucs_ft` (4-source, 6-source, fine-tuned) to a binary file format, and [Eigen](https://eigen.tuxfamily.org/index.php?title=Main_Page) (+ OpenMP) to implement the inference. There are also programs for multi-threaded Demucs inference using C++11's `std::thread`. +The inference library (in `src/`) uses the [ggml](https://github.com/ggerganov/ggml) file format to serialize the PyTorch weights of `hdemucs_mmi`, `htdemucs`, `htdemucs_6s`, and `htdemucs_ft` (v3, v4 4-source, v4 6-source, v4 fine-tuned) to a binary file format, and [Eigen](https://eigen.tuxfamily.org/index.php?title=Main_Page) to implement the inference (with OpenMP as a requirement). -**All Hybrid-Transformer weights** (4-source, 6-source, fine-tuned) are supported. See the [Convert weights](#convert-weights) section below. Demixing quality is nearly identical to PyTorch as shown in the [SDR scores doc](./.github/SDR_scores.md). +The cli programs (in `cli-apps/`) additionally use [libnyquist](https://github.com/ddiakopoulos/libnyquist) to read and write audio files, and the multithreaded cli programs use C++11's `std::thread`. + +**All Hybrid-Transformer weights** (4-source, 6-source, fine-tuned) are supported. See the [Convert weights](#convert-weights) section below. Inference for the **Demucs v3 Hybrid model weights** `hdemucs_mmi` is also supported. Demixing quality is practically identical to PyTorch as shown in the [SDR scores doc](./.github/SDR_scores.md). ### Directory structure @@ -23,8 +25,10 @@ It uses [libnyquist](https://github.com/ddiakopoulos/libnyquist) to load audio f 1. `demucs_ft.cpp.main`: run all four fine-tuned models for `htdemucs_ft` inference, same as the BagOfModels idea of PyTorch Demucs 1. `demucs_mt.cpp.main`: run a single model, multi-threaded 1. `demucs_ft_mt.cpp.main`: run all four fine-tuned models, multi-threaded +1. `demucs_v3.cpp.main`: run a single model for v3 `hdemucs_mmi` +1. `demucs_v3_mt.cpp.main`: run a single model for v3 `hdemucs_mmi`, multi-threaded -See the [PERFORMANCE doc](./.github/PERFORMANCE.md) for details on multi-threading, external BLAS libraries, etc.. +See the [PERFORMANCE doc](./.github/PERFORMANCE.md) for time measurements, benchmarks, details on multi-threading, external BLAS libraries, etc. ## Instructions @@ -45,10 +49,6 @@ $ sudo apt-get install gcc g++ cmake clang-tools libopenblas0-openmp libopenblas Compile with CMake: ``` $ mkdir -p build && cd build && cmake .. && make -j16 -libdemucs.cpp.lib.a <--- library -demucs.cpp.main <--- single-model (4s, 6s, ft) -demucs_ft.cpp.main <--- bag of ft models -demucs.cpp.test <--- unit tests ``` ### Convert weights @@ -62,7 +62,7 @@ $ mamba activate demucscpp $ python -m pip install -r ./scripts/requirements.txt ``` -Dump Demucs weights to ggml file, with flag `--six-source` for the 6-source variant, and all of `--ft-drums, --ft-vocals, --ft-bass, --ft-other` for the fine-tuned models: +Dump Demucs weights to ggml file, with flag `--six-source` for the 6-source variant, all of `--ft-drums, --ft-vocals, --ft-bass, --ft-other` for the fine-tuned models, and `--v3` for the v3 model: ``` $ python ./scripts/convert-pth-to-ggml.py ./ggml-demucs ... @@ -76,14 +76,15 @@ Done. Output file: ggml-demucs/ggml-model-htdemucs-4s-f16.bin All supported models would look like this: ``` -$ ls ../ggml-demucs/ -total 133M - 81M Jan 10 22:40 ggml-model-htdemucs-4s-f16.bin - 53M Jan 10 22:41 ggml-model-htdemucs-6s-f16.bin - 81M Jan 10 22:41 ggml-model-htdemucs_ft_drums-4s-f16.bin - 81M Jan 10 22:43 ggml-model-htdemucs_ft_bass-4s-f16.bin - 81M Jan 10 22:43 ggml-model-htdemucs_ft_other-4s-f16.bin - 81M Jan 10 22:43 ggml-model-htdemucs_ft_vocals-4s-f16.bin +$ ls ./ggml-demucs/ +total 613M +160M May 5 14:38 ggml-model-hdemucs_mmi-v3-f16.bin + 53M May 5 16:50 ggml-model-htdemucs-6s-f16.bin + 81M May 5 16:50 ggml-model-htdemucs_ft_vocals-4s-f16.bin + 81M May 5 16:50 ggml-model-htdemucs_ft_bass-4s-f16.bin + 81M May 5 16:50 ggml-model-htdemucs_ft_drums-4s-f16.bin + 81M May 5 16:50 ggml-model-htdemucs_ft_other-4s-f16.bin + 81M May 5 16:51 ggml-model-htdemucs-4s-f16.bin ``` ### Run demucs.cpp diff --git a/cli-apps/demucs_v3.cpp b/cli-apps/demucs_v3.cpp new file mode 100644 index 0000000..0e9077b --- /dev/null +++ b/cli-apps/demucs_v3.cpp @@ -0,0 +1,232 @@ +#include "dsp.hpp" +#include "model.hpp" +#include "tensor.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace demucscpp_v3; +using namespace demucscpp; +using namespace nqr; + +static Eigen::MatrixXf load_audio_file(std::string filename) +{ + // load a wav file with libnyquist + std::shared_ptr fileData = std::make_shared(); + + NyquistIO loader; + + loader.Load(fileData.get(), filename); + + if (fileData->sampleRate != demucscpp::SUPPORTED_SAMPLE_RATE) + { + std::cerr << "[ERROR] demucs.cpp only supports the following sample " + "rate (Hz): " + << SUPPORTED_SAMPLE_RATE << std::endl; + exit(1); + } + + std::cout << "Input samples: " + << fileData->samples.size() / fileData->channelCount << std::endl; + std::cout << "Length in seconds: " << fileData->lengthSeconds << std::endl; + std::cout << "Number of channels: " << fileData->channelCount << std::endl; + + if (fileData->channelCount != 2 && fileData->channelCount != 1) + { + std::cerr << "[ERROR] demucs.cpp only supports mono and stereo audio" + << std::endl; + exit(1); + } + + // number of samples per channel + size_t N = fileData->samples.size() / fileData->channelCount; + + // create a struct to hold two float vectors for left and right channels + Eigen::MatrixXf ret(2, N); + + if (fileData->channelCount == 1) + { + // Mono case + for (size_t i = 0; i < N; ++i) + { + ret(0, i) = fileData->samples[i]; // left channel + ret(1, i) = fileData->samples[i]; // right channel + } + } + else + { + // Stereo case + for (size_t i = 0; i < N; ++i) + { + ret(0, i) = fileData->samples[2 * i]; // left channel + ret(1, i) = fileData->samples[2 * i + 1]; // right channel + } + } + + return ret; +} + +// write a function to write a StereoWaveform to a wav file +static void write_audio_file(const Eigen::MatrixXf &waveform, + std::string filename) +{ + // create a struct to hold the audio data + std::shared_ptr fileData = std::make_shared(); + + // set the sample rate + fileData->sampleRate = SUPPORTED_SAMPLE_RATE; + + // set the number of channels + fileData->channelCount = 2; + + // set the number of samples + fileData->samples.resize(waveform.cols() * 2); + + // write the left channel + for (long int i = 0; i < waveform.cols(); ++i) + { + fileData->samples[2 * i] = waveform(0, i); + fileData->samples[2 * i + 1] = waveform(1, i); + } + + int encoderStatus = + encode_wav_to_disk({fileData->channelCount, PCM_FLT, DITHER_TRIANGLE}, + fileData.get(), filename); + std::cout << "Encoder Status: " << encoderStatus << std::endl; +} + +int main(int argc, const char **argv) +{ + if (argc != 4) + { + std::cerr << "Usage: " << argv[0] + << " " << std::endl; + exit(1); + } + + std::cout << "demucs.cpp Main driver program" << std::endl; + + // load model passed as argument + std::string model_file = argv[1]; + + // load audio passed as argument + std::string wav_file = argv[2]; + + // output dir passed as argument + std::string out_dir = argv[3]; + + Eigen::MatrixXf audio = load_audio_file(wav_file); + Eigen::Tensor3dXf out_targets; + + // initialize a struct demucs_model + struct demucs_v3_model model + { + }; + + // debug some members of model + auto ret = load_demucs_v3_model(model_file, &model); + std::cout << "demucs_model_load returned " << (ret ? "true" : "false") + << std::endl; + if (!ret) + { + std::cerr << "Error loading model" << std::endl; + exit(1); + } + + const int nb_sources = 4; + + std::cout << "Starting Demucs v3 MMI inference" << std::endl; + + // set output precision to 3 decimal places + std::cout << std::fixed << std::setprecision(3); + + demucscpp::ProgressCallback progressCallback = + [](float progress, const std::string &log_message) + { + std::cout << "(" << std::setw(3) << std::setfill(' ') + << progress * 100.0f << "%) " << log_message << std::endl; + }; + + // create 4 audio matrix same size, to hold output + Eigen::Tensor3dXf audio_targets = + demucscpp_v3::demucs_v3_inference(model, audio, progressCallback); + + out_targets = audio_targets; + + const int nb_out_sources = nb_sources; + + for (int target = 0; target < nb_out_sources; ++target) + { + // now write the 4 audio waveforms to files in the output dir + // using libnyquist + // join out_dir with "/target_0.wav" + // using std::filesystem::path; + + std::filesystem::path p = out_dir; + // make sure the directory exists + std::filesystem::create_directories(p); + + auto p_target = p / "target_0.wav"; + + // target 0,1,2,3 map to drums,bass,other,vocals + + std::string target_name; + + switch (target) + { + case 0: + target_name = "drums"; + break; + case 1: + target_name = "bass"; + break; + case 2: + target_name = "other"; + break; + case 3: + target_name = "vocals"; + break; + case 4: + target_name = "guitar"; + break; + case 5: + target_name = "piano"; + break; + default: + std::cerr << "Error: target " << target << " not supported" + << std::endl; + exit(1); + } + + // insert target_name into the path after the digit + // e.g. target_name_0_drums.wav + p_target.replace_filename("target_" + std::to_string(target) + "_" + + target_name + ".wav"); + + std::cout << "Writing wav file " << p_target << std::endl; + + Eigen::MatrixXf target_waveform(2, audio.cols()); + + // copy the input stereo wav file into all 4 targets + for (int channel = 0; channel < 2; ++channel) + { + for (int sample = 0; sample < audio.cols(); ++sample) + { + target_waveform(channel, sample) = + out_targets(target, channel, sample); + } + } + + write_audio_file(target_waveform, p_target); + } +} diff --git a/cli-apps/demucs_v3_mt.cpp b/cli-apps/demucs_v3_mt.cpp new file mode 100644 index 0000000..61bd88a --- /dev/null +++ b/cli-apps/demucs_v3_mt.cpp @@ -0,0 +1,227 @@ +#include "dsp.hpp" +#include "model.hpp" +#include "tensor.hpp" +#include "threaded_inference.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace demucscpp_v3; +using namespace demucscpp; +using namespace nqr; + +static Eigen::MatrixXf load_audio_file(std::string filename) +{ + // load a wav file with libnyquist + std::shared_ptr fileData = std::make_shared(); + + NyquistIO loader; + + loader.Load(fileData.get(), filename); + + if (fileData->sampleRate != demucscpp::SUPPORTED_SAMPLE_RATE) + { + std::cerr << "[ERROR] demucs_mt.cpp only supports the following sample " + "rate (Hz): " + << SUPPORTED_SAMPLE_RATE << std::endl; + exit(1); + } + + std::cout << "Input samples: " + << fileData->samples.size() / fileData->channelCount << std::endl; + std::cout << "Length in seconds: " << fileData->lengthSeconds << std::endl; + std::cout << "Number of channels: " << fileData->channelCount << std::endl; + + if (fileData->channelCount != 2 && fileData->channelCount != 1) + { + std::cerr << "[ERROR] demucs_mt.cpp only supports mono and stereo audio" + << std::endl; + exit(1); + } + + // number of samples per channel + size_t N = fileData->samples.size() / fileData->channelCount; + + // create a struct to hold two float vectors for left and right channels + Eigen::MatrixXf ret(2, N); + + if (fileData->channelCount == 1) + { + // Mono case + for (size_t i = 0; i < N; ++i) + { + ret(0, i) = fileData->samples[i]; // left channel + ret(1, i) = fileData->samples[i]; // right channel + } + } + else + { + // Stereo case + for (size_t i = 0; i < N; ++i) + { + ret(0, i) = fileData->samples[2 * i]; // left channel + ret(1, i) = fileData->samples[2 * i + 1]; // right channel + } + } + + return ret; +} + +// write a function to write a StereoWaveform to a wav file +static void write_audio_file(const Eigen::MatrixXf &waveform, + std::string filename) +{ + // create a struct to hold the audio data + std::shared_ptr fileData = std::make_shared(); + + // set the sample rate + fileData->sampleRate = SUPPORTED_SAMPLE_RATE; + + // set the number of channels + fileData->channelCount = 2; + + // set the number of samples + fileData->samples.resize(waveform.cols() * 2); + + // write the left channel + for (long int i = 0; i < waveform.cols(); ++i) + { + fileData->samples[2 * i] = waveform(0, i); + fileData->samples[2 * i + 1] = waveform(1, i); + } + + int encoderStatus = + encode_wav_to_disk({fileData->channelCount, PCM_FLT, DITHER_TRIANGLE}, + fileData.get(), filename); + std::cout << "Encoder Status: " << encoderStatus << std::endl; +} + +int main(int argc, const char **argv) +{ + if (argc != 5) + { + std::cerr << "Usage: " << argv[0] + << " " + << std::endl; + exit(1); + } + + std::cout << "demucs_mt.cpp (Multi-threaded) driver program" << std::endl; + + // load model passed as argument + std::string model_file = argv[1]; + + // load audio passed as argument + std::string wav_file = argv[2]; + + // output dir passed as argument + std::string out_dir = argv[3]; + + // get num threads from user parameter argv[4] + // cast it to int + int num_threads = std::stoi(argv[4]); + + Eigen::MatrixXf audio = load_audio_file(wav_file); + Eigen::Tensor3dXf out_targets; + + // initialize a struct demucs_model + struct demucs_v3_model model + { + }; + + // debug some members of model + auto ret = load_demucs_v3_model(model_file, &model); + std::cout << "demucs_model_load returned " << (ret ? "true" : "false") + << std::endl; + if (!ret) + { + std::cerr << "Error loading model" << std::endl; + exit(1); + } + + const int nb_sources = 4; + + std::cout << "Starting Demucs v3 MMI inference" << std::endl; + + // create 4 audio matrix same size, to hold output + Eigen::Tensor3dXf audio_targets = + demucscppthreaded_v3::threaded_inference(model, audio, num_threads); + + out_targets = audio_targets; + + for (int target = 0; target < nb_sources; ++target) + { + // now write the 4 audio waveforms to files in the output dir + // using libnyquist + // join out_dir with "/target_0.wav" + // using std::filesystem::path; + + std::filesystem::path p = out_dir; + // make sure the directory exists + std::filesystem::create_directories(p); + + auto p_target = p / "target_0.wav"; + + // target 0,1,2,3 map to drums,bass,other,vocals + + std::string target_name; + + switch (target) + { + case 0: + target_name = "drums"; + break; + case 1: + target_name = "bass"; + break; + case 2: + target_name = "other"; + break; + case 3: + target_name = "vocals"; + break; + case 4: + target_name = "guitar"; + break; + case 5: + target_name = "piano"; + break; + default: + std::cerr << "Error: target " << target << " not supported" + << std::endl; + exit(1); + } + + // insert target_name into the path after the digit + // e.g. target_name_0_drums.wav + p_target.replace_filename("target_" + std::to_string(target) + "_" + + target_name + ".wav"); + + std::cout << "Writing wav file " << p_target << std::endl; + + Eigen::MatrixXf target_waveform(2, audio.cols()); + + // copy the input stereo wav file into all 4 targets + for (int channel = 0; channel < 2; ++channel) + { + for (int sample = 0; sample < audio.cols(); ++sample) + { + target_waveform(channel, sample) = + out_targets(target, channel, sample); + } + } + + write_audio_file(target_waveform, p_target); + } +} diff --git a/cli-apps/threaded_inference.hpp b/cli-apps/threaded_inference.hpp index 53a2313..d0adf1d 100644 --- a/cli-apps/threaded_inference.hpp +++ b/cli-apps/threaded_inference.hpp @@ -181,7 +181,9 @@ threaded_inference(const struct demucscpp::demucs_model &model, { // account for summing per-target by dividing by n targets, // 2 channels - final_output(t, ch, i) /= (sum_weight(i) / (2.0f * static_cast(nb_out_sources))); + final_output(t, ch, i) /= + (sum_weight(i) / + (2.0f * static_cast(nb_out_sources))); } } } @@ -190,3 +192,178 @@ threaded_inference(const struct demucscpp::demucs_model &model, return final_output; } }; // namespace demucscppthreaded + +namespace demucscppthreaded_v3 +{ +// bigger overlap from free-music-demixer +const int SAMPLE_RATE = 44100; +const float OVERLAP = 0.75; +const int OVERLAP_SAMPLES = ::floorf(SAMPLE_RATE * OVERLAP); + +Eigen::Tensor3dXf +threaded_inference(const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::MatrixXf &full_audio, int num_threads, + const std::string &prefix = "") +{ + // set output precision to 3 decimal places + std::cout << std::fixed << std::setprecision(3); + + // create vector of progresscallbacks per-thread + std::vector cbs; + for (int i = 0; i < num_threads; ++i) + { + cbs.push_back( + [i, prefix](float progress, const std::string &log_message) + { + std::cout << prefix << "[THREAD " << i << "] (" << std::setw(3) + << std::setfill(' ') << progress * 100.0f << "%) " + << log_message << std::endl; + }); + } + + // calculate segment length by dividing n_samples by num_threads + int total_length = full_audio.cols(); + int segment_length = ::ceilf((float)total_length / (float)num_threads); + + std::vector segments; + // split the full audio into segments + for (int i = 0; i < num_threads; ++i) + { + int start = i * segment_length; + int end = std::min(total_length, start + segment_length); + + // Create a new segment with padding for overlap + Eigen::MatrixXf segment = + Eigen::MatrixXf::Zero(2, end - start + 2 * OVERLAP_SAMPLES); + + // Overlap-padding for the left and right channels + // For the first segment, no padding at the start + if (i == 0) + { + segment.block(0, 0, 2, OVERLAP_SAMPLES).colwise() = + full_audio.col(0); + } + else + { + segment.block(0, 0, 2, OVERLAP_SAMPLES) = full_audio.block( + 0, start - OVERLAP_SAMPLES, 2, OVERLAP_SAMPLES); + } + + // For the last segment, no padding at the end + if (i == num_threads - 1) + { + int remaining_samples = total_length - end; + segment.block(0, end - start + OVERLAP_SAMPLES, 2, + remaining_samples) = + full_audio.block(0, end, 2, remaining_samples); + } + else + { + segment.block(0, end - start + OVERLAP_SAMPLES, 2, + OVERLAP_SAMPLES) = + full_audio.block(0, end, 2, OVERLAP_SAMPLES); + } + + // Assign the original segment data + segment.block(0, OVERLAP_SAMPLES, 2, end - start) = + full_audio.block(0, start, 2, end - start); + segments.push_back(segment); + } + + // insert parallel processing here + // pretend like segment_outs contains: + // (4, 2, segment_samples) + // which are 4 targets, stereo/2 channels, and the above segment length + // and we want this to be recombined into a single tensor + // i.e. Eigen::Tensor3dXf(4, 2, total_length) + std::vector segment_outs(num_threads); + + // This vector will hold the threads + std::vector threads; + + for (int i = 0; i < num_threads; ++i) + { + threads.emplace_back( + [&model, &segments, &segment_outs, i, &cbs]() + { + segment_outs[i] = demucscpp_v3::demucs_v3_inference( + model, segments[i], cbs[i]); + }); + } + + // Wait for all threads to finish + for (auto &thread : threads) + { + thread.join(); + } + + const int nb_out_sources = 4; + + // Calculate total output size and create the output tensor + Eigen::Tensor3dXf final_output(nb_out_sources, 2, total_length); + final_output.setZero(); + + Eigen::VectorXf ramp(segment_length); + for (int i = 0; i < segment_length; ++i) + { + ramp(i) = std::min(i + 1, segment_length - i); + } + ramp /= ramp.maxCoeff(); // Normalize the ramp + + Eigen::VectorXf sum_weight = Eigen::VectorXf::Zero(total_length); + + for (size_t i = 0; i < segment_outs.size(); ++i) + { + int segment_start = i * segment_length; + for (int t = 0; t < nb_out_sources; ++t) + { // For each target + for (int ch = 0; ch < 2; ++ch) + { // For each channel + for (int j = 0; j < segment_length + 2 * OVERLAP_SAMPLES; ++j) + { + int global_idx = segment_start + j - OVERLAP_SAMPLES; + if (global_idx >= 0 && global_idx < total_length) + { + float weight = 1.0; + // Apply ramp weights at the beginning and end of the + // segment + if (j < OVERLAP_SAMPLES) + { + weight = ramp(j); + } + else if (j >= segment_length) + { + weight = ramp(segment_length + 2 * OVERLAP_SAMPLES - + j - 1); + } + final_output(t, ch, global_idx) += + segment_outs[i](t, ch, j) * weight; + sum_weight(global_idx) += weight; + } + } + } + } + } + + // Normalize the output by the sum of weights + for (int t = 0; t < nb_out_sources; ++t) + { + for (int ch = 0; ch < 2; ++ch) + { + for (int i = 0; i < total_length; ++i) + { + if (sum_weight(i) > 0) + { + // account for summing per-target by dividing by n targets, + // 2 channels + final_output(t, ch, i) /= + (sum_weight(i) / + (2.0f * static_cast(nb_out_sources))); + } + } + } + } + + return final_output; +} +}; // namespace demucscppthreaded_v3 diff --git a/scripts/convert-pth-to-ggml.py b/scripts/convert-pth-to-ggml.py index ed76fc6..7aac848 100644 --- a/scripts/convert-pth-to-ggml.py +++ b/scripts/convert-pth-to-ggml.py @@ -12,6 +12,7 @@ DEMUCS_MODEL = "htdemucs" DEMUCS_MODEL_6S = "htdemucs_6s" +DEMUCS_V3_MMI = "hdemucs_mmi" DEMUCS_MODEL_FT = "htdemucs_ft" DEMUCS_MODEL_FT_DRUMS = "htdemucs_ft_drums" DEMUCS_MODEL_FT_BASS = "htdemucs_ft_bass" @@ -20,6 +21,7 @@ HT_HUB_PATH = "955717e8-8726e21a.th" HT_HUB_PATH_6S = "5c90dfd2-34c22ccb.th" +V3_HUB_PATH = "75fc33f5-1941ce65.th" HT_HUB_PATH_FT_DRUMS = "f7e0c4bc-ba3fe64a.th" HT_HUB_PATH_FT_BASS = "d12395a8-e57c48e6.th" HT_HUB_PATH_FT_OTHER = "92cfc3b6-ef3bcb9c.th" @@ -32,6 +34,7 @@ parser = argparse.ArgumentParser(description='Convert Demucs PyTorch models to GGML') parser.add_argument("dest_dir", type=str, help="destination path for the converted model") parser.add_argument("--six-source", default=False, action="store_true", help="convert 6s model (default: 4s)") + parser.add_argument("--v3", default=False, action="store_true", help="convert demucs v3-mmi model (default: 4s)") parser.add_argument("--ft-drums", default=False, action="store_true", help="convert fine-tuned drum model") parser.add_argument("--ft-bass", default=False, action="store_true", help="convert fine-tuned bass model") parser.add_argument("--ft-other", default=False, action="store_true", help="convert fine-tuned other model") @@ -57,13 +60,20 @@ model_name = DEMUCS_MODEL_FT_OTHER elif args.ft_vocals: model_name = DEMUCS_MODEL_FT_VOCALS + elif args.v3: + model = get_model(DEMUCS_V3_MMI) + model_name = DEMUCS_V3_MMI print(model) # get torchub path torchhub_path = Path(torch.hub.get_dir()) / "checkpoints" - suffix = "-6s" if args.six_source else "-4s" + suffix = "-4s" + if args.six_source: + suffix = "-6s" + if args.v3: + suffix = "-v3" dest_name = dir_out / f"ggml-model-{model_name}{suffix}-f16.bin" fname_inp = torchhub_path / HT_HUB_PATH @@ -77,6 +87,8 @@ fname_inp = torchhub_path / HT_HUB_PATH_FT_OTHER elif args.ft_vocals: fname_inp = torchhub_path / HT_HUB_PATH_FT_VOCALS + elif args.v3: + fname_inp = torchhub_path / V3_HUB_PATH # try to load PyTorch binary data # even though we loaded it above to print its info @@ -100,6 +112,8 @@ magic = 0x646d6334 if args.six_source: magic = 0x646d6336 + if args.v3: + magic = 0x646d6333 # fine-tuned has same magic diff --git a/scripts/demucs_pytorch_inference.py b/scripts/demucs_pytorch_inference.py index 2c01d53..c80dc86 100644 --- a/scripts/demucs_pytorch_inference.py +++ b/scripts/demucs_pytorch_inference.py @@ -1,6 +1,6 @@ #!/usr/bin/env python from demucs.apply import apply_model -from demucs.utils import debug_tensor_demucscpp +#from demucs.utils import debug_tensor_demucscpp from demucs.pretrained import get_model from demucs.pretrained import SOURCES import torch @@ -20,8 +20,9 @@ parser = argparse.ArgumentParser(description='Demucs') parser.add_argument('input_file', type=str, help='path to input wav file') parser.add_argument('--dest-dir', type=str, default=None, help='path to write output files') - parser.add_argument("--six-source", default=False, action="store_true", help="convert 6s model (default: 4s)") - parser.add_argument("--fine-tuned", default=False, action="store_true", help="convert 6s model (default: 4s)") + parser.add_argument("--six-source", default=False, action="store_true", help="use 6s model (default: 4s)") + parser.add_argument("--fine-tuned", default=False, action="store_true", help="use ft model (default: 4s)") + parser.add_argument("--v3", default=False, action="store_true", help="use v3 (hdemucs_mmi) model (default: 4s)") args = parser.parse_args() @@ -36,18 +37,20 @@ model_name += '_6s' elif args.fine_tuned: model_name = 'htdemucs_ft' + elif args.v3: + model_name = 'hdemucs_mmi' # demucs v4 hybrid transformer model = get_model(model_name) nb_out_sources = 6 if args.six_source else 4 print(model) - debug_tensor_demucscpp(audio, "input audio") + #debug_tensor_demucscpp(audio, "input audio") ref = audio.mean(0) audio = (audio - ref.mean()) / ref.std() - debug_tensor_demucscpp(audio, "audio post-normalization") + #debug_tensor_demucscpp(audio, "audio post-normalization") sources = apply_model(model, audio[None])[0] sources = sources * ref.std() + ref.mean() @@ -58,7 +61,7 @@ print(f"Saving target {target_name}") out_audio = sources[target_idx] - debug_tensor_demucscpp(out_audio, f"target {target_name}") + #debug_tensor_demucscpp(out_audio, f"target {target_name}") # write to file in directory if args.dest_dir is not None: diff --git a/scripts/demucs_pytorch_layer_test_v3.py b/scripts/demucs_pytorch_layer_test_v3.py new file mode 100644 index 0000000..3cffe32 --- /dev/null +++ b/scripts/demucs_pytorch_layer_test_v3.py @@ -0,0 +1,555 @@ +#!/usr/bin/env python +from demucs.apply import apply_model +from demucs.pretrained import get_model +from demucs.pretrained import SOURCES +import torch +import torchaudio.backend.sox_io_backend +import torchaudio +import argparse +import numpy as np +import os +from einops import rearrange +import sys + + +def debug_tensor_demucscpp(x, name): + #check if x is of type TensorChunk + if hasattr(x, 'tensor'): + # split into subchunk from self.offset:self.offset+self.length + x = x.tensor[..., x.offset:x.offset+x.length] + + print(f"Debugging tensor!: {name}") + print(f"\tshape: {tuple(x.shape)}") + x_min, x_min_idx = torch.min(x.reshape(-1), dim=0) + x_max, x_max_idx = torch.max(x.reshape(-1), dim=0) + x_mean = torch.mean(x) + x_stddev = torch.std(x) + x_sum = torch.sum(x) + print(f"\tmin: {x_min.item()}") + print(f"\tmax: {x_max.item()}") + print(f"\tmean: {x_mean.item()}") + print(f"\tstddev: {x_stddev.item()}") + print(f"\tsum: {x_sum.item()}") + print(f"\tmin idx: {tuple(np.unravel_index(x_min_idx.item(), x.shape))}") + print(f"\tmax idx: {tuple(np.unravel_index(x_max_idx.item(), x.shape))}") + print(f"FINISHED DEBUG FOR TENSOR {name}") + + +if __name__ == '__main__': + #input_file = "./test/data/gspi_stereo_short.wav" + + ## load audio file and resample to 44100 Hz + #metadata = torchaudio.info(input_file) + #print(metadata) + #audio, rate = torchaudio.load(input_file) + #print(rate) + + # demucs v3 hybrid + model = get_model('hdemucs_mmi') + print(model) + + try: + test_name = sys.argv[1] + except IndexError: + test_name = "all" + + if test_name == "all" or test_name == "freq-enc": + # get the henclayer + henclayer_0 = model.models[0].encoder[0] + + # create a fake tensor of shape (1, 4, 2048, 336) + x = torch.ones((1, 4, 2048, 336)) + + # set alternating odd index values to -1 + x[..., ::2] = -1 + + debug_tensor_demucscpp(x, "x") + + x_enc_0 = henclayer_0(x) + + debug_tensor_demucscpp(x_enc_0, "x_enc_0") + + # continue for the rest of the encoder layers + # generate tensors for each layer + # shapes are: + # (96, 128, 336) -> (192, 32, 336) -> (384, 8, 336) + # continue with x_enc_1,2,3 + + henclayer_1 = model.models[0].encoder[1] + x_enc_1 = henclayer_1(x_enc_0) + + debug_tensor_demucscpp(x_enc_1, "x_enc_1") + + henclayer_2 = model.models[0].encoder[2] + x_enc_2 = henclayer_2(x_enc_1) + + debug_tensor_demucscpp(x_enc_2, "x_enc_2") + + henclayer_3 = model.models[0].encoder[3] + x_enc_3 = henclayer_3(x_enc_2) + + debug_tensor_demucscpp(x_enc_3, "x_enc_3") + + if test_name == "all" or test_name == "time-enc": + # create fake xt tensor of shape (1, 2, 343980) + xt = torch.ones((1, 2, 343980)) + xt[..., ::2] = -1 + + htenclayer_0 = model.models[0].tencoder[0] + + debug_tensor_demucscpp(xt, "xt") + + xt_enc_0 = htenclayer_0(xt) + + debug_tensor_demucscpp(xt_enc_0, "xt_enc_0") + + htenclayer_1 = model.models[0].tencoder[1] + xt_enc_1 = htenclayer_1(xt_enc_0) + + debug_tensor_demucscpp(xt_enc_1, "xt_enc_1") + + htenclayer_2 = model.models[0].tencoder[2] + xt_enc_2 = htenclayer_2(xt_enc_1) + + debug_tensor_demucscpp(xt_enc_2, "xt_enc_2") + + htenclayer_3 = model.models[0].tencoder[3] + xt_enc_3 = htenclayer_3(xt_enc_2) + + debug_tensor_demucscpp(xt_enc_3, "xt_enc_3") + + if test_name == "all" or test_name == "encoder45": + # get the henclayer + henclayer_0 = model.models[0].encoder[0] + + # create a fake tensor of shape (1, 4, 2048, 336) + x = torch.ones((1, 4, 2048, 336)) + + # set alternating odd index values to -1 + x[..., ::2] = -1 + + debug_tensor_demucscpp(x, "x") + + x_enc_0 = henclayer_0(x) + + debug_tensor_demucscpp(x_enc_0, "x_enc_0") + + # continue for the rest of the encoder layers + # generate tensors for each layer + # shapes are: + # (96, 128, 336) -> (192, 32, 336) -> (384, 8, 336) + # continue with x_enc_1,2,3 + + henclayer_1 = model.models[0].encoder[1] + x_enc_1 = henclayer_1(x_enc_0) + + debug_tensor_demucscpp(x_enc_1, "x_enc_1") + + henclayer_2 = model.models[0].encoder[2] + x_enc_2 = henclayer_2(x_enc_1) + + debug_tensor_demucscpp(x_enc_2, "x_enc_2") + + henclayer_3 = model.models[0].encoder[3] + x_enc_3 = henclayer_3(x_enc_2) + + debug_tensor_demucscpp(x_enc_3, "x_enc_3") + + # create fake xt tensor of shape (1, 2, 343980) + xt = torch.ones((1, 2, 343980)) + xt[..., ::2] = -1 + + htenclayer_0 = model.models[0].tencoder[0] + + debug_tensor_demucscpp(xt, "xt") + + xt_enc_0 = htenclayer_0(xt) + + debug_tensor_demucscpp(xt_enc_0, "xt_enc_0") + + htenclayer_1 = model.models[0].tencoder[1] + xt_enc_1 = htenclayer_1(xt_enc_0) + + debug_tensor_demucscpp(xt_enc_1, "xt_enc_1") + + htenclayer_2 = model.models[0].tencoder[2] + xt_enc_2 = htenclayer_2(xt_enc_1) + + debug_tensor_demucscpp(xt_enc_2, "xt_enc_2") + + htenclayer_3 = model.models[0].tencoder[3] + xt_enc_3 = htenclayer_3(xt_enc_2) + + debug_tensor_demucscpp(xt_enc_3, "xt_enc_3") + + htenclayer_4 = model.models[0].tencoder[4] + xt_enc_4 = htenclayer_4(xt_enc_3) + + debug_tensor_demucscpp(xt_enc_4, "xt_enc_4") + + henclayer_4 = model.models[0].encoder[4] + x_enc_4 = henclayer_4(x_enc_3, inject=xt_enc_4) + + debug_tensor_demucscpp(x_enc_4, "x_enc_4") + + henclayer_5 = model.models[0].encoder[5] + x_shared_enc_5 = henclayer_5(x_enc_4) + + debug_tensor_demucscpp(x_shared_enc_5, "x_shared_enc_5") + + if test_name == "all" or test_name == "decoder01": + x_fake_shared_enc_5 = torch.ones((1, 1536, 168)) + skip_fake_dec_4 = torch.ones((768, 1, 336)) + + # set even index values to -1 + x_fake_shared_enc_5[..., ::2] = -1 + + # for the skip, set even index values to 0.5, odd to -0.5 + skip_fake_dec_4[..., ::2] = 0.5 + skip_fake_dec_4[..., 1::2] = -0.5 + + debug_tensor_demucscpp(x_fake_shared_enc_5, "x_fake_shared_enc_5") + debug_tensor_demucscpp(skip_fake_dec_4, "skip_fake_dec_4") + + hdecoder_0 = model.models[0].decoder[0] + x_empty = torch.zeros((1, 1536, 168)) + x_fake_dec_4, pre_t_unused = hdecoder_0(x_empty, x_fake_shared_enc_5, 336) + + debug_tensor_demucscpp(x_fake_dec_4, "x_fake_dec_4") + debug_tensor_demucscpp(pre_t_unused, "pre_t_unused") + + hdecoder_1 = model.models[0].decoder[1] + x_fake_dec_3, pre_t = hdecoder_1(x_fake_dec_4, skip_fake_dec_4, 336) + + debug_tensor_demucscpp(x_fake_dec_3, "x_fake_dec_3") + debug_tensor_demucscpp(pre_t, "pre_t") + + tdecoder_0 = model.models[0].tdecoder[0] + pre_t = pre_t[:, :, 0] + debug_tensor_demucscpp(pre_t, "pre_t") + xt_fake_dec_3, _ = tdecoder_0(pre_t, None, 1344) + + debug_tensor_demucscpp(xt_fake_dec_3, "xt_fake_dec_3") + + if test_name == "all" or test_name == "decoder1isolated": + x_fake_dec_4 = torch.ones((1, 768, 336)) + skip_fake_dec_4 = torch.ones((768, 1, 336)) + + # set even index values to -1 + x_fake_dec_4[..., ::2] = -1 + + # for the skip, set even index values to 0.5, odd to -0.5 + skip_fake_dec_4[..., ::2] = 0.5 + skip_fake_dec_4[..., 1::2] = -0.5 + + hdecoder_1 = model.models[0].decoder[1] + x_fake_dec_3, pre_t = hdecoder_1(x_fake_dec_4, skip_fake_dec_4, 336) + + debug_tensor_demucscpp(x_fake_dec_3, "x_fake_dec_3") + debug_tensor_demucscpp(pre_t, "pre_t") + + tdecoder_0 = model.models[0].tdecoder[0] + pre_t = pre_t[:, :, 0] + debug_tensor_demucscpp(pre_t, "pre_t") + xt_fake_dec_3, _ = tdecoder_0(pre_t, None, 1344) + + debug_tensor_demucscpp(xt_fake_dec_3, "xt_fake_dec_3") + + if test_name == "all" or test_name == "alldecoders": + x_fake_shared_enc_5 = torch.ones((1, 1536, 168)) + skip_fake_dec_4 = torch.ones((768, 1, 336)) + + # set even index values to -1 + x_fake_shared_enc_5[..., ::2] = -1 + + # for the skip, set even index values to 0.5, odd to -0.5 + skip_fake_dec_4[..., ::2] = 0.5 + skip_fake_dec_4[..., 1::2] = -0.5 + + debug_tensor_demucscpp(x_fake_shared_enc_5, "x_fake_shared_enc_5") + debug_tensor_demucscpp(skip_fake_dec_4, "skip_fake_dec_4") + + x_fake_dec_4 = torch.ones((1, 768, 336)) + + # set even index values to -1 + x_fake_dec_4[..., ::2] = -1 + + skip_fake_dec_3 = torch.ones((384, 8, 336)) + + # 0.5, -0.5 again + skip_fake_dec_3[..., ::2] = 0.5 + skip_fake_dec_3[..., 1::2] = -0.5 + + skip_fake_dec_2 = torch.ones((192, 32, 336)) + + # 0.5, -0.5 again + skip_fake_dec_2[..., ::2] = 0.5 + skip_fake_dec_2[..., 1::2] = -0.5 + + skip_fake_dec_1 = torch.ones((96, 128, 336)) + + # 0.5, -0.5 again + skip_fake_dec_1[..., ::2] = 0.5 + skip_fake_dec_1[..., 1::2] = -0.5 + + skip_fake_dec_0 = torch.ones((48, 512, 336)) + + # 0.5, -0.5 again + skip_fake_dec_0[..., ::2] = 0.5 + skip_fake_dec_0[..., 1::2] = -0.5 + + skip_fake_tdec_3 = torch.ones((1, 384, 1344)) + + # 0.5, -0.5 again + skip_fake_tdec_3[..., ::2] = 0.5 + skip_fake_tdec_3[..., 1::2] = -0.5 + + skip_fake_tdec_2 = torch.ones((1, 192, 5375)) + + # 0.5, -0.5 again + skip_fake_tdec_2[..., ::2] = 0.5 + skip_fake_tdec_2[..., 1::2] = -0.5 + + skip_fake_tdec_1 = torch.ones((1, 96, 21499)) + + # 0.5, -0.5 again + skip_fake_tdec_1[..., ::2] = 0.5 + skip_fake_tdec_1[..., 1::2] = -0.5 + + skip_fake_tdec_0 = torch.ones((1, 48, 85995)) + + # 0.5, -0.5 again + skip_fake_tdec_0[..., ::2] = 0.5 + skip_fake_tdec_0[..., 1::2] = -0.5 + + hdecoder_0 = model.models[0].decoder[0] + x_empty = torch.zeros((1, 1536, 168)) + x_fake_dec_4, pre_t_unused = hdecoder_0(x_empty, x_fake_shared_enc_5, 336) + + debug_tensor_demucscpp(x_fake_dec_4, "x_fake_dec_4") + debug_tensor_demucscpp(pre_t_unused, "pre_t_unused") + + hdecoder_1 = model.models[0].decoder[1] + x_fake_dec_3, pre_t = hdecoder_1(x_fake_dec_4, skip_fake_dec_4, 336) + + debug_tensor_demucscpp(x_fake_dec_3, "x_fake_dec_3") + debug_tensor_demucscpp(pre_t, "pre_t") + + hdecoder_1 = model.models[0].decoder[1] + x_fake_dec_3, pre_t = hdecoder_1(x_fake_dec_4, skip_fake_dec_4, 336) + + debug_tensor_demucscpp(x_fake_dec_3, "x_fake_dec_3") + debug_tensor_demucscpp(pre_t, "pre_t") + + tdecoder_0 = model.models[0].tdecoder[0] + pre_t = pre_t[:, :, 0] + debug_tensor_demucscpp(pre_t, "pre_t") + xt_fake_dec_3, _ = tdecoder_0(pre_t, None, 1344) + + debug_tensor_demucscpp(xt_fake_dec_3, "xt_fake_dec_3") + + hdecoder_2 = model.models[0].decoder[2] + x_fake_dec_2, _ = hdecoder_2(x_fake_dec_3, skip_fake_dec_3, 1344) + + tdecoder_1 = model.models[0].tdecoder[1] + xt_fake_dec_2, _ = tdecoder_1(xt_fake_dec_3, skip_fake_tdec_3, 5375) + + debug_tensor_demucscpp(x_fake_dec_2, "x_fake_dec_2") + debug_tensor_demucscpp(xt_fake_dec_2, "xt_fake_dec_2") + + if test_name == "all" or test_name == "end2end": + # get the henclayer + henclayer_0 = model.models[0].encoder[0] + + # create a fake tensor of shape (1, 4, 2048, 336) + x = torch.ones((1, 4, 2048, 336)) + + # set alternating odd index values to -1 + x[..., ::2] = -1 + + debug_tensor_demucscpp(x, "x") + + x_enc_0 = henclayer_0(x) + + debug_tensor_demucscpp(x_enc_0, "x_enc_0") + + # continue for the rest of the encoder layers + # generate tensors for each layer + # shapes are: + # (96, 128, 336) -> (192, 32, 336) -> (384, 8, 336) + # continue with x_enc_1,2,3 + + henclayer_1 = model.models[0].encoder[1] + x_enc_1 = henclayer_1(x_enc_0) + + debug_tensor_demucscpp(x_enc_1, "x_enc_1") + + henclayer_2 = model.models[0].encoder[2] + x_enc_2 = henclayer_2(x_enc_1) + + debug_tensor_demucscpp(x_enc_2, "x_enc_2") + + henclayer_3 = model.models[0].encoder[3] + x_enc_3 = henclayer_3(x_enc_2) + + debug_tensor_demucscpp(x_enc_3, "x_enc_3") + + # create fake xt tensor of shape (1, 2, 343980) + xt = torch.ones((1, 2, 343980)) + xt[..., ::2] = -1 + + htenclayer_0 = model.models[0].tencoder[0] + + debug_tensor_demucscpp(xt, "xt") + + xt_enc_0 = htenclayer_0(xt) + + debug_tensor_demucscpp(xt_enc_0, "xt_enc_0") + + htenclayer_1 = model.models[0].tencoder[1] + xt_enc_1 = htenclayer_1(xt_enc_0) + + debug_tensor_demucscpp(xt_enc_1, "xt_enc_1") + + htenclayer_2 = model.models[0].tencoder[2] + xt_enc_2 = htenclayer_2(xt_enc_1) + + debug_tensor_demucscpp(xt_enc_2, "xt_enc_2") + + htenclayer_3 = model.models[0].tencoder[3] + xt_enc_3 = htenclayer_3(xt_enc_2) + + debug_tensor_demucscpp(xt_enc_3, "xt_enc_3") + + htenclayer_4 = model.models[0].tencoder[4] + xt_enc_4 = htenclayer_4(xt_enc_3) + + debug_tensor_demucscpp(xt_enc_4, "xt_enc_4") + + henclayer_4 = model.models[0].encoder[4] + x_enc_4 = henclayer_4(x_enc_3, inject=xt_enc_4) + + debug_tensor_demucscpp(x_enc_4, "x_enc_4") + + henclayer_5 = model.models[0].encoder[5] + x_shared_enc_5 = henclayer_5(x_enc_4, inject=None) + + debug_tensor_demucscpp(x_shared_enc_5, "x_shared_enc_5") + + skip_fake_dec_4 = torch.ones((768, 1, 336)) + + # for the skip, set even index values to 0.5, odd to -0.5 + skip_fake_dec_4[..., ::2] = 0.5 + skip_fake_dec_4[..., 1::2] = -0.5 + + debug_tensor_demucscpp(skip_fake_dec_4, "skip_fake_dec_4") + + x_fake_dec_4 = torch.ones((1, 768, 336)) + + # set even index values to -1 + x_fake_dec_4[..., ::2] = -1 + + skip_fake_dec_3 = torch.ones((384, 8, 336)) + + # 0.5, -0.5 again + skip_fake_dec_3[..., ::2] = 0.5 + skip_fake_dec_3[..., 1::2] = -0.5 + + skip_fake_dec_2 = torch.ones((192, 32, 336)) + + # 0.5, -0.5 again + skip_fake_dec_2[..., ::2] = 0.5 + skip_fake_dec_2[..., 1::2] = -0.5 + + skip_fake_dec_1 = torch.ones((96, 128, 336)) + + # 0.5, -0.5 again + skip_fake_dec_1[..., ::2] = 0.5 + skip_fake_dec_1[..., 1::2] = -0.5 + + skip_fake_dec_0 = torch.ones((48, 512, 336)) + + # 0.5, -0.5 again + skip_fake_dec_0[..., ::2] = 0.5 + skip_fake_dec_0[..., 1::2] = -0.5 + + skip_fake_tdec_3 = torch.ones((1, 384, 1344)) + + # 0.5, -0.5 again + skip_fake_tdec_3[..., ::2] = 0.5 + skip_fake_tdec_3[..., 1::2] = -0.5 + + skip_fake_tdec_2 = torch.ones((1, 192, 5375)) + + # 0.5, -0.5 again + skip_fake_tdec_2[..., ::2] = 0.5 + skip_fake_tdec_2[..., 1::2] = -0.5 + + skip_fake_tdec_1 = torch.ones((1, 96, 21499)) + + # 0.5, -0.5 again + skip_fake_tdec_1[..., ::2] = 0.5 + skip_fake_tdec_1[..., 1::2] = -0.5 + + skip_fake_tdec_0 = torch.ones((1, 48, 85995)) + + # 0.5, -0.5 again + skip_fake_tdec_0[..., ::2] = 0.5 + skip_fake_tdec_0[..., 1::2] = -0.5 + + hdecoder_0 = model.models[0].decoder[0] + x_empty = torch.zeros((1, 1536, 168)) + x_fake_dec_4, pre_t_unused = hdecoder_0(x_empty, x_shared_enc_5, 336) + + debug_tensor_demucscpp(x_fake_dec_4, "x_fake_dec_4") + debug_tensor_demucscpp(pre_t_unused, "pre_t_unused") + + hdecoder_1 = model.models[0].decoder[1] + x_fake_dec_3, pre_t = hdecoder_1(x_fake_dec_4, skip_fake_dec_4, 336) + + debug_tensor_demucscpp(x_fake_dec_3, "x_fake_dec_3") + debug_tensor_demucscpp(pre_t, "pre_t") + + tdecoder_1 = model.models[0].tdecoder[0] + pre_t = pre_t[:, :, 0] + debug_tensor_demucscpp(pre_t, "pre_t") + xt_fake_dec_3, _ = tdecoder_1(pre_t, None, 1344) + + debug_tensor_demucscpp(xt_fake_dec_3, "xt_fake_dec_3") + + hdecoder_2 = model.models[0].decoder[2] + x_fake_dec_2, _ = hdecoder_2(x_fake_dec_3, skip_fake_dec_3, 336) + + tdecoder_2 = model.models[0].tdecoder[1] + xt_fake_dec_2, _ = tdecoder_2(xt_fake_dec_3, skip_fake_tdec_3, 5375) + + debug_tensor_demucscpp(x_fake_dec_2, "x_fake_dec_2") + debug_tensor_demucscpp(xt_fake_dec_2, "xt_fake_dec_2") + + hdecoder_3 = model.models[0].decoder[3] + x_fake_dec_1, _ = hdecoder_3(x_fake_dec_2, skip_fake_dec_2, 336) + + tdecoder_3 = model.models[0].tdecoder[2] + xt_fake_dec_1, _ = tdecoder_3(xt_fake_dec_2, skip_fake_tdec_2, 21499) + + debug_tensor_demucscpp(x_fake_dec_1, "x_fake_dec_1") + debug_tensor_demucscpp(xt_fake_dec_1, "xt_fake_dec_1") + + hdecoder_4 = model.models[0].decoder[4] + x_fake_dec_0, _ = hdecoder_4(x_fake_dec_1, skip_fake_dec_1, 336) + + tdecoder_4 = model.models[0].tdecoder[3] + xt_fake_dec_0, _ = tdecoder_4(xt_fake_dec_1, skip_fake_tdec_1, 85995) + + debug_tensor_demucscpp(x_fake_dec_0, "x_fake_dec_0") + debug_tensor_demucscpp(xt_fake_dec_0, "xt_fake_dec_0") + + hdecoder_5 = model.models[0].decoder[5] + x_out, _ = hdecoder_5(x_fake_dec_0, skip_fake_dec_0, 336) + + + tdecoder_5 = model.models[0].tdecoder[4] + xt_out, _ = tdecoder_5(xt_fake_dec_0, skip_fake_tdec_0, 343980) + + debug_tensor_demucscpp(x_out, "x_out") + debug_tensor_demucscpp(xt_out, "xt_out") diff --git a/src/conv.hpp b/src/conv.hpp index 0d10b1f..b967048 100644 --- a/src/conv.hpp +++ b/src/conv.hpp @@ -17,18 +17,22 @@ inline Eigen::MatrixXf im2col(const Eigen::Tensor3dXf &input) { // Adjust the calculation of height_col and width_col for dilation int in_channels = input.dimension(0); - int height_col = (input.dimension(1) + 2 * pad_height - - dilation_height * (kernel_height - 1) - 1) / - stride_height + - 1; - int width_col = (input.dimension(2) + 2 * pad_width - - dilation_width * (kernel_width - 1) - 1) / - stride_width + - 1; - int in_height = input.dimension(1); int in_width = input.dimension(2); + // Apply floating point division before ceiling to correctly calculate + // dimensions + int height_col = + static_cast(std::ceil((in_height + 2 * pad_height - + dilation_height * (kernel_height - 1) - 1) / + float(stride_height)) + + 1); + int width_col = + static_cast(std::ceil((in_width + 2 * pad_width - + dilation_width * (kernel_width - 1) - 1) / + float(stride_width)) + + 1); + Eigen::MatrixXf output(height_col * width_col, in_channels * kernel_height * kernel_width); output.setZero(); @@ -67,21 +71,22 @@ inline Eigen::MatrixXf im2col(const Eigen::Tensor3dXf &input) template -Eigen::Tensor3dXf conv2d_gemm(const Eigen::Tensor3dXf &x, - const Eigen::Tensor4dXf &w, - const Eigen::Tensor1dXf &b) +Eigen::Tensor3dXf conv2d(const Eigen::Tensor3dXf &x, const Eigen::Tensor4dXf &w, + const Eigen::Tensor1dXf &b) { int in_height = x.dimension(1); int in_width = x.dimension(2); - // Calculate output dimensions - int out_height = static_cast(std::floor(in_height + 2 * pad_height - - kernel_height) / - stride_height) + - 1; + // Calculate output dimensions with the correct application of ceil + int out_height = + static_cast(std::ceil( + (float)(in_height + 2 * pad_height - (kernel_height - 1) - 1) / + stride_height)) + + 1; int out_width = - static_cast(std::floor(in_width + 2 * pad_width - kernel_width) / - stride_width) + + static_cast(std::ceil( + (float)(in_width + 2 * pad_width - (kernel_width - 1) - 1) / + stride_width)) + 1; // Apply im2col @@ -136,21 +141,23 @@ Eigen::Tensor3dXf conv2d_gemm(const Eigen::Tensor3dXf &x, template -Eigen::Tensor3dXf conv2d_gemm_fused_gelu(const Eigen::Tensor3dXf &x, - const Eigen::Tensor4dXf &w, - const Eigen::Tensor1dXf &b) +Eigen::Tensor3dXf conv2d_fused_gelu(const Eigen::Tensor3dXf &x, + const Eigen::Tensor4dXf &w, + const Eigen::Tensor1dXf &b) { int in_height = x.dimension(1); int in_width = x.dimension(2); - // Calculate output dimensions - int out_height = static_cast(std::floor(in_height + 2 * pad_height - - kernel_height) / - stride_height) + - 1; + // Calculate output dimensions with the correct application of ceil + int out_height = + static_cast(std::ceil( + (float)(in_height + 2 * pad_height - (kernel_height - 1) - 1) / + stride_height)) + + 1; int out_width = - static_cast(std::floor(in_width + 2 * pad_width - kernel_width) / - stride_width) + + static_cast(std::ceil( + (float)(in_width + 2 * pad_width - (kernel_width - 1) - 1) / + stride_width)) + 1; // Apply im2col @@ -204,17 +211,6 @@ Eigen::Tensor3dXf conv2d_gemm_fused_gelu(const Eigen::Tensor3dXf &x, return y_out; } -template -Eigen::Tensor3dXf conv2d(const Eigen::Tensor3dXf &x, const Eigen::Tensor4dXf &w, - const Eigen::Tensor1dXf &b) -{ - return conv2d_gemm(x, w, b); -} - template Eigen::Tensor3dXf conv1d(const Eigen::Tensor3dXf &x, const Eigen::Tensor3dXf &w, @@ -255,9 +251,9 @@ Eigen::Tensor3dXf conv1d_fused_gelu(const Eigen::Tensor3dXf &x, // do 2d convolution inference here // treating the in_freq dimension as a width dimension with a no-op kernel Eigen::Tensor3dXf y_out = - demucscpp::conv2d_gemm_fused_gelu(x_shuff, w_4d, b); + demucscpp::conv2d_fused_gelu(x_shuff, + w_4d, b); // move end axis to the front Eigen::Tensor3dXf y_out_shuf = @@ -274,9 +270,17 @@ Eigen::MatrixXf im2col_transposed(const Eigen::Tensor3dXf &input) int input_height = input.dimension(1); int input_width = input.dimension(2); - // Calculate the expanded output height and width - int expanded_height = (input_height - 1) * stride_height + kernel_height; - int expanded_width = (input_width - 1) * stride_width + kernel_width; + // Calculate the effective kernel size after dilation + int effective_kernel_height = + kernel_height + (kernel_height - 1) * (dilation_height - 1); + int effective_kernel_width = + kernel_width + (kernel_width - 1) * (dilation_width - 1); + + // Calculate the expanded output height and width considering dilation + int expanded_height = + (input_height - 1) * stride_height + effective_kernel_height; + int expanded_width = + (input_width - 1) * stride_width + effective_kernel_width; // Initialize the output matrix Eigen::MatrixXf output = @@ -294,9 +298,11 @@ Eigen::MatrixXf im2col_transposed(const Eigen::Tensor3dXf &input) { for (int w = 0; w < input_width; ++w) { - // Calculate the position in the expanded output - int expanded_h = h * stride_height + kh - pad_height; - int expanded_w = w * stride_width + kw - pad_width; + // Adjust calculation for dilation + int expanded_h = h * stride_height + + (kh * dilation_height) - pad_height; + int expanded_w = w * stride_width + + (kw * dilation_width) - pad_width; // Check if the indices are within the bounds of the // expanded output @@ -321,28 +327,29 @@ Eigen::MatrixXf im2col_transposed(const Eigen::Tensor3dXf &input) template -Eigen::Tensor3dXf conv2d_tr_gemm(const Eigen::Tensor3dXf &x, - const Eigen::Tensor4dXf &w, - const Eigen::Tensor1dXf &b) +Eigen::Tensor3dXf conv2d_tr(const Eigen::Tensor3dXf &x, + const Eigen::Tensor4dXf &w, + const Eigen::Tensor1dXf &b) { int in_height = x.dimension(1); int in_width = x.dimension(2); - // Calculate the output dimensions - int out_height = - (in_height - 1) * stride_height - 2 * pad_height + kernel_height; + int effective_kernel_height = + kernel_height + (kernel_height - 1) * (dilation_height - 1); + int effective_kernel_width = + kernel_width + (kernel_width - 1) * (dilation_width - 1); + + int out_height = (in_height - 1) * stride_height + effective_kernel_height - + 2 * pad_height; int out_width = - (in_width - 1) * stride_width - 2 * pad_width + kernel_width; + (in_width - 1) * stride_width + effective_kernel_width - 2 * pad_width; - // demucscppdebug::debug_tensor_3dxf(x, "x input"); // Apply an adapted im2col for transposed convolution Eigen::MatrixXf im2col_matrix = im2col_transposed(x); - // demucscppdebug::debug_matrix_xf(im2col_matrix, "x post-im2col"); - // demucscppdebug::debug_tensor_4dxf(w, "weights"); // Reshape and prepare the weights as in conv2d_gemm // keeping in mind transpose weights are stored as (Cin, Cout, Kh, Kw) (not // Cout, Cin, Kh, Kw) @@ -355,11 +362,8 @@ Eigen::Tensor3dXf conv2d_tr_gemm(const Eigen::Tensor3dXf &x, reshaped_weights_tensor.data(), reshaped_weights_tensor.dimension(0), reshaped_weights_tensor.dimension(1)); - // demucscppdebug::debug_matrix_xf(reshaped_weights, "reshaped weights"); - // Perform matrix multiplication with GEMM Eigen::MatrixXf result = im2col_matrix * reshaped_weights.transpose(); - // demucscppdebug::debug_matrix_xf(result, "result of gemm-conv-tr"); // Add bias to result for (int chout = 0; chout < out_channels; ++chout) @@ -367,9 +371,6 @@ Eigen::Tensor3dXf conv2d_tr_gemm(const Eigen::Tensor3dXf &x, result.col(chout).array() += b(chout); } - // demucscppdebug::debug_matrix_xf(result, "result conv2d-tr-gemm - // post-bias!"); - Eigen::Tensor3dXf y_out(out_channels, out_height, out_width); y_out.setZero(); @@ -391,35 +392,35 @@ Eigen::Tensor3dXf conv2d_tr_gemm(const Eigen::Tensor3dXf &x, } } - // demucscppdebug::debug_tensor_3dxf(y_out, "y_out"); return y_out; } template -Eigen::Tensor3dXf conv2d_tr_gemm_fused_gelu(const Eigen::Tensor3dXf &x, - const Eigen::Tensor4dXf &w, - const Eigen::Tensor1dXf &b) +Eigen::Tensor3dXf conv2d_tr_fused_gelu(const Eigen::Tensor3dXf &x, + const Eigen::Tensor4dXf &w, + const Eigen::Tensor1dXf &b) { int in_height = x.dimension(1); int in_width = x.dimension(2); - // Calculate the output dimensions - int out_height = - (in_height - 1) * stride_height - 2 * pad_height + kernel_height; + int effective_kernel_height = + kernel_height + (kernel_height - 1) * (dilation_height - 1); + int effective_kernel_width = + kernel_width + (kernel_width - 1) * (dilation_width - 1); + + int out_height = (in_height - 1) * stride_height + effective_kernel_height - + 2 * pad_height; int out_width = - (in_width - 1) * stride_width - 2 * pad_width + kernel_width; + (in_width - 1) * stride_width + effective_kernel_width - 2 * pad_width; - // demucscppdebug::debug_tensor_3dxf(x, "x input"); // Apply an adapted im2col for transposed convolution Eigen::MatrixXf im2col_matrix = im2col_transposed(x); - // demucscppdebug::debug_matrix_xf(im2col_matrix, "x post-im2col"); - // demucscppdebug::debug_tensor_4dxf(w, "weights"); // Reshape and prepare the weights as in conv2d_gemm // keeping in mind transpose weights are stored as (Cin, Cout, Kh, Kw) (not // Cout, Cin, Kh, Kw) @@ -432,11 +433,8 @@ Eigen::Tensor3dXf conv2d_tr_gemm_fused_gelu(const Eigen::Tensor3dXf &x, reshaped_weights_tensor.data(), reshaped_weights_tensor.dimension(0), reshaped_weights_tensor.dimension(1)); - // demucscppdebug::debug_matrix_xf(reshaped_weights, "reshaped weights"); - // Perform matrix multiplication with GEMM Eigen::MatrixXf result = im2col_matrix * reshaped_weights.transpose(); - // demucscppdebug::debug_matrix_xf(result, "result of gemm-conv-tr"); // Add bias to result for (int chout = 0; chout < out_channels; ++chout) @@ -444,9 +442,6 @@ Eigen::Tensor3dXf conv2d_tr_gemm_fused_gelu(const Eigen::Tensor3dXf &x, result.col(chout).array() += b(chout); } - // demucscppdebug::debug_matrix_xf(result, "result conv2d-tr-gemm - // post-bias!"); - Eigen::Tensor3dXf y_out(out_channels, out_height, out_width); y_out.setZero(); @@ -475,22 +470,9 @@ Eigen::Tensor3dXf conv2d_tr_gemm_fused_gelu(const Eigen::Tensor3dXf &x, } } - // demucscppdebug::debug_tensor_3dxf(y_out, "y_out"); return y_out; } -template -Eigen::Tensor3dXf conv2d_tr(const Eigen::Tensor3dXf &x, - const Eigen::Tensor4dXf &w, - const Eigen::Tensor1dXf &b) -{ - return conv2d_tr_gemm(x, w, b); -} - template Eigen::Tensor3dXf conv1d_tr(const Eigen::Tensor3dXf &x, @@ -506,8 +488,8 @@ Eigen::Tensor3dXf conv1d_tr(const Eigen::Tensor3dXf &x, // Call the 2D transposed convolution function Eigen::Tensor3dXf y_out = - conv2d_tr_gemm(x_shuff, w_4d, b); + conv2d_tr(x_shuff, w_4d, b); // Move end axis to the front Eigen::Tensor3dXf y_out_shuf = @@ -531,9 +513,8 @@ Eigen::Tensor3dXf conv1d_tr_fused_gelu(const Eigen::Tensor3dXf &x, // Call the 2D transposed convolution function Eigen::Tensor3dXf y_out = - conv2d_tr_gemm_fused_gelu(x_shuff, w_4d, - b); + conv2d_tr_fused_gelu(x_shuff, w_4d, b); // Move end axis to the front Eigen::Tensor3dXf y_out_shuf = diff --git a/src/encdec.cpp b/src/encdec.cpp index 4b4265c..5f4f2b5 100644 --- a/src/encdec.cpp +++ b/src/encdec.cpp @@ -5,7 +5,6 @@ #include #include -// forward declaration to apply a frequency encoder void demucscpp::apply_freq_encoder(const struct demucscpp::demucs_model &model, int encoder_idx, const Eigen::Tensor3dXf &x_in, @@ -80,7 +79,6 @@ void demucscpp::apply_freq_encoder(const struct demucscpp::demucs_model &model, x_out = demucscpp::glu(y_shuff, 0); } -// forward declaration to apply a time encoder void demucscpp::apply_time_encoder(const struct demucscpp::demucs_model &model, int tencoder_idx, const Eigen::Tensor3dXf &xt_in, @@ -165,7 +163,6 @@ void demucscpp::apply_time_encoder(const struct demucscpp::demucs_model &model, xt_out = demucscpp::glu(yt, 1); } -// forward declaration to apply a frequency decoder void demucscpp::apply_freq_decoder(const struct demucscpp::demucs_model &model, int decoder_idx, const Eigen::Tensor3dXf &x_in, @@ -220,20 +217,17 @@ void demucscpp::apply_freq_decoder(const struct demucscpp::demucs_model &model, switch (decoder_idx) { case 0: - y = demucscpp::conv2d_tr_gemm_fused_gelu<384, 192, 8, 1, 4, 1, 0, 0, 1, - 1>( + y = demucscpp::conv2d_tr_fused_gelu<384, 192, 8, 1, 4, 1, 0, 0, 1, 1>( y_shuff_2, model.decoder_conv_tr_weight[decoder_idx], model.decoder_conv_tr_bias[decoder_idx]); break; case 1: - y = demucscpp::conv2d_tr_gemm_fused_gelu<192, 96, 8, 1, 4, 1, 0, 0, 1, - 1>( + y = demucscpp::conv2d_tr_fused_gelu<192, 96, 8, 1, 4, 1, 0, 0, 1, 1>( y_shuff_2, model.decoder_conv_tr_weight[decoder_idx], model.decoder_conv_tr_bias[decoder_idx]); break; case 2: - y = demucscpp::conv2d_tr_gemm_fused_gelu<96, 48, 8, 1, 4, 1, 0, 0, 1, - 1>( + y = demucscpp::conv2d_tr_fused_gelu<96, 48, 8, 1, 4, 1, 0, 0, 1, 1>( y_shuff_2, model.decoder_conv_tr_weight[decoder_idx], model.decoder_conv_tr_bias[decoder_idx]); break; @@ -261,7 +255,6 @@ void demucscpp::apply_freq_decoder(const struct demucscpp::demucs_model &model, {y.dimension(0), y_dim1_end, y.dimension(2)})); } -// forward declaration to apply a time decoder void demucscpp::apply_time_decoder(const struct demucscpp::demucs_model &model, int tdecoder_idx, const Eigen::Tensor3dXf &xt_in, @@ -366,3 +359,505 @@ void demucscpp::apply_time_decoder(const struct demucscpp::demucs_model &model, Eigen::array( {yt.dimension(0), yt.dimension(1), out_length})); } + +void demucscpp_v3::apply_freq_encoder_v3( + const struct demucscpp_v3::demucs_v3_model &model, int encoder_idx, + const Eigen::Tensor3dXf &x_in, Eigen::Tensor3dXf &x_out) +{ + Eigen::Tensor3dXf x_shuf = x_in.shuffle(Eigen::array({2, 0, 1})); + + // 2D Convolution operation + Eigen::Tensor3dXf y; + + switch (encoder_idx) + { + case 0: + y = demucscpp::conv1d_fused_gelu<4, 48, 8, 4, 2, 1>( + x_shuf, model.encoder_conv_weight[encoder_idx], + model.encoder_conv_bias[encoder_idx]); + break; + case 1: + y = demucscpp::conv1d_fused_gelu<48, 96, 8, 4, 2, 1>( + x_shuf, model.encoder_conv_weight[encoder_idx], + model.encoder_conv_bias[encoder_idx]); + break; + case 2: + y = demucscpp::conv1d_fused_gelu<96, 192, 8, 4, 2, 1>( + x_shuf, model.encoder_conv_weight[encoder_idx], + model.encoder_conv_bias[encoder_idx]); + break; + case 3: + y = demucscpp::conv1d_fused_gelu<192, 384, 8, 4, 2, 1>( + x_shuf, model.encoder_conv_weight[encoder_idx], + model.encoder_conv_bias[encoder_idx]); + break; + }; + + // reverse all dims + Eigen::Tensor3dXf y_shuff = y.shuffle(Eigen::array({2, 1, 0})); + demucscpp_v3::apply_dconv_v3(model, y_shuff, 0, encoder_idx, + y_shuff.dimension(2)); + + // swap back from H,C,W to C,H,W + // then put W in front to use conv1d function for width=1 conv2d + y = y_shuff.shuffle(Eigen::array({2, 1, 0})); + + // need rewrite, norm2, glu + switch (encoder_idx) + { + case 0: + y = demucscpp::conv1d<48, 96, 1, 1, 0, 1>( + y, model.encoder_rewrite_weight[encoder_idx], + model.encoder_rewrite_bias[encoder_idx]); + break; + case 1: + y = demucscpp::conv1d<96, 192, 1, 1, 0, 1>( + y, model.encoder_rewrite_weight[encoder_idx], + model.encoder_rewrite_bias[encoder_idx]); + break; + case 2: + y = demucscpp::conv1d<192, 384, 1, 1, 0, 1>( + y, model.encoder_rewrite_weight[encoder_idx], + model.encoder_rewrite_bias[encoder_idx]); + break; + case 3: + y = demucscpp::conv1d<384, 768, 1, 1, 0, 1>( + y, model.encoder_rewrite_weight[encoder_idx], + model.encoder_rewrite_bias[encoder_idx]); + break; + }; + + y_shuff = y.shuffle(Eigen::array({1, 2, 0})); + + // copy into x_out + x_out = demucscpp::glu(y_shuff, 0); +} + +void demucscpp_v3::apply_time_encoder_v3( + const struct demucscpp_v3::demucs_v3_model &model, int tencoder_idx, + const Eigen::Tensor3dXf &xt_in, Eigen::Tensor3dXf &xt_out) +{ + int crop = demucscpp::TIME_BRANCH_LEN_0; + // switch case for tencoder_idx + switch (tencoder_idx) + { + case 0: + break; + case 1: + crop = demucscpp::TIME_BRANCH_LEN_1; + break; + case 2: + crop = demucscpp::TIME_BRANCH_LEN_2; + break; + case 3: + crop = demucscpp::TIME_BRANCH_LEN_3; + break; + } + + // now implement the forward pass + // first, apply the convolution + // Conv1d(2, 48, kernel_size=(8,), stride=(4,), padding=(2,)) + Eigen::Tensor3dXf yt; + + switch (tencoder_idx) + { + case 0: + yt = demucscpp::conv1d_fused_gelu<2, 48, 8, 4, 2, 1>( + xt_in, model.tencoder_conv_weight[tencoder_idx], + model.tencoder_conv_bias[tencoder_idx]); + break; + case 1: + yt = demucscpp::conv1d_fused_gelu<48, 96, 8, 4, 2, 1>( + xt_in, model.tencoder_conv_weight[tencoder_idx], + model.tencoder_conv_bias[tencoder_idx]); + break; + case 2: + yt = demucscpp::conv1d_fused_gelu<96, 192, 8, 4, 2, 1>( + xt_in, model.tencoder_conv_weight[tencoder_idx], + model.tencoder_conv_bias[tencoder_idx]); + break; + case 3: + yt = demucscpp::conv1d_fused_gelu<192, 384, 8, 4, 2, 1>( + xt_in, model.tencoder_conv_weight[tencoder_idx], + model.tencoder_conv_bias[tencoder_idx]); + break; + }; + + // now dconv time + demucscpp_v3::apply_dconv_v3(model, yt, 1, tencoder_idx, crop); + + // end of dconv? + + // need rewrite, norm2, glu + switch (tencoder_idx) + { + case 0: + yt = demucscpp::conv1d<48, 96, 1, 1, 0, 1>( + yt, model.tencoder_rewrite_weight[tencoder_idx], + model.tencoder_rewrite_bias[tencoder_idx]); + break; + case 1: + yt = demucscpp::conv1d<96, 192, 1, 1, 0, 1>( + yt, model.tencoder_rewrite_weight[tencoder_idx], + model.tencoder_rewrite_bias[tencoder_idx]); + break; + case 2: + yt = demucscpp::conv1d<192, 384, 1, 1, 0, 1>( + yt, model.tencoder_rewrite_weight[tencoder_idx], + model.tencoder_rewrite_bias[tencoder_idx]); + break; + case 3: + yt = demucscpp::conv1d<384, 768, 1, 1, 0, 1>( + yt, model.tencoder_rewrite_weight[tencoder_idx], + model.tencoder_rewrite_bias[tencoder_idx]); + break; + }; + + xt_out = demucscpp::glu(yt, 1); +} + +void demucscpp_v3::apply_time_encoder_4( + const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &xt_in, Eigen::Tensor3dXf &xt_out) +{ + // now implement the forward pass + // first, apply the convolution + // Conv1d(2, 48, kernel_size=(8,), stride=(4,), padding=(2,)) + Eigen::Tensor3dXf yt = demucscpp::conv1d<384, 768, 8, 4, 2, 1>( + xt_in, model.tencoder_4_conv_weight, model.tencoder_4_conv_bias); + + xt_out = yt; +} + +void demucscpp_v3::apply_freq_encoder_4( + const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &x_in, const Eigen::Tensor3dXf &x_inject, + Eigen::Tensor3dXf &x_out, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers) +{ + const int encoder_idx = 0; + + // 2D Convolution operation + Eigen::Tensor3dXf y; + Eigen::Tensor3dXf y_shuff; + + y = demucscpp::conv2d<384, 768, 8, 1, 4, 1, 0, 0, 1, 1>( + x_in, model.encoder_4_conv_weight, + model.encoder_4_5_conv_bias[encoder_idx]); + // encoder 4 (i.e. encoder_idx 0) has the inject param + // swap first two dims of x_inject + + y_shuff = y.shuffle(Eigen::array({1, 0, 2})) + x_inject; + + // apply groupnorm + y = groupnorm::group_norm_fused_gelu( + y_shuff, model.encoder_4_5_norm1_weight[encoder_idx], + model.encoder_4_5_norm1_bias[encoder_idx], 4, 1e-05); + + // special dconv with bilstm + local attn + demucscpp_v3::apply_dconv_v3_encoder_4_5(model, y, encoder_idx, + y_shuff.dimension(2), buffers); + + y = demucscpp::conv1d<768, 1536, 1, 1, 0, 1>( + y, model.encoder_4_5_rewrite_weight[encoder_idx], + model.encoder_4_5_rewrite_bias[encoder_idx]); + + // apply groupnorm + y = demucscpp::group_norm(y, model.encoder_4_5_norm2_weight[encoder_idx], + model.encoder_4_5_norm2_bias[encoder_idx], 4, + 1e-05); + + // copy into x_out + y_shuff = y.shuffle(Eigen::array({1, 0, 2})); + x_out = demucscpp::glu(y_shuff, 0); +} + +void demucscpp_v3::apply_shared_encoder_5( + const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &x_in, Eigen::Tensor3dXf &x_out, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers) +{ + // 2D Convolution operation + Eigen::Tensor3dXf y; + Eigen::Tensor3dXf y_shuff; + + const int encoder_idx = 1; + + // shuffle first two dims of x_in + // first assign y to x_in with first two dims swapped + y = x_in.shuffle(Eigen::array({1, 0, 2})); + + y = demucscpp::conv1d<768, 1536, 4, 2, 1, 1>( + y, model.encoder_5_conv_weight, + model.encoder_4_5_conv_bias[encoder_idx]); + + // apply groupnorm + y = groupnorm::group_norm_fused_gelu( + y, model.encoder_4_5_norm1_weight[encoder_idx], + model.encoder_4_5_norm1_bias[encoder_idx], 4, 1e-05); + + // special dconv with bilstm + local attn + demucscpp_v3::apply_dconv_v3_encoder_4_5(model, y, encoder_idx, + y.dimension(2), buffers); + + // need rewrite, norm2, glu + y = demucscpp::conv1d<1536, 3072, 1, 1, 0, 1>( + y, model.encoder_4_5_rewrite_weight[encoder_idx], + model.encoder_4_5_rewrite_bias[encoder_idx]); + + // apply groupnorm + y = demucscpp::group_norm(y, model.encoder_4_5_norm2_weight[encoder_idx], + model.encoder_4_5_norm2_bias[encoder_idx], 4, + 1e-05); + + // copy into x_out + x_out = demucscpp::glu(y, 1); +} + +Eigen::Tensor3dXf demucscpp_v3::apply_shared_decoder_0( + const struct demucscpp_v3::demucs_v3_model &model, Eigen::Tensor3dXf &x_out, + const Eigen::Tensor3dXf &skip) +{ + const int decoder_idx = 0; + + // input is empty, so we use skip directly + Eigen::Tensor3dXf y = skip; + + y = demucscpp::conv1d<1536, 3072, 3, 1, 1, 1>( + y, model.decoder_0_rewrite_weight, + model.decoder_0_1_rewrite_bias[decoder_idx]); + + // apply groupnorm1 with norm1 weights + y = groupnorm::group_norm(y, model.decoder_0_1_norm1_weight[decoder_idx], + model.decoder_0_1_norm1_bias[decoder_idx], 4, + 1e-05); + + y = demucscpp::glu(y, 1); + + // return pre, to be used optionally (as first input to time decoder) + Eigen::Tensor3dXf pre_ret = y; + + // no dconv for decoders + // simply conv_tr -> norm2 + y = demucscpp::conv1d_tr<1536, 768, 4, 2, 0, 1>( + y, model.decoder_0_conv_tr_weight, + model.decoder_0_1_conv_tr_bias[decoder_idx]); + + y = groupnorm::group_norm_fused_gelu( + y, model.decoder_0_1_norm2_weight[decoder_idx], + model.decoder_0_1_norm2_bias[decoder_idx], 4, 1e-05); + + // remove extra elems equivalent to `1:337` + x_out = y.slice(Eigen::array({0, 0, 1}), + Eigen::array( + {y.dimension(0), y.dimension(1), FREQ_BRANCH_LEN})); + + return pre_ret; +} + +Eigen::Tensor3dXf demucscpp_v3::apply_freq_decoder_1( + const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &x_in, Eigen::Tensor3dXf &x_out, + const Eigen::Tensor3dXf &skip) +{ + const int decoder_idx = 1; + + Eigen::Tensor3dXf y = x_in.shuffle(Eigen::array({1, 0, 2})) + skip; + + // first glu(norm1(rewrite)) + y = demucscpp::conv2d<768, 1536, 3, 3, 1, 1, 1, 1, 1, 1>( + y, model.decoder_1_rewrite_weight, + model.decoder_0_1_rewrite_bias[decoder_idx]); + + // apply groupnorm1 with norm1 weights + y = groupnorm::group_norm_2(y, model.decoder_0_1_norm1_weight[decoder_idx], + model.decoder_0_1_norm1_bias[decoder_idx], 4, + 1e-05); + + y = demucscpp::glu(y, 0); + + // return pre, to be used optionally (as first input to time decoder) + Eigen::Tensor3dXf pre_ret = y; + + // no dconv for decoders + // simply conv_tr -> norm2 + + // 2D Convolution operation + y = demucscpp::conv2d_tr<768, 384, 8, 1, 4, 1, 0, 0, 1, 1>( + y, model.decoder_1_conv_tr_weight, + model.decoder_0_1_conv_tr_bias[decoder_idx]); + + y = groupnorm::group_norm_fused_gelu_2( + y, model.decoder_0_1_norm2_weight[decoder_idx], + model.decoder_0_1_norm2_bias[decoder_idx], 4, 1e-05); + + // no slicing for this one + + x_out = y; + return pre_ret; +} + +void demucscpp_v3::apply_time_decoder_0( + const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &x_in, Eigen::Tensor3dXf &x_out) +{ + // simple decoder + // rewrite and conv_tr, no group norms + // swap first two dims + Eigen::Tensor3dXf y_shuff = x_in.shuffle(Eigen::array({1, 0, 2})); + + // no norm1, rewrite, dconv for tdecoder0 + // simply conv_tr -> norm2 + + // 2D Convolution operation + Eigen::Tensor3dXf y = demucscpp::conv1d_tr<768, 384, 8, 4, 0, 1>( + y_shuff, model.tdecoder_0_conv_tr_weight, + model.tdecoder_0_conv_tr_bias); + + // now apply groupnorm2 with norm2 weights + y = groupnorm::group_norm_fused_gelu(y, model.tdecoder_0_norm2_weight, + model.tdecoder_0_norm2_bias, 4, 1e-05); + + // for time branch, crop to length + int out_length = x_out.dimension(2); + x_out = y.slice(Eigen::array({0, 0, 2}), + Eigen::array( + {y.dimension(0), y.dimension(1), out_length})); +} + +void demucscpp_v3::apply_common_decoder( + const struct demucscpp_v3::demucs_v3_model &model, + const int freq_or_time_idx, const int decoder_idx, + const Eigen::Tensor3dXf &x_in, Eigen::Tensor3dXf &x_out, + const Eigen::Tensor3dXf &skip) +{ + // simple decoder + // rewrite and conv_tr, no group norms + + Eigen::Tensor3dXf y = x_in + skip; + + // first glu(norm1(rewrite)) + if ((freq_or_time_idx == 0) && (decoder_idx == 0)) + { + y = demucscpp::conv2d<384, 768, 3, 3, 1, 1, 1, 1, 1, 1>( + y, model.freq_decoders_rewrite_weight[decoder_idx], + model.decoders_rewrite_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 1) && (decoder_idx == 0)) + { + y = demucscpp::conv1d<384, 768, 3, 1, 1, 1>( + y, model.time_decoders_rewrite_weight[decoder_idx], + model.decoders_rewrite_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 0) && (decoder_idx == 1)) + { + y = demucscpp::conv2d<192, 384, 3, 3, 1, 1, 1, 1, 1, 1>( + y, model.freq_decoders_rewrite_weight[decoder_idx], + model.decoders_rewrite_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 1) && (decoder_idx == 1)) + { + y = demucscpp::conv1d<192, 384, 3, 1, 1, 1>( + y, model.time_decoders_rewrite_weight[decoder_idx], + model.decoders_rewrite_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 0) && (decoder_idx == 2)) + { + y = demucscpp::conv2d<96, 192, 3, 3, 1, 1, 1, 1, 1, 1>( + y, model.freq_decoders_rewrite_weight[decoder_idx], + model.decoders_rewrite_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 1) && (decoder_idx == 2)) + { + y = demucscpp::conv1d<96, 192, 3, 1, 1, 1>( + y, model.time_decoders_rewrite_weight[decoder_idx], + model.decoders_rewrite_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 0) && (decoder_idx == 3)) + { + y = demucscpp::conv2d<48, 96, 3, 3, 1, 1, 1, 1, 1, 1>( + y, model.freq_decoders_rewrite_weight[decoder_idx], + model.decoders_rewrite_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 1) && (decoder_idx == 3)) + { + y = demucscpp::conv1d<48, 96, 3, 1, 1, 1>( + y, model.time_decoders_rewrite_weight[decoder_idx], + model.decoders_rewrite_bias[freq_or_time_idx][decoder_idx]); + } + + y = demucscpp::glu(y, freq_or_time_idx); + + // no dconv for decoders, no norm2 + // simply conv_tr + + // 2D Convolution operation + if ((freq_or_time_idx == 0) && (decoder_idx == 0)) + { + y = demucscpp::conv2d_tr_fused_gelu<384, 192, 8, 1, 4, 1, 0, 0, 1, 1>( + y, model.freq_decoders_conv_tr_weight[decoder_idx], + model.decoders_conv_tr_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 1) && (decoder_idx == 0)) + { + y = demucscpp::conv1d_tr_fused_gelu<384, 192, 8, 4, 0, 1>( + y, model.time_decoders_conv_tr_weight[decoder_idx], + model.decoders_conv_tr_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 0) && (decoder_idx == 1)) + { + y = demucscpp::conv2d_tr_fused_gelu<192, 96, 8, 1, 4, 1, 0, 0, 1, 1>( + y, model.freq_decoders_conv_tr_weight[decoder_idx], + model.decoders_conv_tr_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 1) && (decoder_idx == 1)) + { + y = demucscpp::conv1d_tr_fused_gelu<192, 96, 8, 4, 0, 1>( + y, model.time_decoders_conv_tr_weight[decoder_idx], + model.decoders_conv_tr_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 0) && (decoder_idx == 2)) + { + y = demucscpp::conv2d_tr_fused_gelu<96, 48, 8, 1, 4, 1, 0, 0, 1, 1>( + y, model.freq_decoders_conv_tr_weight[decoder_idx], + model.decoders_conv_tr_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 1) && (decoder_idx == 2)) + { + y = demucscpp::conv1d_tr_fused_gelu<96, 48, 8, 4, 0, 1>( + y, model.time_decoders_conv_tr_weight[decoder_idx], + model.decoders_conv_tr_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 0) && (decoder_idx == 3)) + { + y = demucscpp::conv2d_tr<48, 16, 8, 1, 4, 1, 0, 0, 1, 1>( + y, model.freq_decoders_conv_tr_weight[decoder_idx], + model.decoders_conv_tr_bias[freq_or_time_idx][decoder_idx]); + } + else if ((freq_or_time_idx == 1) && (decoder_idx == 3)) + { + y = demucscpp::conv1d_tr<48, 8, 8, 4, 0, 1>( + y, model.time_decoders_conv_tr_weight[decoder_idx], + model.decoders_conv_tr_bias[freq_or_time_idx][decoder_idx]); + } + + if (freq_or_time_idx == 1) + { + // for time branch, crop to length + int out_length = x_out.dimension(2); + x_out = y.slice(Eigen::array({0, 0, 2}), + Eigen::array( + {y.dimension(0), y.dimension(1), out_length})); + } + else + { + // for freq branch + int desired_dim1_length = x_out.dimension(1); + + // remove 2 elements from begin and end of y along dimension 1 (0, 1, 2) + x_out = + y.slice(Eigen::array({0, 2, 0}), + Eigen::array( + {y.dimension(0), desired_dim1_length, y.dimension(2)})); + } +} diff --git a/src/encdec.hpp b/src/encdec.hpp index 23a4fd6..e4a29ab 100644 --- a/src/encdec.hpp +++ b/src/encdec.hpp @@ -29,4 +29,55 @@ void apply_time_decoder(const struct demucscpp::demucs_model &model, const Eigen::Tensor3dXf &skip); } // namespace demucscpp +namespace demucscpp_v3 +{ +void apply_freq_encoder_v3(const struct demucscpp_v3::demucs_v3_model &model, + int encoder_idx, const Eigen::Tensor3dXf &x_in, + Eigen::Tensor3dXf &x_out); + +// forward declaration to apply a time encoder +void apply_time_encoder_v3(const struct demucscpp_v3::demucs_v3_model &model, + int encoder_idx, const Eigen::Tensor3dXf &xt_in, + Eigen::Tensor3dXf &xt_out); + +// unique time encoder 4 +void apply_time_encoder_4(const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &xt_in, + Eigen::Tensor3dXf &xt_out); + +// freq encoder 4, shared encoder 5 +// uniquely contain bilstm, localattn +void apply_freq_encoder_4( + const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &x_in, const Eigen::Tensor3dXf &x_inject, + Eigen::Tensor3dXf &x_out, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers); + +void apply_shared_encoder_5( + const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &x_in, Eigen::Tensor3dXf &x_out, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers); + +Eigen::Tensor3dXf +apply_shared_decoder_0(const struct demucscpp_v3::demucs_v3_model &model, + Eigen::Tensor3dXf &x_out, const Eigen::Tensor3dXf &skip); + +Eigen::Tensor3dXf +apply_freq_decoder_1(const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &x_in, Eigen::Tensor3dXf &x_out, + const Eigen::Tensor3dXf &skip); + +void apply_time_decoder_0(const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::Tensor3dXf &x_in, + Eigen::Tensor3dXf &x_out); + +// forward declaration to apply a common freq or time decoder +void apply_common_decoder(const struct demucscpp_v3::demucs_v3_model &model, + const int freq_or_time_idx, const int decoder_idx, + const Eigen::Tensor3dXf &x_in, + Eigen::Tensor3dXf &x_out, + const Eigen::Tensor3dXf &skip); + +} // namespace demucscpp_v3 + #endif // ENCDEC_HPP diff --git a/src/layers.cpp b/src/layers.cpp index 01239aa..c13d023 100644 --- a/src/layers.cpp +++ b/src/layers.cpp @@ -1,4 +1,6 @@ #include "layers.hpp" +#include "conv.hpp" +#include "lstm.hpp" #include "model.hpp" #include "tensor.hpp" #include @@ -527,3 +529,585 @@ void demucscpp::common_encoder_layer( q = q_shuf; } + +void demucscpp_v3::local_attention( + Eigen::Tensor3dXf &x, // x = frequency, time, or combined + // input tensor [B, C, T] + const Eigen::Tensor3dXf &content_weight, + const Eigen::Tensor1dXf &content_bias, + const Eigen::Tensor3dXf &query_weight, const Eigen::Tensor1dXf &query_bias, + const Eigen::Tensor3dXf &key_weight, const Eigen::Tensor1dXf &key_bias, + const Eigen::Tensor3dXf &query_decay_weight, + const Eigen::Tensor1dXf &query_decay_bias, + const Eigen::Tensor2dXf &query_decay_kernel, + const Eigen::Tensor3dXf &proj_weight, const Eigen::Tensor1dXf &proj_bias, + const int hidden_size) +{ + // local-attention block + + int B = x.dimension(0); + int C = x.dimension(1); + int T = x.dimension(2); + + const int num_heads = demucscpp_v3::LOCAL_ATTN_N_HEADS; + + // apply query conv1d on x + Eigen::Tensor3dXf queries; + Eigen::Tensor3dXf query_decays; + Eigen::Tensor3dXf keys; + Eigen::Tensor3dXf content; + + if (hidden_size == 192) + { + queries = demucscpp::conv1d<192, 192, 1, 1, 0, 1>(x, query_weight, + query_bias); + keys = demucscpp::conv1d<192, 192, 1, 1, 0, 1>(x, key_weight, key_bias); + query_decays = demucscpp::conv1d<192, 16, 1, 1, 0, 1>( + x, query_decay_weight, query_decay_bias); + content = demucscpp::conv1d<192, 192, 1, 1, 0, 1>(x, content_weight, + content_bias); + } + else + { + queries = demucscpp::conv1d<384, 384, 1, 1, 0, 1>(x, query_weight, + query_bias); + keys = demucscpp::conv1d<384, 384, 1, 1, 0, 1>(x, key_weight, key_bias); + query_decays = demucscpp::conv1d<384, 16, 1, 1, 0, 1>( + x, query_decay_weight, query_decay_bias); + content = demucscpp::conv1d<384, 384, 1, 1, 0, 1>(x, content_weight, + content_bias); + } + + // so far, this is correct and matches pytorch + + int features_per_head = C / num_heads; + + // now implement dots calculation + + // Initialize the dots tensor + Eigen::Tensor4dXf dots(B, num_heads, T, T); + dots.setZero(); + + // Precompute the square root of features_per_head + float sqrt_features_per_head = std::sqrt(features_per_head); + + // apply a sigmoid activation with a 1/2 incorporated + query_decays = query_decays.unaryExpr( + [](float v) { return 0.5f / (1.0f + std::exp(-v)); }); + + // Initialize the weights tensor for softmax + Eigen::Tensor4dXf weights(B, num_heads, T, T); + + // Loop structure to compute both dot products and apply decay + // simultaneously + for (int b = 0; b < B; ++b) + { + for (int h = 0; h < num_heads; ++h) + { + for (int t = 0; t < T; ++t) + { + for (int s = 0; s < T; ++s) + { + float dot_product = 0.0f; + float decay_effect = 0.0f; + + // Compute the standard dot product + for (int c = 0; c < features_per_head; ++c) + { + int channel_index = h * features_per_head + c; + dot_product += queries(b, channel_index, s) * + keys(b, channel_index, t); + } + dots(b, h, t, s) = dot_product / sqrt_features_per_head; + + // Calculate decay effect for this dot product + for (int n = 0; n < LOCAL_ATTN_N_DECAY; ++n) + { + int decay_index = std::abs( + t - s); // Assuming decay_kernel is indexed by delta + float decay_kernel_value = + query_decay_kernel(n, decay_index); + + // Transform query_decay by applying sigmoid directly + // here + float decay_query_value = + query_decays(b, h * LOCAL_ATTN_N_DECAY + n, s); + + decay_effect += decay_kernel_value * decay_query_value; + } + + // Apply decay effect directly to the dot product + if (t != s) + { + dots(b, h, t, s) += decay_effect; + } + else + { + dots(b, h, t, s) = -100.0f; + } + } + } + + for (int t = 0; t < T; ++t) + { + float max_val = -std::numeric_limits::infinity(); + for (int s = 0; s < T; ++s) + { + if (dots(b, h, s, t) > max_val) + { + max_val = dots(b, h, s, t); + } + } + + float sum_exp = 0.0f; + // Calculate the exponentials and sum them + for (int s = 0; s < T; ++s) + { + weights(b, h, s, t) = std::exp(dots(b, h, s, t) - max_val); + sum_exp += weights(b, h, s, t); + } + + // Normalize the weights to form a proper probability + // distribution + for (int s = 0; s < T; ++s) + { + weights(b, h, s, t) /= sum_exp; + } + } + } + } + + // Initialize the reshaped result tensor directly + Eigen::Tensor3dXf reshaped_result(B, C, T); + reshaped_result.setZero(); + + // Merge computation of result tensor and reshaping + for (int b = 0; b < B; ++b) + { + for (int h = 0; h < num_heads; ++h) + { + for (int c = 0; c < C / num_heads; ++c) + { + for (int s = 0; s < T; ++s) + { + for (int t = 0; t < T; ++t) + { + // Directly update the reshaped_result tensor + int full_channel_index = h * (C / num_heads) + c; + reshaped_result(b, full_channel_index, s) += + weights(b, h, t, s) * + content(b, h * (C / num_heads) + c, t); + } + } + } + } + } + + // Apply projection layer + Eigen::Tensor3dXf projected_result; + if (hidden_size == 192) + { + projected_result = demucscpp::conv1d<192, 192, 1, 1, 0, 1>( + reshaped_result, proj_weight, proj_bias); + } + else + { + projected_result = demucscpp::conv1d<384, 384, 1, 1, 0, 1>( + reshaped_result, proj_weight, proj_bias); + } + + // Add x to projected_result + x += projected_result; +} + +void demucscpp_v3::apply_dconv_v3( + const struct demucscpp_v3::demucs_v3_model &model, Eigen::Tensor3dXf &y, + int freq_idx, int layer_idx, int mid_crop) +{ + // store another copy of y to sum back later + Eigen::Tensor3dXf y_copy = y; + + // now dconv time + + switch (layer_idx) + { + case 0: + y = demucscpp::conv1d<48, 12, 3, 1, 1, 1>( + y, model.dconv_layers_0_conv1d_weight[freq_idx][layer_idx][0], + model.dconv_layers_0_conv1d_bias[freq_idx][layer_idx][0]); + break; + case 1: + y = demucscpp::conv1d<96, 24, 3, 1, 1, 1>( + y, model.dconv_layers_0_conv1d_weight[freq_idx][layer_idx][0], + model.dconv_layers_0_conv1d_bias[freq_idx][layer_idx][0]); + break; + case 2: + y = demucscpp::conv1d<192, 48, 3, 1, 1, 1>( + y, model.dconv_layers_0_conv1d_weight[freq_idx][layer_idx][0], + model.dconv_layers_0_conv1d_bias[freq_idx][layer_idx][0]); + break; + case 3: + y = demucscpp::conv1d<384, 96, 3, 1, 1, 1>( + y, model.dconv_layers_0_conv1d_weight[freq_idx][layer_idx][0], + model.dconv_layers_0_conv1d_bias[freq_idx][layer_idx][0]); + break; + }; + + // y = demucscpp_v3::groupnorm::group_norm_fused_gelu( + // y, + // model.dconv_layers_1_groupnorm_weight[freq_idx][layer_idx] + // [0], + // model.dconv_layers_1_groupnorm_bias[freq_idx][layer_idx][0], + // 1, + // 1e-05); + + y = demucscpp::group_norm_fused_gelu( + y, model.dconv_layers_1_groupnorm_weight[freq_idx][layer_idx][0], + model.dconv_layers_1_groupnorm_bias[freq_idx][layer_idx][0], 1e-05); + + switch (layer_idx) + { + case 0: + y = demucscpp::conv1d<12, 96, 1, 1, 0, 1>( + y, model.dconv_layers_3_conv1d_weight[freq_idx][layer_idx][0], + model.dconv_layers_3_conv1d_bias[freq_idx][layer_idx][0]); + break; + case 1: + y = demucscpp::conv1d<24, 192, 1, 1, 0, 1>( + y, model.dconv_layers_3_conv1d_weight[freq_idx][layer_idx][0], + model.dconv_layers_3_conv1d_bias[freq_idx][layer_idx][0]); + break; + case 2: + y = demucscpp::conv1d<48, 384, 1, 1, 0, 1>( + y, model.dconv_layers_3_conv1d_weight[freq_idx][layer_idx][0], + model.dconv_layers_3_conv1d_bias[freq_idx][layer_idx][0]); + break; + case 3: + y = demucscpp::conv1d<96, 768, 1, 1, 0, 1>( + y, model.dconv_layers_3_conv1d_weight[freq_idx][layer_idx][0], + model.dconv_layers_3_conv1d_bias[freq_idx][layer_idx][0]); + break; + }; + + y = demucscpp::group_norm( + y, model.dconv_layers_4_groupnorm_weight[freq_idx][layer_idx][0], + model.dconv_layers_4_groupnorm_bias[freq_idx][layer_idx][0], 1, 1e-05); + + y = demucscpp::glu(y, 1); + + y = demucscpp::layer_scale( + y, model.dconv_layers_6_scale[freq_idx][layer_idx][0]); + + // now we add y to itself + y = y + y_copy; + + // store another copy of y to sum back later + y_copy = y; + + // NEXT ENTIRE SUBSEQUENCE OF DCONV WITH SLIGHTLY DIFFERENT PARAMS + + // Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,)) + switch (layer_idx) + { + case 0: + y = demucscpp::conv1d<48, 12, 3, 1, 2, 2>( + y, model.dconv_layers_0_conv1d_weight[freq_idx][layer_idx][1], + model.dconv_layers_0_conv1d_bias[freq_idx][layer_idx][1]); + break; + case 1: + y = demucscpp::conv1d<96, 24, 3, 1, 2, 2>( + y, model.dconv_layers_0_conv1d_weight[freq_idx][layer_idx][1], + model.dconv_layers_0_conv1d_bias[freq_idx][layer_idx][1]); + break; + case 2: + y = demucscpp::conv1d<192, 48, 3, 1, 2, 2>( + y, model.dconv_layers_0_conv1d_weight[freq_idx][layer_idx][1], + model.dconv_layers_0_conv1d_bias[freq_idx][layer_idx][1]); + break; + case 3: + y = demucscpp::conv1d<384, 96, 3, 1, 2, 2>( + y, model.dconv_layers_0_conv1d_weight[freq_idx][layer_idx][1], + model.dconv_layers_0_conv1d_bias[freq_idx][layer_idx][1]); + break; + }; + + Eigen::Tensor3dXf y_cropped = + y.slice(Eigen::array({0, 0, 0}), + Eigen::array( + {y.dimension(0), y.dimension(1), mid_crop})); + + y = y_cropped; + + y = demucscpp::group_norm_fused_gelu( + y, model.dconv_layers_1_groupnorm_weight[freq_idx][layer_idx][1], + model.dconv_layers_1_groupnorm_bias[freq_idx][layer_idx][1], 1e-05); + + // Conv1d(6, 96, kernel_size=(1,), stride=(1,)) + switch (layer_idx) + { + case 0: + y = demucscpp::conv1d<12, 96, 1, 1, 0, 1>( + y, model.dconv_layers_3_conv1d_weight[freq_idx][layer_idx][1], + model.dconv_layers_3_conv1d_bias[freq_idx][layer_idx][1]); + break; + case 1: + y = demucscpp::conv1d<24, 192, 1, 1, 0, 1>( + y, model.dconv_layers_3_conv1d_weight[freq_idx][layer_idx][1], + model.dconv_layers_3_conv1d_bias[freq_idx][layer_idx][1]); + break; + case 2: + y = demucscpp::conv1d<48, 384, 1, 1, 0, 1>( + y, model.dconv_layers_3_conv1d_weight[freq_idx][layer_idx][1], + model.dconv_layers_3_conv1d_bias[freq_idx][layer_idx][1]); + break; + case 3: + y = demucscpp::conv1d<96, 768, 1, 1, 0, 1>( + y, model.dconv_layers_3_conv1d_weight[freq_idx][layer_idx][1], + model.dconv_layers_3_conv1d_bias[freq_idx][layer_idx][1]); + break; + }; + + y = demucscpp::group_norm( + y, model.dconv_layers_4_groupnorm_weight[freq_idx][layer_idx][1], + model.dconv_layers_4_groupnorm_bias[freq_idx][layer_idx][1], 1, 1e-05); + + y = demucscpp::glu(y, 1); + y = demucscpp::layer_scale( + y, model.dconv_layers_6_scale[freq_idx][layer_idx][1]); + + // if y_copy is shorter than y in the last dim + // pad the last dim with zeros to match + + if (y_copy.dimension(2) < y.dimension(2)) + { + // pad the last dim with zeros to match + Eigen::Tensor3dXf padded_tensor_copy( + y_copy.dimension(0), y_copy.dimension(1), y.dimension(2)); + padded_tensor_copy.setZero(); + padded_tensor_copy.slice(Eigen::array({0, 0, 0}), + y_copy.dimensions()) = y_copy; + y_copy = padded_tensor_copy; + } + + // now sum with itself + y = y + y_copy; +} + +void demucscpp_v3::apply_dconv_v3_encoder_4_5( + const struct demucscpp_v3::demucs_v3_model &model, Eigen::Tensor3dXf &y, + int encoder_idx, int mid_crop, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers) +{ + int lstm_hidden_size = encoder_idx == 0 ? demucscpp_v3::LSTM_HIDDEN_SIZE_0 + : demucscpp_v3::LSTM_HIDDEN_SIZE_1; + + // store another copy of y to sum back later + Eigen::Tensor3dXf y_copy = y; + + // now dconv time + + switch (encoder_idx) + { + case 0: + y = demucscpp::conv1d<768, 192, 3, 1, 1, 1>( + y, model.encoder_4_5_dconv_layers_0_conv1d_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_0_conv1d_bias[encoder_idx][0]); + break; + case 1: + y = demucscpp::conv1d<1536, 384, 3, 1, 1, 1>( + y, model.encoder_4_5_dconv_layers_0_conv1d_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_0_conv1d_bias[encoder_idx][0]); + break; + }; + + y = demucscpp::group_norm_fused_gelu( + y, model.encoder_4_5_dconv_layers_1_groupnorm_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_1_groupnorm_bias[encoder_idx][0], 1e-05); + + // transpose it to put time seq last + Eigen::MatrixXf y_mat = + Eigen::Map(y.data(), y.dimension(1), y.dimension(2)) + .transpose(); + + // then, bilstm + demucscpp_v3::lstm_forward(model, encoder_idx, 0, y_mat, buffers, + lstm_hidden_size); + + // access last element of the last dim which is the output of the bilstm + Eigen::MatrixXf lstm_out_0 = buffers.lstm_output[encoder_idx][0][1]; + + // set lstm state to 0 + demucscpp_v3::lstm_reset_zero(encoder_idx, 0, buffers); + + // apply the linear layer on the lstm_out_0 + lstm_out_0 = (lstm_out_0 * + model.encoder_4_5_dconv_layers_3_linear_weight[encoder_idx][0] + .transpose()) + .rowwise() + + model.encoder_4_5_dconv_layers_3_linear_bias[encoder_idx][0] + .transpose(); + + // then apply skip connection + lstm_out_0 = lstm_out_0 + y_mat; + + // copy it to a original 3d tensor + y = Eigen::TensorMap( + lstm_out_0.data(), lstm_out_0.rows(), 1, lstm_out_0.cols()); + + // swap dims from 0,1,2 to 1,2,0 + Eigen::Tensor3dXf y_shuff = y.shuffle(Eigen::array({1, 2, 0})); + + // then, localattn + demucscpp_v3::local_attention( + y_shuff, + model.encoder_4_5_dconv_layers_4_content_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_4_content_bias[encoder_idx][0], + model.encoder_4_5_dconv_layers_4_query_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_4_query_bias[encoder_idx][0], + model.encoder_4_5_dconv_layers_4_key_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_4_key_bias[encoder_idx][0], + model.encoder_4_5_dconv_layers_4_query_decay_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_4_query_decay_bias[encoder_idx][0], + buffers.local_attn_decay_kernel, + model.encoder_4_5_dconv_layers_4_proj_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_4_proj_bias[encoder_idx][0], + lstm_hidden_size); + + y = y_shuff; + + switch (encoder_idx) + { + case 0: + y = demucscpp::conv1d<192, 1536, 1, 1, 0, 1>( + y, model.encoder_4_5_dconv_layers_5_conv1d_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_5_conv1d_bias[encoder_idx][0]); + break; + case 1: + y = demucscpp::conv1d<384, 3072, 1, 1, 0, 1>( + y, model.encoder_4_5_dconv_layers_5_conv1d_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_5_conv1d_bias[encoder_idx][0]); + break; + }; + + y = demucscpp::group_norm( + y, model.encoder_4_5_dconv_layers_6_groupnorm_weight[encoder_idx][0], + model.encoder_4_5_dconv_layers_6_groupnorm_bias[encoder_idx][0], 1, + 1e-05); + + y = demucscpp::glu(y, 1); + + y = demucscpp::layer_scale( + y, model.encoder_4_5_dconv_layers_8_scale[encoder_idx][0]); + + // now we add y to itself + y = y + y_copy; + + // store another copy of y to sum back later + y_copy = y; + + // NEXT ENTIRE SUBSEQUENCE OF DCONV WITH SLIGHTLY DIFFERENT PARAMS + // now dconv time + + switch (encoder_idx) + { + case 0: + y = demucscpp::conv1d<768, 192, 3, 1, 2, 2>( + y, model.encoder_4_5_dconv_layers_0_conv1d_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_0_conv1d_bias[encoder_idx][1]); + break; + case 1: + y = demucscpp::conv1d<1536, 384, 3, 1, 2, 2>( + y, model.encoder_4_5_dconv_layers_0_conv1d_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_0_conv1d_bias[encoder_idx][1]); + break; + }; + + Eigen::Tensor3dXf y_cropped = + y.slice(Eigen::array({0, 0, 0}), + Eigen::array( + {y.dimension(0), y.dimension(1), mid_crop})); + + y = y_cropped; + + y = demucscpp::group_norm_fused_gelu( + y, model.encoder_4_5_dconv_layers_1_groupnorm_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_1_groupnorm_bias[encoder_idx][1], 1e-05); + + // transpose it to put time seq last + y_mat = + Eigen::Map(y.data(), y.dimension(1), y.dimension(2)) + .transpose(); + + // then, bilstm + demucscpp_v3::lstm_forward(model, encoder_idx, 1, y_mat, buffers, + lstm_hidden_size); + + // access last element of the last dim which is the output of the bilstm + lstm_out_0 = buffers.lstm_output[encoder_idx][1][1]; + + // reset lstm state to 0 + demucscpp_v3::lstm_reset_zero(encoder_idx, 1, buffers); + + // apply the linear layer on the lstm_out_0 + lstm_out_0 = (lstm_out_0 * + model.encoder_4_5_dconv_layers_3_linear_weight[encoder_idx][1] + .transpose()) + .rowwise() + + model.encoder_4_5_dconv_layers_3_linear_bias[encoder_idx][1] + .transpose(); + + // then apply skip connection + lstm_out_0 = lstm_out_0 + y_mat; + + // copy it to a original 3d tensor + y = Eigen::TensorMap( + lstm_out_0.data(), lstm_out_0.rows(), 1, lstm_out_0.cols()); + + // swap dims from 0,1,2 to 1,2,0 + y_shuff = y.shuffle(Eigen::array({1, 2, 0})); + + // then, localattn + demucscpp_v3::local_attention( + y_shuff, + model.encoder_4_5_dconv_layers_4_content_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_4_content_bias[encoder_idx][1], + model.encoder_4_5_dconv_layers_4_query_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_4_query_bias[encoder_idx][1], + model.encoder_4_5_dconv_layers_4_key_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_4_key_bias[encoder_idx][1], + model.encoder_4_5_dconv_layers_4_query_decay_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_4_query_decay_bias[encoder_idx][1], + buffers.local_attn_decay_kernel, + model.encoder_4_5_dconv_layers_4_proj_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_4_proj_bias[encoder_idx][1], + lstm_hidden_size); + + y = y_shuff; + + switch (encoder_idx) + { + case 0: + y = demucscpp::conv1d<192, 1536, 1, 1, 0, 1>( + y, model.encoder_4_5_dconv_layers_5_conv1d_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_5_conv1d_bias[encoder_idx][1]); + break; + case 1: + y = demucscpp::conv1d<384, 3072, 1, 1, 0, 1>( + y, model.encoder_4_5_dconv_layers_5_conv1d_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_5_conv1d_bias[encoder_idx][1]); + break; + }; + + y = demucscpp::group_norm( + y, model.encoder_4_5_dconv_layers_6_groupnorm_weight[encoder_idx][1], + model.encoder_4_5_dconv_layers_6_groupnorm_bias[encoder_idx][1], 1, + 1e-05); + + y = demucscpp::glu(y, 1); + + y = demucscpp::layer_scale( + y, model.encoder_4_5_dconv_layers_8_scale[encoder_idx][1]); + + // now sum with itself + y = y + y_copy; +} diff --git a/src/layers.hpp b/src/layers.hpp index d84e804..47725c8 100644 --- a/src/layers.hpp +++ b/src/layers.hpp @@ -93,7 +93,138 @@ inline float calculate_variance(const Eigen::Tensor1dXf &tensor, float mean) float variance = sum_squares(0) / (tensor.size() - 1); return variance; } - } // namespace demucscpp +namespace demucscpp_v3 +{ +void apply_dconv_v3(const struct demucscpp_v3::demucs_v3_model &model, + Eigen::Tensor3dXf &y, int freq_idx, int layer_idx, + int mid_crop); + +void apply_dconv_v3_encoder_4_5( + const struct demucscpp_v3::demucs_v3_model &model, Eigen::Tensor3dXf &y, + int encoder_idx, int mid_crop, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers); + +// new function for LocalState, a local attention layer used +// in demucs v3 +void local_attention(Eigen::Tensor3dXf &x, // x = frequency, time, or combined + const Eigen::Tensor3dXf &content_weight, + const Eigen::Tensor1dXf &content_bias, + const Eigen::Tensor3dXf &query_weight, + const Eigen::Tensor1dXf &query_bias, + const Eigen::Tensor3dXf &key_weight, + const Eigen::Tensor1dXf &key_bias, + const Eigen::Tensor3dXf &query_decay_weight, + const Eigen::Tensor1dXf &query_decay_bias, + const Eigen::Tensor2dXf &query_decay_kernel, + const Eigen::Tensor3dXf &proj_weight, + const Eigen::Tensor1dXf &proj_bias, const int hidden_size); + +// this shit is complicated, give it its own namespace +namespace groupnorm +{ +using ActivationFunc = std::function; + +// this applies the group norm on the second dimension +template +inline Eigen::Tensor3dXf generalized_group_norm(const Eigen::Tensor3dXf &x, + const Eigen::Tensor1dXf &weight, + const Eigen::Tensor1dXf &bias, + int num_groups, float eps, + ActivationFunc activation_func) +{ + int freq = x.dimension(0); + int channels = x.dimension(1); + int width = x.dimension(2); + + Eigen::Tensor3dXf y_out(freq, channels, width); + y_out.setZero(); + + int group_size = channels / num_groups; + + for (int g = 0; g < num_groups; ++g) + { + int start = g * group_size; + int end = (g + 1) * group_size; + + Eigen::Tensor3dXf slice = + x.slice(Eigen::array{0, start, 0}, + Eigen::array{freq, group_size, width}); + + Eigen::Tensor mean_tensor = slice.mean(); + float mean = mean_tensor(0); + float var = demucscpp::calculate_variance(slice, mean); + + for (int i = 0; i < freq; ++i) + { + for (int c = start; c < end; ++c) + { + for (int w = 0; w < width; ++w) + { + float norm_val = (x(i, c, w) - mean) / std::sqrt(var + eps); + norm_val = norm_val * weight(c) + bias(c); + y_out(i, c, w) = activation_func(norm_val); + } + } + } + } + + return y_out; +} + +inline float gelu(float a) +{ + return 0.5f * a * (1.0f + std::erf(a / std::sqrt(2.0f))); +} + +inline Eigen::Tensor3dXf group_norm(const Eigen::Tensor3dXf &x, + const Eigen::Tensor1dXf &weight, + const Eigen::Tensor1dXf &bias, + int num_groups, float eps) +{ + return generalized_group_norm(x, weight, bias, num_groups, eps, + [](float x) { return x; }); +} + +inline Eigen::Tensor3dXf group_norm_fused_gelu(const Eigen::Tensor3dXf &x, + const Eigen::Tensor1dXf &weight, + const Eigen::Tensor1dXf &bias, + int num_groups, float eps) +{ + return generalized_group_norm(x, weight, bias, num_groups, eps, + [](float x) + { return demucscpp_v3::groupnorm::gelu(x); }); +} + +inline Eigen::Tensor3dXf group_norm_2(const Eigen::Tensor3dXf &x, + const Eigen::Tensor1dXf &weight, + const Eigen::Tensor1dXf &bias, + int num_groups, float eps) +{ + Eigen::array shuffle_dims = {1, 0, 2}; // Shuffle dimensions + + // Shuffle, apply group norm, and unshuffle + Eigen::Tensor3dXf x_shuffled = x.shuffle(shuffle_dims); + Eigen::Tensor3dXf y_shuffled = generalized_group_norm( + x_shuffled, weight, bias, num_groups, eps, [](float x) { return x; }); + return y_shuffled.shuffle(shuffle_dims); +} + +inline Eigen::Tensor3dXf group_norm_fused_gelu_2( + const Eigen::Tensor3dXf &x, const Eigen::Tensor1dXf &weight, + const Eigen::Tensor1dXf &bias, int num_groups, float eps) +{ + Eigen::array shuffle_dims = {1, 0, 2}; // Shuffle dimensions + + // Shuffle, apply group norm with GELU, and unshuffle + Eigen::Tensor3dXf x_shuffled = x.shuffle(shuffle_dims); + Eigen::Tensor3dXf y_shuffled = generalized_group_norm( + x_shuffled, weight, bias, num_groups, eps, + [](float x) { return demucscpp_v3::groupnorm::gelu(x); }); + return y_shuffled.shuffle(shuffle_dims); +} +} // namespace groupnorm +} // namespace demucscpp_v3 + #endif // LAYERS_HPP diff --git a/src/lstm.cpp b/src/lstm.cpp new file mode 100644 index 0000000..a87aba3 --- /dev/null +++ b/src/lstm.cpp @@ -0,0 +1,147 @@ +#include "lstm.hpp" +#include "Eigen/Dense" +#include "model.hpp" +#include + +// preliminary shapes: +// +// input of shape (batch, input_size) or (input_size): tensor containing +// input features +// +// h_0 of shape (batch, hidden_size) or (hidden_size): tensor containing +// the initial hidden state c_0 of shape (batch, hidden_size) or +// (hidden_size): tensor containing the initial cell state +// +// weight_ih (torch.Tensor) – the learnable input-hidden weights, of +// shape (4*hidden_size, input_size) +// presumably consisting of: W_ii, W_if, W_ig, W_io +// weight_hh (torch.Tensor) – the learnable hidden-hidden weights, of +// shape (4*hidden_size, hidden_size) +// presumably consisting of: W_hi, W_hf, W_hg, W_ho +// +// similarly for biases: +// bias_ih (torch.Tensor) – the learnable input-hidden bias, of +// shape (4*hidden_size) bias_hh (torch.Tensor) – the learnable +// hidden-hidden bias, of shape (4*hidden_size) +// +// it = sigmoid(W_ii x_t + b_ii + W_hi h_t + b_hi) +// ft = sigmoid(W_if x_t + b_if + W_hf h_t + b_hf) +// gt = tanh(W_ig x_t + b_ig + W_hg h_t + b_hg) +// ot = sigmoid(W_io x_t + b_io + W_ho h_t + b_ho) +// ct = f * c + i * g +// eigen's array() multiplication is element-wise multiplication +// i.e. Hadamard product +// ht = o * tanh(c) + +static Eigen::MatrixXf sigmoid(const Eigen::MatrixXf &x) +{ + return 1.0 / (1.0 + (-x).array().exp()); +} + +void demucscpp_v3::lstm_reset_zero( + const int encoder_idx, const int dconv_idx, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers) +{ + for (int lstm_layer = 0; lstm_layer < 2; ++lstm_layer) + { + for (int direction = 0; direction < 2; ++direction) + { + // Reset hidden and cell states to zero + buffers.lstm_hidden[encoder_idx][dconv_idx][lstm_layer][direction] + .setZero(); + buffers.lstm_cell[encoder_idx][dconv_idx][lstm_layer][direction] + .setZero(); + + // Assuming lstm_output_per_direction and lstm_output are structured + // similarly and need to be reset as well. Adjust dimensions as + // necessary. + buffers + .lstm_output_per_direction[encoder_idx][dconv_idx][lstm_layer] + [direction] + .setZero(); + } + // Reset the concatenated output buffer for each layer + buffers.lstm_output[encoder_idx][dconv_idx][lstm_layer].setZero(); + } +} + +void demucscpp_v3::lstm_forward(const demucs_v3_model &model, + const int encoder_idx, const int dconv_idx, + const Eigen::MatrixXf &input, + demucs_v3_segment_buffers &buffers, + int hidden_size) +{ + int seq_len = input.rows(); // Time sequence is now along the rows + int hidden_state_size = hidden_size; + + Eigen::MatrixXf loop_input = input; + + for (int lstm_layer = 0; lstm_layer < 2; ++lstm_layer) + { + for (int direction = 0; direction < 2; ++direction) + { + // Initialize hidden and cell states to zero if not already done + // buffers.lstm_hidden[encoder_idx][dconv_idx][lstm_layer][direction].setZero(hidden_state_size, + // 1); + // buffers.lstm_cell[encoder_idx][dconv_idx][lstm_layer][direction].setZero(hidden_state_size, + // 1); + + for (int t = (direction == 0 ? 0 : seq_len - 1); + (direction == 0 ? t < seq_len : t > -1); + t += (direction == 0 ? 1 : -1)) + { + Eigen::MatrixXf gates = + model.encoder_4_5_dconv_layers_3_lstm_ih_w + [encoder_idx][dconv_idx][lstm_layer][direction] * + loop_input.row(t).transpose() + + model.encoder_4_5_dconv_layers_3_lstm_ih_b + [encoder_idx][dconv_idx][lstm_layer][direction] + + model.encoder_4_5_dconv_layers_3_lstm_hh_w + [encoder_idx][dconv_idx][lstm_layer][direction] * + buffers.lstm_hidden[encoder_idx][dconv_idx][lstm_layer] + [direction] + + model.encoder_4_5_dconv_layers_3_lstm_hh_b + [encoder_idx][dconv_idx][lstm_layer][direction]; + + Eigen::MatrixXf i_t = + sigmoid(gates.block(0, 0, hidden_state_size, 1)); + Eigen::MatrixXf f_t = sigmoid( + gates.block(hidden_state_size, 0, hidden_state_size, 1)); + Eigen::MatrixXf g_t = + gates.block(2 * hidden_state_size, 0, hidden_state_size, 1) + .array() + .tanh(); + Eigen::MatrixXf o_t = sigmoid(gates.block( + 3 * hidden_state_size, 0, hidden_state_size, 1)); + + Eigen::MatrixXf c_t = + f_t.array() * buffers + .lstm_cell[encoder_idx][dconv_idx] + [lstm_layer][direction] + .array() + + i_t.array() * g_t.array(); + Eigen::MatrixXf h_t = o_t.array() * c_t.array().tanh(); + + buffers.lstm_hidden[encoder_idx][dconv_idx][lstm_layer] + [direction] = h_t; + buffers + .lstm_cell[encoder_idx][dconv_idx][lstm_layer][direction] = + c_t; + + buffers + .lstm_output_per_direction[encoder_idx][dconv_idx] + [lstm_layer][direction] + .row(t) = h_t.transpose(); // Adjusted for transposed output + } + } + + // Concatenate the outputs from both directions + buffers.lstm_output[encoder_idx][dconv_idx][lstm_layer] + << buffers.lstm_output_per_direction[encoder_idx][dconv_idx] + [lstm_layer][0], + buffers.lstm_output_per_direction[encoder_idx][dconv_idx] + [lstm_layer][1]; + + loop_input = buffers.lstm_output[encoder_idx][dconv_idx][lstm_layer]; + } +} diff --git a/src/lstm.hpp b/src/lstm.hpp new file mode 100644 index 0000000..fe74b07 --- /dev/null +++ b/src/lstm.hpp @@ -0,0 +1,21 @@ +#ifndef LSTM_HPP +#define LSTM_HPP + +#include "model.hpp" +#include + +namespace demucscpp_v3 +{ + +void lstm_forward(const struct demucscpp_v3::demucs_v3_model &model, + const int encoder_idx, const int dconv_idx, + const Eigen::MatrixXf &input, + struct demucscpp_v3::demucs_v3_segment_buffers &data, + int hidden_size); + +void lstm_reset_zero(const int encoder_idx, const int dconv_idx, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers); + +}; // namespace demucscpp_v3 + +#endif // LSTM_HPP diff --git a/src/model.hpp b/src/model.hpp index 0e92b33..630ef4e 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -591,9 +591,8 @@ struct demucs_segment_buffers Eigen::Tensor3dXf x_3_channel_upsampled; // time branch - Eigen::Tensor3dXf xt; // input - Eigen::Tensor3dXf xt_out; // output - Eigen::Tensor3dXf xt_decoded_out; // hold time decoder output + Eigen::Tensor3dXf xt; // input + Eigen::Tensor3dXf xt_out; // output Eigen::Tensor3dXf xt_0; Eigen::Tensor3dXf xt_1; Eigen::Tensor3dXf xt_2; @@ -636,9 +635,8 @@ struct demucs_segment_buffers x_3_channel_upsampled(512, 8, FREQ_BRANCH_LEN), xt(1, nb_channels, segment_samples), xt_out(1, nb_sources * nb_channels, segment_samples), - xt_decoded_out(1, 8, segment_samples), xt_0(1, 48, TIME_BRANCH_LEN_0), - xt_1(1, 96, TIME_BRANCH_LEN_1), xt_2(1, 192, TIME_BRANCH_LEN_2), - xt_3(1, 384, TIME_BRANCH_LEN_3), + xt_0(1, 48, TIME_BRANCH_LEN_0), xt_1(1, 96, TIME_BRANCH_LEN_1), + xt_2(1, 192, TIME_BRANCH_LEN_2), xt_3(1, 384, TIME_BRANCH_LEN_3), xt_3_channel_upsampled(1, 512, TIME_BRANCH_LEN_3), saved_0(48, 512, FREQ_BRANCH_LEN), saved_1(96, 128, FREQ_BRANCH_LEN), saved_2(192, 32, FREQ_BRANCH_LEN), saved_3(384, 8, FREQ_BRANCH_LEN), @@ -668,4 +666,753 @@ void model_inference(const struct demucs_model &model, float segment_progress); } // namespace demucscpp +// V3 Hybrid time-frequency model (no transformer) +namespace demucscpp_v3 +{ + +const int FREQ_BRANCH_LEN = 336; +const int TIME_BRANCH_LEN_IN = 343980; +const int TIME_BRANCH_LEN_0 = 85995; +const int TIME_BRANCH_LEN_1 = 21499; +const int TIME_BRANCH_LEN_2 = 5375; +const int TIME_BRANCH_LEN_3 = 1344; +const int TIME_BRANCH_LEN_4 = 336; + +const int SHARED_BRANCH_LEN = 168; + +// dconv lstm constants +// the seq len is 336, the final encoded time branch length +// (for both time and frequency) +const int LSTM_HIDDEN_SIZE_0 = 192; +const int LSTM_HIDDEN_SIZE_1 = 384; + +// dconv localstate +const int LOCAL_ATTN_N_HEADS = 4; +const int LOCAL_ATTN_N_FREQS = 0; +const int LOCAL_ATTN_N_DECAY = 4; +const int LOCAL_ATTN_CHANNELS = 192; + +struct demucs_v3_model +{ + // Encoder convolution layers + Eigen::Tensor3dXf encoder_conv_weight[4] = { + Eigen::Tensor3dXf(48, 4, 8), Eigen::Tensor3dXf(96, 48, 8), + Eigen::Tensor3dXf(192, 96, 8), Eigen::Tensor3dXf(384, 192, 8)}; + + Eigen::Tensor1dXf encoder_conv_bias[4] = { + Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(192), + Eigen::Tensor1dXf(384)}; + + // Encoder rewrite layers + Eigen::Tensor3dXf encoder_rewrite_weight[4] = { + Eigen::Tensor3dXf(96, 48, 1), Eigen::Tensor3dXf(192, 96, 1), + Eigen::Tensor3dXf(384, 192, 1), Eigen::Tensor3dXf(768, 384, 1)}; + + Eigen::Tensor1dXf encoder_rewrite_bias[4] = { + Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(384), + Eigen::Tensor1dXf(768)}; + + // TEncoder 0-3 + Eigen::Tensor3dXf tencoder_conv_weight[4] = { + Eigen::Tensor3dXf(48, 2, 8), Eigen::Tensor3dXf(96, 48, 8), + Eigen::Tensor3dXf(192, 96, 8), Eigen::Tensor3dXf(384, 192, 8)}; + + Eigen::Tensor1dXf tencoder_conv_bias[4] = { + Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(192), + Eigen::Tensor1dXf(384)}; + + Eigen::Tensor3dXf tencoder_rewrite_weight[4] = { + Eigen::Tensor3dXf(96, 48, 1), Eigen::Tensor3dXf(192, 96, 1), + Eigen::Tensor3dXf(384, 192, 1), Eigen::Tensor3dXf(768, 384, 1)}; + + Eigen::Tensor1dXf tencoder_rewrite_bias[4] = { + Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(384), + Eigen::Tensor1dXf(768)}; + + // DConv layers + // first index: time or frequency + // second index: enc/dec layer number + // third index: dconv 0 or 1 + // this takes care of 4 freq encoders and 4 time encoders + // each with 2 dconv layers + Eigen::Tensor3dXf dconv_layers_0_conv1d_weight[2][4][2]{ + {{Eigen::Tensor3dXf(12, 48, 3), Eigen::Tensor3dXf(12, 48, 3)}, + {Eigen::Tensor3dXf(24, 96, 3), Eigen::Tensor3dXf(24, 96, 3)}, + {Eigen::Tensor3dXf(48, 192, 3), Eigen::Tensor3dXf(48, 192, 3)}, + {Eigen::Tensor3dXf(96, 384, 3), Eigen::Tensor3dXf(96, 384, 3)}}, + {{Eigen::Tensor3dXf(12, 48, 3), Eigen::Tensor3dXf(12, 48, 3)}, + {Eigen::Tensor3dXf(24, 96, 3), Eigen::Tensor3dXf(24, 96, 3)}, + {Eigen::Tensor3dXf(48, 192, 3), Eigen::Tensor3dXf(48, 192, 3)}, + {Eigen::Tensor3dXf(96, 384, 3), Eigen::Tensor3dXf(96, 384, 3)}}}; + + Eigen::Tensor1dXf dconv_layers_0_conv1d_bias[2][4][2]{ + {{Eigen::Tensor1dXf(12), Eigen::Tensor1dXf(12)}, + {Eigen::Tensor1dXf(24), Eigen::Tensor1dXf(24)}, + {Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(48)}, + {Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}}, + {{Eigen::Tensor1dXf(12), Eigen::Tensor1dXf(12)}, + {Eigen::Tensor1dXf(24), Eigen::Tensor1dXf(24)}, + {Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(48)}, + {Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}}}; + + Eigen::Tensor1dXf dconv_layers_1_groupnorm_weight[2][4][2]{ + {{Eigen::Tensor1dXf(12), Eigen::Tensor1dXf(12)}, + {Eigen::Tensor1dXf(24), Eigen::Tensor1dXf(24)}, + {Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(48)}, + {Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}}, + {{Eigen::Tensor1dXf(12), Eigen::Tensor1dXf(12)}, + {Eigen::Tensor1dXf(24), Eigen::Tensor1dXf(24)}, + {Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(48)}, + {Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}}}; + + Eigen::Tensor1dXf dconv_layers_1_groupnorm_bias[2][4][2]{ + {{Eigen::Tensor1dXf(12), Eigen::Tensor1dXf(12)}, + {Eigen::Tensor1dXf(24), Eigen::Tensor1dXf(24)}, + {Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(48)}, + {Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}}, + {{Eigen::Tensor1dXf(12), Eigen::Tensor1dXf(12)}, + {Eigen::Tensor1dXf(24), Eigen::Tensor1dXf(24)}, + {Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(48)}, + {Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}}}; + + Eigen::Tensor3dXf dconv_layers_3_conv1d_weight[2][4][2]{ + {{Eigen::Tensor3dXf(96, 12, 1), Eigen::Tensor3dXf(96, 12, 1)}, + {Eigen::Tensor3dXf(192, 24, 1), Eigen::Tensor3dXf(192, 24, 1)}, + {Eigen::Tensor3dXf(384, 48, 1), Eigen::Tensor3dXf(384, 48, 1)}, + {Eigen::Tensor3dXf(768, 96, 1), Eigen::Tensor3dXf(768, 96, 1)}}, + {{Eigen::Tensor3dXf(96, 12, 1), Eigen::Tensor3dXf(96, 12, 1)}, + {Eigen::Tensor3dXf(192, 24, 1), Eigen::Tensor3dXf(192, 24, 1)}, + {Eigen::Tensor3dXf(384, 48, 1), Eigen::Tensor3dXf(384, 48, 1)}, + {Eigen::Tensor3dXf(768, 96, 1), Eigen::Tensor3dXf(768, 96, 1)}}}; + + Eigen::Tensor1dXf dconv_layers_3_conv1d_bias[2][4][2]{ + {{Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}, + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}, + {Eigen::Tensor1dXf(768), Eigen::Tensor1dXf(768)}}, + {{Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}, + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}, + {Eigen::Tensor1dXf(768), Eigen::Tensor1dXf(768)}}}; + + Eigen::Tensor1dXf dconv_layers_4_groupnorm_weight[2][4][2]{ + {{Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}, + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}, + {Eigen::Tensor1dXf(768), Eigen::Tensor1dXf(768)}}, + {{Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}, + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}, + {Eigen::Tensor1dXf(768), Eigen::Tensor1dXf(768)}}}; + + Eigen::Tensor1dXf dconv_layers_4_groupnorm_bias[2][4][2]{ + {{Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}, + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}, + {Eigen::Tensor1dXf(768), Eigen::Tensor1dXf(768)}}, + {{Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}, + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}, + {Eigen::Tensor1dXf(768), Eigen::Tensor1dXf(768)}}}; + + Eigen::Tensor1dXf dconv_layers_6_scale[2][4][2]{ + {{Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(48)}, + {Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}, + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}}, + {{Eigen::Tensor1dXf(48), Eigen::Tensor1dXf(48)}, + {Eigen::Tensor1dXf(96), Eigen::Tensor1dXf(96)}, + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}}}; + + // time encoder 4 is super simple, just 1 conv + Eigen::Tensor3dXf tencoder_4_conv_weight{Eigen::Tensor3dXf(768, 384, 8)}; + + Eigen::Tensor1dXf tencoder_4_conv_bias{Eigen::Tensor1dXf(768)}; + + // freq encoder 4 and shared encoder 5 + // have the bilistm and localattention layers, similar to each other + // index of two to hold both + Eigen::Tensor4dXf encoder_4_conv_weight{Eigen::Tensor4dXf(768, 384, 8, 1)}; + Eigen::Tensor3dXf encoder_5_conv_weight{Eigen::Tensor3dXf(1536, 768, 4)}; + + Eigen::Tensor1dXf encoder_4_5_conv_bias[2]{Eigen::Tensor1dXf(768), + Eigen::Tensor1dXf(1536)}; + + Eigen::Tensor1dXf encoder_4_5_norm1_weight[2]{Eigen::Tensor1dXf(768), + Eigen::Tensor1dXf(1536)}; + + Eigen::Tensor1dXf encoder_4_5_norm1_bias[2]{Eigen::Tensor1dXf(768), + Eigen::Tensor1dXf(1536)}; + + Eigen::Tensor3dXf encoder_4_5_rewrite_weight[2]{ + Eigen::Tensor3dXf(1536, 768, 1), Eigen::Tensor3dXf(3072, 1536, 1)}; + + Eigen::Tensor1dXf encoder_4_5_rewrite_bias[2]{Eigen::Tensor1dXf(1536), + Eigen::Tensor1dXf(3072)}; + + Eigen::Tensor1dXf encoder_4_5_norm2_weight[2]{Eigen::Tensor1dXf(1536), + Eigen::Tensor1dXf(3072)}; + + Eigen::Tensor1dXf encoder_4_5_norm2_bias[2]{Eigen::Tensor1dXf(1536), + Eigen::Tensor1dXf(3072)}; + + Eigen::Tensor3dXf encoder_4_5_dconv_layers_0_conv1d_weight[2][2]{ + {Eigen::Tensor3dXf(192, 768, 3), Eigen::Tensor3dXf(192, 768, 3)}, + {Eigen::Tensor3dXf(384, 1536, 3), Eigen::Tensor3dXf(384, 1536, 3)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_0_conv1d_bias[2][2]{ + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_1_groupnorm_weight[2][2]{ + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_1_groupnorm_bias[2][2]{ + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}}; + + // 2 encoders, 2 dconv layers, 2 layer bi-lstm (2 layers 2 directions) => + // [2][2][2][2] first index = encoder, second index = dconv layer, third + // index = layer, fourth index = direction + Eigen::MatrixXf encoder_4_5_dconv_layers_3_lstm_ih_w[2][2][2][2]{ + // encoder 4 + { + // dconv layer 0 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(768, 192), Eigen::MatrixXf(768, 192)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(768, 384), Eigen::MatrixXf(768, 384)}, + }, + // dconv layer 1 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(768, 192), Eigen::MatrixXf(768, 192)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(768, 384), Eigen::MatrixXf(768, 384)}, + }, + }, + // encoder 5 + { + // dconv layer 0 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(1536, 384), Eigen::MatrixXf(1536, 384)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(1536, 768), Eigen::MatrixXf(1536, 768)}, + }, + // dconv layer 1 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(1536, 384), Eigen::MatrixXf(1536, 384)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(1536, 768), Eigen::MatrixXf(1536, 768)}, + }, + }}; + + Eigen::MatrixXf encoder_4_5_dconv_layers_3_lstm_ih_b[2][2][2][2]{ + // encoder 4 + { + // dconv layer 0 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(768, 1), Eigen::MatrixXf(768, 1)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(768, 1), Eigen::MatrixXf(768, 1)}, + }, + // dconv layer 1 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(768, 1), Eigen::MatrixXf(768, 1)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(768, 1), Eigen::MatrixXf(768, 1)}, + }, + }, + // encoder 5 + { + // dconv layer 0 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(1536, 1), Eigen::MatrixXf(1536, 1)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(1536, 1), Eigen::MatrixXf(1536, 1)}, + }, + // dconv layer 1 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(1536, 1), Eigen::MatrixXf(1536, 1)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(1536, 1), Eigen::MatrixXf(1536, 1)}, + }, + }, + }; + + Eigen::MatrixXf encoder_4_5_dconv_layers_3_lstm_hh_w[2][2][2][2]{ + // encoder 4 + { + // dconv layer 0 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(768, 192), Eigen::MatrixXf(768, 192)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(768, 192), Eigen::MatrixXf(768, 192)}, + }, + // dconv layer 1 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(768, 192), Eigen::MatrixXf(768, 192)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(768, 192), Eigen::MatrixXf(768, 192)}, + }, + }, + // encoder 5 + { + // dconv layer 0 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(1536, 384), Eigen::MatrixXf(1536, 384)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(1536, 384), Eigen::MatrixXf(1536, 384)}, + }, + // dconv layer 1 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(1536, 384), Eigen::MatrixXf(1536, 384)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(1536, 384), Eigen::MatrixXf(1536, 384)}, + }, + }, + }; + + Eigen::MatrixXf encoder_4_5_dconv_layers_3_lstm_hh_b[2][2][2][2]{ + // encoder 4 + { + // dconv layer 0 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(768, 1), Eigen::MatrixXf(768, 1)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(768, 1), Eigen::MatrixXf(768, 1)}, + }, + // dconv layer 1 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(768, 1), Eigen::MatrixXf(768, 1)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(768, 1), Eigen::MatrixXf(768, 1)}, + }, + }, + // encoder 5 + { + // dconv layer 0 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(1536, 1), Eigen::MatrixXf(1536, 1)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(1536, 1), Eigen::MatrixXf(1536, 1)}, + }, + // dconv layer 1 + { + // ih_l0, ih_l0_reverse + {Eigen::MatrixXf(1536, 1), Eigen::MatrixXf(1536, 1)}, + // ih_l1, ih_l1_reverse + {Eigen::MatrixXf(1536, 1), Eigen::MatrixXf(1536, 1)}, + }, + }, + }; + + Eigen::MatrixXf encoder_4_5_dconv_layers_3_linear_weight[2][2]{ + {Eigen::MatrixXf(192, 384), Eigen::MatrixXf(192, 384)}, + {Eigen::MatrixXf(384, 768), Eigen::MatrixXf(384, 768)}}; + + Eigen::VectorXf encoder_4_5_dconv_layers_3_linear_bias[2][2]{ + {Eigen::VectorXf(192), Eigen::VectorXf(192)}, + {Eigen::VectorXf(384), Eigen::VectorXf(384)}}; + + Eigen::Tensor3dXf encoder_4_5_dconv_layers_4_content_weight[2][2]{ + {Eigen::Tensor3dXf(192, 192, 1), Eigen::Tensor3dXf(192, 192, 1)}, + {Eigen::Tensor3dXf(384, 384, 1), Eigen::Tensor3dXf(384, 384, 1)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_4_content_bias[2][2]{ + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}}; + + Eigen::Tensor3dXf encoder_4_5_dconv_layers_4_query_weight[2][2]{ + {Eigen::Tensor3dXf(192, 192, 1), Eigen::Tensor3dXf(192, 192, 1)}, + {Eigen::Tensor3dXf(384, 384, 1), Eigen::Tensor3dXf(384, 384, 1)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_4_query_bias[2][2]{ + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}}; + + Eigen::Tensor3dXf encoder_4_5_dconv_layers_4_key_weight[2][2]{ + {Eigen::Tensor3dXf(192, 192, 1), Eigen::Tensor3dXf(192, 192, 1)}, + {Eigen::Tensor3dXf(384, 384, 1), Eigen::Tensor3dXf(384, 384, 1)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_4_key_bias[2][2]{ + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}}; + + Eigen::Tensor3dXf encoder_4_5_dconv_layers_4_query_decay_weight[2][2]{ + {Eigen::Tensor3dXf(16, 192, 1), Eigen::Tensor3dXf(16, 192, 1)}, + {Eigen::Tensor3dXf(16, 384, 1), Eigen::Tensor3dXf(16, 384, 1)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_4_query_decay_bias[2][2]{ + {Eigen::Tensor1dXf(16), Eigen::Tensor1dXf(16)}, + {Eigen::Tensor1dXf(16), Eigen::Tensor1dXf(16)}}; + + Eigen::Tensor3dXf encoder_4_5_dconv_layers_4_proj_weight[2][2]{ + {Eigen::Tensor3dXf(192, 192, 1), Eigen::Tensor3dXf(192, 192, 1)}, + {Eigen::Tensor3dXf(384, 384, 1), Eigen::Tensor3dXf(384, 384, 1)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_4_proj_bias[2][2]{ + {Eigen::Tensor1dXf(192), Eigen::Tensor1dXf(192)}, + {Eigen::Tensor1dXf(384), Eigen::Tensor1dXf(384)}}; + + Eigen::Tensor3dXf encoder_4_5_dconv_layers_5_conv1d_weight[2][2]{ + {Eigen::Tensor3dXf(1536, 192, 1), Eigen::Tensor3dXf(1536, 192, 1)}, + {Eigen::Tensor3dXf(3072, 384, 1), Eigen::Tensor3dXf(3072, 384, 1)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_5_conv1d_bias[2][2]{ + {Eigen::Tensor1dXf(1536), Eigen::Tensor1dXf(1536)}, + {Eigen::Tensor1dXf(3072), Eigen::Tensor1dXf(3072)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_6_groupnorm_weight[2][2]{ + {Eigen::Tensor1dXf(1536), Eigen::Tensor1dXf(1536)}, + {Eigen::Tensor1dXf(3072), Eigen::Tensor1dXf(3072)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_6_groupnorm_bias[2][2]{ + {Eigen::Tensor1dXf(1536), Eigen::Tensor1dXf(1536)}, + {Eigen::Tensor1dXf(3072), Eigen::Tensor1dXf(3072)}}; + + Eigen::Tensor1dXf encoder_4_5_dconv_layers_8_scale[2][2]{ + {Eigen::Tensor1dXf(768), Eigen::Tensor1dXf(768)}, + {Eigen::Tensor1dXf(1536), Eigen::Tensor1dXf(1536)}}; + + // now we need 8 decoders that have a simple similar structure + // conv_tr (weight + bias), rewrite (weight + bias) + // first array dim is [2] for freq, time + // next is layer for (2,3,4,5) for freq, (1,2,3,4) for time + // Reshaped struct arrays to [2][4] + Eigen::Tensor4dXf freq_decoders_conv_tr_weight[4]{ + Eigen::Tensor4dXf(384, 192, 8, 1), // decoder.2.conv_tr.weight + Eigen::Tensor4dXf(192, 96, 8, 1), // decoder.3.conv_tr.weight + Eigen::Tensor4dXf(96, 48, 8, 1), // decoder.4.conv_tr.weight + Eigen::Tensor4dXf(48, 16, 8, 1) // decoder.5.conv_tr.weight + }; + + Eigen::Tensor3dXf time_decoders_conv_tr_weight[4]{ + Eigen::Tensor3dXf(384, 192, 8), // tdecoder.1.conv_tr.weight + Eigen::Tensor3dXf(192, 96, 8), // tdecoder.2.conv_tr.weight + Eigen::Tensor3dXf(96, 48, 8), // tdecoder.3.conv_tr.weight + Eigen::Tensor3dXf(48, 8, 8) // tdecoder.4.conv_tr.weight + }; + + Eigen::Tensor1dXf decoders_conv_tr_bias[2][4]{ + { + Eigen::Tensor1dXf(192), // decoder.2.conv_tr.bias + Eigen::Tensor1dXf(96), // decoder.3.conv_tr.bias + Eigen::Tensor1dXf(48), // decoder.4.conv_tr.bias + Eigen::Tensor1dXf(16) // decoder.5.conv_tr.bias + }, + { + Eigen::Tensor1dXf(192), // tdecoder.1.conv_tr.bias + Eigen::Tensor1dXf(96), // tdecoder.2.conv_tr.bias + Eigen::Tensor1dXf(48), // tdecoder.3.conv_tr.bias + Eigen::Tensor1dXf(8) // tdecoder.4.conv_tr.bias + }}; + + Eigen::Tensor4dXf freq_decoders_rewrite_weight[4]{ + Eigen::Tensor4dXf(768, 384, 3, 3), // decoder.2.rewrite.weight + Eigen::Tensor4dXf(384, 192, 3, 3), // decoder.3.rewrite.weight + Eigen::Tensor4dXf(192, 96, 3, 3), // decoder.4.rewrite.weight + Eigen::Tensor4dXf(96, 48, 3, 3) // decoder.5.rewrite.weight + }; + + Eigen::Tensor3dXf time_decoders_rewrite_weight[4]{ + Eigen::Tensor3dXf(768, 384, 3), // tdecoder.1.rewrite.weight + Eigen::Tensor3dXf(384, 192, 3), // tdecoder.2.rewrite.weight + Eigen::Tensor3dXf(192, 96, 3), // tdecoder.3.rewrite.weight + Eigen::Tensor3dXf(96, 48, 3) // tdecoder.4.rewrite.weight + }; + + Eigen::Tensor1dXf decoders_rewrite_bias[2][4]{ + { + Eigen::Tensor1dXf(768), // decoder.2.rewrite.bias + Eigen::Tensor1dXf(384), // decoder.3.rewrite.bias + Eigen::Tensor1dXf(192), // decoder.4.rewrite.bias + Eigen::Tensor1dXf(96) // decoder.5.rewrite.bias + }, + { + Eigen::Tensor1dXf(768), // tdecoder.1.rewrite.bias + Eigen::Tensor1dXf(384), // tdecoder.2.rewrite.bias + Eigen::Tensor1dXf(192), // tdecoder.3.rewrite.bias + Eigen::Tensor1dXf(96) // tdecoder.4.rewrite.bias + }}; + + // Frequency Decoders + Eigen::Tensor3dXf decoder_0_conv_tr_weight{ + Eigen::Tensor3dXf(1536, 768, 4)}; // decoder.0.conv_tr.weight + + Eigen::Tensor4dXf decoder_1_conv_tr_weight{ + Eigen::Tensor4dXf(768, 384, 8, 1) // decoder.1.conv_tr.weight + }; + + Eigen::Tensor1dXf decoder_0_1_conv_tr_bias[2]{ + Eigen::Tensor1dXf(768), // decoder.0.conv_tr.bias + Eigen::Tensor1dXf(384) // decoder.1.conv_tr.bias + }; + + Eigen::Tensor1dXf decoder_0_1_norm2_weight[2]{ + Eigen::Tensor1dXf(768), // decoder.0.norm2.weight + Eigen::Tensor1dXf(384) // decoder.1.norm2.weight + }; + + Eigen::Tensor1dXf decoder_0_1_norm2_bias[2]{ + Eigen::Tensor1dXf(768), // decoder.0.norm2.bias + Eigen::Tensor1dXf(384) // decoder.1.norm2.bias + }; + + Eigen::Tensor3dXf decoder_0_rewrite_weight{ + Eigen::Tensor3dXf(3072, 1536, 3)}; + + Eigen::Tensor4dXf decoder_1_rewrite_weight{ + Eigen::Tensor4dXf(1536, 768, 3, 3)}; + + Eigen::Tensor1dXf decoder_0_1_rewrite_bias[2]{ + Eigen::Tensor1dXf(3072), // decoder.0.rewrite.bias + Eigen::Tensor1dXf(1536) // decoder.1.rewrite.bias + }; + + Eigen::Tensor1dXf decoder_0_1_norm1_weight[2]{ + Eigen::Tensor1dXf(3072), // decoder.0.norm1.weight + Eigen::Tensor1dXf(1536) // decoder.1.norm1.weight + }; + + Eigen::Tensor1dXf decoder_0_1_norm1_bias[2]{ + Eigen::Tensor1dXf(3072), // decoder.0.norm1.bias + Eigen::Tensor1dXf(1536) // decoder.1.norm1.bias + }; + + // Unique tdecoder 0 + Eigen::Tensor3dXf tdecoder_0_conv_tr_weight{ + Eigen::Tensor3dXf(768, 384, 8)}; // tdecoder.0.conv_tr.weight + Eigen::Tensor1dXf tdecoder_0_conv_tr_bias{ + Eigen::Tensor1dXf(384)}; // tdecoder.0.conv_tr.bias + Eigen::Tensor1dXf tdecoder_0_norm2_weight{ + Eigen::Tensor1dXf(384)}; // tdecoder.0.norm2.weight + Eigen::Tensor1dXf tdecoder_0_norm2_bias{ + Eigen::Tensor1dXf(384)}; // tdecoder.0.norm2.bias + + // freq_emb + Eigen::MatrixXf freq_emb_embedding_weight{Eigen::MatrixXf(512, 48)}; +}; + +struct demucs_v3_segment_buffers +{ + int segment_samples; + int le; + int pad; + int pad_end; + int padded_segment_samples; + int nb_stft_frames; + int nb_stft_bins; + + Eigen::MatrixXf mix; + Eigen::Tensor3dXf targets_out; + Eigen::MatrixXf padded_mix; + Eigen::Tensor3dXcf z; + + // freq branch, one for each encoded representation + Eigen::Tensor3dXf x; // input + Eigen::Tensor3dXf x_out; // input + Eigen::Tensor3dXf x_0; + Eigen::Tensor3dXf x_1; + Eigen::Tensor3dXf x_2; + Eigen::Tensor3dXf x_3; + Eigen::Tensor3dXf x_4; + + // shared after encoder 5 + Eigen::Tensor3dXf x_shared_5; + + // time branch + Eigen::Tensor3dXf xt; // input + Eigen::Tensor3dXf xt_out; // output + Eigen::Tensor3dXf xt_decoded_out; // hold time decoder output + Eigen::Tensor3dXf xt_0; + Eigen::Tensor3dXf xt_1; + Eigen::Tensor3dXf xt_2; + Eigen::Tensor3dXf xt_3; + Eigen::Tensor3dXf xt_4; + + // skip conns for frequency and time + // easier as hardcoded matrix sizes + Eigen::Tensor3dXf saved_0; + Eigen::Tensor3dXf saved_1; + Eigen::Tensor3dXf saved_2; + Eigen::Tensor3dXf saved_3; + Eigen::Tensor3dXf saved_4; + + Eigen::Tensor3dXf savedt_0; + Eigen::Tensor3dXf savedt_1; + Eigen::Tensor3dXf savedt_2; + Eigen::Tensor3dXf savedt_3; + + // LSTM data + // 2 encoders, 2 dconv layers, 2 layers, 2 directions + // per-direction buffers + Eigen::MatrixXf lstm_output_per_direction[2][2][2][2]; + Eigen::MatrixXf lstm_hidden[2][2][2][2]; + Eigen::MatrixXf lstm_cell[2][2][2][2]; + // out-of-direction buffers + Eigen::MatrixXf lstm_output[2][2][2]; + + // LocalAttention structs + Eigen::VectorXi local_attn_index; + Eigen::MatrixXi local_attn_delta; + Eigen::Tensor1dXf local_attn_decays; + + Eigen::Tensor2dXf local_attn_decay_kernel; + + // constructor for demucs_segment_buffers that takes int parameters + + // let's do pesky precomputing of the signal repadding to 1/4 hop + // for time and frequency alignment + demucs_v3_segment_buffers(int nb_channels, int segment_samples, + int nb_sources) + : segment_samples(segment_samples), + le(int(std::ceil((float)segment_samples / + (float)demucscpp::FFT_HOP_SIZE))), + pad(std::floor((float)demucscpp::FFT_HOP_SIZE / 2.0f) * 3), + pad_end(pad + le * demucscpp::FFT_HOP_SIZE - segment_samples), + padded_segment_samples(segment_samples + pad + pad_end), + nb_stft_frames(segment_samples / demucscpp::FFT_HOP_SIZE + 1), + nb_stft_bins(demucscpp::FFT_WINDOW_SIZE / 2 + 1), + mix(nb_channels, segment_samples), + targets_out(nb_sources, nb_channels, segment_samples), + padded_mix(nb_channels, padded_segment_samples), + z(nb_channels, nb_stft_bins, nb_stft_frames), + // complex-as-channels implies 2*nb_channels for real+imag + x(2 * nb_channels, nb_stft_bins - 1, nb_stft_frames), + x_out(nb_sources * 2 * nb_channels, nb_stft_bins - 1, nb_stft_frames), + x_0(48, 512, FREQ_BRANCH_LEN), x_1(96, 128, FREQ_BRANCH_LEN), + x_2(192, 32, FREQ_BRANCH_LEN), x_3(384, 8, FREQ_BRANCH_LEN), + x_4(768, 1, FREQ_BRANCH_LEN), + x_shared_5(1, 1536, SHARED_BRANCH_LEN), // merged freq and time + xt(1, nb_channels, segment_samples), + xt_out(1, nb_sources * nb_channels, segment_samples), + xt_0(1, 48, TIME_BRANCH_LEN_0), xt_1(1, 96, TIME_BRANCH_LEN_1), + xt_2(1, 192, TIME_BRANCH_LEN_2), xt_3(1, 384, TIME_BRANCH_LEN_3), + xt_4(1, 768, TIME_BRANCH_LEN_4), saved_0(48, 512, FREQ_BRANCH_LEN), + saved_1(96, 128, FREQ_BRANCH_LEN), saved_2(192, 32, FREQ_BRANCH_LEN), + saved_3(384, 8, FREQ_BRANCH_LEN), saved_4(768, 1, FREQ_BRANCH_LEN), + savedt_0(1, 48, TIME_BRANCH_LEN_0), + savedt_1(1, 96, TIME_BRANCH_LEN_1), + savedt_2(1, 192, TIME_BRANCH_LEN_2), + savedt_3(1, 384, TIME_BRANCH_LEN_3), + local_attn_index(FREQ_BRANCH_LEN), + local_attn_delta(FREQ_BRANCH_LEN, FREQ_BRANCH_LEN), + local_attn_decays(LOCAL_ATTN_N_DECAY), + local_attn_decay_kernel(LOCAL_ATTN_N_DECAY, FREQ_BRANCH_LEN) + { + // initialize lstm buffers + int hidden_size = -1; + int cell_size = -1; + int lstm_seq_len = -1; + + // encoder layer + for (int i = 0; i < 2; i++) + { + if (i == 0) + { + hidden_size = LSTM_HIDDEN_SIZE_0; + cell_size = LSTM_HIDDEN_SIZE_0; + lstm_seq_len = FREQ_BRANCH_LEN; + } + else + { + hidden_size = LSTM_HIDDEN_SIZE_1; + cell_size = LSTM_HIDDEN_SIZE_1; + lstm_seq_len = SHARED_BRANCH_LEN; + } + + // dconv layer + for (int j = 0; j < 2; j++) + { + // lstm layer + for (int k = 0; k < 2; k++) + { + // lstm direction + for (int l = 0; l < 2; l++) + { + lstm_output_per_direction[i][j][k][l] = + Eigen::MatrixXf::Zero(lstm_seq_len, hidden_size); + lstm_hidden[i][j][k][l] = + Eigen::MatrixXf::Zero(hidden_size, 1); + lstm_cell[i][j][k][l] = + Eigen::MatrixXf::Zero(cell_size, 1); + } + + lstm_output[i][j][k] = + Eigen::MatrixXf::Zero(lstm_seq_len, 2 * hidden_size); + } + } + } + // initialize local attn stuff + for (int i = 0; i < FREQ_BRANCH_LEN; ++i) + { + local_attn_index(i) = i; + } + + // delta = indexes[:, None] - indexes[None, :] + for (int i = 0; i < FREQ_BRANCH_LEN; ++i) + { + for (int j = 0; j < FREQ_BRANCH_LEN; ++j) + { + local_attn_delta(i, j) = + local_attn_index(i) - local_attn_index(j); + } + } + + // Decay levels from 1 to ndecay + for (int i = 0; i < LOCAL_ATTN_N_DECAY; ++i) + { + local_attn_decays(i) = i + 1; + } + + for (int d = 0; d < LOCAL_ATTN_N_DECAY; ++d) + { + for (int t = 0; t < FREQ_BRANCH_LEN; ++t) + { + local_attn_decay_kernel(d, t) = + -local_attn_decays(d) * std::abs(local_attn_delta(0, t)) / + std::sqrt(LOCAL_ATTN_N_DECAY); + } + } + }; +}; + +bool load_demucs_v3_model(const std::string &model_dir, + struct demucs_v3_model *model); + +const float SEGMENT_LEN_SECS = 7.8; // 8 seconds, the demucs chunk size +const float SEGMENT_OVERLAP_SECS = 0.25; // 0.25 overlap +const float MAX_SHIFT_SECS = 0.5; // max shift +const float OVERLAP = 0.25; // overlap between segments +const float TRANSITION_POWER = 1.0; // transition between segments + +Eigen::Tensor3dXf +demucs_v3_inference(const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::MatrixXf &full_audio, + demucscpp::ProgressCallback cb); + +void model_v3_inference(const struct demucs_v3_model &model, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers, + struct demucscpp::stft_buffers &stft_buf, + demucscpp::ProgressCallback cb, float current_progress, + float segment_progress); +} // namespace demucscpp_v3 + #endif // MODEL_HPP diff --git a/src/model_apply.cpp b/src/model_apply.cpp index bef25de..d7e5262 100644 --- a/src/model_apply.cpp +++ b/src/model_apply.cpp @@ -286,3 +286,250 @@ static Eigen::Tensor3dXf segment_inference( return chunk_out; } + +// forward declaration of inner fns +static Eigen::Tensor3dXf +shift_inference(const struct demucscpp_v3::demucs_v3_model &model, + Eigen::MatrixXf &full_audio, demucscpp::ProgressCallback cb); + +static Eigen::Tensor3dXf +split_inference(const struct demucscpp_v3::demucs_v3_model &model, + Eigen::MatrixXf &full_audio, demucscpp::ProgressCallback cb); + +static Eigen::Tensor3dXf segment_inference( + const struct demucscpp_v3::demucs_v3_model &model, Eigen::MatrixXf chunk, + int segment_sample, struct demucscpp_v3::demucs_v3_segment_buffers &buffers, + struct demucscpp::stft_buffers &stft_buf, demucscpp::ProgressCallback cb, + float current_progress, float segment_progress); + +Eigen::Tensor3dXf demucscpp_v3::demucs_v3_inference( + const struct demucscpp_v3::demucs_v3_model &model, + const Eigen::MatrixXf &audio, demucscpp::ProgressCallback cb) +{ + // working copy to modify + Eigen::MatrixXf full_audio = audio; + + // first, normalize the audio to mean and std + // ref = wav.mean(0) + // wav = (wav - ref.mean()) / ref.std() + // Calculate the overall mean and standard deviation + // Compute the mean and standard deviation separately for each channel + Eigen::VectorXf ref_mean_0 = full_audio.colwise().mean(); + + float ref_mean = ref_mean_0.mean(); + float ref_std = std::sqrt((ref_mean_0.array() - ref_mean).square().sum() / + (ref_mean_0.size() - 1)); + + // Normalize the audio + Eigen::MatrixXf normalized_audio = + (full_audio.array() - ref_mean) / ref_std; + + full_audio = normalized_audio; + + Eigen::Tensor3dXf waveform_outputs = shift_inference(model, full_audio, cb); + + // now inverse the normalization in Eigen C++ + // sources = sources * ref.std() + ref.mean() + waveform_outputs = (waveform_outputs * ref_std).eval() + ref_mean; + + return waveform_outputs; +} + +static Eigen::Tensor3dXf +shift_inference(const struct demucscpp_v3::demucs_v3_model &model, + Eigen::MatrixXf &full_audio, demucscpp::ProgressCallback cb) +{ + // first, apply shifts for time invariance + // we simply only support shift=1, the demucs default + // shifts (int): if > 0, will shift in time `mix` by a random amount between + // 0 and 0.5 sec + // and apply the oppositve shift to the output. This is repeated + // `shifts` time and all predictions are averaged. This effectively + // makes the model time equivariant and improves SDR by up to 0.2 + // points. + int max_shift = + (int)(demucscpp::MAX_SHIFT_SECS * demucscpp::SUPPORTED_SAMPLE_RATE); + + int length = full_audio.cols(); + + Eigen::MatrixXf padded_mix(2, length + 2 * max_shift); + + symmetric_zero_padding(padded_mix, full_audio, 2 * max_shift); + + int offset = rand() % max_shift; + // int offset = 1337; + + std::cout << "1., apply model w/ shift, offset: " << offset << std::endl; + + Eigen::MatrixXf shifted_audio = + padded_mix.block(0, offset, 2, length + max_shift - offset); + + Eigen::Tensor3dXf waveform_outputs = + split_inference(model, shifted_audio, cb); + + const int nb_out_sources = 4; + + // trim the output to the original length + // waveform_outputs = waveform_outputs[..., max_shift:max_shift + length] + Eigen::Tensor3dXf trimmed_waveform_outputs = + waveform_outputs + .reshape(Eigen::array( + {nb_out_sources, 2, + static_cast(waveform_outputs.dimension(2))})) + .slice(Eigen::array({0, 0, max_shift - offset}), + Eigen::array({nb_out_sources, 2, length})); + + return trimmed_waveform_outputs; +} + +static Eigen::Tensor3dXf +split_inference(const struct demucscpp_v3::demucs_v3_model &model, + Eigen::MatrixXf &full_audio, demucscpp::ProgressCallback cb) +{ + // calculate segment in samples + int segment_samples = + (int)(demucscpp::SEGMENT_LEN_SECS * demucscpp::SUPPORTED_SAMPLE_RATE); + + const int nb_out_sources = 4; + + // let's create reusable buffers with padded sizes + struct demucscpp_v3::demucs_v3_segment_buffers buffers(2, segment_samples, + nb_out_sources); + struct demucscpp::stft_buffers stft_buf(buffers.padded_segment_samples); + + // next, use splits with weighted transition and overlap + // split (bool): if True, the input will be broken down in 8 seconds + // extracts + // and predictions will be performed individually on each and + // concatenated. Useful for model with large memory footprint like + // Tasnet. + + int stride_samples = (int)((1 - demucscpp::OVERLAP) * segment_samples); + + int length = full_audio.cols(); + + // create an output tensor of zeros for four source waveforms + Eigen::Tensor3dXf out = Eigen::Tensor3dXf(nb_out_sources, 2, length); + out.setZero(); + + // create weight tensor + Eigen::VectorXf weight(segment_samples); + weight.setZero(); + + weight.head(segment_samples / 2) = + Eigen::VectorXf::LinSpaced(segment_samples / 2, 1, segment_samples / 2); + weight.tail(segment_samples / 2) = + weight.head(segment_samples / 2).reverse(); + weight /= weight.maxCoeff(); + weight = weight.array().pow(demucscpp::TRANSITION_POWER); + + Eigen::VectorXf sum_weight(length); + sum_weight.setZero(); + + // i prefer using `std::ceilf` but :shrug: + int total_chunks = ::ceilf((float)length / (float)stride_samples); + float increment_per_chunk = 1.0f / (float)total_chunks; + float inference_progress = 0.0f; + + for (int offset = 0; offset < length; offset += stride_samples) + { + // create a chunk of the padded_full_audio + int chunk_end = std::min(segment_samples, length - offset); + Eigen::MatrixXf chunk = full_audio.block(0, offset, 2, chunk_end); + int chunk_length = chunk.cols(); + + std::cout << "2., apply model w/ split, offset: " << offset + << ", chunk shape: (" << chunk.rows() << ", " << chunk.cols() + << ")" << std::endl; + + Eigen::Tensor3dXf chunk_out = + segment_inference(model, chunk, segment_samples, buffers, stft_buf, + cb, inference_progress, increment_per_chunk); + + // add the weighted chunk to the output + // out[..., offset:offset + segment] += (weight[:chunk_length] * + // chunk_out).to(mix.device) + for (int i = 0; i < nb_out_sources; ++i) + { + for (int j = 0; j < 2; ++j) + { + for (int k = 0; k < chunk_length; ++k) + { + if (offset + k >= length) + { + break; + } + out(i, j, offset + k) += + weight(k % chunk_length) * chunk_out(i, j, k); + } + } + } + + // sum_weight[offset:offset + segment] += + // weight[:chunk_length].to(mix.device) + for (int k = 0; k < chunk_length; ++k) + { + if (offset + k >= length) + { + break; + } + sum_weight(offset + k) += weight(k % chunk_length); + } + + inference_progress += increment_per_chunk; + } + + for (int i = 0; i < nb_out_sources; ++i) + { + for (int j = 0; j < 2; ++j) + { + for (int k = 0; k < length; ++k) + { + out(i, j, k) /= sum_weight[k]; + } + } + } + return out; +} + +static Eigen::Tensor3dXf +segment_inference(const struct demucscpp_v3::demucs_v3_model &model, + Eigen::MatrixXf chunk, int segment_samples, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers, + struct demucscpp::stft_buffers &stft_buf, + demucscpp::ProgressCallback cb, float current_progress, + float segment_progress) +{ + int chunk_length = chunk.cols(); + + // copy chunk into buffers.mix with symmetric zero-padding + // assign two ints to tuple return value + std::tuple padding = symmetric_zero_padding( + buffers.mix, chunk, segment_samples - chunk_length); + + // apply demucs inference + demucscpp_v3::model_v3_inference(model, buffers, stft_buf, cb, + current_progress, segment_progress); + + const int nb_out_sources = 4; + + // copy from buffers.targets_out into chunk_out with center trimming + Eigen::Tensor3dXf chunk_out = + Eigen::Tensor3dXf(nb_out_sources, 2, chunk_length); + chunk_out.setZero(); + + for (int i = 0; i < nb_out_sources; ++i) + { + for (int j = 0; j < 2; ++j) + { + for (int k = 0; k < chunk_length; ++k) + { + // undoing center_trim + chunk_out(i, j, k) = + buffers.targets_out(i, j, k + std::get<0>(padding)); + } + } + } + + return chunk_out; +} diff --git a/src/model_inference.cpp b/src/model_inference.cpp index 418b360..abffdab 100644 --- a/src/model_inference.cpp +++ b/src/model_inference.cpp @@ -473,3 +473,384 @@ void demucscpp::model_inference( } } } + +void demucscpp_v3::model_v3_inference( + const struct demucscpp_v3::demucs_v3_model &model, + struct demucscpp_v3::demucs_v3_segment_buffers &buffers, + struct demucscpp::stft_buffers &stft_buf, demucscpp::ProgressCallback cb, + float current_progress, float segment_progress) +{ + // apply demucs inference + std::ostringstream ss; + ss << "3., apply_model mix shape: (" << buffers.mix.rows() << ", " + << buffers.mix.cols() << ")"; + cb(current_progress + 0.0f, ss.str()); + ss.str(""); + + // pad buffers.pad on the left, reflect + // pad buffers.pad_end on the right, reflect + // copy buffers.mix into buffers.padded_mix with reflect padding as above + reflect_padding(buffers.padded_mix, buffers.mix, buffers.pad, + buffers.pad_end); + + // copy buffers.padded_mix into stft_buf.waveform + stft_buf.waveform = buffers.padded_mix; + + // let's get a stereo complex spectrogram first + demucscpp::stft(stft_buf); + + // remove 2: 2 + le of stft + // same behavior as _spec in the python apply.py code + buffers.z = stft_buf.spec.slice( + Eigen::array{0, 0, 2}, + Eigen::array{2, (int)stft_buf.spec.dimension(1), + (int)stft_buf.spec.dimension(2) - 4}); + + // print z shape + ss << "buffers.z: " << buffers.z.dimension(0) << ", " + << buffers.z.dimension(1) << ", " << buffers.z.dimension(2); + cb(current_progress + 0.0f, ss.str()); + ss.str(""); + + // x = mag = z.abs(), but for CaC we're simply stacking the complex + // spectrogram along the channel dimension + for (int i = 0; i < buffers.z.dimension(0); ++i) + { + // limiting to j-1 because we're dropping 2049 to 2048 bins + for (int j = 0; j < buffers.z.dimension(1) - 1; ++j) + { + for (int k = 0; k < buffers.z.dimension(2); ++k) + { + buffers.x(2 * i, j, k) = buffers.z(i, j, k).real(); + buffers.x(2 * i + 1, j, k) = buffers.z(i, j, k).imag(); + } + } + } + + // x shape is complex*chan, nb_frames, nb_bins (2048) + // using CaC (complex-as-channels) + // print x shape + ss << "buffers.x: " << buffers.x.dimension(0) << ", " + << buffers.x.dimension(1) << ", " << buffers.x.dimension(2); + cb(current_progress + 0.0f, ss.str()); + ss.str(""); + + // apply following pytorch operations to buffers.x in Eigen C++ code: + // mean = x.mean(dim=(1, 2, 3), keepdim=True) + // std = x.std(dim=(1, 2, 3), keepdim=True) + // x = (x - mean) / (1e-5 + std) + + // Compute mean and standard deviation using Eigen + Eigen::Tensor mean_tensor = buffers.x.mean(); + float mean = mean_tensor(0); + float variance = demucscpp::calculate_variance(buffers.x, mean); + float std_ = std::sqrt(variance); + + // Normalize x + const float epsilon = 1e-5; + + // buffers.x will be the freq branch input + buffers.x = (buffers.x - mean) / (std_ + epsilon); + + // prepare time branch input by copying buffers.mix into buffers.xt(0, ...) + for (int i = 0; i < buffers.mix.rows(); ++i) + { + for (int j = 0; j < buffers.mix.cols(); ++j) + { + buffers.xt(0, i, j) = buffers.mix(i, j); + } + } + + cb(current_progress + 0.0f, "Freq branch: normalized"); + + // apply similar mean, std normalization as above using 2d mean, std + Eigen::Tensor meant_tensor = buffers.xt.mean(); + float meant = meant_tensor(0); + float variancet = demucscpp::calculate_variance(buffers.xt, meant); + float stdt = std::sqrt(variancet); + + // Normalize x + buffers.xt = (buffers.xt - meant) / (stdt + epsilon); + + cb(current_progress + 0.0f, "Time branch: normalized"); + + // buffers.xt will be the time branch input + + /* HEART OF INFERENCE CODE HERE !! */ + // ITERATION 0 + + // apply tenc, enc + + float float_steps = 22.0f; + + demucscpp_v3::apply_time_encoder_v3(model, 0, buffers.xt, buffers.xt_0); + cb(current_progress + segment_progress * 1.0f / float_steps, + "Time encoder 0"); + + demucscpp_v3::apply_freq_encoder_v3(model, 0, buffers.x, buffers.x_0); + cb(current_progress + segment_progress * 2.0f / float_steps, + "Freq encoder 0"); + + // absorb both scaling factors in one expression + // i.e. eliminate const float freq_emb_scale = 0.2f; + const float emb_scale = 10.0f * 0.2f; + + Eigen::MatrixXf emb = + model.freq_emb_embedding_weight.transpose() * emb_scale; + + // apply embedding to buffers.x_0 + for (int i = 0; i < 48; ++i) + { + for (int j = 0; j < 512; ++j) + { + for (int k = 0; k < buffers.x_0.dimension(2); ++k) + { + // implicit broadcasting + buffers.x_0(i, j, k) += emb(i, j); + } + } + } + + buffers.saved_0 = buffers.x_0; + buffers.savedt_0 = buffers.xt_0; + + cb(current_progress + segment_progress * 2.0f / float_steps, + "Freq branch: applied frequency embedding"); + + apply_time_encoder_v3(model, 1, buffers.xt_0, buffers.xt_1); + cb(current_progress + segment_progress * 3.0f / float_steps, + "Time encoder 1"); + + apply_freq_encoder_v3(model, 1, buffers.x_0, buffers.x_1); + cb(current_progress + segment_progress * 4.0f / float_steps, + "Freq encoder 1"); + + buffers.saved_1 = buffers.x_1; + buffers.savedt_1 = buffers.xt_1; + + apply_time_encoder_v3(model, 2, buffers.xt_1, buffers.xt_2); + cb(current_progress + segment_progress * 5.0f / float_steps, + "Time encoder 2"); + + apply_freq_encoder_v3(model, 2, buffers.x_1, buffers.x_2); + cb(current_progress + segment_progress * 6.0f / float_steps, + "Freq encoder 2"); + + buffers.saved_2 = buffers.x_2; + buffers.savedt_2 = buffers.xt_2; + + apply_time_encoder_v3(model, 3, buffers.xt_2, buffers.xt_3); + cb(current_progress + segment_progress * 7.0f / float_steps, + "Time encoder 3"); + + apply_freq_encoder_v3(model, 3, buffers.x_2, buffers.x_3); + cb(current_progress + segment_progress * 8.0f / float_steps, + "Freq encoder 3"); + + buffers.saved_3 = buffers.x_3; + buffers.savedt_3 = buffers.xt_3; + + // t/time branch: unique tencoder 4 + apply_time_encoder_4(model, buffers.xt_3, buffers.xt_4); + cb(current_progress + segment_progress * 9.0f / float_steps, + "Time encoder 4"); + + // possible this is not used, since it is the "inject" parameter + // buffers.savedt_4 = buffers.xt_4; + + // z/spec branch: unique encoder 4 (bilstm, local attn) + // merge time and frequency with the inject parameter + apply_freq_encoder_4(model, buffers.x_3, buffers.xt_4, buffers.x_4, + buffers); + cb(current_progress + segment_progress * 10.0f / float_steps, + "Freq encoder 4"); + + buffers.saved_4 = buffers.x_4; + + // shared: unique encoder 5 (bistlm local attn) + apply_shared_encoder_5(model, buffers.x_4, buffers.x_shared_5, buffers); + cb(current_progress + segment_progress * 11.0f / float_steps, + "Shared encoder 5"); + + // now decoder time! + + // shared decoder 5, which is one of the two unique decoder_0_1 + + // start from 0 tensors + + Eigen::Tensor3dXf pre_t_unused = + apply_shared_decoder_0(model, buffers.x_4, buffers.x_shared_5); + cb(current_progress + segment_progress * 12.0f / float_steps, + "Shared decoder 0"); + + Eigen::Tensor3dXf pre_t = + apply_freq_decoder_1(model, buffers.x_4, buffers.x_3, buffers.saved_4); + cb(current_progress + segment_progress * 13.0f / float_steps, + "Freq decoder 1"); + + // we're skipping the inject branch i.e. xt_4, leapfrogging to xt_3 + apply_time_decoder_0(model, pre_t, buffers.xt_3); + cb(current_progress + segment_progress * 14.0f / float_steps, + "Time decoder 1"); + + apply_common_decoder(model, 0, 0, buffers.x_3, buffers.x_2, + buffers.saved_3); + cb(current_progress + segment_progress * 15.0f / float_steps, + "Freq decoder 2"); + + apply_common_decoder(model, 1, 0, buffers.xt_3, buffers.xt_2, + buffers.savedt_3); + cb(current_progress + segment_progress * 16.0f / float_steps, + "Time decoder 2"); + + apply_common_decoder(model, 0, 1, buffers.x_2, buffers.x_1, + buffers.saved_2); + cb(current_progress + segment_progress * 17.0f / float_steps, + "Freq decoder 3"); + + apply_common_decoder(model, 1, 1, buffers.xt_2, buffers.xt_1, + buffers.savedt_2); + cb(current_progress + segment_progress * 18.0f / float_steps, + "Time decoder 3"); + + apply_common_decoder(model, 0, 2, buffers.x_1, buffers.x_0, + buffers.saved_1); + cb(current_progress + segment_progress * 19.0f / float_steps, + "Freq decoder 4"); + + apply_common_decoder(model, 1, 2, buffers.xt_1, buffers.xt_0, + buffers.savedt_1); + cb(current_progress + segment_progress * 20.0f / float_steps, + "Time decoder 4"); + + apply_common_decoder(model, 0, 3, buffers.x_0, buffers.x_out, + buffers.saved_0); + cb(current_progress + segment_progress * 21.0f / float_steps, + "Freq decoder 5"); + + apply_common_decoder(model, 1, 3, buffers.xt_0, buffers.xt_out, + buffers.savedt_0); + cb(current_progress + segment_progress * 22.0f / float_steps, + "Time decoder 5"); + + cb(current_progress + segment_progress, "Mask + istft"); + + // xt dim 1 is a fake dim of 1 + // so we could have symmetry between the tensor3dxf of the freq and time + // branches + + const int nb_out_sources = 4; + + // 4 sources, 2 channels * 2 complex channels (real+imag), F bins, T frames + Eigen::Tensor4dXf x_4d = Eigen::Tensor4dXf( + nb_out_sources, 4, buffers.x.dimension(1), buffers.x.dimension(2)); + + // 4 sources, 2 channels, N samples + std::vector xt_3d = { + Eigen::MatrixXf(2, buffers.xt.dimension(2)), + Eigen::MatrixXf(2, buffers.xt.dimension(2)), + Eigen::MatrixXf(2, buffers.xt.dimension(2)), + Eigen::MatrixXf(2, buffers.xt.dimension(2))}; + + // distribute the channels of buffers.x into x_4d + // in pytorch it's (16, 2048, 336) i.e. (chan, freq, time) + // then apply `.view(4, -1, freq, time) + + // implement above logic in Eigen C++ + // copy buffers.x into x_4d + // apply opposite of + // buffers.x(i, j, k) = (buffers.x(i, j, k) - mean) / (epsilon + std_); + for (int s = 0; s < nb_out_sources; ++s) + { // loop over 4 sources + for (int i = 0; i < 4; ++i) + { + for (int j = 0; j < buffers.x.dimension(1); ++j) + { + for (int k = 0; k < buffers.x.dimension(2); ++k) + { + x_4d(s, i, j, k) = + std_ * buffers.x_out(s * 4 + i, j, k) + mean; + } + } + } + } + + // let's also copy buffers.xt into xt_4d + for (int s = 0; s < nb_out_sources; ++s) + { // loop over 4 sources + for (int i = 0; i < 2; ++i) + { + for (int j = 0; j < buffers.xt.dimension(2); ++j) + { + xt_3d[s](i, j) = stdt * buffers.xt_out(0, s * 2 + i, j) + meant; + } + } + } + + // If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + // undo complex-as-channels by splitting the 2nd dim of x_4d into (2, 2) + for (int source = 0; source < nb_out_sources; ++source) + { + Eigen::Tensor3dXcf x_target = Eigen::Tensor3dXcf( + 2, buffers.x.dimension(1), buffers.x.dimension(2)); + + // in the CaC case, we're simply unstacking the complex + // spectrogram from the channel dimension + for (int i = 0; i < buffers.z.dimension(0); ++i) + { + // limiting to j-1 because we're dropping 2049 to 2048 bins + for (int j = 0; j < buffers.z.dimension(1) - 1; ++j) + { + for (int k = 0; k < buffers.z.dimension(2); ++k) + { + // buffers.x(2*i, j, k) = buffers.z(i, j, k).real(); + // buffers.x(2*i + 1, j, k) = buffers.z(i, j, k).imag(); + x_target(i, j, k) = + std::complex(x_4d(source, 2 * i, j, k), + x_4d(source, 2 * i + 1, j, k)); + } + } + } + + // need to re-pad 2: 2 + le on spectrogram + // opposite of this + // buffers.z = stft_buf.spec.slice(Eigen::array{0, 0, 2}, + // Eigen::array{2, (int)stft_buf.spec.dimension(1), + // (int)stft_buf.spec.dimension(2) - 4}); + // Add padding to spectrogram + + Eigen::array, 3> paddings = { + std::make_pair(0, 0), std::make_pair(0, 1), std::make_pair(2, 2)}; + Eigen::Tensor3dXcf x_target_padded = + x_target.pad(paddings, std::complex(0.0f, 0.0f)); + + stft_buf.spec = x_target_padded; + + demucscpp::istft(stft_buf); + + // now we have waveform from istft(x), the frequency branch + // that we need to sum with xt, the time branch + Eigen::MatrixXf padded_waveform = stft_buf.waveform; + + // undo the reflect pad 1d by copying padded_mix into mix + // from range buffers.pad:buffers.pad + buffers.segment_samples + Eigen::MatrixXf unpadded_waveform = + padded_waveform.block(0, buffers.pad, 2, buffers.segment_samples); + + // sum with xt + unpadded_waveform += xt_3d[source]; + + ss << "mix: " << buffers.mix.rows() << ", " << buffers.mix.cols(); + cb(current_progress + segment_progress, ss.str()); + ss.str(""); + + // copy target waveform into all 4 dims of targets_out + for (int j = 0; j < 2; ++j) + { + for (int k = 0; k < buffers.mix.cols(); ++k) + { + buffers.targets_out(source, j, k) = unpadded_waveform(j, k); + } + } + } +} diff --git a/src/model_load.cpp b/src/model_load.cpp index 849258c..6cc4e86 100644 --- a/src/model_load.cpp +++ b/src/model_load.cpp @@ -26,23 +26,23 @@ static void my_fprintf(const std::FILE *stream, const char *format, ...) } // forward declaration -static size_t load_single_tensor1d(FILE *f, std::string &name, +static size_t load_single_tensor1d(FILE *f, const std::string &name, Eigen::Tensor1dXf &matrix, int *ne, int32_t nelements); -static size_t load_single_vector(FILE *f, std::string &name, +static size_t load_single_vector(FILE *f, const std::string &name, Eigen::VectorXf &matrix, int *ne, int32_t nelements); -static size_t load_single_matrix(FILE *f, std::string &name, +static size_t load_single_matrix(FILE *f, const std::string &name, Eigen::MatrixXf &matrix, int *ne, int32_t nelements); -static size_t load_single_tensor3d(FILE *f, std::string &name, +static size_t load_single_tensor3d(FILE *f, const std::string &name, Eigen::Tensor3dXf &tensor, int *ne, int32_t nelements); -static size_t load_single_tensor4d(FILE *f, std::string &name, +static size_t load_single_tensor4d(FILE *f, const std::string &name, Eigen::Tensor4dXf &tensor, int *ne, int32_t nelements); @@ -1089,7 +1089,7 @@ bool demucscpp::load_demucs_model(const std::string &model_file, return true; } -static size_t load_single_matrix(FILE *f, std::string &name, +static size_t load_single_matrix(FILE *f, const std::string &name, Eigen::MatrixXf &matrix, int *ne, int32_t nelements) { @@ -1133,7 +1133,7 @@ static size_t load_single_matrix(FILE *f, std::string &name, return nbytes_tensor; } -static size_t load_single_tensor3d(FILE *f, std::string &name, +static size_t load_single_tensor3d(FILE *f, const std::string &name, Eigen::Tensor3dXf &tensor, int *ne, int32_t nelements) { @@ -1178,7 +1178,7 @@ static size_t load_single_tensor3d(FILE *f, std::string &name, return nbytes_tensor; } -static size_t load_single_tensor4d(FILE *f, std::string &name, +static size_t load_single_tensor4d(FILE *f, const std::string &name, Eigen::Tensor4dXf &tensor, int *ne, int32_t nelements) { @@ -1229,7 +1229,7 @@ static size_t load_single_tensor4d(FILE *f, std::string &name, return nbytes_tensor; } -static size_t load_single_tensor1d(FILE *f, std::string &name, +static size_t load_single_tensor1d(FILE *f, const std::string &name, Eigen::Tensor1dXf &tensor, int *ne, int32_t nelements) { @@ -1264,7 +1264,7 @@ static size_t load_single_tensor1d(FILE *f, std::string &name, return nbytes_tensor; } -static size_t load_single_vector(FILE *f, std::string &name, +static size_t load_single_vector(FILE *f, const std::string &name, Eigen::VectorXf &vector, int *ne, int32_t nelements) { @@ -1298,3 +1298,869 @@ static size_t load_single_vector(FILE *f, std::string &name, return nbytes_vector; } + +bool demucscpp_v3::load_demucs_v3_model( + const std::string &model_file, struct demucscpp_v3::demucs_v3_model *model) +{ + my_fprintf(stderr, "%s: loading model\n", __func__); + + // compute t_start_us using C++ std::chrono + const auto t_start_us = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + + std::cout << "Loading model_file... " << std::endl; + + FILE *f = fopen(model_file.c_str(), "rb"); + if (!f) + { + my_fprintf(stderr, "%s: failed to open %s\n", __func__, + model_file.c_str()); + return false; + } + + // verify magic + uint32_t magic; + + std::cout << "Checking the magic of model_file" << std::endl; + + // read the size of uint32_t bytes from f into magic + fread(&magic, sizeof(uint32_t), 1, f); + + if (magic != 0x646d6333) // dmc3 = v3 mmi + { + fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__); + fclose(f); + return false; + } + + std::cout << "Model magic is Demucs V3 MMI" << std::endl; + + std::cout << "Loading demucs model... " << std::endl; + + // we dont need to prepare memory for the weights + // they come preallocated in the hardcoded model + + size_t total_size = 0; + uint32_t n_loaded = 0; + + // equivalent of with open(...) as f on each model_file + std::cout << "Loading weights from model_file" << std::endl; + + // load weights from the file one tensor at a time + + for (;;) + { + int32_t n_dims; + int32_t length; + + fread(&n_dims, sizeof(int32_t), 1, f); + fread(&length, sizeof(int32_t), 1, f); + + int32_t nelements = 1; + + // we are loading up to 4d tensors, so allocate 4 dims + int32_t ne[4] = {1, 1, 1, 1}; + for (int i = 0; i < n_dims; ++i) + { + fread(&ne[i], sizeof(int32_t), 1, f); + nelements *= ne[i]; + } + + std::string name; + std::vector tmp(length); // create a buffer + fread(&tmp[0], sizeof(char), tmp.size(), f); // read to buffer + name.assign(&tmp[0], tmp.size()); + + // check if we reached eof of the open file f + if (feof(f)) + { + break; + } + + // std::cout << "Loading tensor " << name << " with shape [" << ne[0] + // << ", " << ne[1] << ", " << ne[2] << ", " << ne[3] << "]" + // << std::endl; + + // match the tensor name to the correct tensor in the model + size_t loaded_size = 0; + + // 4 Encoders + for (int i = 0; i < 4; ++i) + { + if (name == "encoder." + std::to_string(i) + ".conv.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->encoder_conv_weight[i], ne, nelements); + } + else if (name == "encoder." + std::to_string(i) + ".conv.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->encoder_conv_bias[i], ne, nelements); + } + else if (name == "encoder." + std::to_string(i) + ".rewrite.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->encoder_rewrite_weight[i], ne, nelements); + } + else if (name == "encoder." + std::to_string(i) + ".rewrite.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->encoder_rewrite_bias[i], ne, nelements); + } + + // each sub-dconv is a stack of 2 + for (int j = 0; j < 2; ++j) + { + if (name == "encoder." + std::to_string(i) + ".dconv.layers." + + std::to_string(j) + ".0.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->dconv_layers_0_conv1d_weight[0][i][j], + ne, nelements); + } + else if (name == "encoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".0.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_0_conv1d_bias[0][i][j], ne, + nelements); + } + else if (name == "encoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".1.weight") + { + loaded_size = load_single_tensor1d( + f, name, + model->dconv_layers_1_groupnorm_weight[0][i][j], ne, + nelements); + } + else if (name == "encoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".1.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_1_groupnorm_bias[0][i][j], + ne, nelements); + } + else if (name == "encoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".3.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->dconv_layers_3_conv1d_weight[0][i][j], + ne, nelements); + } + else if (name == "encoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".3.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_3_conv1d_bias[0][i][j], ne, + nelements); + } + else if (name == "encoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".4.weight") + { + loaded_size = load_single_tensor1d( + f, name, + model->dconv_layers_4_groupnorm_weight[0][i][j], ne, + nelements); + } + else if (name == "encoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".4.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_4_groupnorm_bias[0][i][j], + ne, nelements); + } + else if (name == "encoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".6.scale") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_6_scale[0][i][j], ne, + nelements); + } + } + } + + // Loop over encoders 4 and 5 non-dconv layers + // dconv will be treated specially in the next block + for (int encoder_index = 4; encoder_index <= 5; ++encoder_index) + { + int array_index = encoder_index - 4; // Maps 4 to 0, 5 to 1 + + if (name == + "encoder." + std::to_string(encoder_index) + ".conv.weight") + { + if (encoder_index == 4) + { + loaded_size = load_single_tensor4d( + f, name, model->encoder_4_conv_weight, ne, nelements); + } + else + { + loaded_size = load_single_tensor3d( + f, name, model->encoder_5_conv_weight, ne, nelements); + } + } + else if (name == + "encoder." + std::to_string(encoder_index) + ".conv.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->encoder_4_5_conv_bias[array_index], ne, + nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".norm1.weight") + { + loaded_size = load_single_tensor1d( + f, name, model->encoder_4_5_norm1_weight[array_index], ne, + nelements); + } + else if (name == + "encoder." + std::to_string(encoder_index) + ".norm1.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->encoder_4_5_norm1_bias[array_index], ne, + nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".rewrite.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->encoder_4_5_rewrite_weight[array_index], ne, + nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".rewrite.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->encoder_4_5_rewrite_bias[array_index], ne, + nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".norm2.weight") + { + loaded_size = load_single_tensor1d( + f, name, model->encoder_4_5_norm2_weight[array_index], ne, + nelements); + } + else if (name == + "encoder." + std::to_string(encoder_index) + ".norm2.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->encoder_4_5_norm2_bias[array_index], ne, + nelements); + } + + // dconv time: 2 per layer + for (int dconv_index = 0; dconv_index < 2; ++dconv_index) + { + if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + std::to_string(dconv_index) + + ".0.weight") + { + loaded_size = load_single_tensor3d( + f, name, + model->encoder_4_5_dconv_layers_0_conv1d_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + ".0.bias") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_0_conv1d_bias + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + ".1.weight") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_1_groupnorm_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + ".1.bias") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_1_groupnorm_bias + [array_index][dconv_index], + ne, nelements); + } + + // dconv lstm for encoder 4, 5 + for (int lstm_index = 0; lstm_index < 2; ++lstm_index) + { + for (int direction = 0; direction < 2; ++direction) + { + std::string direction_suffix = + direction == 0 ? "" : "_reverse"; + std::string layer_suffix = + "l" + std::to_string(lstm_index); + + if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".3.lstm.weight_ih_" + layer_suffix + + direction_suffix) + { + loaded_size = load_single_matrix( + f, name, + model->encoder_4_5_dconv_layers_3_lstm_ih_w + [array_index][dconv_index][lstm_index] + [direction], + ne, nelements); + } + else if (name == "encoder." + + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".3.lstm.weight_hh_" + + layer_suffix + direction_suffix) + { + loaded_size = load_single_matrix( + f, name, + model->encoder_4_5_dconv_layers_3_lstm_hh_w + [array_index][dconv_index][lstm_index] + [direction], + ne, nelements); + } + else if (name == "encoder." + + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".3.lstm.bias_ih_" + layer_suffix + + direction_suffix) + { + loaded_size = load_single_matrix( + f, name, + model->encoder_4_5_dconv_layers_3_lstm_ih_b + [array_index][dconv_index][lstm_index] + [direction], + ne, nelements); + } + else if (name == "encoder." + + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".3.lstm.bias_hh_" + layer_suffix + + direction_suffix) + { + loaded_size = load_single_matrix( + f, name, + model->encoder_4_5_dconv_layers_3_lstm_hh_b + [array_index][dconv_index][lstm_index] + [direction], + ne, nelements); + } + } + } + + // continue after the lstm with the attn etc. + if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + std::to_string(dconv_index) + + ".3.linear.weight") + { + loaded_size = load_single_matrix( + f, name, + model->encoder_4_5_dconv_layers_3_linear_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".3.linear.bias") + { + loaded_size = load_single_vector( + f, name, + model->encoder_4_5_dconv_layers_3_linear_bias + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.content.weight") + { + loaded_size = load_single_tensor3d( + f, name, + model->encoder_4_5_dconv_layers_4_content_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.content.bias") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_4_content_bias + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.query.weight") + { + loaded_size = load_single_tensor3d( + f, name, + model->encoder_4_5_dconv_layers_4_query_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.query.bias") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_4_query_bias + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.key.weight") + { + loaded_size = load_single_tensor3d( + f, name, + model->encoder_4_5_dconv_layers_4_key_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.key.bias") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_4_key_bias[array_index] + [dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.query_decay.weight") + { + loaded_size = load_single_tensor3d( + f, name, + model->encoder_4_5_dconv_layers_4_query_decay_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.query_decay.bias") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_4_query_decay_bias + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.proj.weight") + { + loaded_size = load_single_tensor3d( + f, name, + model->encoder_4_5_dconv_layers_4_proj_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + + ".4.proj.bias") + { + loaded_size = load_single_tensor1d( + f, name, + model + ->encoder_4_5_dconv_layers_4_proj_bias[array_index] + [dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + ".5.weight") + { + loaded_size = load_single_tensor3d( + f, name, + model->encoder_4_5_dconv_layers_5_conv1d_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + ".5.bias") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_5_conv1d_bias + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + ".6.weight") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_6_groupnorm_weight + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + ".6.bias") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_6_groupnorm_bias + [array_index][dconv_index], + ne, nelements); + } + else if (name == "encoder." + std::to_string(encoder_index) + + ".dconv.layers." + + std::to_string(dconv_index) + ".8.scale") + { + loaded_size = load_single_tensor1d( + f, name, + model->encoder_4_5_dconv_layers_8_scale[array_index] + [dconv_index], + ne, nelements); + } + } + } + + // 4 TEncoders + for (int i = 0; i < 4; ++i) + { + if (name == "tencoder." + std::to_string(i) + ".conv.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->tencoder_conv_weight[i], ne, nelements); + } + else if (name == "tencoder." + std::to_string(i) + ".conv.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->tencoder_conv_bias[i], ne, nelements); + } + else if (name == + "tencoder." + std::to_string(i) + ".rewrite.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->tencoder_rewrite_weight[i], ne, nelements); + } + else if (name == "tencoder." + std::to_string(i) + ".rewrite.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->tencoder_rewrite_bias[i], ne, nelements); + } + + // each sub-dconv is a stack of 2 + for (int j = 0; j < 2; ++j) + { + if (name == "tencoder." + std::to_string(i) + ".dconv.layers." + + std::to_string(j) + ".0.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->dconv_layers_0_conv1d_weight[1][i][j], + ne, nelements); + } + else if (name == "tencoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".0.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_0_conv1d_bias[1][i][j], ne, + nelements); + } + else if (name == "tencoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".1.weight") + { + loaded_size = load_single_tensor1d( + f, name, + model->dconv_layers_1_groupnorm_weight[1][i][j], ne, + nelements); + } + else if (name == "tencoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".1.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_1_groupnorm_bias[1][i][j], + ne, nelements); + } + else if (name == "tencoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".3.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->dconv_layers_3_conv1d_weight[1][i][j], + ne, nelements); + } + else if (name == "tencoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".3.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_3_conv1d_bias[1][i][j], ne, + nelements); + } + else if (name == "tencoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".4.weight") + { + loaded_size = load_single_tensor1d( + f, name, + model->dconv_layers_4_groupnorm_weight[1][i][j], ne, + nelements); + } + else if (name == "tencoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".4.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_4_groupnorm_bias[1][i][j], + ne, nelements); + } + else if (name == "tencoder." + std::to_string(i) + + ".dconv.layers." + std::to_string(j) + + ".6.scale") + { + loaded_size = load_single_tensor1d( + f, name, model->dconv_layers_6_scale[1][i][j], ne, + nelements); + } + } + } + + // 5th unique tencoder_4 + if (name == "tencoder.4.conv.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->tencoder_4_conv_weight, ne, nelements); + } + else if (name == "tencoder.4.conv.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->tencoder_4_conv_bias, ne, nelements); + } + + // start with decoder 0,1 for frequency which + // has its own arrays + + // next, tdecoder_0 which is unique + + // finally, decoder 2,3,4,5 + // and tdecoder 1,2,3,4 + // which are all grouped together in struct members + // with arrays of size [8] + + // start with decoder 0,1 for frequency which has its own arrays + for (int i = 0; i < 2; ++i) + { + if (name == "decoder." + std::to_string(i) + ".conv_tr.weight") + { + if (i == 0) + { + loaded_size = load_single_tensor3d( + f, name, model->decoder_0_conv_tr_weight, ne, + nelements); + } + else + { + loaded_size = load_single_tensor4d( + f, name, model->decoder_1_conv_tr_weight, ne, + nelements); + } + } + else if (name == "decoder." + std::to_string(i) + ".conv_tr.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->decoder_0_1_conv_tr_bias[i], ne, nelements); + } + else if (name == "decoder." + std::to_string(i) + ".norm2.weight") + { + loaded_size = load_single_tensor1d( + f, name, model->decoder_0_1_norm2_weight[i], ne, nelements); + } + else if (name == "decoder." + std::to_string(i) + ".norm2.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->decoder_0_1_norm2_bias[i], ne, nelements); + } + else if (name == "decoder." + std::to_string(i) + ".rewrite.weight") + { + if (i == 0) + { + loaded_size = load_single_tensor3d( + f, name, model->decoder_0_rewrite_weight, ne, + nelements); + } + else + { + loaded_size = load_single_tensor4d( + f, name, model->decoder_1_rewrite_weight, ne, + nelements); + } + } + else if (name == "decoder." + std::to_string(i) + ".rewrite.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->decoder_0_1_rewrite_bias[i], ne, nelements); + } + else if (name == "decoder." + std::to_string(i) + ".norm1.weight") + { + loaded_size = load_single_tensor1d( + f, name, model->decoder_0_1_norm1_weight[i], ne, nelements); + } + else if (name == "decoder." + std::to_string(i) + ".norm1.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->decoder_0_1_norm1_bias[i], ne, nelements); + } + } + + // next, tdecoder_0 which is unique + if (name == "tdecoder.0.conv_tr.weight") + { + loaded_size = load_single_tensor3d( + f, name, model->tdecoder_0_conv_tr_weight, ne, nelements); + } + else if (name == "tdecoder.0.conv_tr.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->tdecoder_0_conv_tr_bias, ne, nelements); + } + else if (name == "tdecoder.0.norm2.weight") + { + loaded_size = load_single_tensor1d( + f, name, model->tdecoder_0_norm2_weight, ne, nelements); + } + else if (name == "tdecoder.0.norm2.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->tdecoder_0_norm2_bias, ne, nelements); + } + + // finally, decoder 2,3,4,5 and tdecoder 1,2,3,4 which are all grouped + // together in struct members with arrays of size [8] Loop over the + // first dimension [2] for freq, time + for (int freq_time = 0; freq_time < 2; ++freq_time) + { + // Loop over the second dimension [4] for layers + for (int layer = 0; layer < 4; ++layer) + { + // Construct the base name for current decoder/tdecoder + std::string base_name = + (freq_time == 0 ? "decoder." : "tdecoder.") + + std::to_string(layer + (freq_time == 0 ? 2 : 1)); + + // Load conv_tr.weight + if (name == base_name + ".conv_tr.weight") + { + if (freq_time == 0) + { + loaded_size = load_single_tensor4d( + f, name, model->freq_decoders_conv_tr_weight[layer], + ne, nelements); + } + else + { + loaded_size = load_single_tensor3d( + f, name, model->time_decoders_conv_tr_weight[layer], + ne, nelements); + } + } + // Load conv_tr.bias + else if (name == base_name + ".conv_tr.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->decoders_conv_tr_bias[freq_time][layer], + ne, nelements); + } + // Load rewrite.weight + else if (name == base_name + ".rewrite.weight") + { + if (freq_time == 0) + { + loaded_size = load_single_tensor4d( + f, name, model->freq_decoders_rewrite_weight[layer], + ne, nelements); + } + else + { + loaded_size = load_single_tensor3d( + f, name, model->time_decoders_rewrite_weight[layer], + ne, nelements); + } + } + // Load rewrite.bias + else if (name == base_name + ".rewrite.bias") + { + loaded_size = load_single_tensor1d( + f, name, model->decoders_rewrite_bias[freq_time][layer], + ne, nelements); + } + } + } + + if (name == "freq_emb.embedding.weight") + { + loaded_size = load_single_matrix( + f, name, model->freq_emb_embedding_weight, ne, nelements); + } + + if (loaded_size == 0) + { + my_fprintf(stderr, "%s: failed to load %s\n", __func__, + name.c_str()); + return false; + } + total_size += loaded_size; + n_loaded++; + } + + fclose(f); + + // compute finish time in microseconds using std::chrono + + const auto t_end_us = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + + // print load time in seconds + my_fprintf(stdout, "Loaded model (%u tensors, %6.2f MB) in %f s\n", + n_loaded, total_size / 1024.0 / 1024.0, + (float)(t_end_us - t_start_us) / 1000000.0f); + + return true; +} diff --git a/test/test_layers.cpp b/test/test_layers.cpp index b211784..3d78773 100644 --- a/test/test_layers.cpp +++ b/test/test_layers.cpp @@ -12,7 +12,7 @@ #include #include -namespace demucscppdebug +namespace demucscppdebug_test { inline void assert_(bool condition) @@ -756,18 +756,18 @@ static void setUpTestSuite() // } // } // -// demucscppdebug::debug_tensor_3dxf(x_fake, "x_fake"); +// demucscppdebug_test::debug_tensor_3dxf(x_fake, "x_fake"); // // // use kernel=(1,1), stride=(1,1), padding=(0,0), dilation=(1,1) // Eigen::MatrixXf x_im2col = demucscpp::im2col(x_fake, 1, 1, 1, 1, 0, 0, 1, // 1); // -// demucscppdebug::debug_matrix_xf(x_im2col, "x_im2col"); +// demucscppdebug_test::debug_matrix_xf(x_im2col, "x_im2col"); // // Eigen::Tensor3dXf x_col2im = demucscpp::col2im(x_im2col, 2, 3, 3, 1, 1, // 1, 1, 0, 0, 1, 1); // -// demucscppdebug::debug_tensor_3dxf(x_col2im, "x_col2im"); +// demucscppdebug_test::debug_tensor_3dxf(x_col2im, "x_col2im"); // // Eigen::Tensor3dXf x_fake_2(2, 4, 4); // @@ -781,15 +781,15 @@ static void setUpTestSuite() // } // } // -// demucscppdebug::debug_tensor_3dxf(x_fake_2, "x_fake_2"); +// demucscppdebug_test::debug_tensor_3dxf(x_fake_2, "x_fake_2"); // // // Parameters: kernel=(2,2), stride=(2,2), padding=(0,0) // Eigen::MatrixXf x_im2col_2 = demucscpp::im2col(x_fake_2, 2, 2, 2, 2, 0, -// 0, 1, 1); demucscppdebug::debug_matrix_xf(x_im2col_2, "x_im2col_2"); +// 0, 1, 1); demucscppdebug_test::debug_matrix_xf(x_im2col_2, "x_im2col_2"); // // // Reverse im2col // Eigen::Tensor3dXf x_col2im_2 = demucscpp::col2im(x_im2col_2, 2, 4, 4, 2, -// 2, 2, 2, 0, 0, 1, 1); demucscppdebug::debug_tensor_3dxf(x_col2im_2, +// 2, 2, 2, 0, 0, 1, 1); demucscppdebug_test::debug_tensor_3dxf(x_col2im_2, // "x_col2im_2"); // // Eigen::Tensor3dXf x_fake_3(2, 4, 4); @@ -804,15 +804,15 @@ static void setUpTestSuite() // } // } // -// demucscppdebug::debug_tensor_3dxf(x_fake_3, "x_fake_3"); +// demucscppdebug_test::debug_tensor_3dxf(x_fake_3, "x_fake_3"); // // // Parameters: kernel=(2,2), stride=(2,2), padding=(0,0) // Eigen::MatrixXf x_im2col_3 = demucscpp::im2col(x_fake_3, 2, 2, 2, 2, 1, -// 1, 1, 1); demucscppdebug::debug_matrix_xf(x_im2col_3, "x_im2col_3"); +// 1, 1, 1); demucscppdebug_test::debug_matrix_xf(x_im2col_3, "x_im2col_3"); // // // Reverse im2col // Eigen::Tensor3dXf x_col2im_3 = demucscpp::col2im(x_im2col_3, 2, 4, 4, 2, -// 2, 2, 2, 1, 1, 1, 1); demucscppdebug::debug_tensor_3dxf(x_col2im_3, +// 2, 2, 2, 1, 1, 1, 1); demucscppdebug_test::debug_tensor_3dxf(x_col2im_3, // "x_col2im_3"); // // // dilation test @@ -828,7 +828,7 @@ static void setUpTestSuite() // } // } // -// demucscppdebug::debug_tensor_3dxf(x_dilation_test, "x_dilation_test"); +// demucscppdebug_test::debug_tensor_3dxf(x_dilation_test, "x_dilation_test"); // // // Dilation test parameters // int kernel_height = 3; @@ -844,13 +844,13 @@ static void setUpTestSuite() // Eigen::MatrixXf x_im2col_dilated = demucscpp::im2col(x_dilation_test, // kernel_height, kernel_width, stride_height, stride_width, pad_height, // pad_width, dilation_height, dilation_width); -// demucscppdebug::debug_matrix_xf(x_im2col_dilated, "x_im2col_dilated"); +// demucscppdebug_test::debug_matrix_xf(x_im2col_dilated, "x_im2col_dilated"); // // // Reverse with col2im // Eigen::Tensor3dXf x_col2im_dilated = demucscpp::col2im(x_im2col_dilated, // 2, 6, 6, kernel_height, kernel_width, stride_height, stride_width, // pad_height, pad_width, dilation_height, dilation_width); -// demucscppdebug::debug_tensor_3dxf(x_col2im_dilated, "x_col2im_dilated"); +// demucscppdebug_test::debug_tensor_3dxf(x_col2im_dilated, "x_col2im_dilated"); // } TEST(DemucsCPPLayers, GemmConv) @@ -886,7 +886,7 @@ TEST(DemucsCPPLayers, GemmConv) } } - // demucscppdebug::debug_tensor_3dxf(x, "x"); + // demucscppdebug_test::debug_tensor_3dxf(x, "x"); Eigen::Tensor4dXf w(out_channels, in_channels, kernel_height, kernel_width); int counter = 1; @@ -925,7 +925,7 @@ TEST(DemucsCPPLayers, GemmConv) } } - // demucscppdebug::debug_tensor_4dxf(w, "w"); + // demucscppdebug_test::debug_tensor_4dxf(w, "w"); // bias: 4 out channels Eigen::Tensor1dXf b(out_channels); @@ -944,11 +944,11 @@ TEST(DemucsCPPLayers, GemmConv) // } } - // demucscppdebug::debug_tensor_1dxf(b, "b"); + // demucscppdebug_test::debug_tensor_1dxf(b, "b"); // apply conv2d_gemm with some params Eigen::Tensor3dXf y_gemm = - demucscpp::conv2d_gemm(x, w, b); // apply regular conv2d with some params @@ -1032,8 +1032,8 @@ TEST(DemucsCPPLayers, GemmConv) } // compare y_gemm and y_conv2d - demucscppdebug::debug_tensor_3dxf(y_gemm, "y_gemm"); - demucscppdebug::debug_tensor_3dxf(y_conv2d, "y_conv2d"); + demucscppdebug_test::debug_tensor_3dxf(y_gemm, "y_gemm"); + demucscppdebug_test::debug_tensor_3dxf(y_conv2d, "y_conv2d"); } TEST(DemucsCPPLayers, GemmConv2) @@ -1069,7 +1069,7 @@ TEST(DemucsCPPLayers, GemmConv2) } } - // demucscppdebug::debug_tensor_3dxf(x, "x"); + // demucscppdebug_test::debug_tensor_3dxf(x, "x"); Eigen::Tensor4dXf w(out_channels, in_channels, kernel_height, kernel_width); int counter = 1; @@ -1108,7 +1108,7 @@ TEST(DemucsCPPLayers, GemmConv2) } } - // demucscppdebug::debug_tensor_4dxf(w, "w"); + // demucscppdebug_test::debug_tensor_4dxf(w, "w"); // bias: 4 out channels Eigen::Tensor1dXf b(out_channels); @@ -1127,11 +1127,11 @@ TEST(DemucsCPPLayers, GemmConv2) // } } - // demucscppdebug::debug_tensor_1dxf(b, "b"); + // demucscppdebug_test::debug_tensor_1dxf(b, "b"); // apply conv2d_gemm with some params Eigen::Tensor3dXf y_gemm = - demucscpp::conv2d_gemm(x, w, b); // apply regular conv2d with some params @@ -1215,8 +1215,8 @@ TEST(DemucsCPPLayers, GemmConv2) } // compare y_gemm and y_conv2d - demucscppdebug::debug_tensor_3dxf(y_gemm, "y_gemm"); - demucscppdebug::debug_tensor_3dxf(y_conv2d, "y_conv2d"); + demucscppdebug_test::debug_tensor_3dxf(y_gemm, "y_gemm"); + demucscppdebug_test::debug_tensor_3dxf(y_conv2d, "y_conv2d"); } TEST(DemucsCPPLayers, GemmTrConv) @@ -1252,7 +1252,7 @@ TEST(DemucsCPPLayers, GemmTrConv) } } - // demucscppdebug::debug_tensor_3dxf(x, "x"); + // demucscppdebug_test::debug_tensor_3dxf(x, "x"); Eigen::Tensor4dXf w(in_channels, out_channels, kernel_height, kernel_width); int counter = 1; @@ -1291,7 +1291,7 @@ TEST(DemucsCPPLayers, GemmTrConv) } } - // demucscppdebug::debug_tensor_4dxf(w, "w"); + // demucscppdebug_test::debug_tensor_4dxf(w, "w"); // bias: 4 out channels Eigen::Tensor1dXf b(out_channels); @@ -1310,11 +1310,11 @@ TEST(DemucsCPPLayers, GemmTrConv) // } } - // demucscppdebug::debug_tensor_1dxf(b, "b"); + // demucscppdebug_test::debug_tensor_1dxf(b, "b"); // apply conv2d_gemm with some params Eigen::Tensor3dXf y_gemm = - demucscpp::conv2d_tr_gemm(x, w, b); // apply regular conv2d with some params @@ -1382,8 +1382,8 @@ TEST(DemucsCPPLayers, GemmTrConv) } // compare y_gemm and y_conv2d - demucscppdebug::debug_tensor_3dxf(y_gemm, "y_gemm"); - // demucscppdebug::debug_tensor_3dxf(y_conv2d, "y_conv2d"); + demucscppdebug_test::debug_tensor_3dxf(y_gemm, "y_gemm"); + // demucscppdebug_test::debug_tensor_3dxf(y_conv2d, "y_conv2d"); } // write a basic test case for a stereo file @@ -1417,20 +1417,20 @@ TEST(DemucsCPPLayers, FreqEncoders) Eigen::Tensor3dXf x_fake_enc_0(48, 512, 336); demucscpp::apply_freq_encoder(model, 0, x_fake, x_fake_enc_0); - demucscppdebug::debug_tensor_3dxf(x_fake, "x_fake"); - demucscppdebug::debug_tensor_3dxf(x_fake_enc_0, "x_fake_enc_0"); + demucscppdebug_test::debug_tensor_3dxf(x_fake, "x_fake"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_enc_0, "x_fake_enc_0"); Eigen::Tensor3dXf x_fake_enc_1(96, 128, 336); demucscpp::apply_freq_encoder(model, 1, x_fake_enc_0, x_fake_enc_1); - demucscppdebug::debug_tensor_3dxf(x_fake_enc_1, "x_fake_enc_1"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_enc_1, "x_fake_enc_1"); Eigen::Tensor3dXf x_fake_enc_2(192, 32, 336); demucscpp::apply_freq_encoder(model, 2, x_fake_enc_1, x_fake_enc_2); - demucscppdebug::debug_tensor_3dxf(x_fake_enc_2, "x_fake_enc_2"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_enc_2, "x_fake_enc_2"); Eigen::Tensor3dXf x_fake_enc_3(384, 8, 336); demucscpp::apply_freq_encoder(model, 3, x_fake_enc_2, x_fake_enc_3); - demucscppdebug::debug_tensor_3dxf(x_fake_enc_3, "x_fake_enc_3"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_enc_3, "x_fake_enc_3"); } TEST(DemucsCPPLayers, FreqDecoders) @@ -1526,23 +1526,23 @@ TEST(DemucsCPPLayers, FreqDecoders) demucscpp::apply_freq_decoder(model, 0, x_fake_dec_0, x_fake_dec_1, skip_fake_dec_0); - demucscppdebug::debug_tensor_3dxf(x_fake_dec_0, "x_fake_dec_0"); - demucscppdebug::debug_tensor_3dxf(x_fake_dec_1, "x_fake_dec_1"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_dec_0, "x_fake_dec_0"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_dec_1, "x_fake_dec_1"); demucscpp::apply_freq_decoder(model, 1, x_fake_dec_1, x_fake_dec_2, skip_fake_dec_1); - demucscppdebug::debug_tensor_3dxf(x_fake_dec_2, "x_fake_dec_2"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_dec_2, "x_fake_dec_2"); demucscpp::apply_freq_decoder(model, 2, x_fake_dec_2, x_fake_dec_3, skip_fake_dec_2); - demucscppdebug::debug_tensor_3dxf(x_fake_dec_3, "x_fake_dec_3"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_dec_3, "x_fake_dec_3"); demucscpp::apply_freq_decoder(model, 3, x_fake_dec_3, x_fake_dec_4, skip_fake_dec_3); - demucscppdebug::debug_tensor_3dxf(x_fake_dec_4, "x_fake_dec_4"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_dec_4, "x_fake_dec_4"); } // write a basic test case for a stereo file @@ -1573,27 +1573,27 @@ TEST(DemucsCPPLayers, TimeEncoders) } } - demucscppdebug::debug_tensor_3dxf(xt_fake, "xt_fake"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake, "xt_fake"); Eigen::Tensor3dXf xt_fake_enc_0(1, 48, 85995); demucscpp::apply_time_encoder(model, 0, xt_fake, xt_fake_enc_0); - demucscppdebug::debug_tensor_3dxf(xt_fake_enc_0, "xt_fake_enc_0"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_enc_0, "xt_fake_enc_0"); Eigen::Tensor3dXf xt_fake_enc_1(1, 96, 21499); demucscpp::apply_time_encoder(model, 1, xt_fake_enc_0, xt_fake_enc_1); - demucscppdebug::debug_tensor_3dxf(xt_fake_enc_1, "xt_fake_enc_1"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_enc_1, "xt_fake_enc_1"); Eigen::Tensor3dXf xt_fake_enc_2(1, 192, 5375); demucscpp::apply_time_encoder(model, 2, xt_fake_enc_1, xt_fake_enc_2); - demucscppdebug::debug_tensor_3dxf(xt_fake_enc_2, "xt_fake_enc_2"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_enc_2, "xt_fake_enc_2"); Eigen::Tensor3dXf xt_fake_enc_3(1, 384, 1344); demucscpp::apply_time_encoder(model, 3, xt_fake_enc_2, xt_fake_enc_3); - demucscppdebug::debug_tensor_3dxf(xt_fake_enc_3, "xt_fake_enc_3"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_enc_3, "xt_fake_enc_3"); } TEST(DemucsCPPLayers, TimeDecoders) @@ -1677,23 +1677,23 @@ TEST(DemucsCPPLayers, TimeDecoders) demucscpp::apply_time_decoder(model, 0, xt_fake_dec_0, xt_fake_dec_1, skipt_fake_dec_0); - demucscppdebug::debug_tensor_3dxf(xt_fake_dec_0, "xt_fake_dec_0"); - demucscppdebug::debug_tensor_3dxf(xt_fake_dec_1, "xt_fake_dec_1"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_dec_0, "xt_fake_dec_0"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_dec_1, "xt_fake_dec_1"); demucscpp::apply_time_decoder(model, 1, xt_fake_dec_1, xt_fake_dec_2, skipt_fake_dec_1); - demucscppdebug::debug_tensor_3dxf(xt_fake_dec_2, "xt_fake_dec_2"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_dec_2, "xt_fake_dec_2"); demucscpp::apply_time_decoder(model, 2, xt_fake_dec_2, xt_fake_dec_3, skipt_fake_dec_2); - demucscppdebug::debug_tensor_3dxf(xt_fake_dec_3, "xt_fake_dec_3"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_dec_3, "xt_fake_dec_3"); demucscpp::apply_time_decoder(model, 3, xt_fake_dec_3, xt_fake_dec_4, skipt_fake_dec_3); - demucscppdebug::debug_tensor_3dxf(xt_fake_dec_4, "xt_fake_dec_4"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_dec_4, "xt_fake_dec_4"); // compare first and last element of waveform_outputs and normalized_audio // EXPECT_NEAR(waveform_outputs(0, 0, 0), normalized_audio(0, 0), @@ -1746,8 +1746,8 @@ TEST(DemucsCPPLayers, CrossTransformer) } } - demucscppdebug::debug_tensor_3dxf(x_fake, "x_fake"); - demucscppdebug::debug_tensor_3dxf(xt_fake, "xt_fake"); + demucscppdebug_test::debug_tensor_3dxf(x_fake, "x_fake"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake, "xt_fake"); // Reshape buffers.x_3 into x_3_reshaped // Apply the conv1d function @@ -1775,9 +1775,9 @@ TEST(DemucsCPPLayers, CrossTransformer) xt_fake, ct_4s->channel_upsampler_t_weight, ct_4s->channel_upsampler_t_bias); - demucscppdebug::debug_tensor_3dxf(x_fake_upsampled, + demucscppdebug_test::debug_tensor_3dxf(x_fake_upsampled, "x pre-crosstransformer"); - demucscppdebug::debug_tensor_3dxf(xt_fake_upsampled, + demucscppdebug_test::debug_tensor_3dxf(xt_fake_upsampled, "xt pre-crosstransformer"); /*************************/ @@ -1793,9 +1793,9 @@ TEST(DemucsCPPLayers, CrossTransformer) // buffers.x_3_channel_upsampled.reshape(Eigen::array({1, 512, // 8*336})); buffers.x_3_channel_upsampled = x_3_reshaped_upsampled_2; - demucscppdebug::debug_tensor_3dxf(x_fake_upsampled, + demucscppdebug_test::debug_tensor_3dxf(x_fake_upsampled, "x post-crosstransformer"); - demucscppdebug::debug_tensor_3dxf(xt_fake_upsampled, + demucscppdebug_test::debug_tensor_3dxf(xt_fake_upsampled, "xt post-crosstransformer"); // then apply the conv1d_2d function @@ -1812,8 +1812,8 @@ TEST(DemucsCPPLayers, CrossTransformer) xt_fake_upsampled, ct_4s->channel_downsampler_t_weight, ct_4s->channel_downsampler_t_bias); - demucscppdebug::debug_tensor_3dxf(x_fake_downsampled, "x post-downsampler"); - demucscppdebug::debug_tensor_3dxf(xt_fake_downsampled, + demucscppdebug_test::debug_tensor_3dxf(x_fake_downsampled, "x post-downsampler"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_downsampled, "xt post-downsampler"); } @@ -1858,8 +1858,8 @@ TEST(DemucsCPPLayers, CrossTransformerNoUpsamp) } } - demucscppdebug::debug_tensor_3dxf(x_fake, "x pre-crosstransformer"); - demucscppdebug::debug_tensor_3dxf(xt_fake, "xt pre-crosstransformer"); + demucscppdebug_test::debug_tensor_3dxf(x_fake, "x pre-crosstransformer"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake, "xt pre-crosstransformer"); /*************************/ /* CROSS-TRANSFORMER! */ @@ -1867,8 +1867,8 @@ TEST(DemucsCPPLayers, CrossTransformerNoUpsamp) demucscpp::ProgressCallback callback = [](float progress, const std::string &msg) {}; demucscpp::apply_crosstransformer(model, x_fake, xt_fake, callback, 0.0f, 0.0f); - demucscppdebug::debug_tensor_3dxf(x_fake, "x post-crosstransformer"); - demucscppdebug::debug_tensor_3dxf(xt_fake, "xt post-crosstransformer"); + demucscppdebug_test::debug_tensor_3dxf(x_fake, "x post-crosstransformer"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake, "xt post-crosstransformer"); } TEST(DemucsCPPLayers, Upsamplers) @@ -1916,8 +1916,8 @@ TEST(DemucsCPPLayers, Upsamplers) } } - demucscppdebug::debug_tensor_3dxf(x_fake, "x_fake"); - demucscppdebug::debug_tensor_3dxf(xt_fake, "xt_fake"); + demucscppdebug_test::debug_tensor_3dxf(x_fake, "x_fake"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake, "xt_fake"); auto *ct_4s = static_cast( model.crosstransformer.get()); @@ -1945,8 +1945,8 @@ TEST(DemucsCPPLayers, Upsamplers) xt_fake, ct_4s->channel_upsampler_t_weight, ct_4s->channel_upsampler_t_bias); - demucscppdebug::debug_tensor_3dxf(x_fake_upsampled, "x upsampled"); - demucscppdebug::debug_tensor_3dxf(xt_fake_upsampled, "xt upsampled"); + demucscppdebug_test::debug_tensor_3dxf(x_fake_upsampled, "x upsampled"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_upsampled, "xt upsampled"); // reshape x_fake_upsampled to 1, 512, 2688 @@ -1965,8 +1965,8 @@ TEST(DemucsCPPLayers, Upsamplers) xt_fake_upsampled, ct_4s->channel_downsampler_t_weight, ct_4s->channel_downsampler_t_bias); - demucscppdebug::debug_tensor_3dxf(x_fake_downsampled, "x post-downsampler"); - demucscppdebug::debug_tensor_3dxf(xt_fake_downsampled, + demucscppdebug_test::debug_tensor_3dxf(x_fake_downsampled, "x post-downsampler"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake_downsampled, "xt post-downsampler"); } @@ -2008,8 +2008,8 @@ TEST(DemucsCPPLayers, CTLayers) } } - // demucscppdebug::debug_tensor_3dxf(x_fake, "x pre-crosstransformer"); - // demucscppdebug::debug_tensor_3dxf(xt_fake, "xt pre-crosstransformer"); + // demucscppdebug_test::debug_tensor_3dxf(x_fake, "x pre-crosstransformer"); + // demucscppdebug_test::debug_tensor_3dxf(xt_fake, "xt pre-crosstransformer"); // make copies of each Eigen::Tensor3dXf x_fake_copy = x_fake; @@ -2123,17 +2123,17 @@ TEST(DemucsCPPLayers, CTLayers) [weight_idx], 8, eps); - demucscppdebug::debug_tensor_3dxf(x_fake, "x post-layer-0"); - demucscppdebug::debug_tensor_3dxf(xt_fake, "xt post-tlayer-0"); + demucscppdebug_test::debug_tensor_3dxf(x_fake, "x post-layer-0"); + demucscppdebug_test::debug_tensor_3dxf(xt_fake, "xt post-tlayer-0"); - // demucscppdebug::debug_tensor_1dxf(model.crosstransformer_norm_in_weight, + // demucscppdebug_test::debug_tensor_1dxf(model.crosstransformer_norm_in_weight, // "norm_in weight"); - // demucscppdebug::debug_tensor_1dxf(model.crosstransformer_norm_in_bias, + // demucscppdebug_test::debug_tensor_1dxf(model.crosstransformer_norm_in_bias, // "norm_in bias"); - // demucscppdebug::debug_tensor_1dxf(model.crosstransformer_norm_in_t_weight, + // demucscppdebug_test::debug_tensor_1dxf(model.crosstransformer_norm_in_t_weight, // "norm_in_t weight"); - // demucscppdebug::debug_tensor_1dxf(model.crosstransformer_norm_in_t_bias, + // demucscppdebug_test::debug_tensor_1dxf(model.crosstransformer_norm_in_t_bias, // "norm_in_t bias"); // Eigen::Tensor3dXf x_norm_in = demucscpp::layer_norm( @@ -2149,11 +2149,11 @@ TEST(DemucsCPPLayers, CTLayers) // x_fake_copy, model.crosstransformer_norm_in_t_weight, // model.crosstransformer_norm_in_t_bias, eps); - // demucscppdebug::debug_tensor_3dxf(x_norm_in, "x norm-in"); - // demucscppdebug::debug_tensor_3dxf(xt_norm_in, "xt norm-in-t"); + // demucscppdebug_test::debug_tensor_3dxf(x_norm_in, "x norm-in"); + // demucscppdebug_test::debug_tensor_3dxf(xt_norm_in, "xt norm-in-t"); - // demucscppdebug::debug_tensor_3dxf(x_norm_in_t, "x norm-in_t"); - // demucscppdebug::debug_tensor_3dxf(xt_norm_in_f, "xt norm-in-t_f"); + // demucscppdebug_test::debug_tensor_3dxf(x_norm_in_t, "x norm-in_t"); + // demucscppdebug_test::debug_tensor_3dxf(xt_norm_in_f, "xt norm-in-t_f"); } TEST(DemucsCPPLayers, LayerNormBasic) @@ -2177,13 +2177,13 @@ TEST(DemucsCPPLayers, LayerNormBasic) b(1) = -0.25; b(2) = 0.75; - demucscppdebug::debug_tensor_3dxf(x, "x"); - demucscppdebug::debug_tensor_1dxf(w, "w"); - demucscppdebug::debug_tensor_1dxf(b, "b"); + demucscppdebug_test::debug_tensor_3dxf(x, "x"); + demucscppdebug_test::debug_tensor_1dxf(w, "w"); + demucscppdebug_test::debug_tensor_1dxf(b, "b"); Eigen::Tensor3dXf x_out = demucscpp::layer_norm(x, w, b, 1e-5); - demucscppdebug::debug_tensor_3dxf(x_out, "x_out"); + demucscppdebug_test::debug_tensor_3dxf(x_out, "x_out"); } TEST(DemucsCPPLayers, LayerNormBigger) @@ -2224,11 +2224,11 @@ TEST(DemucsCPPLayers, LayerNormBigger) } } - demucscppdebug::debug_tensor_3dxf(x, "x"); - demucscppdebug::debug_tensor_1dxf(w, "w"); - demucscppdebug::debug_tensor_1dxf(b, "b"); + demucscppdebug_test::debug_tensor_3dxf(x, "x"); + demucscppdebug_test::debug_tensor_1dxf(w, "w"); + demucscppdebug_test::debug_tensor_1dxf(b, "b"); Eigen::Tensor3dXf x_out = demucscpp::layer_norm(x, w, b, 1e-5); - demucscppdebug::debug_tensor_3dxf(x_out, "x_out"); + demucscppdebug_test::debug_tensor_3dxf(x_out, "x_out"); } diff --git a/test/test_layers_v3.cpp b/test/test_layers_v3.cpp new file mode 100644 index 0000000..11022c8 --- /dev/null +++ b/test/test_layers_v3.cpp @@ -0,0 +1,1714 @@ +// use gtest to test the load_audio_for_kissfft function + +#include "conv.hpp" +#include "dsp.hpp" +#include "encdec.hpp" +#include "layers.hpp" +#include "model.hpp" +#include "lstm.hpp" +#include "tensor.hpp" +#include +#include +#include +#include + +namespace demucscppdebug_test_v3 +{ + +inline void assert_(bool condition) +{ + if (!condition) + { + std::cout << "Assertion failed!" << std::endl; + std::exit(1); + } +} + +inline void debug_tensor_4dxf(const Eigen::Tensor4dXf &x, + const std::string &name) +{ + // return; + std::cout << "Debugging tensor!: " << name << std::endl; + std::cout << "\tshape: (" << x.dimension(0) << ", " << x.dimension(1) + << ", " << x.dimension(2) << ", " << x.dimension(3) << ")" + << std::endl; + + float x_min = 100000000.0f; + float x_max = -100000000.0f; + float x_sum = 0.0f; + float x_mean = 0.0f; + float x_stddev = 0.0f; + + // store dimension index to save index of min/max + int x_min_idx_0 = -1; + int x_min_idx_1 = -1; + int x_min_idx_2 = -1; + int x_min_idx_3 = -1; + int x_max_idx_0 = -1; + int x_max_idx_1 = -1; + int x_max_idx_2 = -1; + int x_max_idx_3 = -1; + + // loop over tensor and find min/max/sum + for (int i = 0; i < x.dimension(0); ++i) + { + for (int j = 0; j < x.dimension(1); ++j) + { + for (int k = 0; k < x.dimension(2); ++k) + { + for (int l = 0; l < x.dimension(3); ++l) + { + float val = x(i, j, k, l); + x_sum += val; + if (val < x_min) + { + x_min = val; + x_min_idx_0 = i; + x_min_idx_1 = j; + x_min_idx_2 = k; + x_min_idx_3 = l; + } + if (val > x_max) + { + x_max = val; + x_max_idx_0 = i; + x_max_idx_1 = j; + x_max_idx_2 = k; + x_max_idx_3 = l; + } + } + } + } + } + + // compute mean and standard deviation + x_mean = x_sum / (x.dimension(0) * x.dimension(1) * x.dimension(2) * + x.dimension(3)); + for (int i = 0; i < x.dimension(0); ++i) + { + for (int j = 0; j < x.dimension(1); ++j) + { + for (int k = 0; k < x.dimension(2); ++k) + { + for (int l = 0; l < x.dimension(3); ++l) + { + float val = x(i, j, k, l); + x_stddev += (val - x_mean) * (val - x_mean); + } + } + } + } + x_stddev = std::sqrt(x_stddev / (x.dimension(0) * x.dimension(1) * + x.dimension(2) * x.dimension(3))); + + // now print min, max, mean, stddev, and indices + std::cout << "\tmin: " << x_min << std::endl; + std::cout << "\tmax: " << x_max << std::endl; + std::cout << "\tmean: " << x_mean << std::endl; + std::cout << "\tstddev: " << x_stddev << std::endl; + std::cout << "\tsum: " << x_sum << std::endl; + std::cout << "\tmin idx: (" << x_min_idx_0 << ", " << x_min_idx_1 << ", " + << x_min_idx_2 << ", " << x_min_idx_3 << ")" << std::endl; + std::cout << "\tmax idx: (" << x_max_idx_0 << ", " << x_max_idx_1 << ", " + << x_max_idx_2 << ", " << x_max_idx_3 << ")" << std::endl; + + std::cout << "FINISHED DEBUG for tensor: " << name << std::endl; + ; +} + +// write function to debug a tensor and pause execution +inline void debug_tensor_3dxcf(const Eigen::Tensor3dXcf &x, + const std::string &name) +{ + // return; + std::cout << "Debugging tensor!: " << name << std::endl; + std::cout << "\tshape: (" << x.dimension(0) << ", " << x.dimension(1) + << ", " << x.dimension(2) << ")" << std::endl; + + float x_min_real = 100000000.0f; + float x_max_real = -100000000.0f; + float x_min_imag = 100000000.0f; + float x_max_imag = -100000000.0f; + float x_sum_real = 0.0f; + float x_sum_imag = 0.0f; + float x_mean_real = 0.0f; + float x_mean_imag = 0.0f; + float x_stddev_real = 0.0f; + float x_stddev_imag = 0.0f; + + // store dimension index to save index of min/max + int x_min_real_idx_0 = -1; + int x_min_real_idx_1 = -1; + int x_min_real_idx_2 = -1; + int x_max_real_idx_0 = -1; + int x_max_real_idx_1 = -1; + int x_max_real_idx_2 = -1; + + int x_min_imag_idx_0 = -1; + int x_min_imag_idx_1 = -1; + int x_min_imag_idx_2 = -1; + int x_max_imag_idx_0 = -1; + int x_max_imag_idx_1 = -1; + int x_max_imag_idx_2 = -1; + + // loop over tensor and find min/max/sum + for (int i = 0; i < x.dimension(0); ++i) + { + for (int j = 0; j < x.dimension(1); ++j) + { + for (int k = 0; k < x.dimension(2); ++k) + { + float real = std::real(x(i, j, k)); + float imag = std::imag(x(i, j, k)); + x_sum_real += real; + x_sum_imag += imag; + if (real < x_min_real) + { + x_min_real = real; + x_min_real_idx_0 = i; + x_min_real_idx_1 = j; + x_min_real_idx_2 = k; + } + if (real > x_max_real) + { + x_max_real = real; + x_max_real_idx_0 = i; + x_max_real_idx_1 = j; + x_max_real_idx_2 = k; + } + if (imag < x_min_imag) + { + x_min_imag = imag; + x_min_imag_idx_0 = i; + x_min_imag_idx_1 = j; + x_min_imag_idx_2 = k; + } + if (imag > x_max_imag) + { + x_max_imag = imag; + x_max_imag_idx_0 = i; + x_max_imag_idx_1 = j; + x_max_imag_idx_2 = k; + } + } + } + } + + // compute mean and standard deviation + x_mean_real = + x_sum_real / (x.dimension(0) * x.dimension(1) * x.dimension(2)); + x_mean_imag = + x_sum_imag / (x.dimension(0) * x.dimension(1) * x.dimension(2)); + for (int i = 0; i < x.dimension(0); ++i) + { + for (int j = 0; j < x.dimension(1); ++j) + { + for (int k = 0; k < x.dimension(2); ++k) + { + float real = std::real(x(i, j, k)); + float imag = std::imag(x(i, j, k)); + x_stddev_real += (real - x_mean_real) * (real - x_mean_real); + x_stddev_imag += (imag - x_mean_imag) * (imag - x_mean_imag); + } + } + } + x_stddev_real = std::sqrt( + x_stddev_real / (x.dimension(0) * x.dimension(1) * x.dimension(2))); + x_stddev_imag = std::sqrt( + x_stddev_imag / (x.dimension(0) * x.dimension(1) * x.dimension(2))); + + // now print min, max, mean, stddev, and indices for real and imaginary + // parts + std::cout << "\tmin real: " << x_min_real << std::endl; + std::cout << "\tmax real: " << x_max_real << std::endl; + std::cout << "\tmean real: " << x_mean_real << std::endl; + std::cout << "\tstddev real: " << x_stddev_real << std::endl; + std::cout << "\tmin real idx: (" << x_min_real_idx_0 << ", " + << x_min_real_idx_1 << ", " << x_min_real_idx_2 << ")" + << std::endl; + std::cout << "\tmax real idx: (" << x_max_real_idx_0 << ", " + << x_max_real_idx_1 << ", " << x_max_real_idx_2 << ")" + << std::endl; + std::cout << "\tsum real: " << x_sum_real << std::endl; + + std::cout << "\tmin imag: " << x_min_imag << std::endl; + std::cout << "\tmax imag: " << x_max_imag << std::endl; + std::cout << "\tmean imag: " << x_mean_imag << std::endl; + std::cout << "\tstddev imag: " << x_stddev_imag << std::endl; + std::cout << "\tmin imag idx: (" << x_min_imag_idx_0 << ", " + << x_min_imag_idx_1 << ", " << x_min_imag_idx_2 << ")" + << std::endl; + std::cout << "\tmax imag idx: (" << x_max_imag_idx_0 << ", " + << x_max_imag_idx_1 << ", " << x_max_imag_idx_2 << ")" + << std::endl; + std::cout << "\tsum imag: " << x_sum_imag << std::endl; + + std::cout << "FINISHED DEBUG for tensor: " << name << std::endl; +} + +// For Tensor3dXf +inline void debug_tensor_3dxf(const Eigen::Tensor3dXf &x, + const std::string &name) +{ + // return; + std::cout << "Debugging tensor!: " << name << std::endl; + std::cout << "\tshape: (" << x.dimension(0) << ", " << x.dimension(1) + << ", " << x.dimension(2) << ")" << std::endl; + + auto x_min = x.minimum(); + auto x_max = x.maximum(); + Eigen::Tensor x_sum_tensor = x.sum(); + float x_sum = x_sum_tensor(0); + Eigen::Tensor x_mean_tensor = x.mean(); + float x_mean = x_mean_tensor(0); + Eigen::Tensor x_stddev_tensor = + ((x - x_mean).square()).mean().sqrt(); + float x_stddev = x_stddev_tensor(0); + + // You might need to keep the existing loop for this purpose, or use other + // methods Re-inserting the loop for finding indices of min and max + int x_min_idx_0 = -1, x_min_idx_1 = -1, x_min_idx_2 = -1; + int x_max_idx_0 = -1, x_max_idx_1 = -1, x_max_idx_2 = -1; + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::lowest(); + + for (int i = 0; i < x.dimension(0); ++i) + { + for (int j = 0; j < x.dimension(1); ++j) + { + for (int k = 0; k < x.dimension(2); ++k) + { + float val = x(i, j, k); + if (val < min_val) + { + min_val = val; + x_min_idx_0 = i; + x_min_idx_1 = j; + x_min_idx_2 = k; + } + if (val > max_val) + { + max_val = val; + x_max_idx_0 = i; + x_max_idx_1 = j; + x_max_idx_2 = k; + } + } + } + } + + std::cout << "\tmin: " << x_min << std::endl; + std::cout << "\tmax: " << x_max << std::endl; + std::cout << "\tmean: " << x_mean << std::endl; + std::cout << "\tstddev: " << x_stddev << std::endl; + std::cout << "\tsum: " << x_sum << std::endl; + std::cout << "\tmin idx: (" << x_min_idx_0 << ", " << x_min_idx_1 << ", " + << x_min_idx_2 << ")" << std::endl; + std::cout << "\tmax idx: (" << x_max_idx_0 << ", " << x_max_idx_1 << ", " + << x_max_idx_2 << ")" << std::endl; + + std::cout << "FINISHED DEBUG for tensor: " << name << std::endl; +} + +// For Tensor3dXf +inline void debug_tensor_2dxf(const Eigen::Tensor &x, + const std::string &name) +{ + // return; + std::cout << "Debugging tensor!: " << name << std::endl; + std::cout << "\tshape: (" << x.dimension(0) << ", " << x.dimension(1) << ")" + << std::endl; + + auto x_min = x.minimum(); + auto x_max = x.maximum(); + Eigen::Tensor x_sum_tensor = x.sum(); + float x_sum = x_sum_tensor(0); + Eigen::Tensor x_mean_tensor = x.mean(); + float x_mean = x_mean_tensor(0); + Eigen::Tensor x_stddev_tensor = + ((x - x_mean).square()).mean().sqrt(); + float x_stddev = x_stddev_tensor(0); + + // You might need to keep the existing loop for this purpose, or use other + // methods Re-inserting the loop for finding indices of min and max + int x_min_idx_0 = -1, x_min_idx_1 = -1; + int x_max_idx_0 = -1, x_max_idx_1 = -1; + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::lowest(); + + for (int i = 0; i < x.dimension(0); ++i) + { + for (int j = 0; j < x.dimension(1); ++j) + { + float val = x(i, j); + if (val < min_val) + { + min_val = val; + x_min_idx_0 = i; + x_min_idx_1 = j; + } + if (val > max_val) + { + max_val = val; + x_max_idx_0 = i; + x_max_idx_1 = j; + } + } + } + + std::cout << "\tmin: " << x_min << std::endl; + std::cout << "\tmax: " << x_max << std::endl; + std::cout << "\tmean: " << x_mean << std::endl; + std::cout << "\tstddev: " << x_stddev << std::endl; + std::cout << "\tsum: " << x_sum << std::endl; + std::cout << "\tmin idx: (" << x_min_idx_0 << ", " << x_min_idx_1 << ")" + << std::endl; + std::cout << "\tmax idx: (" << x_max_idx_0 << ", " << x_max_idx_1 << ")" + << std::endl; + + std::cout << "FINISHED DEBUG for tensor: " << name << std::endl; +} + +// For Tensor1dXf +inline void debug_tensor_1dxf(const Eigen::Tensor1dXf &x, + const std::string &name) +{ + // return; + std::cout << "Debugging tensor!: " << name << std::endl; + std::cout << "\tshape: (" << x.dimension(0) << ")" << std::endl; + + float x_min = 100000000.0f; + float x_max = -100000000.0f; + float x_sum = 0.0f; + float x_mean = 0.0f; + float x_stddev = 0.0f; + + // store dimension index to save index of min/max + int x_min_idx_0 = -1; + int x_max_idx_0 = -1; + + // loop over tensor and find min/max/sum + for (int i = 0; i < x.dimension(0); ++i) + { + float val = x(i); + x_sum += val; + if (val < x_min) + { + x_min = val; + x_min_idx_0 = i; + } + if (val > x_max) + { + x_max = val; + x_max_idx_0 = i; + } + } + + // compute mean and standard deviation + x_mean = x_sum / x.dimension(0); + for (int i = 0; i < x.dimension(0); ++i) + { + float val = x(i); + x_stddev += (val - x_mean) * (val - x_mean); + } + x_stddev = std::sqrt(x_stddev / x.dimension(0)); + + // now print min, max, mean, stddev, and indices + std::cout << "\tmin: " << x_min << std::endl; + std::cout << "\tmax: " << x_max << std::endl; + std::cout << "\tmean: " << x_mean << std::endl; + std::cout << "\tstddev: " << x_stddev << std::endl; + std::cout << "\tsum: " << x_sum << std::endl; + std::cout << "\tmin idx: (" << x_min_idx_0 << ")" << std::endl; + std::cout << "\tmax idx: (" << x_max_idx_0 << ")" << std::endl; + + std::cout << "FINISHED DEBUG for tensor: " << name << std::endl; +} + +// For Tensor4dXh +inline void debug_tensor_4dxh(const Eigen::Tensor4dXh &x, + const std::string &name) +{ + // return; + std::cout << "Debugging tensor!: " << name << std::endl; + std::cout << "\tshape: (" << x.dimension(0) << ", " << x.dimension(1) + << ", " << x.dimension(2) << ", " << x.dimension(3) << ")" + << std::endl; + + Eigen::half x_min = Eigen::half(100000000.0f); + Eigen::half x_max = Eigen::half(-100000000.0f); + float x_sum = 0.0f; + float x_mean = 0.0f; + float x_stddev = 0.0f; + + // store dimension index to save index of min/max + int x_min_idx_0 = -1; + int x_min_idx_1 = -1; + int x_min_idx_2 = -1; + int x_min_idx_3 = -1; + int x_max_idx_0 = -1; + int x_max_idx_1 = -1; + int x_max_idx_2 = -1; + int x_max_idx_3 = -1; + + // loop over tensor and find min/max/sum + for (int i = 0; i < x.dimension(0); ++i) + { + for (int j = 0; j < x.dimension(1); ++j) + { + for (int k = 0; k < x.dimension(2); ++k) + { + for (int l = 0; l < x.dimension(3); ++l) + { + Eigen::half val = x(i, j, k, l); + x_sum += val; + if (val < x_min) + { + x_min = val; + x_min_idx_0 = i; + x_min_idx_1 = j; + x_min_idx_2 = k; + x_min_idx_3 = l; + } + if (val > x_max) + { + x_max = val; + x_max_idx_0 = i; + x_max_idx_1 = j; + x_max_idx_2 = k; + x_max_idx_3 = l; + } + } + } + } + } + + // compute mean and standard deviation + x_mean = x_sum / (x.dimension(0) * x.dimension(1) * x.dimension(2) * + x.dimension(3)); + for (int i = 0; i < x.dimension(0); ++i) + { + for (int j = 0; j < x.dimension(1); ++j) + { + for (int k = 0; k < x.dimension(2); ++k) + { + for (int l = 0; l < x.dimension(3); ++l) + { + Eigen::half val = x(i, j, k, l); + x_stddev += (val - x_mean) * (val - x_mean); + } + } + } + } + x_stddev = std::sqrt(x_stddev / (x.dimension(0) * x.dimension(1) * + x.dimension(2) * x.dimension(3))); + + // now print min, max, mean, stddev, and indices + std::cout << "\tmin: " << Eigen::half(x_min) << std::endl; + std::cout << "\tmax: " << Eigen::half(x_max) << std::endl; + std::cout << "\tmean: " << x_mean << std::endl; + std::cout << "\tstddev: " << x_stddev << std::endl; + std::cout << "\tsum: " << x_sum << std::endl; + std::cout << "\tmin idx: (" << x_min_idx_0 << ", " << x_min_idx_1 << ", " + << x_min_idx_2 << ", " << x_min_idx_3 << ")" << std::endl; + std::cout << "\tmax idx: (" << x_max_idx_0 << ", " << x_max_idx_1 << ", " + << x_max_idx_2 << ", " << x_max_idx_3 << ")" << std::endl; + + std::cout << "FINISHED DEBUG for tensor: " << name << std::endl; +} + +// For MatrixXf +inline void debug_matrix_xf(const Eigen::MatrixXf &x, const std::string &name) +{ + // return; + std::cout << "Debugging matrix!: " << name << std::endl; + std::cout << "\tshape: (" << x.rows() << ", " << x.cols() << ")" + << std::endl; + + float x_min = 100000000.0f; + float x_max = -100000000.0f; + float x_sum = 0.0f; + float x_mean = 0.0f; + float x_stddev = 0.0f; + + // store dimension index to save index of min/max + int x_min_idx_0 = -1; + int x_min_idx_1 = -1; + int x_max_idx_0 = -1; + int x_max_idx_1 = -1; + + // loop over matrix and find min/max/sum + for (int i = 0; i < x.rows(); ++i) + { + for (int j = 0; j < x.cols(); ++j) + { + float value = x(i, j); + x_sum += value; + if (value < x_min) + { + x_min = value; + x_min_idx_0 = i; + x_min_idx_1 = j; + } + if (value > x_max) + { + x_max = value; + x_max_idx_0 = i; + x_max_idx_1 = j; + } + } + } + + // compute mean and standard deviation + x_mean = x_sum / (x.rows() * x.cols()); + for (int i = 0; i < x.rows(); ++i) + { + for (int j = 0; j < x.cols(); ++j) + { + float value = x(i, j); + x_stddev += (value - x_mean) * (value - x_mean); + } + } + x_stddev = std::sqrt(x_stddev / (x.rows() * x.cols())); + + // now print min, max, mean, stddev, median, and indices + std::cout << "\tmin: " << x_min << std::endl; + std::cout << "\tmax: " << x_max << std::endl; + std::cout << "\tmean: " << x_mean << std::endl; + std::cout << "\tstddev: " << x_stddev << std::endl; + std::cout << "\tsum: " << x_sum << std::endl; + std::cout << "\tmin idx: (" << x_min_idx_0 << ", " << x_min_idx_1 << ")" + << std::endl; + std::cout << "\tmax idx: (" << x_max_idx_0 << ", " << x_max_idx_1 << ")" + << std::endl; + + std::cout << "FINISHED DEBUG for tensor: " << name << std::endl; +} + +// debug VectorXf +inline void debug_vector_xf(const Eigen::VectorXf &x, const std::string &name) +{ + // return; + std::cout << "Debugging vector!: " << name << std::endl; + std::cout << "\tshape: (" << x.size() << ")" << std::endl; + + float x_min = 100000000.0f; + float x_max = -100000000.0f; + float x_sum = 0.0f; + float x_mean = 0.0f; + float x_stddev = 0.0f; + + // store dimension index to save index of min/max + int x_min_idx_0 = -1; + int x_max_idx_0 = -1; + + // loop over vector and find min/max/sum + for (int i = 0; i < x.size(); ++i) + { + float value = static_cast(x(i)); + x_sum += value; + if (value < x_min) + { + x_min = value; + x_min_idx_0 = i; + } + if (value > x_max) + { + x_max = value; + x_max_idx_0 = i; + } + } + + // compute mean and standard deviation + x_mean = x_sum / x.size(); + for (int i = 0; i < x.size(); ++i) + { + float value = static_cast(x(i)); + x_stddev += (value - x_mean) * (value - x_mean); + } + x_stddev = std::sqrt(x_stddev / x.size()); + + // now print min, max, mean, stddev, median, and indices + std::cout << "\tmin: " << x_min << std::endl; + std::cout << "\tmax: " << x_max << std::endl; + std::cout << "\tmean: " << x_mean << std::endl; + std::cout << "\tstddev: " << x_stddev << std::endl; + std::cout << "\tsum: " << x_sum << std::endl; + std::cout << "\tmin idx: (" << x_min_idx_0 << ")" << std::endl; + std::cout << "\tmax idx: (" << x_max_idx_0 << ")" << std::endl; + + std::cout << "FINISHED DEBUG for tensor: " << name << std::endl; +} + +inline void debug_tensor_3dxd(const Eigen::Tensor &x, + const std::string &name) +{ + // return; + std::cout << "Debugging tensor!: " << name << std::endl; + std::cout << "\tshape: (" << x.dimension(0) << ", " << x.dimension(1) + << ", " << x.dimension(2) << ")" << std::endl; + + auto x_min = x.minimum(); + auto x_max = x.maximum(); + Eigen::Tensor x_sum_tensor = x.sum(); + double x_sum = x_sum_tensor(0); + Eigen::Tensor x_mean_tensor = x.mean(); + double x_mean = x_mean_tensor(0); + Eigen::Tensor x_stddev_tensor = + ((x - x_mean).square()).mean().sqrt(); + double x_stddev = x_stddev_tensor(0); + + // You might need to keep the existing loop for this purpose, or use other + // methods Re-inserting the loop for finding indices of min and max + int x_min_idx_0 = -1, x_min_idx_1 = -1, x_min_idx_2 = -1; + int x_max_idx_0 = -1, x_max_idx_1 = -1, x_max_idx_2 = -1; + double min_val = std::numeric_limits::max(); + double max_val = std::numeric_limits::lowest(); + + for (int i = 0; i < x.dimension(0); ++i) + { + for (int j = 0; j < x.dimension(1); ++j) + { + for (int k = 0; k < x.dimension(2); ++k) + { + double val = x(i, j, k); + if (val < min_val) + { + min_val = val; + x_min_idx_0 = i; + x_min_idx_1 = j; + x_min_idx_2 = k; + } + if (val > max_val) + { + max_val = val; + x_max_idx_0 = i; + x_max_idx_1 = j; + x_max_idx_2 = k; + } + } + } + } + + std::cout << "\tmin: " << x_min << std::endl; + std::cout << "\tmax: " << x_max << std::endl; + std::cout << "\tmean: " << x_mean << std::endl; + std::cout << "\tstddev: " << x_stddev << std::endl; + std::cout << "\tsum: " << x_sum << std::endl; + std::cout << "\tmin idx: (" << x_min_idx_0 << ", " << x_min_idx_1 << ", " + << x_min_idx_2 << ")" << std::endl; + std::cout << "\tmax idx: (" << x_max_idx_0 << ", " << x_max_idx_1 << ", " + << x_max_idx_2 << ")" << std::endl; + + std::cout << "FINISHED DEBUG for tensor: " << name << std::endl; +} +} // namespace demucscppdebug + +#define NEAR_TOLERANCE 1e-4 + +// initialize a struct demucs_model +static struct demucscpp_v3::demucs_v3_model model +{ +}; +static bool loaded = false; + +// google test global setup for model before all tests +static void setUpTestSuite() +{ + if (loaded) + { + return; + } + + // load model from "../ggml-demucs/ggml-model-htdemucs-f16.bin" + std::string model_file = "../ggml-demucs/ggml-model-hdemucs_mmi-v3-f16.bin"; + + auto ret = demucscpp_v3::load_demucs_v3_model(model_file, &model); + std::cout << "demucs_model_load returned " << (ret ? "true" : "false") + << std::endl; + if (!ret) + { + std::cerr << "Error loading model" << std::endl; + exit(1); + } + + loaded = true; +} + +// write a basic test case for a stereo file +TEST(DemucsCPP_V3_Layers, FreqEncoders03) +{ + setUpTestSuite(); + + std::cout << std::fixed << std::setprecision(20) << std::endl; + + Eigen::Tensor3dXf x_fake(4, 2048, 336); + + // fill with -1, 1 alternating + for (long i = 0; i < 4; ++i) + { + for (long j = 0; j < 2048; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + x_fake(i, j, k) = -1.0; + } + else + { + x_fake(i, j, k) = 1.0; + } + } + } + } + + Eigen::Tensor3dXf x_fake_enc_0(48, 512, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 0, x_fake, x_fake_enc_0); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake, "x_fake"); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_0, "x_fake_enc_0"); + + Eigen::Tensor3dXf x_fake_enc_1(96, 128, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 1, x_fake_enc_0, x_fake_enc_1); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_1, "x_fake_enc_1"); + + Eigen::Tensor3dXf x_fake_enc_2(192, 32, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 2, x_fake_enc_1, x_fake_enc_2); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_2, "x_fake_enc_2"); + + Eigen::Tensor3dXf x_fake_enc_3(384, 8, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 3, x_fake_enc_2, x_fake_enc_3); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_3, "x_fake_enc_3"); +} + +// write a basic test case for a stereo file +TEST(DemucsCPP_V3_Layers, TimeEncoders03) +{ + setUpTestSuite(); + std::cout << std::fixed << std::setprecision(20) << std::endl; + + Eigen::Tensor3dXf xt_fake(1, 2, 343980); + + // fill with -1, 1 alternating + for (long i = 0; i < 1; ++i) + { + for (long j = 0; j < 2; ++j) + { + for (long k = 0; k < 343980; ++k) + { + if (k % 2 == 0) + { + xt_fake(i, j, k) = -1.0; + } + else + { + xt_fake(i, j, k) = 1.0; + } + } + } + } + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake, "xt_fake"); + + Eigen::Tensor3dXf xt_fake_enc_0(1, 48, 85995); + demucscpp_v3::apply_time_encoder_v3(model, 0, xt_fake, xt_fake_enc_0); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_0, "xt_fake_enc_0"); + + Eigen::Tensor3dXf xt_fake_enc_1(1, 96, 21499); + demucscpp_v3::apply_time_encoder_v3(model, 1, xt_fake_enc_0, xt_fake_enc_1); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_1, "xt_fake_enc_1"); + + Eigen::Tensor3dXf xt_fake_enc_2(1, 192, 5375); + + demucscpp_v3::apply_time_encoder_v3(model, 2, xt_fake_enc_1, xt_fake_enc_2); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_2, "xt_fake_enc_2"); + + Eigen::Tensor3dXf xt_fake_enc_3(1, 384, 1344); + + demucscpp_v3::apply_time_encoder_v3(model, 3, xt_fake_enc_2, xt_fake_enc_3); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_3, "xt_fake_enc_3"); +} + +TEST(DemucsCPP_V3_Layers, Encoders45) +{ + setUpTestSuite(); + + std::cout << std::fixed << std::setprecision(20) << std::endl; + + Eigen::Tensor3dXf x_fake(4, 2048, 336); + + // fill with -1, 1 alternating + for (long i = 0; i < 4; ++i) + { + for (long j = 0; j < 2048; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + x_fake(i, j, k) = -1.0; + } + else + { + x_fake(i, j, k) = 1.0; + } + } + } + } + + Eigen::Tensor3dXf xt_fake(1, 2, 343980); + + // fill with -1, 1 alternating + for (long i = 0; i < 1; ++i) + { + for (long j = 0; j < 2; ++j) + { + for (long k = 0; k < 343980; ++k) + { + if (k % 2 == 0) + { + xt_fake(i, j, k) = -1.0; + } + else + { + xt_fake(i, j, k) = 1.0; + } + } + } + } + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake, "x_fake"); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake, "xt_fake"); + + // first 4 freq encoders + Eigen::Tensor3dXf x_fake_enc_0(48, 512, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 0, x_fake, x_fake_enc_0); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake, "x_fake"); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_0, "x_fake_enc_0"); + + Eigen::Tensor3dXf x_fake_enc_1(96, 128, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 1, x_fake_enc_0, x_fake_enc_1); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_1, "x_fake_enc_1"); + + Eigen::Tensor3dXf x_fake_enc_2(192, 32, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 2, x_fake_enc_1, x_fake_enc_2); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_2, "x_fake_enc_2"); + + Eigen::Tensor3dXf x_fake_enc_3(384, 8, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 3, x_fake_enc_2, x_fake_enc_3); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_3, "x_fake_enc_3"); + + // calculate segment in samples + int segment_samples = + (int)(demucscpp::SEGMENT_LEN_SECS * demucscpp::SUPPORTED_SAMPLE_RATE); + + // let's create reusable buffers with padded sizes + struct demucscpp_v3::demucs_v3_segment_buffers buffers(2, segment_samples, + 4); + + // then 4 time encoders + Eigen::Tensor3dXf xt_fake_enc_0(1, 48, 85995); + demucscpp_v3::apply_time_encoder_v3(model, 0, xt_fake, xt_fake_enc_0); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_0, "xt_fake_enc_0"); + + Eigen::Tensor3dXf xt_fake_enc_1(1, 96, 21499); + demucscpp_v3::apply_time_encoder_v3(model, 1, xt_fake_enc_0, xt_fake_enc_1); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_1, "xt_fake_enc_1"); + + Eigen::Tensor3dXf xt_fake_enc_2(1, 192, 5375); + + demucscpp_v3::apply_time_encoder_v3(model, 2, xt_fake_enc_1, xt_fake_enc_2); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_2, "xt_fake_enc_2"); + + Eigen::Tensor3dXf xt_fake_enc_3(1, 384, 1344); + + demucscpp_v3::apply_time_encoder_v3(model, 3, xt_fake_enc_2, xt_fake_enc_3); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_3, "xt_fake_enc_3"); + + Eigen::Tensor3dXf xt_fake_enc_4(1, 768, 336); + demucscpp_v3::apply_time_encoder_4(model, xt_fake_enc_3, xt_fake_enc_4); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_4, "xt_fake_enc_4"); + + // now apply the shared encoders with time inject + + Eigen::Tensor3dXf x_fake_enc_4(768, 1, 336); + demucscpp_v3::apply_freq_encoder_4(model, x_fake_enc_3, xt_fake_enc_4, + x_fake_enc_4, buffers); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_4, "x_fake_enc_4"); + + Eigen::Tensor3dXf x_fake_shared_enc_5(1536, 1, 168); + demucscpp_v3::apply_shared_encoder_5(model, x_fake_enc_4, x_fake_shared_enc_5, buffers); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_shared_enc_5, "x_fake_shared_enc_5"); +} + +TEST(DemucsCPP_V3_Layers, Decoders01) +{ + setUpTestSuite(); + + std::cout << std::fixed << std::setprecision(20) << std::endl; + + // calculate segment in samples + int segment_samples = + (int)(demucscpp::SEGMENT_LEN_SECS * demucscpp::SUPPORTED_SAMPLE_RATE); + + // let's create reusable buffers with padded sizes + struct demucscpp_v3::demucs_v3_segment_buffers buffers(2, segment_samples, + 4); + + Eigen::Tensor3dXf x_fake_shared_enc_5(1, 1536, 168); + + // fill with -1, 1 alternating + for (long j = 0; j < 1536; ++j) + { + for (long k = 0; k < 168; ++k) + { + if (k % 2 == 0) + { + x_fake_shared_enc_5(0, j, k) = -1.0; + } + else + { + x_fake_shared_enc_5(0, j, k) = 1.0; + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_4(768, 1, 336); + // fill with alternating -0.5, 0.5 + + for (long j = 0; j < 768; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_4(j, 0, k) = 0.5; + } + else + { + skip_fake_dec_4(j, 0, k) = -0.5; + } + } + } + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_shared_enc_5, "x_fake_shared_enc_5"); + demucscppdebug_test_v3::debug_tensor_3dxf(skip_fake_dec_4, "skip_fake_dec_4"); + + Eigen::Tensor3dXf x_fake_dec_4(768, 1, 336); + Eigen::Tensor3dXf pre_t_unused = demucscpp_v3::apply_shared_decoder_0(model, x_fake_dec_4, x_fake_shared_enc_5); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_4, "x_fake_dec_4"); + demucscppdebug_test_v3::debug_tensor_3dxf(pre_t_unused, "pre_t_unused"); + + Eigen::Tensor3dXf x_fake_dec_3(384, 8, 336); + + Eigen::Tensor3dXf pre_t = demucscpp_v3::apply_freq_decoder_1( + model, x_fake_dec_4, x_fake_dec_3, skip_fake_dec_4); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_3, "x_fake_dec_3"); + demucscppdebug_test_v3::debug_tensor_3dxf(pre_t, "pre_t"); + + // remember we leapfrog xt_fake_dec_4 + Eigen::Tensor3dXf xt_fake_dec_3(1, 768, 336); + + demucscpp_v3::apply_time_decoder_0(model, pre_t, xt_fake_dec_3); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_dec_3, "xt_fake_dec_3"); +} + +TEST(DemucsCPP_V3_Layers, Decoder1Isolated) +{ + setUpTestSuite(); + + std::cout << std::fixed << std::setprecision(20) << std::endl; + + // calculate segment in samples + int segment_samples = + (int)(demucscpp::SEGMENT_LEN_SECS * demucscpp::SUPPORTED_SAMPLE_RATE); + + // let's create reusable buffers with padded sizes + struct demucscpp_v3::demucs_v3_segment_buffers buffers(2, segment_samples, + 4); + + Eigen::Tensor3dXf x_fake_dec_4(1, 768, 336); + + // fill with -1, 1 alternating + for (long j = 0; j < 768; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + x_fake_dec_4(0, j, k) = -1.0; + } + else + { + x_fake_dec_4(0, j, k) = 1.0; + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_4(768, 1, 336); + // fill with alternating -0.5, 0.5 + + for (long j = 0; j < 768; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_4(j, 0, k) = 0.5; + } + else + { + skip_fake_dec_4(j, 0, k) = -0.5; + } + } + } + + Eigen::Tensor3dXf x_fake_dec_3(384, 8, 336); + + Eigen::Tensor3dXf pre_t = demucscpp_v3::apply_freq_decoder_1( + model, x_fake_dec_4, x_fake_dec_3, skip_fake_dec_4); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_3, "x_fake_dec_3"); + demucscppdebug_test_v3::debug_tensor_3dxf(pre_t, "pre_t"); + + // remember we leapfrog xt_fake_dec_4 + Eigen::Tensor3dXf xt_fake_dec_3(1, 768, 336); + + demucscpp_v3::apply_time_decoder_0(model, pre_t, xt_fake_dec_3); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_dec_3, "xt_fake_dec_3"); +} + +TEST(DemucsCPP_V3_Layers, AllDecoders) +{ + setUpTestSuite(); + + std::cout << std::fixed << std::setprecision(20) << std::endl; + + // calculate segment in samples + int segment_samples = + (int)(demucscpp::SEGMENT_LEN_SECS * demucscpp::SUPPORTED_SAMPLE_RATE); + + // let's create reusable buffers with padded sizes + struct demucscpp_v3::demucs_v3_segment_buffers buffers(2, segment_samples, + 4); + + Eigen::Tensor3dXf x_fake_shared_enc_5(1, 1536, 168); + + // fill with -1, 1 alternating + for (long j = 0; j < 1536; ++j) + { + for (long k = 0; k < 168; ++k) + { + if (k % 2 == 0) + { + x_fake_shared_enc_5(0, j, k) = -1.0; + } + else + { + x_fake_shared_enc_5(0, j, k) = 1.0; + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_4(768, 1, 336); + // fill with alternating -0.5, 0.5 + + for (long j = 0; j < 768; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_4(j, 0, k) = 0.5; + } + else + { + skip_fake_dec_4(j, 0, k) = -0.5; + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_3(384, 8, 336); + // fill with alternating -0.5, 0.5 + + for (long i = 0; i < 8; ++i) + { + for (long j = 0; j < 384; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_3(j, i, k) = 0.5; + } + else + { + skip_fake_dec_3(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_2(192, 32, 336); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 32; ++i) + { + for (long j = 0; j < 192; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_2(j, i, k) = 0.5; + } + else + { + skip_fake_dec_2(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_1(96, 128, 336); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 128; ++i) + { + for (long j = 0; j < 96; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_1(j, i, k) = 0.5; + } + else + { + skip_fake_dec_1(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_0(48, 512, 336); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 512; ++i) + { + for (long j = 0; j < 48; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_0(j, i, k) = 0.5; + } + else + { + skip_fake_dec_0(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_tdec_3(1, 384, 1344); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 384; ++i) + { + for (long j = 0; j < 1; ++j) + { + for (long k = 0; k < 1344; ++k) + { + if (k % 2 == 0) + { + skip_fake_tdec_3(j, i, k) = 0.5; + } + else + { + skip_fake_tdec_3(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_tdec_2(1, 192, 5375); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 192; ++i) + { + for (long j = 0; j < 1; ++j) + { + for (long k = 0; k < 5375; ++k) + { + if (k % 2 == 0) + { + skip_fake_tdec_2(j, i, k) = 0.5; + } + else + { + skip_fake_tdec_2(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_tdec_1(1, 96, 21499); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 96; ++i) + { + for (long j = 0; j < 1; ++j) + { + for (long k = 0; k < 21499; ++k) + { + if (k % 2 == 0) + { + skip_fake_tdec_1(j, i, k) = 0.5; + } + else + { + skip_fake_tdec_1(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_tdec_0(1, 48, 85995); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 48; ++i) + { + for (long j = 0; j < 1; ++j) + { + for (long k = 0; k < 85995; ++k) + { + if (k % 2 == 0) + { + skip_fake_tdec_0(j, i, k) = 0.5; + } + else + { + skip_fake_tdec_0(j, i, k) = -0.5; + } + } + } + } + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_shared_enc_5, "x_fake_shared_enc_5"); + demucscppdebug_test_v3::debug_tensor_3dxf(skip_fake_dec_4, "skip_fake_dec_4"); + + Eigen::Tensor3dXf x_fake_dec_4(768, 1, 336); + Eigen::Tensor3dXf pre_t_unused = demucscpp_v3::apply_shared_decoder_0(model, x_fake_dec_4, x_fake_shared_enc_5); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_4, "x_fake_dec_4"); + demucscppdebug_test_v3::debug_tensor_3dxf(pre_t_unused, "pre_t_unused"); + + Eigen::Tensor3dXf x_fake_dec_3(384, 8, 336); + Eigen::Tensor3dXf pre_t = demucscpp_v3::apply_freq_decoder_1( + model, x_fake_dec_4, x_fake_dec_3, skip_fake_dec_4); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_3, "x_fake_dec_3"); + demucscppdebug_test_v3::debug_tensor_3dxf(pre_t, "pre_t"); + + // remember we leapfrog xt_fake_dec_4 + Eigen::Tensor3dXf xt_fake_dec_3(1, 768, 336); + demucscpp_v3::apply_time_decoder_0(model, pre_t, xt_fake_dec_3); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_dec_3, "xt_fake_dec_3"); + + Eigen::Tensor3dXf x_fake_dec_2(192, 32, 336); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_3, "x_fake_dec_3"); + demucscppdebug_test_v3::debug_tensor_3dxf(skip_fake_dec_3, "skip_fake_dec_3"); + + demucscpp_v3::apply_common_decoder(model, 0, 0, x_fake_dec_3, x_fake_dec_2, skip_fake_dec_3); + + Eigen::Tensor3dXf xt_fake_dec_2(1, 384, 1344); + demucscpp_v3::apply_common_decoder(model, 1, 0, xt_fake_dec_3, xt_fake_dec_2, skip_fake_tdec_3); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_2, "x_fake_dec_2"); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_dec_2, "xt_fake_dec_2"); +} + +// write a basic test case for a stereo file +TEST(DemucsCPP_V3_Layers, End2End) +{ + setUpTestSuite(); + + std::cout << std::fixed << std::setprecision(20) << std::endl; + + Eigen::Tensor3dXf x_fake(4, 2048, 336); + + // fill with -1, 1 alternating + for (long i = 0; i < 4; ++i) + { + for (long j = 0; j < 2048; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + x_fake(i, j, k) = -1.0; + } + else + { + x_fake(i, j, k) = 1.0; + } + } + } + } + + Eigen::Tensor3dXf x_fake_enc_0(48, 512, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 0, x_fake, x_fake_enc_0); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake, "x_fake"); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_0, "x_fake_enc_0"); + + Eigen::Tensor3dXf x_fake_enc_1(96, 128, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 1, x_fake_enc_0, x_fake_enc_1); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_1, "x_fake_enc_1"); + + Eigen::Tensor3dXf x_fake_enc_2(192, 32, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 2, x_fake_enc_1, x_fake_enc_2); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_2, "x_fake_enc_2"); + + Eigen::Tensor3dXf x_fake_enc_3(384, 8, 336); + demucscpp_v3::apply_freq_encoder_v3(model, 3, x_fake_enc_2, x_fake_enc_3); + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_3, "x_fake_enc_3"); + + Eigen::Tensor3dXf xt_fake(1, 2, 343980); + + // fill with -1, 1 alternating + for (long i = 0; i < 1; ++i) + { + for (long j = 0; j < 2; ++j) + { + for (long k = 0; k < 343980; ++k) + { + if (k % 2 == 0) + { + xt_fake(i, j, k) = -1.0; + } + else + { + xt_fake(i, j, k) = 1.0; + } + } + } + } + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake, "xt_fake"); + + Eigen::Tensor3dXf xt_fake_enc_0(1, 48, 85995); + demucscpp_v3::apply_time_encoder_v3(model, 0, xt_fake, xt_fake_enc_0); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_0, "xt_fake_enc_0"); + + Eigen::Tensor3dXf xt_fake_enc_1(1, 96, 21499); + demucscpp_v3::apply_time_encoder_v3(model, 1, xt_fake_enc_0, xt_fake_enc_1); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_1, "xt_fake_enc_1"); + + Eigen::Tensor3dXf xt_fake_enc_2(1, 192, 5375); + + demucscpp_v3::apply_time_encoder_v3(model, 2, xt_fake_enc_1, xt_fake_enc_2); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_2, "xt_fake_enc_2"); + + Eigen::Tensor3dXf xt_fake_enc_3(1, 384, 1344); + + demucscpp_v3::apply_time_encoder_v3(model, 3, xt_fake_enc_2, xt_fake_enc_3); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_3, "xt_fake_enc_3"); + + // calculate segment in samples + int segment_samples = + (int)(demucscpp::SEGMENT_LEN_SECS * demucscpp::SUPPORTED_SAMPLE_RATE); + + // let's create reusable buffers with padded sizes + struct demucscpp_v3::demucs_v3_segment_buffers buffers(2, segment_samples, + 4); + + Eigen::Tensor3dXf xt_fake_enc_4(1, 768, 336); + demucscpp_v3::apply_time_encoder_4(model, xt_fake_enc_3, xt_fake_enc_4); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_enc_4, "xt_fake_enc_4"); + + // now apply the shared encoders with time inject + + Eigen::Tensor3dXf x_fake_enc_4(768, 1, 336); + demucscpp_v3::apply_freq_encoder_4(model, x_fake_enc_3, xt_fake_enc_4, + x_fake_enc_4, buffers); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_enc_4, "x_fake_enc_4"); + + Eigen::Tensor3dXf x_fake_shared_enc_5(1536, 1, 168); + demucscpp_v3::apply_shared_encoder_5(model, x_fake_enc_4, x_fake_shared_enc_5, buffers); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_shared_enc_5, "x_fake_shared_enc_5"); + + Eigen::Tensor3dXf skip_fake_dec_4(768, 1, 336); + // fill with alternating -0.5, 0.5 + + for (long j = 0; j < 768; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_4(j, 0, k) = 0.5; + } + else + { + skip_fake_dec_4(j, 0, k) = -0.5; + } + } + } + + demucscppdebug_test_v3::debug_tensor_3dxf(skip_fake_dec_4, "skip_fake_dec_4"); + + Eigen::Tensor3dXf x_fake_dec_4(768, 1, 336); + Eigen::Tensor3dXf pre_t_unused = demucscpp_v3::apply_shared_decoder_0( + model, x_fake_dec_4, x_fake_shared_enc_5); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_4, "x_fake_dec_4"); + demucscppdebug_test_v3::debug_tensor_3dxf(pre_t_unused, "pre_t_unused"); + + Eigen::Tensor3dXf x_fake_dec_3(384, 8, 336); + Eigen::Tensor3dXf pre_t = demucscpp_v3::apply_freq_decoder_1( + model, x_fake_dec_4, x_fake_dec_3, skip_fake_dec_4); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_3, "x_fake_dec_3"); + demucscppdebug_test_v3::debug_tensor_3dxf(pre_t, "pre_t"); + + // remember we leapfrog xt_fake_dec_4 + Eigen::Tensor3dXf xt_fake_dec_3(1, 384, 1344); + + demucscpp_v3::apply_time_decoder_0(model, pre_t, xt_fake_dec_3); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_dec_3, "xt_fake_dec_3"); + + Eigen::Tensor3dXf skip_fake_dec_3(384, 8, 336); + // fill with alternating -0.5, 0.5 + + for (long i = 0; i < 8; ++i) + { + for (long j = 0; j < 384; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_3(j, i, k) = 0.5; + } + else + { + skip_fake_dec_3(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_2(192, 32, 336); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 32; ++i) + { + for (long j = 0; j < 192; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_2(j, i, k) = 0.5; + } + else + { + skip_fake_dec_2(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_1(96, 128, 336); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 128; ++i) + { + for (long j = 0; j < 96; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_1(j, i, k) = 0.5; + } + else + { + skip_fake_dec_1(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_dec_0(48, 512, 336); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 512; ++i) + { + for (long j = 0; j < 48; ++j) + { + for (long k = 0; k < 336; ++k) + { + if (k % 2 == 0) + { + skip_fake_dec_0(j, i, k) = 0.5; + } + else + { + skip_fake_dec_0(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_tdec_3(1, 384, 1344); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 384; ++i) + { + for (long j = 0; j < 1; ++j) + { + for (long k = 0; k < 1344; ++k) + { + if (k % 2 == 0) + { + skip_fake_tdec_3(j, i, k) = 0.5; + } + else + { + skip_fake_tdec_3(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_tdec_2(1, 192, 5375); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 192; ++i) + { + for (long j = 0; j < 1; ++j) + { + for (long k = 0; k < 5375; ++k) + { + if (k % 2 == 0) + { + skip_fake_tdec_2(j, i, k) = 0.5; + } + else + { + skip_fake_tdec_2(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_tdec_1(1, 96, 21499); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 96; ++i) + { + for (long j = 0; j < 1; ++j) + { + for (long k = 0; k < 21499; ++k) + { + if (k % 2 == 0) + { + skip_fake_tdec_1(j, i, k) = 0.5; + } + else + { + skip_fake_tdec_1(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf skip_fake_tdec_0(1, 48, 85995); + + // fill with alternating -0.5, 0.5 + for (long i = 0; i < 48; ++i) + { + for (long j = 0; j < 1; ++j) + { + for (long k = 0; k < 85995; ++k) + { + if (k % 2 == 0) + { + skip_fake_tdec_0(j, i, k) = 0.5; + } + else + { + skip_fake_tdec_0(j, i, k) = -0.5; + } + } + } + } + + Eigen::Tensor3dXf x_fake_dec_2(192, 32, 336); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_3, "x_fake_dec_3"); + demucscppdebug_test_v3::debug_tensor_3dxf(skip_fake_dec_3, "skip_fake_dec_3"); + + demucscpp_v3::apply_common_decoder(model, 0, 0, x_fake_dec_3, x_fake_dec_2, skip_fake_dec_3); + + Eigen::Tensor3dXf xt_fake_dec_2(1, 384, 5375); + demucscpp_v3::apply_common_decoder(model, 1, 0, xt_fake_dec_3, xt_fake_dec_2, skip_fake_tdec_3); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_2, "x_fake_dec_2"); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_dec_2, "xt_fake_dec_2"); + + Eigen::Tensor3dXf x_fake_dec_1(96, 128, 336); + demucscpp_v3::apply_common_decoder(model, 0, 1, x_fake_dec_2, x_fake_dec_1, skip_fake_dec_2); + + Eigen::Tensor3dXf xt_fake_dec_1(1, 192, 21499); + demucscpp_v3::apply_common_decoder(model, 1, 1, xt_fake_dec_2, xt_fake_dec_1, skip_fake_tdec_2); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_1, "x_fake_dec_1"); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_dec_1, "xt_fake_dec_1"); + + Eigen::Tensor3dXf x_fake_dec_0(48, 512, 336); + demucscpp_v3::apply_common_decoder(model, 0, 2, x_fake_dec_1, x_fake_dec_0, skip_fake_dec_1); + + Eigen::Tensor3dXf xt_fake_dec_0(1, 96, 85995); + demucscpp_v3::apply_common_decoder(model, 1, 2, xt_fake_dec_1, xt_fake_dec_0, skip_fake_tdec_1); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_fake_dec_0, "x_fake_dec_0"); + demucscppdebug_test_v3::debug_tensor_3dxf(xt_fake_dec_0, "xt_fake_dec_0"); + + // now apply the final decoder + Eigen::Tensor3dXf x_out(16, 2048, 336); + demucscpp_v3::apply_common_decoder(model, 0, 3, x_fake_dec_0, x_out, skip_fake_dec_0); + + demucscppdebug_test_v3::debug_tensor_3dxf(x_out, "x_out"); + + // now apply the final decoder + Eigen::Tensor3dXf xt_out(1, 8, 343980); + demucscpp_v3::apply_common_decoder(model, 1, 3, xt_fake_dec_0, xt_out, skip_fake_tdec_0); + + demucscppdebug_test_v3::debug_tensor_3dxf(xt_out, "xt_out"); +}