diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index cb8cb39f9..116b8d054 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -168,6 +168,9 @@ void CodegenNeuronCppVisitor::print_check_table_function_prototypes() { printer->push_block(); printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml, _type};"); printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!info.artificial_cell) { + printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml);", info.mod_suffix); + } if (!codegen_thread_variables.empty()) { printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", thread_variables_struct(), @@ -251,7 +254,9 @@ void CodegenNeuronCppVisitor::print_function_or_procedure( printer->fmt_line("int ret_{} = 0;", name); } - printer->fmt_line("auto v = inst.{}[id];", naming::VOLTAGE_UNUSED_VARIABLE); + if (!info.artificial_cell) { + printer->add_line("auto v = node_data.node_voltages[node_data.nodeindices[id]];"); + } print_statement_block(*node.get_statement_block(), false, false); printer->fmt_line("return ret_{};", name); @@ -294,6 +299,8 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body( Datum* _thread; NrnThread* nt; )CODE"); + + std::string prop_name; if (info.point_process) { printer->add_multi_line(R"CODE( auto* const _pnt = static_cast(_vptr); @@ -307,6 +314,8 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body( _thread = _extcall_thread.data(); nt = static_cast(_pnt->_vnt); )CODE"); + + prop_name = "_p"; } else if (wrapper_type == InterpreterWrapper::HOC) { if (program_symtab->lookup(block_name)->has_all_properties(NmodlType::use_range_ptr_var)) { printer->push_block("if (!_prop_id)"); @@ -328,16 +337,22 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body( _thread = _extcall_thread.data(); nt = nrn_threads; )CODE"); + prop_name = "_local_prop"; } else { // wrapper_type == InterpreterWrapper::Python printer->add_multi_line(R"CODE( _nrn_mechanism_cache_instance _lmc{_prop}; - size_t const id{}; + size_t const id = 0; _ppvar = _nrn_mechanism_access_dparam(_prop); _thread = _extcall_thread.data(); nt = nrn_threads; )CODE"); + prop_name = "_prop"; } + printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!info.artificial_cell) { + printer->fmt_line("auto node_data = make_node_data_{}({});", info.mod_suffix, prop_name); + } if (!codegen_thread_variables.empty()) { printer->fmt_line("auto _thread_vars = {}(_thread[{}].get());", thread_variables_struct(), @@ -415,6 +430,9 @@ CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::internal_method_paramete ParamVector params; params.emplace_back("", "_nrn_mechanism_cache_range&", "", "_lmc"); params.emplace_back("", fmt::format("{}&", instance_struct()), "", "inst"); + if (!info.artificial_cell) { + params.emplace_back("", fmt::format("{}&", node_data_struct()), "", "node_data"); + } params.emplace_back("", "size_t", "", "id"); params.emplace_back("", "Datum*", "", "_ppvar"); params.emplace_back("", "Datum*", "", "_thread"); @@ -804,10 +822,6 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in // TODO implement these when needed. } - if (!info.vectorize && !info.top_local_variables.empty()) { - throw std::runtime_error("Not implemented, global vectorize something."); - } - if (!info.thread_variables.empty()) { size_t prefix_sum = 0; for (size_t i = 0; i < info.thread_variables.size(); ++i) { @@ -834,6 +848,14 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in } } + if (!info.vectorize && !info.top_local_variables.empty()) { + for (size_t i = 0; i < info.top_local_variables.size(); ++i) { + const auto& var = info.top_local_variables[i]; + codegen_global_variables.push_back(var); + } + } + + if (!codegen_thread_variables.empty()) { if (!info.vectorize) { // MOD files that aren't "VECTORIZED" don't have thread data. @@ -1385,6 +1407,26 @@ void CodegenNeuronCppVisitor::print_make_node_data() const { printer->pop_block(";"); printer->pop_block(); + + + printer->fmt_push_block("static {} make_node_data_{}(Prop * _prop)", + node_data_struct(), + info.mod_suffix); + printer->add_line("static std::vector node_index{0};"); + printer->add_line("Node* _node = _nrn_mechanism_access_node(_prop);"); + + make_node_data_args = {"node_index.data()", + "&_nrn_mechanism_access_voltage(_node)", + "&_nrn_mechanism_access_d(_node)", + "&_nrn_mechanism_access_rhs(_node)", + "1"}; + + printer->fmt_push_block("return {}", node_data_struct()); + printer->add_multi_line(fmt::format("{}", fmt::join(make_node_data_args, ",\n"))); + + printer->pop_block(";"); + printer->pop_block(); + printer->add_newline(); } void CodegenNeuronCppVisitor::print_thread_variables_structure(bool print_initializers) { @@ -1475,7 +1517,6 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) { if (!info.artificial_cell) { printer->add_line("int node_id = node_data.nodeindices[id];"); printer->add_line("auto v = node_data.node_voltages[node_id];"); - printer->fmt_line("inst.{}[id] = v;", naming::VOLTAGE_UNUSED_VARIABLE); } print_rename_state_vars(); @@ -2164,6 +2205,9 @@ void CodegenNeuronCppVisitor::print_net_receive() { printer->add_line("auto * _ppvar = _nrn_mechanism_access_dparam(_pnt->prop);"); printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix); + if (!info.artificial_cell) { + printer->fmt_line("auto node_data = make_node_data_{}(_pnt->prop);", info.mod_suffix); + } printer->fmt_line("// nocmodl has a nullptr dereference for thread variables."); printer->fmt_line("// NMODL will fail to compile at a later point, because of"); printer->fmt_line("// missing '_thread_vars'."); diff --git a/test/usecases/function/artificial_functions.mod b/test/usecases/function/artificial_functions.mod new file mode 100644 index 000000000..a6d574fde --- /dev/null +++ b/test/usecases/function/artificial_functions.mod @@ -0,0 +1,34 @@ +NEURON { + ARTIFICIAL_CELL art_functions + RANGE x + GLOBAL gbl +} + +ASSIGNED { + gbl + v + x +} + +FUNCTION x_plus_a(a) { + x_plus_a = x + a +} + +FUNCTION identity(v) { + identity = v +} + +INITIAL { + x = 1.0 + gbl = 42.0 +} + +: A LINEAR block makes a MOD file not VECTORIZED. +STATE { + z +} + +LINEAR lin { + ~ z = 2 +} + diff --git a/test/usecases/function/non_threadsafe.mod b/test/usecases/function/non_threadsafe.mod new file mode 100644 index 000000000..8c8405642 --- /dev/null +++ b/test/usecases/function/non_threadsafe.mod @@ -0,0 +1,38 @@ +NEURON { + SUFFIX non_threadsafe + RANGE x + GLOBAL gbl +} + +ASSIGNED { + gbl + v + x +} + +FUNCTION x_plus_a(a) { + x_plus_a = x + a +} + +FUNCTION v_plus_a(a) { + v_plus_a = v + a +} + +FUNCTION identity(v) { + identity = v +} + +INITIAL { + x = 1.0 + gbl = 42.0 +} + +: A LINEAR block makes a MOD file not VECTORIZED. +STATE { + z +} + +LINEAR lin { + ~ z = 2 +} + diff --git a/test/usecases/function/point_non_threadsafe.mod b/test/usecases/function/point_non_threadsafe.mod new file mode 100644 index 000000000..27326822e --- /dev/null +++ b/test/usecases/function/point_non_threadsafe.mod @@ -0,0 +1,38 @@ +NEURON { + POINT_PROCESS point_non_threadsafe + RANGE x + GLOBAL gbl +} + +ASSIGNED { + gbl + v + x +} + +FUNCTION x_plus_a(a) { + x_plus_a = x + a +} + +FUNCTION v_plus_a(a) { + v_plus_a = v + a +} + +FUNCTION identity(v) { + identity = v +} + +INITIAL { + x = 1.0 + gbl = 42.0 +} + +: A LINEAR block makes a MOD file not VECTORIZED. +STATE { + z +} + +LINEAR lin { + ~ z = 2 +} + diff --git a/test/usecases/function/test_functions.py b/test/usecases/function/test_functions.py index 197245bbc..046cb8555 100644 --- a/test/usecases/function/test_functions.py +++ b/test/usecases/function/test_functions.py @@ -1,7 +1,7 @@ from neuron import h -def check_functions(get_instance): +def check_callable(get_instance, has_voltage=True): for x, value in zip(coords, values): get_instance(x).x = value @@ -19,10 +19,11 @@ def check_functions(get_instance): actual = get_instance(x).identity(expected) assert actual == expected, f"{actual} == {expected}" - # Check `f` using `v`. - expected = -2.0 - actual = get_instance(x).v_plus_a(40.0) - assert actual == expected, f"{actual} == {expected}" + if has_voltage: + # Check `f` using `v`. + expected = -2.0 + actual = get_instance(x).v_plus_a(40.0) + assert actual == expected, f"{actual} == {expected}" nseg = 5 @@ -30,11 +31,18 @@ def check_functions(get_instance): s.nseg = nseg s.insert("functions") +s.insert("non_threadsafe") coords = [(0.5 + k) * 1.0 / nseg for k in range(nseg)] values = [0.1 + k for k in range(nseg)] point_processes = {x: h.point_functions(s(x)) for x in coords} +point_non_threadsafe = {x: h.point_non_threadsafe(s(x)) for x in coords} + +art_cells = {x: h.art_functions() for x in coords} -check_functions(lambda x: s(x).functions) -check_functions(lambda x: point_processes[x]) +check_callable(lambda x: s(x).functions) +check_callable(lambda x: s(x).non_threadsafe) +check_callable(lambda x: point_processes[x]) +check_callable(lambda x: point_non_threadsafe[x]) +check_callable(lambda x: art_cells[x], has_voltage=False) diff --git a/test/usecases/global/non_threadsafe.mod b/test/usecases/global/non_threadsafe.mod new file mode 100644 index 000000000..860578248 --- /dev/null +++ b/test/usecases/global/non_threadsafe.mod @@ -0,0 +1,43 @@ +NEURON { + SUFFIX non_threadsafe + GLOBAL gbl +} + +LOCAL top_local + +PARAMETER { + parameter = 41.0 +} + +ASSIGNED { + gbl +} + +FUNCTION get_gbl() { + get_gbl = gbl +} + +FUNCTION get_top_local() { + get_top_local = top_local +} + +FUNCTION get_parameter() { + get_parameter = parameter +} + +INITIAL { + gbl = 42.0 + top_local = 43.0 +} + +: A LINEAR block makes the MOD file not thread-safe and not +: vectorized. We don't otherwise care about anything below +: this comment. +STATE { + z +} + +LINEAR lin { + ~ z = 2 +} + diff --git a/test/usecases/global/test_non_threadsafe.py b/test/usecases/global/test_non_threadsafe.py new file mode 100644 index 000000000..fcd413fb7 --- /dev/null +++ b/test/usecases/global/test_non_threadsafe.py @@ -0,0 +1,33 @@ +import numpy as np + +from neuron import h, gui +from neuron.units import ms + + +def test_non_threadsafe(): + nseg = 1 + + s = h.Section() + s.insert("non_threadsafe") + s.nseg = nseg + + h.finitialize() + + instance = s(0.5).non_threadsafe + + # Check INITIAL values. + assert instance.get_parameter() == 41.0 + assert instance.get_gbl() == 42.0 + assert instance.get_top_local() == 43.0 + + # Check reassigning a value. Top LOCAL variables + # are not exposed to HOC/Python. + h.parameter_non_threadsafe = 32.1 + h.gbl_non_threadsafe = 33.2 + + assert instance.get_parameter() == 32.1 + assert instance.get_gbl() == 33.2 + + +if __name__ == "__main__": + test_non_threadsafe()