Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add message handler and tests: #23 #27

Merged
merged 10 commits into from
Dec 4, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ void do_dopamine_plasticity(
if (step - synapse->rule_.last_spike_step_ < synapse->rule_.dopamine_plasticity_period_)
{
// Change synapse resource.
float d_r =
neuron.dopamine_value_ * std::min(static_cast<float>(std::pow(2, -neuron.stability_)), 1.F);
float d_r = neuron.dopamine_value_ *
std::min(static_cast<float>(std::pow(2, -neuron.stability_)), 1.F) / 1000.F;
synapse->rule_.synaptic_resource_ += d_r;
neuron.free_synaptic_resource_ -= d_r;
}
Expand Down
1 change: 1 addition & 0 deletions knp/base-framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ knp_add_library("${PROJECT_NAME}-core"
impl/model.cpp
impl/model_executor.cpp
impl/model_loader.cpp
impl/message_handlers.cpp
impl/input_converter.cpp
impl/output_channel.cpp
impl/synchronization.cpp
Expand Down
127 changes: 127 additions & 0 deletions knp/base-framework/impl/message_handlers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/**
* @file message_handlers.cpp
* @brief Implementation of message handler functionality.
* @kaspersky_support A. Vartenkov
* @date 25.11.2024
* @license Apache 2.0
* @copyright © 2024 AO Kaspersky Lab
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <knp/framework/message_handlers.h>

#include <unordered_set>
#include <utility>


/**
* @brief namespace for message modifier callables.
*/
namespace knp::framework::modifier
{

knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector<knp::core::messaging::SpikeMessage> &messages)
{
if (messages.empty())
{
return {};
}

auto &msg = messages[0];
if (msg.neuron_indexes_.size() < num_winners_)
{
return msg.neuron_indexes_;
}

knp::core::messaging::SpikeData out_spikes;
for (size_t i = 0; i < num_winners_; ++i)
{
const size_t index = distribution_(random_engine_) % (msg.neuron_indexes_.size() - i);
out_spikes.push_back(msg.neuron_indexes_[index]);
std::swap(msg.neuron_indexes_[index], msg.neuron_indexes_[msg.neuron_indexes_.size() - 1 - i]);
}

return out_spikes;
}


knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()(
const std::vector<knp::core::messaging::SpikeMessage> &messages)
{
if (messages.empty())
{
return {};
}

if (num_winners_ > group_borders_.size())
{
return messages[0].neuron_indexes_;
}

const auto &spikes = messages[0].neuron_indexes_;
if (spikes.empty())
{
return {};
}

std::vector<knp::core::messaging::SpikeData> spikes_per_group(group_borders_.size() + 1);

// Fill groups in.
for (const auto &spike : spikes)
{
const size_t group_index =
std::upper_bound(group_borders_.begin(), group_borders_.end(), spike) - group_borders_.begin();
spikes_per_group[group_index].push_back(spike);
}

// Sort groups by number of elements.
std::sort(
spikes_per_group.begin(), spikes_per_group.end(),
[](const auto &el1, const auto &el2) { return el1.size() > el2.size(); });

// Find all groups with the same number of spikes as the K-th one.
const auto &last_group = spikes_per_group[num_winners_ - 1];
auto group_interval = std::equal_range(
spikes_per_group.begin(), spikes_per_group.end(), last_group,
[](const auto &el1, const auto &el2) { return el1.size() > el2.size(); });
const size_t already_decided = group_interval.first - spikes_per_group.begin() + 1;
assert(already_decided <= num_winners_);
// The approach could be more efficient, but I don't think it's necessary.
std::shuffle(group_interval.first, group_interval.second, random_engine_);
knp::core::messaging::SpikeData result;
for (size_t i = 0; i < num_winners_; ++i)
{
for (const auto &spike : spikes_per_group[i])
{
result.push_back(spike);
}
}
return result;
}


knp::core::messaging::SpikeData SpikeUnionHandler::operator()(
const std::vector<knp::core::messaging::SpikeMessage> &messages)
{
std::unordered_set<knp::core::messaging::SpikeIndex> spikes;
for (const auto &msg : messages)
{
spikes.insert(msg.neuron_indexes_.begin(), msg.neuron_indexes_.end());
}
knp::core::messaging::SpikeData result;
result.reserve(spikes.size());
std::copy(spikes.begin(), spikes.end(), std::back_inserter(result));
return result;
}
} // namespace knp::framework::modifier
76 changes: 76 additions & 0 deletions knp/base-framework/impl/model_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,62 @@
namespace knp::framework
{


class ModelExecutor::SpikeMessageHandler
{
public:
using MessageIn = knp::core::messaging::SpikeMessage;
using MessageData = knp::core::messaging::SpikeData;
using FunctionType = std::function<MessageData(std::vector<MessageIn> &)>;

SpikeMessageHandler(FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {})
: message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid}
{
}

SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default;

SpikeMessageHandler(const SpikeMessageHandler &) = delete;

void subscribe(const std::vector<core::UID> &entities) { endpoint_.subscribe<MessageIn>(base_.uid_, entities); }

void update(size_t step);

[[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; };

~SpikeMessageHandler() = default;

private:
FunctionType message_handler_function_;
knp::core::MessageEndpoint endpoint_;
knp::core::BaseData base_;
};


void ModelExecutor::SpikeMessageHandler::update(size_t step)
{
endpoint_.receive_all_messages();
auto incoming_messages = endpoint_.unload_messages<MessageIn>(base_.uid_);
knp::core::messaging::SpikeMessage outgoing_message = {
{base_.uid_, step}, message_handler_function_(incoming_messages)};
if (!(outgoing_message.neuron_indexes_.empty()))
{
endpoint_.send_message(outgoing_message);
}
}


ModelExecutor::ModelExecutor(
knp::framework::Model &model, std::shared_ptr<core::Backend> backend, ModelLoader::InputChannelMap i_map)
: loader_(backend, i_map)
{
loader_.load(model);
}


ModelExecutor::~ModelExecutor() = default;


void ModelExecutor::start()
{
start([](knp::core::Step) { return true; });
Expand Down Expand Up @@ -55,6 +111,11 @@ void ModelExecutor::start(core::Backend::RunPredicate run_predicate)
{
o_ch.update();
}
// Running handlers
for (auto &handler : message_handlers_)
{
handler->update(get_backend()->get_step());
}
// Run monitoring observers.
for (auto &observer : observers_)
{
Expand All @@ -72,4 +133,19 @@ void ModelExecutor::stop()
get_backend()->stop();
}


void ModelExecutor::add_spike_message_handler(
typename SpikeMessageHandler::FunctionType &&message_handler_function, const std::vector<core::UID> &senders,
const std::vector<core::UID> &receivers, const knp::core::UID &uid)
{
knp::core::MessageEndpoint endpoint = get_backend()->get_message_bus().create_endpoint();
message_handlers_.emplace_back(
std::make_unique<SpikeMessageHandler>(std::move(message_handler_function), std::move(endpoint), uid));
message_handlers_.back()->subscribe(senders);
for (const knp::core::UID &rec_uid : receivers)
{
get_backend()->subscribe<knp::core::messaging::SpikeMessage>(rec_uid, {uid});
}
}

} // namespace knp::framework
122 changes: 122 additions & 0 deletions knp/base-framework/include/knp/framework/message_handlers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/**
* @file message_handlers.h
* @brief A set of predefined message handling functions to add to model executor.
* @kaspersky_support Vartenkov A.
* @date 19.11.2024
* @license Apache 2.0
* @copyright © 2024 AO Kaspersky Lab
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once


#include <knp/core/message_endpoint.h>
#include <knp/core/messaging/messaging.h>

#include <algorithm>
#include <random>
#include <string>
#include <utility>
#include <vector>


namespace knp::framework::modifier
{

/**
* @brief A modifier functor to process spikes and select random K spikes out of the whole set.
* @note Only processes a single message.
*/
class KWtaRandomHandler
{
public:
/**
* @brief Constructor.
* @param winners_number Max number of output spikes.
* @param seed random generator seed.
* @note uses mt19937 for random number generation.
*/
explicit KWtaRandomHandler(size_t winners_number = 1, int seed = 0)
: num_winners_(winners_number), random_engine_(seed)
{
}

/**
* @brief Function call operator that takes a number of messages and returns a set of spikes.
* @param messages spike messages.
* @return spikes data containing no more than K spikes.
* @note it's assumed that it gets no more than one message per step, so all messages except first are ignored.
*/
knp::core::messaging::SpikeData operator()(std::vector<knp::core::messaging::SpikeMessage> &messages);

private:
size_t num_winners_;
std::mt19937 random_engine_;
std::uniform_int_distribution<size_t> distribution_;
};


/**
* @brief MessageHandler functor that only passes through spikes from no more than a fixed number of groups at once.
* @note Group is considered to be winning if it is in the top K groups sorted by number of spikes in descending order.
* @note If last place in the top K is shared between groups, the functor selects random ones among the sharing groups.
*/
class GroupWtaRandomHandler
{
public:
/**
* @brief Functor constructor.
* @param group_borders right borders of the intervals.
* @param num_winning_groups max number of groups that are allowed to pass their spikes further.
* @param seed seed for internal random number generator.
*/
explicit GroupWtaRandomHandler(
const std::vector<size_t> &group_borders, size_t num_winning_groups = 1, int seed = 0)
: group_borders_(group_borders), num_winners_(num_winning_groups), random_engine_(seed)
{
std::sort(group_borders_.begin(), group_borders_.end());
}

/**
* @brief Function call operator.
* @param messages input messages.
* @return spikes from winning groups.
*/
knp::core::messaging::SpikeData operator()(const std::vector<knp::core::messaging::SpikeMessage> &messages);

private:
std::vector<size_t> group_borders_;
size_t num_winners_;
std::mt19937 random_engine_;
std::uniform_int_distribution<size_t> distribution_;
};


/**
* @brief Spike handler functor. An output vector has a spike if that spike was present in at least one input message.
*/
class SpikeUnionHandler
{
public:
/**
* @brief Function call operator, receives a vector of messages, returns a union of all spike sets from those
* messages.
* @param messages incoming spike messages.
* @return spikes vector containing the union of input message spike sets.
*/
knp::core::messaging::SpikeData operator()(const std::vector<knp::core::messaging::SpikeMessage> &messages);
};


} // namespace knp::framework::modifier
Loading
Loading