From 1c4cb6c4124cbf7156a94979edc8e5e62bcc9a95 Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Mon, 9 Sep 2024 12:09:36 +0200 Subject: [PATCH] Get voltage via `prop->node`. (#1414) * Get voltage via `prop->node`. When calling functions directly, don't have access to the `NrnThread` and therefore can't get node properties, like the voltage from there. The solution is to create a link from the Prop to the Node. Then use that link to figure out the node properties. The trick to make the two cases uniform is the same as we use for instance data, we create pointers to array of length one (by taking the address of the element) and setting `_iml`/`id` to `0`. * Artificial cells aren't associated with a node. * Support top LOCALs in non-vectorized MOD files. * Test function calls with non-threadsafe MOD files. * Test globals in non-VECTORIZED MOD files. * Improve function calling coverage. * Test function calls for ARTIFICIAL_CELLs. * Remove debugging output. --- src/codegen/codegen_neuron_cpp_visitor.cpp | 58 ++++++++++++++++--- .../function/artificial_functions.mod | 34 +++++++++++ test/usecases/function/non_threadsafe.mod | 38 ++++++++++++ .../function/point_non_threadsafe.mod | 38 ++++++++++++ test/usecases/function/test_functions.py | 22 ++++--- test/usecases/global/non_threadsafe.mod | 43 ++++++++++++++ test/usecases/global/test_non_threadsafe.py | 33 +++++++++++ 7 files changed, 252 insertions(+), 14 deletions(-) create mode 100644 test/usecases/function/artificial_functions.mod create mode 100644 test/usecases/function/non_threadsafe.mod create mode 100644 test/usecases/function/point_non_threadsafe.mod create mode 100644 test/usecases/global/non_threadsafe.mod create mode 100644 test/usecases/global/test_non_threadsafe.py diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index cb8cb39f9..116b8d054 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -168,6 +168,9 @@ void CodegenNeuronCppVisitor::print_check_table_function_prototypes() { printer->push_block(); printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml, _type};"); printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!info.artificial_cell) { + printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml);", info.mod_suffix); + } if (!codegen_thread_variables.empty()) { printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", thread_variables_struct(), @@ -251,7 +254,9 @@ void CodegenNeuronCppVisitor::print_function_or_procedure( printer->fmt_line("int ret_{} = 0;", name); } - printer->fmt_line("auto v = inst.{}[id];", naming::VOLTAGE_UNUSED_VARIABLE); + if (!info.artificial_cell) { + printer->add_line("auto v = node_data.node_voltages[node_data.nodeindices[id]];"); + } print_statement_block(*node.get_statement_block(), false, false); printer->fmt_line("return ret_{};", name); @@ -294,6 +299,8 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body( Datum* _thread; NrnThread* nt; )CODE"); + + std::string prop_name; if (info.point_process) { printer->add_multi_line(R"CODE( auto* const _pnt = static_cast(_vptr); @@ -307,6 +314,8 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body( _thread = _extcall_thread.data(); nt = static_cast(_pnt->_vnt); )CODE"); + + prop_name = "_p"; } else if (wrapper_type == InterpreterWrapper::HOC) { if (program_symtab->lookup(block_name)->has_all_properties(NmodlType::use_range_ptr_var)) { printer->push_block("if (!_prop_id)"); @@ -328,16 +337,22 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body( _thread = _extcall_thread.data(); nt = nrn_threads; )CODE"); + prop_name = "_local_prop"; } else { // wrapper_type == InterpreterWrapper::Python printer->add_multi_line(R"CODE( _nrn_mechanism_cache_instance _lmc{_prop}; - size_t const id{}; + size_t const id = 0; _ppvar = _nrn_mechanism_access_dparam(_prop); _thread = _extcall_thread.data(); nt = nrn_threads; )CODE"); + prop_name = "_prop"; } + printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!info.artificial_cell) { + printer->fmt_line("auto node_data = make_node_data_{}({});", info.mod_suffix, prop_name); + } if (!codegen_thread_variables.empty()) { printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", thread_variables_struct(), @@ -415,6 +430,9 @@ CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::internal_method_paramete ParamVector params; params.emplace_back("", "_nrn_mechanism_cache_range&", "", "_lmc"); params.emplace_back("", fmt::format("{}&", instance_struct()), "", "inst"); + if (!info.artificial_cell) { + params.emplace_back("", fmt::format("{}&", node_data_struct()), "", "node_data"); + } params.emplace_back("", "size_t", "", "id"); params.emplace_back("", "Datum*", "", "_ppvar"); params.emplace_back("", "Datum*", "", "_thread"); @@ -804,10 +822,6 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in // TODO implement these when needed. } - if (!info.vectorize && !info.top_local_variables.empty()) { - throw std::runtime_error("Not implemented, global vectorize something."); - } - if (!info.thread_variables.empty()) { size_t prefix_sum = 0; for (size_t i = 0; i < info.thread_variables.size(); ++i) { @@ -834,6 +848,14 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in } } + if (!info.vectorize && !info.top_local_variables.empty()) { + for (size_t i = 0; i < info.top_local_variables.size(); ++i) { + const auto& var = info.top_local_variables[i]; + codegen_global_variables.push_back(var); + } + } + + if (!codegen_thread_variables.empty()) { if (!info.vectorize) { // MOD files that aren't "VECTORIZED" don't have thread data. @@ -1385,6 +1407,26 @@ void CodegenNeuronCppVisitor::print_make_node_data() const { printer->pop_block(";"); printer->pop_block(); + + + printer->fmt_push_block("static {} make_node_data_{}(Prop * _prop)", + node_data_struct(), + info.mod_suffix); + printer->add_line("static std::vector node_index{0};"); + printer->add_line("Node* _node = _nrn_mechanism_access_node(_prop);"); + + make_node_data_args = {"node_index.data()", + "&_nrn_mechanism_access_voltage(_node)", + "&_nrn_mechanism_access_d(_node)", + "&_nrn_mechanism_access_rhs(_node)", + "1"}; + + printer->fmt_push_block("return {}", node_data_struct()); + printer->add_multi_line(fmt::format("{}", fmt::join(make_node_data_args, ",\n"))); + + printer->pop_block(";"); + printer->pop_block(); + printer->add_newline(); } void CodegenNeuronCppVisitor::print_thread_variables_structure(bool print_initializers) { @@ -1475,7 +1517,6 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) { if (!info.artificial_cell) { printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); - printer->fmt_line("inst.{}[id] = v;", naming::VOLTAGE_UNUSED_VARIABLE); } print_rename_state_vars(); @@ -2164,6 +2205,9 @@ void CodegenNeuronCppVisitor::print_net_receive() { printer->add_line("auto * _ppvar = _nrn_mechanism_access_dparam(_pnt->prop);"); printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!info.artificial_cell) { + printer->fmt_line("auto node_data = make_node_data_{}(_pnt->prop);", info.mod_suffix); + } printer->fmt_line("// nocmodl has a nullptr dereference for thread variables."); printer->fmt_line("// NMODL will fail to compile at a later point, because of"); printer->fmt_line("// missing '_thread_vars'."); diff --git a/test/usecases/function/artificial_functions.mod b/test/usecases/function/artificial_functions.mod new file mode 100644 index 000000000..a6d574fde --- /dev/null +++ b/test/usecases/function/artificial_functions.mod @@ -0,0 +1,34 @@ +NEURON { + ARTIFICIAL_CELL art_functions + RANGE x + GLOBAL gbl +} + +ASSIGNED { + gbl + v + x +} + +FUNCTION x_plus_a(a) { + x_plus_a = x + a +} + +FUNCTION identity(v) { + identity = v +} + +INITIAL { + x = 1.0 + gbl = 42.0 +} + +: A LINEAR block makes a MOD file not VECTORIZED. +STATE { + z +} + +LINEAR lin { + ~ z = 2 +} + diff --git a/test/usecases/function/non_threadsafe.mod b/test/usecases/function/non_threadsafe.mod new file mode 100644 index 000000000..8c8405642 --- /dev/null +++ b/test/usecases/function/non_threadsafe.mod @@ -0,0 +1,38 @@ +NEURON { + SUFFIX non_threadsafe + RANGE x + GLOBAL gbl +} + +ASSIGNED { + gbl + v + x +} + +FUNCTION x_plus_a(a) { + x_plus_a = x + a +} + +FUNCTION v_plus_a(a) { + v_plus_a = v + a +} + +FUNCTION identity(v) { + identity = v +} + +INITIAL { + x = 1.0 + gbl = 42.0 +} + +: A LINEAR block makes a MOD file not VECTORIZED. +STATE { + z +} + +LINEAR lin { + ~ z = 2 +} + diff --git a/test/usecases/function/point_non_threadsafe.mod b/test/usecases/function/point_non_threadsafe.mod new file mode 100644 index 000000000..27326822e --- /dev/null +++ b/test/usecases/function/point_non_threadsafe.mod @@ -0,0 +1,38 @@ +NEURON { + POINT_PROCESS point_non_threadsafe + RANGE x + GLOBAL gbl +} + +ASSIGNED { + gbl + v + x +} + +FUNCTION x_plus_a(a) { + x_plus_a = x + a +} + +FUNCTION v_plus_a(a) { + v_plus_a = v + a +} + +FUNCTION identity(v) { + identity = v +} + +INITIAL { + x = 1.0 + gbl = 42.0 +} + +: A LINEAR block makes a MOD file not VECTORIZED. +STATE { + z +} + +LINEAR lin { + ~ z = 2 +} + diff --git a/test/usecases/function/test_functions.py b/test/usecases/function/test_functions.py index 197245bbc..046cb8555 100644 --- a/test/usecases/function/test_functions.py +++ b/test/usecases/function/test_functions.py @@ -1,7 +1,7 @@ from neuron import h -def check_functions(get_instance): +def check_callable(get_instance, has_voltage=True): for x, value in zip(coords, values): get_instance(x).x = value @@ -19,10 +19,11 @@ def check_functions(get_instance): actual = get_instance(x).identity(expected) assert actual == expected, f"{actual} == {expected}" - # Check `f` using `v`. - expected = -2.0 - actual = get_instance(x).v_plus_a(40.0) - assert actual == expected, f"{actual} == {expected}" + if has_voltage: + # Check `f` using `v`. + expected = -2.0 + actual = get_instance(x).v_plus_a(40.0) + assert actual == expected, f"{actual} == {expected}" nseg = 5 @@ -30,11 +31,18 @@ def check_functions(get_instance): s.nseg = nseg s.insert("functions") +s.insert("non_threadsafe") coords = [(0.5 + k) * 1.0 / nseg for k in range(nseg)] values = [0.1 + k for k in range(nseg)] point_processes = {x: h.point_functions(s(x)) for x in coords} +point_non_threadsafe = {x: h.point_non_threadsafe(s(x)) for x in coords} + +art_cells = {x: h.art_functions() for x in coords} -check_functions(lambda x: s(x).functions) -check_functions(lambda x: point_processes[x]) +check_callable(lambda x: s(x).functions) +check_callable(lambda x: s(x).non_threadsafe) +check_callable(lambda x: point_processes[x]) +check_callable(lambda x: point_non_threadsafe[x]) +check_callable(lambda x: art_cells[x], has_voltage=False) diff --git a/test/usecases/global/non_threadsafe.mod b/test/usecases/global/non_threadsafe.mod new file mode 100644 index 000000000..860578248 --- /dev/null +++ b/test/usecases/global/non_threadsafe.mod @@ -0,0 +1,43 @@ +NEURON { + SUFFIX non_threadsafe + GLOBAL gbl +} + +LOCAL top_local + +PARAMETER { + parameter = 41.0 +} + +ASSIGNED { + gbl +} + +FUNCTION get_gbl() { + get_gbl = gbl +} + +FUNCTION get_top_local() { + get_top_local = top_local +} + +FUNCTION get_parameter() { + get_parameter = parameter +} + +INITIAL { + gbl = 42.0 + top_local = 43.0 +} + +: A LINEAR block makes the MOD file not thread-safe and not +: vectorized. We don't otherwise care about anything below +: this comment. +STATE { + z +} + +LINEAR lin { + ~ z = 2 +} + diff --git a/test/usecases/global/test_non_threadsafe.py b/test/usecases/global/test_non_threadsafe.py new file mode 100644 index 000000000..fcd413fb7 --- /dev/null +++ b/test/usecases/global/test_non_threadsafe.py @@ -0,0 +1,33 @@ +import numpy as np + +from neuron import h, gui +from neuron.units import ms + + +def test_non_threadsafe(): + nseg = 1 + + s = h.Section() + s.insert("non_threadsafe") + s.nseg = nseg + + h.finitialize() + + instance = s(0.5).non_threadsafe + + # Check INITIAL values. + assert instance.get_parameter() == 41.0 + assert instance.get_gbl() == 42.0 + assert instance.get_top_local() == 43.0 + + # Check reassigning a value. Top LOCAL variables + # are not exposed to HOC/Python. + h.parameter_non_threadsafe = 32.1 + h.gbl_non_threadsafe = 33.2 + + assert instance.get_parameter() == 32.1 + assert instance.get_gbl() == 33.2 + + +if __name__ == "__main__": + test_non_threadsafe()