diff --git a/newrelic/config.py b/newrelic/config.py index 6a1ba2c69..b7c5e0008 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -911,6 +911,29 @@ def apply_local_high_security_mode_setting(settings): return settings +def _toml_config_to_configparser_dict(d, top=None, _path=None): + top = top or {"newrelic": {}} + _path = _path or "" + for key, value in d.items(): + if isinstance(value, dict): + _toml_config_to_configparser_dict(value, top, f"{_path}.{key}" if _path else key) + else: + fixed_value = " ".join(value) if isinstance(value, list) else value + path_split = _path.split(".") + # Handle environments + if _path.startswith("env."): + env_key = f"newrelic:{path_split[1]}" + fixed_key = ".".join((*path_split[2:], key)) + top[env_key] = {**top.get(env_key, {}), fixed_key: fixed_value} + # Handle import-hook:... configuration + elif _path.startswith("import-hook."): + import_hook_key = f"import-hook:{path_split[1]}" + top[import_hook_key] = {**top.get(import_hook_key, {}), key: fixed_value} + else: + top["newrelic"][f"{_path}.{key}" if _path else key] = fixed_value + return top + + def _load_configuration( config_file=None, environment=None, @@ -994,8 +1017,20 @@ def _load_configuration( # Now read in the configuration file. Cache the config file # name in internal settings object as indication of succeeding. - - if not _config_object.read([config_file]): + if config_file.endswith(".toml"): + try: + import tomllib + except ImportError: + raise newrelic.api.exceptions.ConfigurationError( + "TOML configuration file can only be used if tomllib is available (Python 3.11+)." + ) + with open(config_file, "rb") as f: + content = tomllib.load(f) + newrelic_section = content.get("tool", {}).get("newrelic") + if not newrelic_section: + raise newrelic.api.exceptions.ConfigurationError("New Relic configuration not found in TOML file.") + _config_object.read_dict(_toml_config_to_configparser_dict(newrelic_section)) + elif not _config_object.read([config_file]): raise newrelic.api.exceptions.ConfigurationError(f"Unable to open configuration file {config_file}.") _settings.config_file = config_file diff --git a/tests/agent_features/test_configuration.py b/tests/agent_features/test_configuration.py index 83082398a..b2987b6db 100644 --- a/tests/agent_features/test_configuration.py +++ b/tests/agent_features/test_configuration.py @@ -949,6 +949,69 @@ def test_initialize_developer_mode(section, expect_error, logger): assert "CONFIGURATION ERROR" not in logger.caplog.records +newrelic_toml_contents = b""" +[tool.newrelic] +app_name = "test11" +monitor_mode = true + +[tool.newrelic.env.development] +app_name = "test11 (Development)" + +[tool.newrelic.env.production] +app_name = "test11 (Production)" +log_level = "error" + +[tool.newrelic.env.production.distributed_tracing] +enabled = false + +[tool.newrelic.error_collector] +enabled = true +ignore_errors = ["module:name1", "module:name"] + +[tool.newrelic.transaction_tracer] +enabled = true + +[tool.newrelic.import-hook.django] +"instrumentation.scripts.django_admin" = ["stuff", "stuff2"] +""" + + +def test_toml_parse_development(): + settings = global_settings() + _reset_configuration_done() + _reset_config_parser() + _reset_instrumentation_done() + + with tempfile.NamedTemporaryFile(suffix=".toml") as f: + f.write(newrelic_toml_contents) + f.seek(0) + + initialize(config_file=f.name, environment="development") + value = fetch_config_setting(settings, "app_name") + assert value != "test11" + value = fetch_config_setting(settings, "monitor_mode") + assert value is True + value = fetch_config_setting(settings, "error_collector") + assert value.enabled is True + assert value.ignore_classes[0] == "module:name1" + assert value.ignore_classes[1] == "module:name" + + +def test_toml_parse_production(): + settings = global_settings() + # _reset_configuration_done() + + with tempfile.NamedTemporaryFile(suffix=".toml") as f: + f.write(newrelic_toml_contents) + f.seek(0) + + initialize(config_file=f.name, environment="production") + value = fetch_config_setting(settings, "app_name") + assert value == "test11 (Production)" + value = fetch_config_setting(settings, "distributed_tracing") + assert value.enabled is False + + @pytest.fixture def caplog_handler(): class CaplogHandler(logging.StreamHandler):