Skip to content

Commit

Permalink
Get voltage via prop->node. (#1414)
Browse files Browse the repository at this point in the history
* Get voltage via `prop->node`.

When calling functions directly, don't have access to the `NrnThread`
and therefore can't get node properties, like the voltage from there.

The solution is to create a link from the Prop to the Node. Then use
that link to figure out the node properties.

The trick to make the two cases uniform is the same as we use for
instance data, we create pointers to array of length one (by taking the
address of the element) and setting `_iml`/`id` to `0`.

* Artificial cells aren't associated with a node.

* Support top LOCALs in non-vectorized MOD files.

* Test function calls with non-threadsafe MOD files.

* Test globals in non-VECTORIZED MOD files.

* Improve function calling coverage.

* Test function calls for ARTIFICIAL_CELLs.

* Remove debugging output.
  • Loading branch information
1uc authored Sep 9, 2024
1 parent 8c83476 commit 1c4cb6c
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 14 deletions.
58 changes: 51 additions & 7 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ void CodegenNeuronCppVisitor::print_check_table_function_prototypes() {
printer->push_block();
printer->add_line("_nrn_mechanism_cache_range _lmc{_sorted_token, *nt, *_ml, _type};");
printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
if (!info.artificial_cell) {
printer->fmt_line("auto node_data = make_node_data_{}(*nt, *_ml);", info.mod_suffix);
}
if (!codegen_thread_variables.empty()) {
printer->fmt_line("auto _thread_vars = {}(_thread[{}].get<double*>());",
thread_variables_struct(),
Expand Down Expand Up @@ -251,7 +254,9 @@ void CodegenNeuronCppVisitor::print_function_or_procedure(
printer->fmt_line("int ret_{} = 0;", name);
}

printer->fmt_line("auto v = inst.{}[id];", naming::VOLTAGE_UNUSED_VARIABLE);
if (!info.artificial_cell) {
printer->add_line("auto v = node_data.node_voltages[node_data.nodeindices[id]];");
}

print_statement_block(*node.get_statement_block(), false, false);
printer->fmt_line("return ret_{};", name);
Expand Down Expand Up @@ -294,6 +299,8 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body(
Datum* _thread;
NrnThread* nt;
)CODE");

std::string prop_name;
if (info.point_process) {
printer->add_multi_line(R"CODE(
auto* const _pnt = static_cast<Point_process*>(_vptr);
Expand All @@ -307,6 +314,8 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body(
_thread = _extcall_thread.data();
nt = static_cast<NrnThread*>(_pnt->_vnt);
)CODE");

prop_name = "_p";
} else if (wrapper_type == InterpreterWrapper::HOC) {
if (program_symtab->lookup(block_name)->has_all_properties(NmodlType::use_range_ptr_var)) {
printer->push_block("if (!_prop_id)");
Expand All @@ -328,16 +337,22 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body(
_thread = _extcall_thread.data();
nt = nrn_threads;
)CODE");
prop_name = "_local_prop";
} else { // wrapper_type == InterpreterWrapper::Python
printer->add_multi_line(R"CODE(
_nrn_mechanism_cache_instance _lmc{_prop};
size_t const id{};
size_t const id = 0;
_ppvar = _nrn_mechanism_access_dparam(_prop);
_thread = _extcall_thread.data();
nt = nrn_threads;
)CODE");
prop_name = "_prop";
}

printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
if (!info.artificial_cell) {
printer->fmt_line("auto node_data = make_node_data_{}({});", info.mod_suffix, prop_name);
}
if (!codegen_thread_variables.empty()) {
printer->fmt_line("auto _thread_vars = {}(_thread[{}].get<double*>());",
thread_variables_struct(),
Expand Down Expand Up @@ -415,6 +430,9 @@ CodegenCppVisitor::ParamVector CodegenNeuronCppVisitor::internal_method_paramete
ParamVector params;
params.emplace_back("", "_nrn_mechanism_cache_range&", "", "_lmc");
params.emplace_back("", fmt::format("{}&", instance_struct()), "", "inst");
if (!info.artificial_cell) {
params.emplace_back("", fmt::format("{}&", node_data_struct()), "", "node_data");
}
params.emplace_back("", "size_t", "", "id");
params.emplace_back("", "Datum*", "", "_ppvar");
params.emplace_back("", "Datum*", "", "_thread");
Expand Down Expand Up @@ -804,10 +822,6 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in
// TODO implement these when needed.
}

if (!info.vectorize && !info.top_local_variables.empty()) {
throw std::runtime_error("Not implemented, global vectorize something.");
}

if (!info.thread_variables.empty()) {
size_t prefix_sum = 0;
for (size_t i = 0; i < info.thread_variables.size(); ++i) {
Expand All @@ -834,6 +848,14 @@ void CodegenNeuronCppVisitor::print_mechanism_global_var_structure(bool print_in
}
}

if (!info.vectorize && !info.top_local_variables.empty()) {
for (size_t i = 0; i < info.top_local_variables.size(); ++i) {
const auto& var = info.top_local_variables[i];
codegen_global_variables.push_back(var);
}
}


if (!codegen_thread_variables.empty()) {
if (!info.vectorize) {
// MOD files that aren't "VECTORIZED" don't have thread data.
Expand Down Expand Up @@ -1385,6 +1407,26 @@ void CodegenNeuronCppVisitor::print_make_node_data() const {

printer->pop_block(";");
printer->pop_block();


printer->fmt_push_block("static {} make_node_data_{}(Prop * _prop)",
node_data_struct(),
info.mod_suffix);
printer->add_line("static std::vector<int> node_index{0};");
printer->add_line("Node* _node = _nrn_mechanism_access_node(_prop);");

make_node_data_args = {"node_index.data()",
"&_nrn_mechanism_access_voltage(_node)",
"&_nrn_mechanism_access_d(_node)",
"&_nrn_mechanism_access_rhs(_node)",
"1"};

printer->fmt_push_block("return {}", node_data_struct());
printer->add_multi_line(fmt::format("{}", fmt::join(make_node_data_args, ",\n")));

printer->pop_block(";");
printer->pop_block();
printer->add_newline();
}

void CodegenNeuronCppVisitor::print_thread_variables_structure(bool print_initializers) {
Expand Down Expand Up @@ -1475,7 +1517,6 @@ void CodegenNeuronCppVisitor::print_nrn_init(bool skip_init_check) {
if (!info.artificial_cell) {
printer->add_line("int node_id = node_data.nodeindices[id];");
printer->add_line("auto v = node_data.node_voltages[node_id];");
printer->fmt_line("inst.{}[id] = v;", naming::VOLTAGE_UNUSED_VARIABLE);
}

print_rename_state_vars();
Expand Down Expand Up @@ -2164,6 +2205,9 @@ void CodegenNeuronCppVisitor::print_net_receive() {
printer->add_line("auto * _ppvar = _nrn_mechanism_access_dparam(_pnt->prop);");

printer->fmt_line("auto inst = make_instance_{}(_lmc);", info.mod_suffix);
if (!info.artificial_cell) {
printer->fmt_line("auto node_data = make_node_data_{}(_pnt->prop);", info.mod_suffix);
}
printer->fmt_line("// nocmodl has a nullptr dereference for thread variables.");
printer->fmt_line("// NMODL will fail to compile at a later point, because of");
printer->fmt_line("// missing '_thread_vars'.");
Expand Down
34 changes: 34 additions & 0 deletions test/usecases/function/artificial_functions.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
NEURON {
ARTIFICIAL_CELL art_functions
RANGE x
GLOBAL gbl
}

ASSIGNED {
gbl
v
x
}

FUNCTION x_plus_a(a) {
x_plus_a = x + a
}

FUNCTION identity(v) {
identity = v
}

INITIAL {
x = 1.0
gbl = 42.0
}

: A LINEAR block makes a MOD file not VECTORIZED.
STATE {
z
}

LINEAR lin {
~ z = 2
}

38 changes: 38 additions & 0 deletions test/usecases/function/non_threadsafe.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
NEURON {
SUFFIX non_threadsafe
RANGE x
GLOBAL gbl
}

ASSIGNED {
gbl
v
x
}

FUNCTION x_plus_a(a) {
x_plus_a = x + a
}

FUNCTION v_plus_a(a) {
v_plus_a = v + a
}

FUNCTION identity(v) {
identity = v
}

INITIAL {
x = 1.0
gbl = 42.0
}

: A LINEAR block makes a MOD file not VECTORIZED.
STATE {
z
}

LINEAR lin {
~ z = 2
}

38 changes: 38 additions & 0 deletions test/usecases/function/point_non_threadsafe.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
NEURON {
POINT_PROCESS point_non_threadsafe
RANGE x
GLOBAL gbl
}

ASSIGNED {
gbl
v
x
}

FUNCTION x_plus_a(a) {
x_plus_a = x + a
}

FUNCTION v_plus_a(a) {
v_plus_a = v + a
}

FUNCTION identity(v) {
identity = v
}

INITIAL {
x = 1.0
gbl = 42.0
}

: A LINEAR block makes a MOD file not VECTORIZED.
STATE {
z
}

LINEAR lin {
~ z = 2
}

22 changes: 15 additions & 7 deletions test/usecases/function/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from neuron import h


def check_functions(get_instance):
def check_callable(get_instance, has_voltage=True):
for x, value in zip(coords, values):
get_instance(x).x = value

Expand All @@ -19,22 +19,30 @@ def check_functions(get_instance):
actual = get_instance(x).identity(expected)
assert actual == expected, f"{actual} == {expected}"

# Check `f` using `v`.
expected = -2.0
actual = get_instance(x).v_plus_a(40.0)
assert actual == expected, f"{actual} == {expected}"
if has_voltage:
# Check `f` using `v`.
expected = -2.0
actual = get_instance(x).v_plus_a(40.0)
assert actual == expected, f"{actual} == {expected}"


nseg = 5
s = h.Section()
s.nseg = nseg

s.insert("functions")
s.insert("non_threadsafe")

coords = [(0.5 + k) * 1.0 / nseg for k in range(nseg)]
values = [0.1 + k for k in range(nseg)]

point_processes = {x: h.point_functions(s(x)) for x in coords}
point_non_threadsafe = {x: h.point_non_threadsafe(s(x)) for x in coords}

art_cells = {x: h.art_functions() for x in coords}

check_functions(lambda x: s(x).functions)
check_functions(lambda x: point_processes[x])
check_callable(lambda x: s(x).functions)
check_callable(lambda x: s(x).non_threadsafe)
check_callable(lambda x: point_processes[x])
check_callable(lambda x: point_non_threadsafe[x])
check_callable(lambda x: art_cells[x], has_voltage=False)
43 changes: 43 additions & 0 deletions test/usecases/global/non_threadsafe.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
NEURON {
SUFFIX non_threadsafe
GLOBAL gbl
}

LOCAL top_local

PARAMETER {
parameter = 41.0
}

ASSIGNED {
gbl
}

FUNCTION get_gbl() {
get_gbl = gbl
}

FUNCTION get_top_local() {
get_top_local = top_local
}

FUNCTION get_parameter() {
get_parameter = parameter
}

INITIAL {
gbl = 42.0
top_local = 43.0
}

: A LINEAR block makes the MOD file not thread-safe and not
: vectorized. We don't otherwise care about anything below
: this comment.
STATE {
z
}

LINEAR lin {
~ z = 2
}

33 changes: 33 additions & 0 deletions test/usecases/global/test_non_threadsafe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np

from neuron import h, gui
from neuron.units import ms


def test_non_threadsafe():
nseg = 1

s = h.Section()
s.insert("non_threadsafe")
s.nseg = nseg

h.finitialize()

instance = s(0.5).non_threadsafe

# Check INITIAL values.
assert instance.get_parameter() == 41.0
assert instance.get_gbl() == 42.0
assert instance.get_top_local() == 43.0

# Check reassigning a value. Top LOCAL variables
# are not exposed to HOC/Python.
h.parameter_non_threadsafe = 32.1
h.gbl_non_threadsafe = 33.2

assert instance.get_parameter() == 32.1
assert instance.get_gbl() == 33.2


if __name__ == "__main__":
test_non_threadsafe()

0 comments on commit 1c4cb6c

Please sign in to comment.