Skip to content

Commit

Permalink
Update cast_from_pyobject to throw on unsupported types rather than r…
Browse files Browse the repository at this point in the history
…eturning null (#451)

* Currently when `cast_from_pyobject` encounters an unsupported type it returns a json null.
* Updates the method to throw a `pybind11::type_error`, matching the `TypeError` exception raised by the Python std `json.dumps` method.
* Add `get_py_type_name` helper method
* Breaking behavior change

Closes #450

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #451
  • Loading branch information
dagardner-nv authored Mar 12, 2024
1 parent 2dbd985 commit a920644
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 8 deletions.
9 changes: 8 additions & 1 deletion python/mrc/_pymrc/include/pymrc/utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -54,6 +54,13 @@ void from_import_as(pybind11::module_& dest, const std::string& from, const std:
*/
const std::type_info* cpptype_info_from_object(pybind11::object& obj);

/**
* @brief Given a pybind11 object, return the Python type name essentially the same as `str(type(obj))`
* @param obj : pybind11 object
* @return std::string.
*/
std::string get_py_type_name(const pybind11::object& obj);

void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level = 1);

#pragma GCC visibility pop
Expand Down
50 changes: 44 additions & 6 deletions python/mrc/_pymrc/src/utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -17,6 +17,9 @@

#include "pymrc/utils.hpp"

#include "pymrc/utilities/acquire_gil.hpp"

#include <glog/logging.h>
#include <nlohmann/json.hpp>
#include <pybind11/cast.h>
#include <pybind11/detail/internals.h>
Expand All @@ -25,6 +28,7 @@
#include <pyerrors.h>
#include <warnings.h>

#include <sstream>
#include <string>
#include <utility>

Expand Down Expand Up @@ -72,6 +76,18 @@ const std::type_info* cpptype_info_from_object(py::object& obj)
return nullptr;
}

std::string get_py_type_name(const pybind11::object& obj)
{
if (!obj)
{
// calling py::type::of on a null object will trigger an abort
return "";
}

const auto py_type = py::type::of(obj);
return py_type.attr("__name__").cast<std::string>();
}

py::object cast_from_json(const json& source)
{
if (source.is_null())
Expand Down Expand Up @@ -123,7 +139,7 @@ py::object cast_from_json(const json& source)
// throw std::runtime_error("Unsupported conversion type.");
}

json cast_from_pyobject(const py::object& source)
json cast_from_pyobject_impl(const py::object& source, const std::string& parent_path = "")
{
// Dont return via initializer list with JSON. It performs type deduction and gives different results
// NOLINTBEGIN(modernize-return-braced-init-list)
Expand All @@ -137,7 +153,9 @@ json cast_from_pyobject(const py::object& source)
auto json_obj = json::object();
for (const auto& p : py_dict)
{
json_obj[py::cast<std::string>(p.first)] = cast_from_pyobject(p.second.cast<py::object>());
std::string key{p.first.cast<std::string>()};
std::string path{parent_path + "/" + key};
json_obj[key] = cast_from_pyobject_impl(p.second.cast<py::object>(), path);
}

return json_obj;
Expand All @@ -148,7 +166,7 @@ json cast_from_pyobject(const py::object& source)
auto json_arr = json::array();
for (const auto& p : py_list)
{
json_arr.push_back(cast_from_pyobject(p.cast<py::object>()));
json_arr.push_back(cast_from_pyobject_impl(p.cast<py::object>(), parent_path));
}

return json_arr;
Expand All @@ -170,11 +188,31 @@ json cast_from_pyobject(const py::object& source)
return json(py::cast<std::string>(source));
}

// else unsupported return null
return json();
// else unsupported return throw a type error
{
AcquireGIL gil;
std::ostringstream error_message;
std::string path{parent_path};
if (path.empty())
{
path = "/";
}

error_message << "Object (" << py::str(source).cast<std::string>() << ") of type: " << get_py_type_name(source)
<< " at path: " << path << " is not JSON serializable";

DVLOG(5) << error_message.str();
throw py::type_error(error_message.str());
}

// NOLINTEND(modernize-return-braced-init-list)
}

json cast_from_pyobject(const py::object& source)
{
return cast_from_pyobject_impl(source);
}

void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level)
{
PyErr_WarnEx(PyExc_DeprecationWarning, deprecation_message.c_str(), stack_level);
Expand Down
29 changes: 28 additions & 1 deletion python/mrc/_pymrc/tests/test_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -41,6 +41,7 @@
namespace py = pybind11;
namespace pymrc = mrc::pymrc;
using namespace std::string_literals;
using namespace pybind11::literals; // to bring in the `_a` literal

// Create values too big to fit in int & float types to ensure we can pass
// long & double types to both nlohmann/json and python
Expand Down Expand Up @@ -143,6 +144,32 @@ TEST_F(TestUtils, CastFromPyObject)
}
}

TEST_F(TestUtils, CastFromPyObjectSerializeErrors)
{
// Test to verify that cast_from_pyobject throws a python TypeError when encountering something that is not json
// serializable issue #450

// decimal.Decimal is not serializable
py::object Decimal = py::module_::import("decimal").attr("Decimal");
py::object o = Decimal("1.0");
EXPECT_THROW(pymrc::cast_from_pyobject(o), py::type_error);

// Test with object in a nested dict
py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = o))), "other"_a = 2);
EXPECT_THROW(pymrc::cast_from_pyobject(d), py::type_error);
}

TEST_F(TestUtils, GetTypeName)
{
// invalid objects should return an empty string
EXPECT_EQ(pymrc::get_py_type_name(py::object()), "");
EXPECT_EQ(pymrc::get_py_type_name(py::none()), "NoneType");

py::object Decimal = py::module_::import("decimal").attr("Decimal");
py::object o = Decimal("1.0");
EXPECT_EQ(pymrc::get_py_type_name(o), "Decimal");
}

TEST_F(TestUtils, PyObjectWrapper)
{
py::list test_list;
Expand Down

0 comments on commit a920644

Please sign in to comment.