Skip to content

Commit

Permalink
Use TFLite as default for Separator
Browse files Browse the repository at this point in the history
  • Loading branch information
jinay1991 committed Nov 1, 2020
1 parent 66c218a commit 39f7cba
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion spleeter/argument_parser/cli_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct CLIOptions

/// @brief Inference Engine Parameters (contains model_path, input/output tensor names)
InferenceEngineParameters inference_engine_params{
"external/models/5stems/saved_model",
"external/models/5stems.tflite",
"waveform",
{"strided_slice_18", "strided_slice_38", "strided_slice_48", "strided_slice_28", "strided_slice_58"},
"spleeter:5stems"};
Expand Down
2 changes: 1 addition & 1 deletion spleeter/separator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Separator::Separator(const InferenceEngineParameters& inference_engine_param, co
inference_engine_{},
waveform_name_{internal::GetWaveformNames(inference_engine_param.configuration)}
{
inference_engine_.SelectInferenceEngine(InferenceEngineType::kTensorFlow, inference_engine_param);
inference_engine_.SelectInferenceEngine(InferenceEngineType::kTensorFlowLite, inference_engine_param);
inference_engine_.Init();
}

Expand Down
4 changes: 2 additions & 2 deletions spleeter/test/separator_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class SeparatorTest : public ::testing::TestWithParam<std::int32_t>
{
const auto stem{std::to_string(GetParam()) + "stems"};
cli_options_.configuration = "spleeter:" + stem;
cli_options_.inference_engine_params.model_path = "external/models/" + stem + "/saved_model";
cli_options_.inference_engine_params.model_path = "external/models/" + stem + "/" + stem + ".tflite";
cli_options_.inference_engine_params.output_tensor_names = GetOutputTensorNames(cli_options_.configuration);
cli_options_.inference_engine_params.configuration = cli_options_.configuration;

Expand All @@ -36,7 +36,7 @@ class SeparatorTest : public ::testing::TestWithParam<std::int32_t>
unit_ = std::make_unique<Separator>(cli_options_.inference_engine_params, cli_options_.mwf);
}

std::vector<std::string> GetOutputTensorNames(const std::string& configuration)
static std::vector<std::string> GetOutputTensorNames(const std::string& configuration)
{
auto output_tensor_names = std::vector<std::string>{};
if (configuration == "spleeter:2stems")
Expand Down

0 comments on commit 39f7cba

Please sign in to comment.