Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: disconnect client on authZ failure #116

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,59 @@
*/
package io.moquette.broker;

import io.moquette.broker.subscriptions.Topic;
import io.moquette.broker.security.IAuthenticator;
import io.moquette.broker.security.PemUtils;
import io.moquette.broker.subscriptions.Topic;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.mqtt.*;
import io.netty.handler.codec.mqtt.MqttConnAckMessage;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttConnectPayload;
import io.netty.handler.codec.mqtt.MqttConnectReturnCode;
import io.netty.handler.codec.mqtt.MqttFixedHeader;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttMessageBuilders;
import io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader;
import io.netty.handler.codec.mqtt.MqttMessageType;
import io.netty.handler.codec.mqtt.MqttPubAckMessage;
import io.netty.handler.codec.mqtt.MqttPublishMessage;
import io.netty.handler.codec.mqtt.MqttPublishVariableHeader;
import io.netty.handler.codec.mqtt.MqttQoS;
import io.netty.handler.codec.mqtt.MqttSubAckMessage;
import io.netty.handler.codec.mqtt.MqttSubscribeMessage;
import io.netty.handler.codec.mqtt.MqttUnsubAckMessage;
import io.netty.handler.codec.mqtt.MqttUnsubscribeMessage;
import io.netty.handler.codec.mqtt.MqttVersion;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.IdleStateHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import javax.net.ssl.SSLPeerUnverifiedException;
import java.net.InetSocketAddress;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.util.*;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLPeerUnverifiedException;

import static io.netty.channel.ChannelFutureListener.CLOSE_ON_FAILURE;
import static io.netty.channel.ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE;
import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.*;
import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.CONNECTION_ACCEPTED;
import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD;
import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.CONNECTION_REFUSED_IDENTIFIER_REJECTED;
import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.CONNECTION_REFUSED_SERVER_UNAVAILABLE;
import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.CONNECTION_REFUSED_UNACCEPTABLE_PROTOCOL_VERSION;
import static io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader.from;
import static io.netty.handler.codec.mqtt.MqttQoS.*;
import static io.netty.handler.codec.mqtt.MqttQoS.AT_LEAST_ONCE;
import static io.netty.handler.codec.mqtt.MqttQoS.AT_MOST_ONCE;

