Skip to content

Commit

Permalink
Add support for SCRAM-*-PLUS SASL mechanisms
Browse files Browse the repository at this point in the history
This fixes #133

Signed-off-by: Steffen Jaeckel <[email protected]>
  • Loading branch information
sjaeckel committed Nov 8, 2023
1 parent c5b6026 commit 1abb4b5
Show file tree
Hide file tree
Showing 11 changed files with 309 additions and 43 deletions.
129 changes: 106 additions & 23 deletions src/auth.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ static int _handle_digestmd5_rspauth(xmpp_conn_t *conn,
static int _handle_scram_challenge(xmpp_conn_t *conn,
xmpp_stanza_t *stanza,
void *userdata);
static char *_make_scram_init_msg(xmpp_conn_t *conn);
struct scram_user_data;
static int _make_scram_init_msg(struct scram_user_data *scram);

static int _handle_missing_features_sasl(xmpp_conn_t *conn, void *userdata);
static int _handle_missing_bind(xmpp_conn_t *conn, void *userdata);
Expand Down Expand Up @@ -250,8 +251,12 @@ _handle_features(xmpp_conn_t *conn, xmpp_stanza_t *stanza, void *userdata)
conn->sasl_support |= SASL_MASK_EXTERNAL;
else if (strcasecmp(text, "DIGEST-MD5") == 0)
conn->sasl_support |= SASL_MASK_DIGESTMD5;
else if (strcasecmp(text, "SCRAM-SHA-1-PLUS") == 0)
conn->sasl_support |= SASL_MASK_SCRAMSHA1_PLUS;
else if (strcasecmp(text, "SCRAM-SHA-1") == 0)
conn->sasl_support |= SASL_MASK_SCRAMSHA1;
else if (strcasecmp(text, "SCRAM-SHA-256-PLUS") == 0)
conn->sasl_support |= SASL_MASK_SCRAMSHA256_PLUS;
else if (strcasecmp(text, "SCRAM-SHA-256") == 0)
conn->sasl_support |= SASL_MASK_SCRAMSHA256;
else if (strcasecmp(text, "SCRAM-SHA-512") == 0)
Expand Down Expand Up @@ -439,7 +444,11 @@ static int _handle_digestmd5_rspauth(xmpp_conn_t *conn,
}

struct scram_user_data {
xmpp_conn_t *conn;
int sasl_plus;
char *scram_init;
char *channel_binding;
const char *first_bare;
const struct hash_alg *alg;
};

Expand Down Expand Up @@ -471,8 +480,9 @@ static int _handle_scram_challenge(xmpp_conn_t *conn,
if (!challenge)
goto err;

response = sasl_scram(conn->ctx, scram_ctx->alg, challenge,
scram_ctx->scram_init, conn->jid, conn->pass);
response =
sasl_scram(conn->ctx, scram_ctx->alg, scram_ctx->channel_binding,
challenge, scram_ctx->first_bare, conn->jid, conn->pass);
strophe_free(conn->ctx, challenge);
if (!response)
goto err;
Expand Down Expand Up @@ -506,7 +516,8 @@ static int _handle_scram_challenge(xmpp_conn_t *conn,
*/
rc = _handle_sasl_result(conn, stanza,
(void *)scram_ctx->alg->scram_name);
strophe_free(conn->ctx, scram_ctx->scram_init);
strophe_free_and_null(conn->ctx, scram_ctx->channel_binding);
strophe_free_and_null(conn->ctx, scram_ctx->scram_init);
strophe_free(conn->ctx, scram_ctx);
}

Expand All @@ -517,33 +528,97 @@ static int _handle_scram_challenge(xmpp_conn_t *conn,
err_free_response:
strophe_free(conn->ctx, response);
err:
strophe_free(conn->ctx, scram_ctx->scram_init);
strophe_free_and_null(conn->ctx, scram_ctx->channel_binding);
strophe_free_and_null(conn->ctx, scram_ctx->scram_init);
strophe_free(conn->ctx, scram_ctx);
disconnect_mem_error(conn);
return 0;
}

static char *_make_scram_init_msg(xmpp_conn_t *conn)
static int _make_scram_init_msg(struct scram_user_data *scram)
{
xmpp_conn_t *conn = scram->conn;
xmpp_ctx_t *ctx = conn->ctx;
size_t message_len;
char *node;
char *message;
char nonce[32];
const void *binding_data;
char *node, *message, *binding_type;
size_t message_len, binding_type_len = 0, binding_data_len;
int l, is_secured = xmpp_conn_is_secured(conn);
char buf[64];

if (scram->sasl_plus) {
if (!is_secured) {
strophe_error(
ctx, "xmpp",
"SASL: Server requested a -PLUS variant to authenticate, "
"but the connection is not secured. This is an error on "
"the server side we can't do anything about.");
return -1;
}
if (tls_init_channel_binding(conn->tls, &binding_type,
&binding_type_len)) {
return -1;
}
/* directly account for the '=' char in 'p=<binding-type>' */
binding_type_len += 1;
}

node = xmpp_jid_node(ctx, conn->jid);
if (!node) {
return NULL;
return -1;
}
xmpp_rand_nonce(ctx->rand, nonce, sizeof(nonce));
message_len = strlen(node) + strlen(nonce) + 8 + 1;
xmpp_rand_nonce(ctx->rand, buf, sizeof(buf));
message_len = strlen(node) + strlen(buf) + 8 + binding_type_len + 1;
message = strophe_alloc(ctx, message_len);
if (message) {
strophe_snprintf(message, message_len, "n,,n=%s,r=%s", node, nonce);
/* increase length to account for 'y,,', 'n,,' or 'p,,'.
* In the 'p' case the '=' sign has already been accounted for above.
*/
binding_type_len += 3;
if (scram->sasl_plus) {
l = strophe_snprintf(message, message_len, "p=%s,,n=%s,r=%s",
binding_type, node, buf);
} else {
l = strophe_snprintf(message, message_len, "%c,,n=%s,r=%s",
is_secured ? 'y' : 'n', node, buf);
}
if (l < 0 || (size_t)l >= message_len) {
goto err_out;
} else {
/* Make `first_bare` point to the 'n' of the client-first-message */
scram->first_bare = message + binding_type_len;
memcpy(buf, message, binding_type_len);
if (scram->sasl_plus) {
binding_data =
tls_get_channel_binding_data(conn->tls, &binding_data_len);
if (!binding_data) {
goto err_out;
}
if (binding_data_len > sizeof(buf) - binding_type_len) {
strophe_error(ctx, "xmpp",
"Channel binding data len is too long (%zu)",
binding_data_len);
goto err_out;
}
memcpy(&buf[binding_type_len], binding_data, binding_data_len);
binding_type_len += binding_data_len;
}
if (scram->channel_binding)
strophe_free(ctx, scram->channel_binding);
scram->channel_binding =
xmpp_base64_encode(ctx, (void *)buf, binding_type_len);
memset(buf, 0, binding_type_len);
}
}
strophe_free(ctx, node);
scram->scram_init = message;

return message;
return message == NULL ? -1 : 0;
err_out:
strophe_free(ctx, node);
strophe_free(ctx, message);
scram->first_bare = NULL;
scram->scram_init = NULL;
return -1;
}

static xmpp_stanza_t *_make_starttls(xmpp_conn_t *conn)
Expand Down Expand Up @@ -636,7 +711,7 @@ static void _auth(xmpp_conn_t *conn)
return;
}

if (anonjid && conn->sasl_support & SASL_MASK_ANONYMOUS) {
if (anonjid && (conn->sasl_support & SASL_MASK_ANONYMOUS)) {
/* some crap here */
auth = _make_sasl_auth(conn, "ANONYMOUS");
if (!auth) {
Expand Down Expand Up @@ -703,21 +778,29 @@ static void _auth(xmpp_conn_t *conn)
xmpp_disconnect(conn);
} else if (conn->sasl_support & SASL_MASK_SCRAM) {
scram_ctx = strophe_alloc(conn->ctx, sizeof(*scram_ctx));
if (conn->sasl_support & SASL_MASK_SCRAMSHA512)
memset(scram_ctx, 0, sizeof(*scram_ctx));
if (conn->sasl_support & SASL_MASK_SCRAMSHA256_PLUS) {
scram_ctx->alg = &scram_sha256_plus;
} else if (conn->sasl_support & SASL_MASK_SCRAMSHA1_PLUS) {
scram_ctx->alg = &scram_sha1_plus;
} else if (conn->sasl_support & SASL_MASK_SCRAMSHA512) {
scram_ctx->alg = &scram_sha512;
else if (conn->sasl_support & SASL_MASK_SCRAMSHA256)
} else if (conn->sasl_support & SASL_MASK_SCRAMSHA256) {
scram_ctx->alg = &scram_sha256;
else if (conn->sasl_support & SASL_MASK_SCRAMSHA1)
} else if (conn->sasl_support & SASL_MASK_SCRAMSHA1) {
scram_ctx->alg = &scram_sha1;
}

auth = _make_sasl_auth(conn, scram_ctx->alg->scram_name);
if (!auth) {
disconnect_mem_error(conn);
return;
}

/* don't free scram_init on success */
scram_ctx->scram_init = _make_scram_init_msg(conn);
if (!scram_ctx->scram_init) {
scram_ctx->conn = conn;
scram_ctx->sasl_plus =
scram_ctx->alg->mask & SASL_MASK_SCRAM_PLUS ? 1 : 0;
if (_make_scram_init_msg(scram_ctx)) {
strophe_free(conn->ctx, scram_ctx);
xmpp_stanza_release(auth);
disconnect_mem_error(conn);
Expand Down Expand Up @@ -753,7 +836,7 @@ static void _auth(xmpp_conn_t *conn)

send_stanza(conn, auth, XMPP_QUEUE_STROPHE);

/* SASL SCRAM-SHA-1 was tried, unset flag */
/* SASL algorithm was tried, unset flag */
conn->sasl_support &= ~scram_ctx->alg->mask;
} else if (conn->sasl_support & SASL_MASK_DIGESTMD5) {
auth = _make_sasl_auth(conn, "DIGEST-MD5");
Expand Down
7 changes: 6 additions & 1 deletion src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,14 @@ struct _xmpp_send_queue_t {
#define SASL_MASK_SCRAMSHA256 (1 << 4)
#define SASL_MASK_SCRAMSHA512 (1 << 5)
#define SASL_MASK_EXTERNAL (1 << 6)
#define SASL_MASK_SCRAMSHA1_PLUS (1 << 7)
#define SASL_MASK_SCRAMSHA256_PLUS (1 << 8)

#define SASL_MASK_SCRAM \
#define SASL_MASK_SCRAM_PLUS \
(SASL_MASK_SCRAMSHA1_PLUS | SASL_MASK_SCRAMSHA256_PLUS)
#define SASL_MASK_SCRAM_WEAK \
(SASL_MASK_SCRAMSHA1 | SASL_MASK_SCRAMSHA256 | SASL_MASK_SCRAMSHA512)
#define SASL_MASK_SCRAM (SASL_MASK_SCRAM_PLUS | SASL_MASK_SCRAM_WEAK)

enum {
XMPP_PORT_CLIENT = 5222,
Expand Down
43 changes: 26 additions & 17 deletions src/sasl.c
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ char *sasl_digest_md5(xmpp_ctx_t *ctx,
/** generate auth response string for the SASL SCRAM mechanism */
char *sasl_scram(xmpp_ctx_t *ctx,
const struct hash_alg *alg,
const char *channel_binding,
const char *challenge,
const char *first_bare,
const char *jid,
Expand All @@ -398,6 +399,7 @@ char *sasl_scram(xmpp_ctx_t *ctx,
char *result = NULL;
size_t response_len;
size_t auth_len;
int l;

UNUSED(jid);

Expand Down Expand Up @@ -428,37 +430,44 @@ char *sasl_scram(xmpp_ctx_t *ctx,
}
ival = strtol(i, &saveptr, 10);

auth_len = 10 + strlen(r) + strlen(first_bare) + strlen(challenge);
auth = strophe_alloc(ctx, auth_len);
if (!auth) {
/* "c=<channel_binding>," + r + ",p=" + sign_b64 + '\0' */
response_len = 3 + strlen(channel_binding) + strlen(r) + 3 +
((alg->digest_size + 2) / 3 * 4) + 1;
response = strophe_alloc(ctx, response_len);
if (!response) {
goto out_sval;
}

/* "c=biws," + r + ",p=" + sign_b64 + '\0' */
response_len = 7 + strlen(r) + 3 + ((alg->digest_size + 2) / 3 * 4) + 1;
response = strophe_alloc(ctx, response_len);
if (!response) {
goto out_auth;
auth_len = 3 + response_len + strlen(first_bare) + strlen(challenge);
auth = strophe_alloc(ctx, auth_len);
if (!auth) {
goto out_response;
}

strophe_snprintf(response, response_len, "c=biws,%s", r);
strophe_snprintf(auth, auth_len, "%s,%s,%s", first_bare + 3, challenge,
response);
l = strophe_snprintf(response, response_len, "c=%s,%s", channel_binding, r);
if (l < 0 || (size_t)l >= response_len) {
goto out_response;
}
l = strophe_snprintf(auth, auth_len, "%s,%s,%s", first_bare, challenge,
response);
if (l < 0 || (size_t)l >= auth_len) {
goto out_response;
}

SCRAM_ClientKey(alg, (uint8_t *)password, strlen(password), (uint8_t *)sval,
sval_len, (uint32_t)ival, key);
SCRAM_ClientSignature(alg, key, (uint8_t *)auth, strlen(auth), sign);
SCRAM_ClientProof(alg, sign, key, sign);
SCRAM_ClientProof(alg, key, sign, sign);

sign_b64 = xmpp_base64_encode(ctx, sign, alg->digest_size);
if (!sign_b64) {
goto out_response;
goto out_auth;
}

/* Check for buffer overflow */
if (strlen(response) + strlen(sign_b64) + 3 + 1 > response_len) {
strophe_free(ctx, sign_b64);
goto out_response;
goto out_auth;
}
strcat(response, ",p=");
strcat(response, sign_b64);
Expand All @@ -467,14 +476,14 @@ char *sasl_scram(xmpp_ctx_t *ctx,
response_b64 =
xmpp_base64_encode(ctx, (unsigned char *)response, strlen(response));
if (!response_b64) {
goto out_response;
goto out_auth;
}
result = response_b64;

out_response:
strophe_free(ctx, response);
out_auth:
strophe_free(ctx, auth);
out_response:
strophe_free(ctx, response);
out_sval:
strophe_free(ctx, sval);
out:
Expand Down
1 change: 1 addition & 0 deletions src/sasl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ char *sasl_digest_md5(xmpp_ctx_t *ctx,
const char *password);
char *sasl_scram(xmpp_ctx_t *ctx,
const struct hash_alg *alg,
const char *channel_binding,
const char *challenge,
const char *first_bare,
const char *jid,
Expand Down
18 changes: 18 additions & 0 deletions src/scram.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ const struct hash_alg scram_sha1 = {
(void (*)(void *, const uint8_t *, size_t))crypto_SHA1_Update,
(void (*)(void *, uint8_t *))crypto_SHA1_Final};

const struct hash_alg scram_sha1_plus = {
"SCRAM-SHA-1-PLUS",
SASL_MASK_SCRAMSHA1_PLUS,
SHA1_DIGEST_SIZE,
(void (*)(const uint8_t *, size_t, uint8_t *))crypto_SHA1,
(void (*)(void *))crypto_SHA1_Init,
(void (*)(void *, const uint8_t *, size_t))crypto_SHA1_Update,
(void (*)(void *, uint8_t *))crypto_SHA1_Final};

const struct hash_alg scram_sha256 = {
"SCRAM-SHA-256",
SASL_MASK_SCRAMSHA256,
Expand All @@ -51,6 +60,15 @@ const struct hash_alg scram_sha256 = {
(void (*)(void *, const uint8_t *, size_t))sha256_process,
(void (*)(void *, uint8_t *))sha256_done};

const struct hash_alg scram_sha256_plus = {
"SCRAM-SHA-256-PLUS",
SASL_MASK_SCRAMSHA256_PLUS,
SHA256_DIGEST_SIZE,
(void (*)(const uint8_t *, size_t, uint8_t *))sha256_hash,
(void (*)(void *))sha256_init,
(void (*)(void *, const uint8_t *, size_t))sha256_process,
(void (*)(void *, uint8_t *))sha256_done};

const struct hash_alg scram_sha512 = {
"SCRAM-SHA-512",
SASL_MASK_SCRAMSHA512,
Expand Down
2 changes: 2 additions & 0 deletions src/scram.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ struct hash_alg {
};

extern const struct hash_alg scram_sha1;
extern const struct hash_alg scram_sha1_plus;
extern const struct hash_alg scram_sha256;
extern const struct hash_alg scram_sha256_plus;
extern const struct hash_alg scram_sha512;

void SCRAM_ClientKey(const struct hash_alg *alg,
Expand Down
4 changes: 4 additions & 0 deletions src/tls.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ unsigned int tls_id_on_xmppaddr_num(xmpp_conn_t *conn);

xmpp_tlscert_t *tls_peer_cert(xmpp_conn_t *conn);
int tls_set_credentials(tls_t *tls, const char *cafilename);
int tls_init_channel_binding(tls_t *tls,
char **binding_prefix,
size_t *binding_prefix_len);
const void *tls_get_channel_binding_data(tls_t *tls, size_t *size);

int tls_start(tls_t *tls);
int tls_stop(tls_t *tls);
Expand Down
Loading

0 comments on commit 1abb4b5

Please sign in to comment.