Skip to content

Commit

Permalink
v1.1 C (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
laves authored Sep 20, 2024
1 parent 9fff9d6 commit 10a1c13
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/c-demos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
echo PATH=$PATH >> $GITHUB_ENV
- name: Download resource files
run: curl http://${{secrets.PV_CICD_RES_SERVER_AUTHORITY}}/github/picollm/res/phi2-290.pllm/latest/phi2-290.pllm -o phi2-290.pllm
run: curl http://${{secrets.PV_CICD_RES_SERVER_AUTHORITY}}/github/picollm/res/phi2-290.pllm/03-280e68c/phi2-290.pllm -o phi2-290.pllm

- name: Create build directory
run: cmake -G "${{ matrix.make_file }}" -B ./build
Expand Down
2 changes: 2 additions & 0 deletions demo/c/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ picoLLM Inference Engine supports the following open-weight models. The models a
- `mixtral-8x7b-instruct-v0.1`
- Phi-2
- `phi2`
- Phi-3
- `phi3`

## Usage

Expand Down
67 changes: 44 additions & 23 deletions demo/c/picollm_demo_completion.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#endif

#include <getopt.h>
#include <signal.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
Expand Down Expand Up @@ -82,8 +83,19 @@ static void print_dl_error(const char *message) {
#endif
}

static void (*pv_picollm_interrupt_func)(pv_picollm_t *) = NULL;
static pv_picollm_t *picollm = NULL;
static volatile bool is_interrupt = false;

void interrupt_handler(int _) {
(void) _;
is_interrupt = true;
fprintf(stdout, "\n\nInterrupting generation...\n");
pv_picollm_interrupt_func(picollm);
}

void print_error_message(
char **message_stack,
char **message_stack,
int32_t message_stack_depth,
pv_status_t (*pv_get_error_stack_func)(char ***, int32_t *),
void (*pv_free_error_stack_func)(char **),
Expand Down Expand Up @@ -127,6 +139,7 @@ static const char *pv_picollm_endpoint_to_string(pv_picollm_endpoint_t x) {
"END_OF_SENTENCE",
"COMPLETION_TOKEN_LIMIT_REACHED",
"STOP_PHRASE_ENCOUNTERED",
"INTERRUPTED"
};

return STRINGS[x];
Expand All @@ -137,13 +150,14 @@ struct timeval tic;

static void progress_callback(const char *token, void *context) {
(void) context;

fprintf(stdout, "%s", token);
fflush(stdout);
if (num_tokens == -1) {
gettimeofday(&tic, NULL);
if (!is_interrupt) {
fprintf(stdout, "%s", token);
fflush(stdout);
if (num_tokens == -1) {
gettimeofday(&tic, NULL);
}
num_tokens += 1;
}
num_tokens += 1;
}

int picovoice_main(int argc, char **argv) {
Expand Down Expand Up @@ -283,7 +297,7 @@ int picovoice_main(int argc, char **argv) {
exit(EXIT_FAILURE);
}

pv_status_t (*pv_picollm_init_func)(const char *, const char *, const char *, pv_picollm_t **) =
pv_status_t (*pv_picollm_init_func)(const char *, const char *, const char *, pv_picollm_t **) =
load_symbol(dl_handle, "pv_picollm_init");
if (!pv_picollm_init_func) {
print_dl_error("failed to load `pv_picollm_init`");
Expand Down Expand Up @@ -320,63 +334,69 @@ int picovoice_main(int argc, char **argv) {
exit(EXIT_FAILURE);
}

pv_status_t (*pv_picollm_delete_completion_tokens_func)(pv_picollm_completion_token_t *, int32_t) =
pv_picollm_interrupt_func = load_symbol(dl_handle, "pv_picollm_interrupt");
if (!pv_picollm_interrupt_func) {
print_dl_error("failed to load `pv_picollm_interrupt`");
exit(EXIT_FAILURE);
}

pv_status_t (*pv_picollm_delete_completion_tokens_func)(pv_picollm_completion_token_t *, int32_t) =
load_symbol(dl_handle, "pv_picollm_delete_completion_tokens");
if (!pv_picollm_delete_completion_tokens_func) {
print_dl_error("failed to load `pv_picollm_delete_completion_tokens`");
exit(EXIT_FAILURE);
}

pv_status_t (*pv_picollm_delete_completion_func)(char *) =
pv_status_t (*pv_picollm_delete_completion_func)(char *) =
load_symbol(dl_handle, "pv_picollm_delete_completion");
if (!pv_picollm_delete_completion_func) {
print_dl_error("failed to load `pv_picollm_delete_completion`");
exit(EXIT_FAILURE);
}

pv_status_t (*pv_picollm_context_length_func)(const pv_picollm_t *, int32_t *) =
pv_status_t (*pv_picollm_context_length_func)(const pv_picollm_t *, int32_t *) =
load_symbol(dl_handle, "pv_picollm_context_length");
if (!pv_picollm_context_length_func) {
print_dl_error("failed to load `pv_picollm_context_length`");
exit(EXIT_FAILURE);
}

const char *(*pv_picollm_version_func)(void) =
const char *(*pv_picollm_version_func)(void) =
load_symbol(dl_handle, "pv_picollm_version");
if (!pv_picollm_version_func) {
print_dl_error("failed to load `pv_picollm_version`");
exit(EXIT_FAILURE);
}

pv_status_t (*pv_picollm_model_func)(pv_picollm_t *, char **) =
pv_status_t (*pv_picollm_model_func)(pv_picollm_t *, char **) =
load_symbol(dl_handle, "pv_picollm_model");
if (!pv_picollm_model_func) {
print_dl_error("failed to load `pv_picollm_model`");
exit(EXIT_FAILURE);
}

int32_t (*pv_picollm_max_top_choices_func)(void) =
int32_t (*pv_picollm_max_top_choices_func)(void) =
load_symbol(dl_handle, "pv_picollm_max_top_choices");
if (!pv_picollm_max_top_choices_func) {
print_dl_error("failed to load `pv_picollm_max_top_choices`");
exit(EXIT_FAILURE);
}

pv_status_t (*pv_picollm_list_hardware_devices_func)(char ***, int32_t *) =
pv_status_t (*pv_picollm_list_hardware_devices_func)(char ***, int32_t *) =
load_symbol(dl_handle, "pv_picollm_list_hardware_devices");
if (!pv_picollm_list_hardware_devices_func) {
print_dl_error("failed to load `pv_picollm_list_hardware_devices`");
exit(EXIT_FAILURE);
}

pv_status_t (*pv_picollm_free_hardware_devices_func)(char **, int32_t) =
pv_status_t (*pv_picollm_free_hardware_devices_func)(char **, int32_t) =
load_symbol(dl_handle, "pv_picollm_free_hardware_devices");
if (!pv_picollm_free_hardware_devices_func) {
print_dl_error("failed to load `pv_picollm_free_hardware_devices`");
exit(EXIT_FAILURE);
}

pv_status_t (*pv_get_error_stack_func)(char ***, int32_t *) =
pv_status_t (*pv_get_error_stack_func)(char ***, int32_t *) =
load_symbol(dl_handle, "pv_get_error_stack");
if (!pv_get_error_stack_func) {
print_dl_error("failed to load 'pv_get_error_stack_func'");
Expand Down Expand Up @@ -406,7 +426,7 @@ int picovoice_main(int argc, char **argv) {
"Failed to list hardware devices with `%s`.\n",
pv_status_to_string_func(status));
print_error_message(
message_stack,
message_stack,
message_stack_depth,
pv_get_error_stack_func,
pv_free_error_stack_func,
Expand Down Expand Up @@ -436,7 +456,6 @@ int picovoice_main(int argc, char **argv) {
exit(EXIT_FAILURE);
}

pv_picollm_t *picollm = NULL;
pv_status_t status = pv_picollm_init_func(
access_key,
model_path,
Expand All @@ -448,7 +467,7 @@ int picovoice_main(int argc, char **argv) {
"failed to init with `%s`",
pv_status_to_string_func(status));
print_error_message(
message_stack,
message_stack,
message_stack_depth,
pv_get_error_stack_func,
pv_free_error_stack_func,
Expand All @@ -464,7 +483,7 @@ int picovoice_main(int argc, char **argv) {
"Failed to get context length with `%s`.\n",
pv_status_to_string_func(status));
print_error_message(
message_stack,
message_stack,
message_stack_depth,
pv_get_error_stack_func,
pv_free_error_stack_func,
Expand All @@ -488,15 +507,17 @@ int picovoice_main(int argc, char **argv) {
"Failed to get model with `%s`.\n",
pv_status_to_string_func(status));
print_error_message(
message_stack,
message_stack,
message_stack_depth,
pv_get_error_stack_func,
pv_free_error_stack_func,
pv_status_to_string_func
);
}

signal(SIGINT, interrupt_handler);
fprintf(stdout, "Loaded model: `%s`\n", model);
fprintf(stdout, "Generating... (press Ctrl+C to interrupt)\n");

pv_picollm_usage_t usage;
pv_picollm_endpoint_t endpoint;
Expand Down Expand Up @@ -530,7 +551,7 @@ int picovoice_main(int argc, char **argv) {
"Failed to generate with `%s`.\n",
pv_status_to_string_func(status));
print_error_message(
message_stack,
message_stack,
message_stack_depth,
pv_get_error_stack_func,
pv_free_error_stack_func,
Expand Down

0 comments on commit 10a1c13

Please sign in to comment.