final class MQTTConnection {

Expand Down Expand Up @@ -454,7 +477,7 @@ PostOffice.RouteResult processPublish(MqttPublishMessage msg) {
return postOffice.routeCommand(clientId, "PUB QoS0", () -> {
if (!isBoundToSession())
return null;
postOffice.receivedPublishQos0(topic, username, clientId, msg);
postOffice.receivedPublishQos0(this, topic, username, clientId, msg);
return null;
}).ifFailed(msg::release);
case AT_LEAST_ONCE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,12 @@ public void unsubscribe(List<String> topics, MQTTConnection mqttConnection, int
mqttConnection.sendUnsubAckMessage(topics, clientID, messageId);
}

CompletableFuture<Void> receivedPublishQos0(Topic topic, String username, String clientID, MqttPublishMessage msg) {
CompletableFuture<Void> receivedPublishQos0(MQTTConnection connection, Topic topic, String username, String clientID,
MqttPublishMessage msg) {
if (!authorizator.canWrite(topic, username, clientID)) {
LOG.error("client is not authorized to publish on topic: {}", topic);
ReferenceCountUtil.release(msg);
connection.dropConnection();
return CompletableFuture.completedFuture(null);
}
final RoutingResults publishResult = publish2Subscribers(msg.payload(), topic, AT_MOST_ONCE);
Expand Down Expand Up @@ -352,6 +354,7 @@ RoutingResults receivedPublishQos1(MQTTConnection connection, Topic topic, Strin
final String clientId = connection.getClientId();
if (!authorizator.canWrite(topic, username, clientId)) {
LOG.error("MQTT client: {} is not authorized to publish on topic: {}", clientId, topic);
connection.dropConnection();
jcosentino11 marked this conversation as resolved.
Show resolved Hide resolved
ReferenceCountUtil.release(msg);
return RoutingResults.preroutingError();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
*/
package io.moquette.broker;

import io.moquette.broker.security.PermitAllAuthorizatorPolicy;
import io.moquette.broker.security.IAuthenticator;
import io.moquette.broker.security.IAuthorizatorPolicy;
import io.moquette.broker.subscriptions.CTrieSubscriptionDirectory;
import io.moquette.broker.subscriptions.ISubscriptionsDirectory;
import io.moquette.broker.security.IAuthenticator;
import io.moquette.broker.subscriptions.Topic;
import io.moquette.persistence.MemorySubscriptionsRepository;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
Expand Down Expand Up @@ -74,8 +75,17 @@ private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel
subscriptions.init(subscriptionsRepository);
queueRepository = new MemoryQueueRepository();

final PermitAllAuthorizatorPolicy authorizatorPolicy = new PermitAllAuthorizatorPolicy();
final Authorizator permitAll = new Authorizator(authorizatorPolicy);
final Authorizator permitAll = new Authorizator(new IAuthorizatorPolicy() {
@Override
public boolean canWrite(Topic topic, String user, String client) {
return false;
}

@Override
public boolean canRead(Topic topic, String user, String client) {
return false;
}
});
sessionRegistry = new SessionRegistry(subscriptions, queueRepository, permitAll);
final PostOffice postOffice = new PostOffice(subscriptions,
new MemoryRetainedRepository(), sessionRegistry, ConnectionTestUtils.NO_OBSERVERS_INTERCEPTOR, permitAll, 1024);
Expand All @@ -99,4 +109,24 @@ public void dropConnectionOnPublishWithInvalidTopicFormat() throws ExecutionExce
payload.release();
}

@Test
public void dropConnectionOnPublishWithUnauthorized() throws ExecutionException, InterruptedException {
// Connect message with clean session set to true and client id is null.
final ByteBuf payload = Unpooled.copiedBuffer("Hello MQTT world!".getBytes(UTF_8));
MqttPublishMessage publish = MqttMessageBuilders.publish()
.topicName("abc")
.retained(false)
.qos(MqttQoS.AT_MOST_ONCE)
.payload(payload).build();

Session sess = sessionRegistry.createOrReopenSession(connMsg.build(), sut.getClientId(), null).session;
sut.bindSession(sess);
sess.bind(sut);
sut.processPublish(publish).completableFuture().get();

// Verify
assertFalse(channel.isOpen(), "Connection should be closed by the broker");
payload.release();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public void testPublishQoS0ToItself() throws ExecutionException, InterruptedExce

// Exercise
final ByteBuf payload = Unpooled.copiedBuffer("Hello world!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, FAKE_CLIENT_ID,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, FAKE_CLIENT_ID,
MqttMessageBuilders.publish()
.payload(payload.retainedDuplicate())
.qos(MqttQoS.AT_MOST_ONCE)
Expand Down Expand Up @@ -220,7 +220,7 @@ public void testPublishToMultipleSubscribers() throws ExecutionException, Interr

// Exercise
final ByteBuf payload = Unpooled.copiedBuffer("Hello world!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, FAKE_CLIENT_ID,
sut.receivedPublishQos0(connection1, new Topic(NEWS_TOPIC), TEST_USER, FAKE_CLIENT_ID,
MqttMessageBuilders.publish()
.payload(payload.retainedDuplicate())
.qos(MqttQoS.AT_MOST_ONCE)
Expand All @@ -247,7 +247,7 @@ public void testPublishWithEmptyPayloadClearRetainedStore() throws ExecutionExce

// Exercise
final ByteBuf anyPayload = Unpooled.copiedBuffer("Any payload", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, FAKE_CLIENT_ID,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, FAKE_CLIENT_ID,
MqttMessageBuilders.publish()
.payload(anyPayload)
.qos(MqttQoS.AT_MOST_ONCE)
Expand Down Expand Up @@ -426,7 +426,7 @@ public void cleanRetainedMessageStoreWhenPublishWithRetainedQos0IsReceived() thr
// publish a QoS0 retained message
// Exercise
final ByteBuf qos0Payload = Unpooled.copiedBuffer("QoS0 payload", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, connection.getClientId(),
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, connection.getClientId(),
MqttMessageBuilders.publish()
.payload(qos0Payload)
.qos(MqttQoS.AT_MOST_ONCE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ public void testCleanSession_maintainClientSubscriptions() throws ExecutionExcep
assertEquals(1, subscriptions.size(), "After a reconnect, subscription MUST be still present");

final ByteBuf payload = Unpooled.copiedBuffer("Hello world!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
MqttMessageBuilders.publish()
.payload(payload.retainedDuplicate())
.qos(MqttQoS.AT_MOST_ONCE)
Expand Down Expand Up @@ -288,7 +288,7 @@ public void testCleanSession_correctlyClientSubscriptions() throws ExecutionExce

// publish on /news
final ByteBuf payload = Unpooled.copiedBuffer("Hello world!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
MqttMessageBuilders.publish()
.payload(payload)
.qos(MqttQoS.AT_MOST_ONCE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.netty.channel.Channel;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.mqtt.*;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -142,7 +141,7 @@ public void testDontNotifyClientSubscribedToTopicAfterDisconnectedAndReconnectOn

// publish on /news
final ByteBuf payload = Unpooled.copiedBuffer("Hello world!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
MqttMessageBuilders.publish()
.payload(payload.retainedDuplicate())
.qos(MqttQoS.AT_MOST_ONCE)
Expand All @@ -155,7 +154,7 @@ public void testDontNotifyClientSubscribedToTopicAfterDisconnectedAndReconnectOn

// publish on /news
final ByteBuf payload2 = Unpooled.copiedBuffer("Hello world!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
MqttMessageBuilders.publish()
.payload(payload2)
.qos(MqttQoS.AT_MOST_ONCE)
Expand All @@ -180,7 +179,7 @@ public void testDontNotifyClientSubscribedToTopicAfterDisconnectedAndReconnectOn
subscribe(connection, NEWS_TOPIC, AT_MOST_ONCE);
// publish on /news
final ByteBuf payload = Unpooled.copiedBuffer("Hello world!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
MqttMessageBuilders.publish()
.payload(payload.retainedDuplicate())
.qos(MqttQoS.AT_MOST_ONCE)
Expand All @@ -200,7 +199,7 @@ public void testDontNotifyClientSubscribedToTopicAfterDisconnectedAndReconnectOn

// publish on /news
final ByteBuf payload2 = Unpooled.copiedBuffer("Hello world!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
MqttMessageBuilders.publish()
.payload(payload2)
.qos(MqttQoS.AT_MOST_ONCE)
Expand Down Expand Up @@ -278,7 +277,7 @@ public void testConnectSubPub_cycle_getTimeout_on_second_disconnect_issue142() t
subscribe(connection, NEWS_TOPIC, AT_MOST_ONCE);
// publish on /news
final ByteBuf payload = Unpooled.copiedBuffer("Hello world!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
MqttMessageBuilders.publish()
.payload(payload.retainedDuplicate())
.qos(MqttQoS.AT_MOST_ONCE)
Expand All @@ -298,7 +297,7 @@ public void testConnectSubPub_cycle_getTimeout_on_second_disconnect_issue142() t
subscribe(subscriberConnection, NEWS_TOPIC, AT_MOST_ONCE);
// publish on /news
final ByteBuf payload2 = Unpooled.copiedBuffer("Hello world2!", Charset.defaultCharset());
sut.receivedPublishQos0(new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
sut.receivedPublishQos0(connection, new Topic(NEWS_TOPIC), TEST_USER, TEST_PWD,
MqttMessageBuilders.publish()
.payload(payload2.retainedDuplicate())
.qos(MqttQoS.AT_MOST_ONCE)
Expand Down