diff --git a/example/main.c b/example/main.c
index 3f2cb7d..32d17c2 100644
--- a/example/main.c
+++ b/example/main.c
@@ -6,8 +6,8 @@
 
 #define COMMAND_TIMEOUT 5000
 
-static void message_arrived(lwmqtt_client_t *c, lwmqtt_string_t *t, lwmqtt_message_t *m) {
-  printf("message_arrived: %.*s => %.*s\n", t->len, t->data, m->payload_len, (char*)m->payload);
+static void message_arrived(lwmqtt_client_t *c, void *ref, lwmqtt_string_t *t, lwmqtt_message_t *m) {
+  printf("message_arrived: %.*s => %.*s\n", t->len, t->data, m->payload_len, (char *)m->payload);
 }
 
 int main() {
@@ -22,7 +22,7 @@ int main() {
 
   lwmqtt_set_network(&client, &network, lwmqtt_unix_network_read, lwmqtt_unix_network_write);
   lwmqtt_set_timers(&client, &timer1, &timer2, lwmqtt_unix_timer_set, lwmqtt_unix_timer_get);
-  lwmqtt_set_callback(&client, message_arrived);
+  lwmqtt_set_callback(&client, NULL, message_arrived);
 
   lwmqtt_err_t err = lwmqtt_unix_network_connect(&network, "broker.shiftr.io", 1883);
   if (err != LWMQTT_SUCCESS) {
diff --git a/include/lwmqtt.h b/include/lwmqtt.h
index 60a6f5a..651da1a 100644
--- a/include/lwmqtt.h
+++ b/include/lwmqtt.h
@@ -103,7 +103,7 @@ typedef unsigned int (*lwmqtt_timer_get_t)(lwmqtt_client_t *c, void *ref);
 /**
  * The callback used to forward incoming messages.
  */
-typedef void (*lwmqtt_callback_t)(lwmqtt_client_t *, lwmqtt_string_t *, lwmqtt_message_t *);
+typedef void (*lwmqtt_callback_t)(lwmqtt_client_t *, void *ref, lwmqtt_string_t *, lwmqtt_message_t *);
 
 /**
  * The client object.
@@ -116,6 +116,7 @@ struct lwmqtt_client_t {
   int write_buf_size, read_buf_size;
   unsigned char *write_buf, *read_buf;
 
+  void *callback_ref;
   lwmqtt_callback_t callback;
 
   void *network;
@@ -166,9 +167,10 @@ void lwmqtt_set_timers(lwmqtt_client_t *client, void *keep_alive_timer, void *ne
  * Will set the callback used to receive incoming messages.
  *
  * @param client - The client object.
+ * @param ref - A custom reference that will passed to the callback.
  * @param cb - The callback to be called.
  */
-void lwmqtt_set_callback(lwmqtt_client_t *client, lwmqtt_callback_t cb);
+void lwmqtt_set_callback(lwmqtt_client_t *client, void *ref, lwmqtt_callback_t cb);
 
 /**
  * The object defining the last will of a client.
diff --git a/src/client.c b/src/client.c
index 1497d63..02e7f11 100644
--- a/src/client.c
+++ b/src/client.c
@@ -43,7 +43,10 @@ void lwmqtt_set_timers(lwmqtt_client_t *client, void *keep_alive_timer, void *ne
   client->timer_set(client, client->command_timer, 0);
 }
 
-void lwmqtt_set_callback(lwmqtt_client_t *client, lwmqtt_callback_t cb) { client->callback = cb; }
+void lwmqtt_set_callback(lwmqtt_client_t *client, void *ref, lwmqtt_callback_t cb) {
+  client->callback_ref = ref;
+  client->callback = cb;
+}
 
 static unsigned short lwmqtt_get_next_packet_id(lwmqtt_client_t *c) {
   return c->next_packet_id = (unsigned short)((c->next_packet_id == 65535) ? 1 : c->next_packet_id + 1);
@@ -154,7 +157,7 @@ static lwmqtt_err_t lwmqtt_cycle(lwmqtt_client_t *c, int *read, lwmqtt_packet_ty
 
       // call callback if set
       if (c->callback != NULL) {
-        c->callback(c, &topic, &msg);
+        c->callback(c, c->callback_ref, &topic, &msg);
       }
 
       // break early of qos zero
diff --git a/tests/client.cpp b/tests/client.cpp
index a4b0ec2..b760b8d 100644
--- a/tests/client.cpp
+++ b/tests/client.cpp
@@ -12,7 +12,11 @@ char payload[PAYLOAD_LEN + 1];
 
 volatile int counter;
 
-static void message_arrived(lwmqtt_client_t *c, lwmqtt_string_t *t, lwmqtt_message_t *m) {
+const char *custom_ref = "cool";
+
+static void message_arrived(lwmqtt_client_t *c, void *ref, lwmqtt_string_t *t, lwmqtt_message_t *m) {
+  ASSERT_EQ(ref, custom_ref);
+
   int res = lwmqtt_strcmp(t, (char *)"lwmqtt");
   ASSERT_EQ(res, 0);
 
@@ -34,7 +38,7 @@ TEST(Client, PublishSubscribeQOS0) {
 
   lwmqtt_set_network(&client, &network, lwmqtt_unix_network_read, lwmqtt_unix_network_write);
   lwmqtt_set_timers(&client, &timer1, &timer2, lwmqtt_unix_timer_set, lwmqtt_unix_timer_get);
-  lwmqtt_set_callback(&client, message_arrived);
+  lwmqtt_set_callback(&client, (void *)custom_ref, message_arrived);
 
   lwmqtt_err_t err = lwmqtt_unix_network_connect(&network, (char *)"broker.shiftr.io", 1883);
   ASSERT_EQ(err, LWMQTT_SUCCESS);
@@ -95,7 +99,7 @@ TEST(Client, PublishSubscribeQOS1) {
 
   lwmqtt_set_network(&client, &network, lwmqtt_unix_network_read, lwmqtt_unix_network_write);
   lwmqtt_set_timers(&client, &timer1, &timer2, lwmqtt_unix_timer_set, lwmqtt_unix_timer_get);
-  lwmqtt_set_callback(&client, message_arrived);
+  lwmqtt_set_callback(&client, (void *)custom_ref, message_arrived);
 
   lwmqtt_err_t err = lwmqtt_unix_network_connect(&network, (char *)"broker.shiftr.io", 1883);
   ASSERT_EQ(err, LWMQTT_SUCCESS);
@@ -156,7 +160,7 @@ TEST(Client, PublishSubscribeQOS2) {
 
   lwmqtt_set_network(&client, &network, lwmqtt_unix_network_read, lwmqtt_unix_network_write);
   lwmqtt_set_timers(&client, &timer1, &timer2, lwmqtt_unix_timer_set, lwmqtt_unix_timer_get);
-  lwmqtt_set_callback(&client, message_arrived);
+  lwmqtt_set_callback(&client, (void *)custom_ref, message_arrived);
 
   lwmqtt_err_t err = lwmqtt_unix_network_connect(&network, (char *)"broker.shiftr.io", 1883);
   ASSERT_EQ(err, LWMQTT_SUCCESS);