Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add message handler and tests: #23 #27

Merged
merged 10 commits into from
Dec 4, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ void do_dopamine_plasticity(
if (step - synapse->rule_.last_spike_step_ < synapse->rule_.dopamine_plasticity_period_)
{
// Change synapse resource.
float d_r =
neuron.dopamine_value_ * std::min(static_cast<float>(std::pow(2, -neuron.stability_)), 1.F);
float d_r = neuron.dopamine_value_ *
std::min(static_cast<float>(std::pow(2, -neuron.stability_)), 1.F) / 1000.F;
synapse->rule_.synaptic_resource_ += d_r;
neuron.free_synaptic_resource_ -= d_r;
}
Expand Down
1 change: 1 addition & 0 deletions knp/base-framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ knp_add_library("${PROJECT_NAME}-core"
impl/model.cpp
impl/model_executor.cpp
impl/model_loader.cpp
impl/message_handler.cpp
impl/input_converter.cpp
impl/output_channel.cpp
impl/synchronization.cpp
Expand Down
97 changes: 97 additions & 0 deletions knp/base-framework/impl/message_handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
//
artiomn marked this conversation as resolved.
Show resolved Hide resolved
// Created by an_vartenkov on 22.11.24.
//
#include <knp/framework/message_handler.h>

#include <unordered_set>
#include <utility>

artiomn marked this conversation as resolved.
Show resolved Hide resolved
namespace knp::framework::modifier
{

void SpikeMessageHandler::update(size_t step)
{
endpoint_.receive_all_messages();
auto incoming_messages = endpoint_.unload_messages<MessageIn>(base_.uid_);
MessageOut outgoing_message = {{base_.uid_, step}, message_handler_function_(incoming_messages)};
if (!(outgoing_message.neuron_indexes_.empty()))
{
endpoint_.send_message(outgoing_message);
}
}


knp::core::messaging::SpikeData KWtaRandomHandler::operator()(std::vector<knp::core::messaging::SpikeMessage> &messages)
{
if (messages.empty()) return {};

auto &msg = messages[0];
artiomn marked this conversation as resolved.
Show resolved Hide resolved
if (msg.neuron_indexes_.size() < num_winners_) return msg.neuron_indexes_;

knp::core::messaging::SpikeData out_spikes;
for (size_t i = 0; i < num_winners_; ++i)
{
const size_t index = distribution_(random_engine_) % (msg.neuron_indexes_.size() - i);
out_spikes.push_back(msg.neuron_indexes_[index]);
std::swap(msg.neuron_indexes_[index], msg.neuron_indexes_[msg.neuron_indexes_.size() - 1 - i]);
}
return out_spikes;
}


knp::core::messaging::SpikeData GroupWtaRandomHandler::operator()(
const std::vector<knp::core::messaging::SpikeMessage> &messages)
{
if (messages.empty()) return {};
if (num_winners_ > group_borders_.size()) return messages[0].neuron_indexes_;

const auto &spikes = messages[0].neuron_indexes_;
if (spikes.empty()) return {};

std::vector<std::vector<size_t>> spikes_per_group(group_borders_.size() + 1);

// Fill groups in
artiomn marked this conversation as resolved.
Show resolved Hide resolved
for (auto spike : spikes)
artiomn marked this conversation as resolved.
Show resolved Hide resolved
{
const size_t group_index =
std::upper_bound(group_borders_.begin(), group_borders_.end(), spike) - group_borders_.begin();
spikes_per_group[group_index].push_back(spike);
}

// Sort groups by number of elements
artiomn marked this conversation as resolved.
Show resolved Hide resolved
std::sort(
spikes_per_group.begin(), spikes_per_group.end(),
[](const auto &el1, const auto &el2) { return el1.size() > el2.size(); });

// Find all groups with the same number of spikes as the K-th one
const auto &last_group = spikes_per_group[num_winners_ - 1];
auto group_interval = std::equal_range(
spikes_per_group.begin(), spikes_per_group.end(), last_group,
[](const auto &el1, const auto &el2) { return el1.size() > el2.size(); });
const size_t already_decided = group_interval.first - spikes_per_group.begin() + 1;
assert(already_decided <= num_winners_);
// The approach could be more efficient, but I don't think it's necessary.
std::shuffle(group_interval.first, group_interval.second, random_engine_);
knp::core::messaging::SpikeData result;
for (size_t i = 0; i < num_winners_; ++i)
{
for (auto spike : spikes_per_group[i]) result.push_back(spike);
artiomn marked this conversation as resolved.
Show resolved Hide resolved
}
return result;
}


knp::core::messaging::SpikeData SpikeUnionHandler::operator()(
const std::vector<knp::core::messaging::SpikeMessage> &messages)
{
std::unordered_set<size_t> spikes;
for (const auto &msg : messages)
{
spikes.insert(msg.neuron_indexes_.begin(), msg.neuron_indexes_.end());
}
knp::core::messaging::SpikeData result;
result.reserve(spikes.size());
std::copy(spikes.begin(), spikes.end(), std::back_inserter(result));
return result;
}
} // namespace knp::framework::modifier
20 changes: 20 additions & 0 deletions knp/base-framework/impl/model_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ void ModelExecutor::start(core::Backend::RunPredicate run_predicate)
{
o_ch.update();
}
// Running handlers
for (auto &handler : message_handlers_)
{
handler.update(get_backend()->get_step());
}
// Run monitoring observers.
for (auto &observer : observers_)
{
Expand All @@ -72,4 +77,19 @@ void ModelExecutor::stop()
get_backend()->stop();
}


void ModelExecutor::add_message_handler(
typename modifier::SpikeMessageHandler::FunctionType &&message_handler_function,
const std::vector<core::UID> &senders, const std::vector<core::UID> &receivers, const knp::core::UID &uid)
{
knp::core::MessageEndpoint endpoint = get_backend()->get_message_bus().create_endpoint();
message_handlers_.emplace_back(
modifier::SpikeMessageHandler{std::move(message_handler_function), std::move(endpoint), uid});
message_handlers_.back().subscribe(senders);
for (const knp::core::UID &rec_uid : receivers)
{
get_backend()->subscribe<typename modifier::SpikeMessageHandler::MessageOut>(rec_uid, {uid});
}
}

} // namespace knp::framework
194 changes: 194 additions & 0 deletions knp/base-framework/include/knp/framework/message_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/**
* @file message_handler.h
* @brief A class that processes a number of messages then sends messages of its own.
* @kaspersky_support Vartenkov A.
* @date 19.11.2024
* @license Apache 2.0
* @copyright © 2024 AO Kaspersky Lab
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once


#include <knp/core/message_endpoint.h>
#include <knp/core/messaging/messaging.h>

#include <algorithm>
#include <random>
#include <string>
#include <utility>
#include <vector>


namespace knp::framework::modifier
{

/**
* @brief An object that receives and processes messages.
*/
class SpikeMessageHandler
{
public:
using MessageIn = knp::core::messaging::SpikeMessage;
using MessageOut = knp::core::messaging::SpikeMessage;
using FunctionType = std::function<core::messaging::SpikeData(std::vector<MessageIn> &)>;

/**
* @brief Handler constructor.
* @param function a function that takes a vector of spike messages and returns a vector of spikes.
* @param endpoint message endpoint.
* @param uid the uid of this object.
*/
SpikeMessageHandler(FunctionType &&function, knp::core::MessageEndpoint &&endpoint, const knp::core::UID &uid = {})
: message_handler_function_(std::move(function)), endpoint_(std::move(endpoint)), base_{uid}
{
}

/**
* @brief Default move constructor.
* @param other object to move from.
*/
SpikeMessageHandler(SpikeMessageHandler &&other) noexcept = default;

artiomn marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Is not copyable.
*/
SpikeMessageHandler(const SpikeMessageHandler &) = delete;

artiomn marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Subscribe handler to a number of other entities.
* @param entities network uids.
* @note For internal use, don't try to call it manually.
*/
void subscribe(const std::vector<core::UID> &entities) { endpoint_.subscribe<MessageIn>(base_.uid_, entities); }

artiomn marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Read, process and send messages.
* @param step current step.
* @note for internal use, don't try to call it manually.
artiomn marked this conversation as resolved.
Show resolved Hide resolved
*/
void update(size_t step);


/**
* @brief Get handler UID.
* @return object UID.
*/
[[nodiscard]] knp::core::UID get_uid() const { return base_.uid_; };

artiomn marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Get a tag.
* @param tag_name tag name.
* @return tag value reference.
*/
[[nodiscard]] std::any &get_tag(const std::string &tag_name) { return base_.tags_[tag_name]; };

private:
FunctionType message_handler_function_;
core::MessageEndpoint endpoint_;
knp::core::BaseData base_;
};


/**
* @brief A modifier functor to process spikes and select random K spikes out of the whole set.
* @note Only processes a single message.
*/
class KWtaRandomHandler
artiomn marked this conversation as resolved.
Show resolved Hide resolved
{
public:
/**
* @brief Constructor
artiomn marked this conversation as resolved.
Show resolved Hide resolved
* @param winners_number Max number of output spikes.
* @param seed random generator seed.
* @note uses mt19937 for random number generation.
*/
explicit KWtaRandomHandler(size_t winners_number = 1, int seed = 0)
: num_winners_(winners_number), random_engine_(seed)
{
}

artiomn marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief operator that takes a number of messages and returns a set of spikes.
* @param messages spike messages.
* @return spikes data containing no more than K spikes.
* @note it's assumed that it gets no more than one message per step, so all messages except first are ignored.
*/
knp::core::messaging::SpikeData operator()(std::vector<knp::core::messaging::SpikeMessage> &messages);

private:
size_t num_winners_;
std::mt19937 random_engine_;
std::uniform_int_distribution<size_t> distribution_;
};


/**
* @brief MessageHandler functor that only passes through spikes from no more than a fixed number of groups at once.
* @note Group is considered to be winning if it is in the top K groups sorted by number of spikes in descending order.
* @note If last place in the top K is shared between groups, the functor selects random ones among the sharing groups.
*/
class GroupWtaRandomHandler
{
public:
/**
* @brief Functor constructor.
* @param group_borders right borders of the intervals.
* @param num_winning_groups max number of groups that are allowed to pass their spikes further.
* @param seed seed for internal random number generator.
*/
explicit GroupWtaRandomHandler(
artiomn marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<size_t> &group_borders, size_t num_winning_groups = 1, int seed = 0)
: group_borders_(group_borders), num_winners_(num_winning_groups), random_engine_(seed)
{
std::sort(group_borders_.begin(), group_borders_.end());
}

artiomn marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Functor operator.
artiomn marked this conversation as resolved.
Show resolved Hide resolved
* @param messages input messages.
* @return spikes from winning groups.
*/
knp::core::messaging::SpikeData operator()(const std::vector<knp::core::messaging::SpikeMessage> &messages);

private:
std::vector<size_t> group_borders_;
size_t num_winners_;
std::mt19937 random_engine_;
std::uniform_int_distribution<size_t> distribution_;
};


/**
* @brief Spike handler functor. An output vector has a spike if that spike was present in at least one input message.
*/
class SpikeUnionHandler
{
public:
/**
* @brief Functor operator, receives a vector of messages, returns a union of all spike sets from those messages.
* @param messages incoming spike messages.
* @return spikes vector containing the union of input message spike sets.
*/
knp::core::messaging::SpikeData operator()(const std::vector<knp::core::messaging::SpikeMessage> &messages);
};


} // namespace knp::framework::modifier
Loading
Loading