Skip to content

Commit

Permalink
Only log when enabled to avoid unnecessary printing on stdout (#753)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Aug 5, 2024
1 parent 18c25d8 commit 03594a1
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/logging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ void SetLogBool(std::string_view name, bool value) {
g_log.model_output_values = value;
else if (name == "model_logits")
g_log.model_logits = value;
else if (name == "ort_lib")
g_log.ort_lib = value;
else
throw JSON::unknown_value_error{};
}
Expand Down
1 change: 1 addition & 0 deletions src/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct LogItems {
bool model_output_shapes{}; // Before the model runs there are only the output shapes, no values in them. Useful for pre Session::Run debugging
bool model_output_values{}; // After the model runs the output tensor values can be displayed
bool model_logits{}; // Same as model_output_values but only for the logits
bool ort_lib{}; // Log the onnxruntime library loading and api calls.
};

extern LogItems g_log;
Expand Down
57 changes: 57 additions & 0 deletions src/models/env_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "env_utils.h"

#include <stdexcept>

#if _MSC_VER
#include <Windows.h>
#endif

namespace Generators {

std::string GetEnvironmentVariable(const char* var_name) {
#if _MSC_VER
// Why getenv() should be avoided on Windows:
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv
// Instead use the Win32 API: GetEnvironmentVariableA()

// Max limit of an environment variable on Windows including the null-terminating character
constexpr DWORD kBufferSize = 32767;

// Create buffer to hold the result
std::string buffer(kBufferSize, '\0');

// The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters.
// If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character.
// Therefore, If the function succeeds, kBufferSize should be larger than char_count.
auto char_count = ::GetEnvironmentVariableA(var_name, buffer.data(), kBufferSize);

if (kBufferSize > char_count) {
buffer.resize(char_count);
return buffer;
}

return {};
#else
const char* val = getenv(var_name);
return val == nullptr ? "" : std::string(val);
#endif // _MSC_VER
}

void GetEnvironmentVariable(const char* var_name, bool& value) {
std::string str_value = GetEnvironmentVariable(var_name);
if (str_value == "1" || str_value == "true") {
value = true;
} else if (str_value == "0" || str_value == "false") {
value = false;
} else if (!str_value.empty()) {
throw std::invalid_argument("Invalid value for environment variable " + std::string(var_name) + ": " + str_value +
". Expected '1' or 'true' for true, '0' or 'false' for false.");
}

// Otherwise, value will not be modified.
}

} // namespace Generators
15 changes: 15 additions & 0 deletions src/models/env_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <string>

namespace Generators {

std::string GetEnvironmentVariable(const char* var_name);

// This overload is used to get boolean environment variables.
// If the environment variable is set to "1" or "true" (case-sensitive), value will be set to true.
// Otherwise, value will not be modified.
void GetEnvironmentVariable(const char* var_name, bool& value);

} // namespace Generators
21 changes: 16 additions & 5 deletions src/models/onnxruntime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o
#include "onnxruntime_c_api.h"
#include "../span.h"
#include "../logging.h"
#include "env_utils.h"

#if defined(__ANDROID__)
#include <android/log.h>
Expand All @@ -93,11 +94,14 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o
#define PATH_MAX (4096)
#endif

#define LOG_DEBUG(...) Generators::Log("debug", __VA_ARGS__)
#define LOG_INFO(...) Generators::Log("info", __VA_ARGS__)
#define LOG_WARN(...) Generators::Log("warning", __VA_ARGS__)
#define LOG_ERROR(...) Generators::Log("error", __VA_ARGS__)
#define LOG_FATAL(...) Generators::Log("fatal", __VA_ARGS__)
#define LOG_WHEN_ENABLED(LOG_FUNC) \
if (Generators::g_log.enabled && Generators::g_log.ort_lib) LOG_FUNC

#define LOG_DEBUG(...) LOG_WHEN_ENABLED(Generators::Log("debug", __VA_ARGS__))
#define LOG_INFO(...) LOG_WHEN_ENABLED(Generators::Log("info", __VA_ARGS__))
#define LOG_WARN(...) LOG_WHEN_ENABLED(Generators::Log("warning", __VA_ARGS__))
#define LOG_ERROR(...) LOG_WHEN_ENABLED(Generators::Log("error", __VA_ARGS__))
#define LOG_FATAL(...) LOG_WHEN_ENABLED(Generators::Log("fatal", __VA_ARGS__))

#endif

Expand Down Expand Up @@ -175,6 +179,13 @@ inline void InitApi() {
return;
}

bool ort_lib = false;
Generators::GetEnvironmentVariable("ORTGENAI_LOG_ORT_LIB", ort_lib);
if (ort_lib) {
Generators::SetLogBool("enabled", true);
Generators::SetLogBool("ort_lib", true);
}

#if defined(__linux__)
// If the GenAI library links against the onnxruntime library, it will have a dependency on a specific
// version of OrtGetApiBase.
Expand Down

0 comments on commit 03594a1

Please sign in to comment.