diff --git a/ml-module/model-example/main.py b/ml-module/model-example/main.py index cce16cd..11f3e7e 100644 --- a/ml-module/model-example/main.py +++ b/ml-module/model-example/main.py @@ -7,6 +7,8 @@ TOTAL_SAMPLES = SAMPLES_LEN * 3 acc_x_y_z = [0] * TOTAL_SAMPLES +print("Model labels: {}".format(ml.get_labels())) + i = 0 while True: acc_x_y_z[i + 0] = accelerometer.get_x() @@ -24,4 +26,12 @@ else: print("t[{}] {}".format(time.ticks_ms() - t, result)) i = 0 + if button_a.is_pressed(): + print("Use build-in model") + ml.internal_model(True) + sleep(500) + if button_b.is_pressed(): + print("Use flash model") + ml.internal_model(False) + sleep(500) sleep(20) diff --git a/ml-module/model-example/model_example.c b/ml-module/model-example/model_example.c index 8ecd5f8..dad5d0e 100644 --- a/ml-module/model-example/model_example.c +++ b/ml-module/model-example/model_example.c @@ -1,6 +1,33 @@ #include "model_example.h" -const unsigned int ml4f_model_example[13852] = { +#define ml4f_model_example_header_len 52 +#define ml4f_model_example_size 13852 +#define ml4f_full_model_size (ml4f_model_example_header_len + ml4f_model_example_size) + + +// This is a struct representation of the header included at the beginning of model_example +/* const ml_model_header_t ml4f_model_example_header = { + .magic0 = MODEL_LABELS_MAGIC0, + .header_size = 0x31, // 49 + .model_offset = 0x34, // 52 + .number_of_labels = 0x04, + .reserved = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + // 33 bytes + .labels = { + "Jumping\0" + "Running\0" + "Standing\0" + "Walking\0\0\0" + } +}; */ + +const unsigned int model_example[ml4f_full_model_size] = { + // Manually converted ml4f_model_example_header + 0x4D444C42, 0x00340031, 0x00000004, 0x00000000, + 0x706D754A, 0x00676E69, 0x6E6E7552, 0x00676E69, + 0x6E617453, 0x676E6964, 0x6C615700, 0x676E696B, + 0x00000000, + // Original ML4F model from this point forward 0x30470f62, 0x46344c4d, 0x00000054, 0x0000d864, 0x00001874, 0x00000000, 0x00000000, 0x00002ec8, 0x00000008, 0x00000001, 0x00000008, 0x00000001, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x000000fa, 0x00000003, 0x00000000, 0x00000004, 0x00000000, 0x5ff0e92d, 0x6901460f, 0x60391809, diff --git a/ml-module/model-example/model_example.h b/ml-module/model-example/model_example.h index b85128a..291d846 100644 --- a/ml-module/model-example/model_example.h +++ b/ml-module/model-example/model_example.h @@ -1,5 +1,7 @@ #pragma once +#include + #define ml4f_model_example_input_num_elements 750 -#define ml4f_model_example_num_labels 4 -extern const unsigned int ml4f_model_example[13852]; + +extern const unsigned int model_example[]; diff --git a/ml-module/src/mlmodel.c b/ml-module/src/mlmodel.c index 35b31ea..5051f41 100644 --- a/ml-module/src/mlmodel.c +++ b/ml-module/src/mlmodel.c @@ -1,3 +1,15 @@ +/** + * @brief Functions to interact with the ML model. + * + * @copyright + * Copyright 2024 Micro:bit Educational Foundation. + * SPDX-License-Identifier: MIT + * + * @details + * The ML4F model has its own header, but does not include the labels. + * So an extra header with the labels is added on top. + * We call the "full model" the labels header + the ML4F model. + */ #include #include #include @@ -14,7 +26,18 @@ static bool USE_BUILT_IN = true; /*****************************************************************************/ /* Private API */ /*****************************************************************************/ -static int get_model_start_address() { +/** + * @brief Get the start address of the full model (header + ML4F model). + * + * This would also be the start address of the model header with the labels. + * This function does not check if the data in flash is a valid model. + * + * @return The start address to where the full model is stored in flash. + */ +static uint32_t get_full_model_start_address() { + if (USE_BUILT_IN) { + return (uint32_t)model_example; + } // The last section in FLASH is meant to be text, but data section contents // are placed immediately after it (to be copied to RAM), but there isn't // a symbol to indicate its end in FLASH, so we calculate how long data is @@ -30,26 +53,137 @@ static int get_model_start_address() { return (end_of_flash_data + flash_page_size - 1) & ~(flash_page_size - 1); } -static ml4f_header_t* get_model_header() { - return (ml4f_header_t *)get_model_start_address(); +/** + * @brief Get a pointer to the full model header. + * + * @return The model header or NULL if the model is not present or invalid. + */ +static ml_model_header_t* get_model_header() { + ml_model_header_t *model_header = (ml_model_header_t *)get_full_model_start_address(); + if (model_header->magic0 != MODEL_LABELS_MAGIC0) { + return NULL; + } + // We should have at least one label + if (model_header->number_of_labels == 0) { + return NULL; + } + // Also check the ML4F header magic values to ensure it's there too + ml4f_header_t *ml4f_model = (ml4f_header_t *)((uint32_t)model_header + model_header->model_offset); + if (ml4f_model->magic0 != ML4F_MAGIC0 || ml4f_model->magic1 != ML4F_MAGIC1) { + return NULL; + } + return model_header; +} + +/** + * @brief Get a pointer to the ML4F model. + * + * @return The ML4F model or NULL if the model is not present or invalid. + */ +static ml4f_header_t* get_ml4f_model() { + // get_model_header() already checks the ML4F header magic values + ml_model_header_t *model_header = get_model_header(); + if (model_header == NULL) { + return NULL; + } + return (ml4f_header_t *)((uint32_t)model_header + model_header->model_offset); } /*****************************************************************************/ /* Public API */ /*****************************************************************************/ -bool use_built_in_model(bool use) { - USE_BUILT_IN = use; +bool get_use_built_in_model(void) { return USE_BUILT_IN; } +void set_use_built_in_model(bool use) { + USE_BUILT_IN = use; +} + bool is_model_present(void) { - // TODO: Implement built-in module - return USE_BUILT_IN ? true : false; + ml_model_header_t *model_header = get_model_header(); + return model_header != NULL; } size_t get_model_label_num(void) { - // TODO: Implement built-in module - return USE_BUILT_IN ? ml4f_model_example_num_labels : 0; + ml_model_header_t *model_header = get_model_header(); + return (model_header != NULL) ? model_header->number_of_labels : 0; +} + +ml_labels_t* get_model_labels(void) { + static ml_labels_t labels = { + .num_labels = 0, + .labels = NULL + }; + + const ml_model_header_t* const model_header = get_model_header(); + if (model_header == NULL) { + labels.num_labels = 0; + if (labels.labels != NULL) { + free(labels.labels); + labels.labels = NULL; + } + return NULL; + } + + // Workout the addresses in flash from each label, there are as many strings + // as indicated by model_header->number_of_labels, they start from address + // model_header->labels and are null-terminated. + uint32_t header_end = (uint32_t)model_header + model_header->header_size; + const char* flash_labels[model_header->number_of_labels]; + flash_labels[0] = &model_header->labels[0]; + for (int i = 1; i < model_header->number_of_labels; i++) { + // Find the end of the previous string by looking for the null terminator + flash_labels[i] = flash_labels[i - 1]; + while (*flash_labels[i] != '\0' && (uint32_t)flash_labels[i] < header_end) { + flash_labels[i]++; + } + if ((uint32_t)flash_labels[i] >= header_end) { + // We reached the end of the header without finding the null terminator + free(flash_labels); + return NULL; + } + // Currently pointing to the null terminator, so point to the following string + flash_labels[i]++; + } + // Check the last string is null terminated at the end of header + if (*(char *)(header_end - 1) != '\0') { + free(flash_labels); + return NULL; + } + + // First check if the labels are the same, if not we need to set them again + bool set_labels = false; + if (labels.num_labels == 0 || labels.labels == NULL) { + set_labels = true; + } else if (labels.num_labels != model_header->number_of_labels) { + set_labels = true; + } else { + for (int i = 0; i < labels.num_labels; i++) { + if (labels.labels[i] != flash_labels[i]) { + set_labels = true; + break; + } + } + } + if (set_labels) { + // First clear them out if needed + labels.num_labels = 0; + if (labels.labels != NULL) { + free(labels.labels); + } + // Then set them to point to the strings in flash + labels.labels = (const char **)malloc(model_header->number_of_labels * sizeof(char *)); + if (labels.labels == NULL) { + return NULL; + } + labels.num_labels = model_header->number_of_labels; + for (int i = 0; i < labels.num_labels; i++) { + labels.labels[i] = flash_labels[i]; + } + } + + return &labels; } size_t get_model_input_num() { @@ -58,52 +192,40 @@ size_t get_model_input_num() { } ml_prediction_t* model_predict(const float *input) { - if (!USE_BUILT_IN) { - (void)get_model_header(); - return NULL; - } - - typedef struct out_s { - size_t len; - float* values; - } out_t; - - static out_t out = { - .len = 0, - .values = NULL - }; static ml_prediction_t predictions = { .max_index = 0, - .num_labels = ml4f_model_example_num_labels, - .predictions = { - { .prediction = 0.0, .label = "Jumping" }, - { .prediction = 0.0, .label = "Running" }, - { .prediction = 0.0, .label = "Standing" }, - { .prediction = 0.0, .label = "Walking" }, - } + .num_labels = 0, + .labels = NULL, + .predictions = NULL, }; - size_t label_num = get_model_label_num(); + ml_labels_t* labels = get_model_labels(); + if (labels == NULL) { + return NULL; + } - // The model shouldn't really change (only during testing while we built-in - // one), so this should be a one-time allocation. - if (out.len != label_num || out.values == NULL) { - if (out.values != NULL) { - free(out.values); + // Check if we need to resize the predictions array + if (predictions.num_labels != labels->num_labels) { + if (predictions.predictions != NULL) { + free(predictions.predictions); + } + predictions.num_labels = labels->num_labels; + predictions.predictions = (float *)malloc(predictions.num_labels * sizeof(float)); + if (predictions.predictions == NULL) { + predictions.num_labels = 0; + return NULL; } - out.len = label_num; - out.values = (float *)malloc(out.len * sizeof(float)); } + // Always update the labels in case they changed + predictions.labels = labels->labels; - int r = ml4f_full_invoke((ml4f_header_t *)ml4f_model_example, input, out.values); + ml4f_header_t* ml4f_model = get_ml4f_model(); + int r = ml4f_full_invoke(ml4f_model, input, predictions.predictions); if (r != 0) { return NULL; } - for (int i = 0; i < out.len; i++) { - predictions.predictions[i].prediction = out.values[i]; - } - predictions.max_index = ml4f_argmax(out.values, out.len); + predictions.max_index = ml4f_argmax(predictions.predictions, predictions.num_labels); return &predictions; } diff --git a/ml-module/src/mlmodel.h b/ml-module/src/mlmodel.h index f32bfbb..c6675e6 100644 --- a/ml-module/src/mlmodel.h +++ b/ml-module/src/mlmodel.h @@ -17,19 +17,22 @@ typedef __PACKED_STRUCT ml_model_header_t { char labels[]; // Mutiple null-terminated strings, as many as number_of_labels } ml_model_header_t; -typedef struct ml_label_prediction_s { - float prediction; - char* label; -} ml_label_prediction_t; +typedef struct ml_labels_s { + size_t num_labels; + const char **labels; +} ml_labels_t; typedef struct ml_prediction_s { size_t max_index; size_t num_labels; - ml_label_prediction_t predictions[]; + const char **labels; + float *predictions; } ml_prediction_t; -bool use_built_in_model(bool use); +bool get_use_built_in_model(void); +void set_use_built_in_model(bool use); bool is_model_present(void); size_t get_model_label_num(void); +ml_labels_t* get_model_labels(void); size_t get_model_input_num(void); ml_prediction_t* model_predict(const float *input); diff --git a/ml-module/src/mlmodule.c b/ml-module/src/mlmodule.c index 23e4105..bc65b7e 100644 --- a/ml-module/src/mlmodule.c +++ b/ml-module/src/mlmodule.c @@ -1,23 +1,37 @@ #include #include "mlmodel.h" -// Flag to control usage of model included in model_example.h/c -bool USE_BUILT_IN_MODULE = false; - mp_obj_t internal_model_func(size_t n_args, const mp_obj_t *args) { - if (n_args == 0) { - return mp_obj_new_bool(USE_BUILT_IN_MODULE); + if (n_args == 1) { + bool use_internal_model = !!mp_obj_get_int(args[0]); + set_use_built_in_model(use_internal_model); } - bool use_internal_model = mp_obj_is_true(args[0]); - USE_BUILT_IN_MODULE = use_internal_model; - return mp_obj_new_bool(USE_BUILT_IN_MODULE); + return mp_obj_new_bool(get_use_built_in_model()); } static MP_DEFINE_CONST_FUN_OBJ_VAR(internal_model_func_obj, 0, internal_model_func); +mp_obj_t get_labels_func(void) { + ml_labels_t* labels = get_model_labels(); + if (labels == NULL) { + return mp_const_none; + } + mp_obj_t tup_labels[labels->num_labels]; + for (int i = 0; i < labels->num_labels; i++) { + tup_labels[i] = mp_obj_new_str( + labels->labels[i], + strlen(labels->labels[i]) + ); + } + return mp_obj_new_tuple(labels->num_labels, tup_labels); +} +static MP_DEFINE_CONST_FUN_OBJ_0(get_labels_func_obj, get_labels_func); + + mp_obj_t predict_func(mp_obj_t x_y_z_obj) { // TODO: Expand the types of input accepted if (!mp_obj_is_type(x_y_z_obj, &mp_type_list)) { + // TODO: Use a better exception type mp_raise_ValueError(MP_ERROR_TEXT("Input data must be a list.")); } size_t input_len; @@ -29,7 +43,6 @@ mp_obj_t predict_func(mp_obj_t x_y_z_obj) { } const size_t model_input_num = get_model_input_num(); - const size_t model_label_num = get_model_label_num(); if (input_len != model_input_num) { mp_raise_ValueError(MP_ERROR_TEXT("Input data number of elements invalid.")); @@ -47,15 +60,15 @@ mp_obj_t predict_func(mp_obj_t x_y_z_obj) { } // Create a tuple with tuples of (label, prediction_value) - mp_obj_t tup_values[model_label_num]; - for (int i = 0; i < model_label_num; i++) { + mp_obj_t tup_values[predictions->num_labels]; + for (int i = 0; i < predictions->num_labels; i++) { tup_values[i] = mp_obj_new_tuple(2, (mp_obj_t[]){ - mp_obj_new_str(predictions->predictions[i].label, - strlen(predictions->predictions[i].label)), - mp_obj_new_float(predictions->predictions[i].prediction), + mp_obj_new_str(predictions->labels[i], + strlen(predictions->labels[i])), + mp_obj_new_float(predictions->predictions[i]), }); } - mp_obj_t tuple_values = mp_obj_new_tuple(model_label_num, tup_values); + mp_obj_t tuple_values = mp_obj_new_tuple(predictions->num_labels, tup_values); // And a tuple with the index of the max value and the tuple of labels+predictions return mp_obj_new_tuple(2, (mp_obj_t[]){ @@ -75,9 +88,7 @@ static mp_obj_t ml___init__(void) { // __init__ for builtins is called each time the module is imported, // so ensure that initialisation only happens once. MP_STATE_VM(ml_initialised) = true; - mp_printf(&mp_plat_print, "ml.__init_ run\n"); - - USE_BUILT_IN_MODULE = true; + mp_printf(&mp_plat_print, "ml.__init_ called\n"); } else { mp_printf(&mp_plat_print, "ml.__init_ already initialised\n"); } @@ -95,6 +106,7 @@ static const mp_rom_map_elem_t ml_module_globals_table[] = { { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ml) }, { MP_ROM_QSTR(MP_QSTR___init__), MP_ROM_PTR(&ml___init___obj) }, { MP_ROM_QSTR(MP_QSTR_internal_model), MP_ROM_PTR(&internal_model_func_obj) }, + { MP_ROM_QSTR(MP_QSTR_get_labels), MP_ROM_PTR(&get_labels_func_obj) }, { MP_ROM_QSTR(MP_QSTR_predict), MP_ROM_PTR(&predict_func_obj) }, }; static MP_DEFINE_CONST_DICT(ml_module_globals, ml_module_globals_table);