Skip to content

Commit

Permalink
Add error handing (#2765)
Browse files Browse the repository at this point in the history
* add handling for larger Float and Decimal numbers
* add tracing details to the error response
* handle invalid data type
* add data type check parsing
  • Loading branch information
shtirlets authored Apr 25, 2024
1 parent f2c2be1 commit 7ce143d
Show file tree
Hide file tree
Showing 25 changed files with 690 additions and 283 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ protected SessionHandler aggregateResult(SessionHandler aggregate, SessionHandle
}
else
{
throw new RuntimeException("Conflicting handlers for query");
throw new PostgresServerException("Conflicting handlers for query");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.finos.legend.engine.postgres.handler.PostgresResultSetMetaData;
import org.finos.legend.engine.postgres.types.PGType;
import org.finos.legend.engine.postgres.types.PGTypes;
import org.finos.legend.engine.postgres.utils.ErrorMessageFormatter;
import org.finos.legend.engine.postgres.utils.OpenTelemetryUtil;
import org.slf4j.Logger;

Expand All @@ -58,17 +59,20 @@
public class Messages
{

private Messages()
{
}

//private static final Logger LOGGER = LogManager.getLogger(Messages.class);
private static final Logger LOGGER = org.slf4j.LoggerFactory.getLogger(Messages.class);

private static final byte[] METHOD_NAME_CLIENT_AUTH = "ClientAuthentication".getBytes(
StandardCharsets.UTF_8);

public static ChannelFuture sendAuthenticationOK(Channel channel)
private final ErrorMessageFormatter errorMessageFormatter;

public Messages(ErrorMessageFormatter errorMessageFormatter)
{
this.errorMessageFormatter = errorMessageFormatter;
}


public ChannelFuture sendAuthenticationOK(Channel channel)
{
ByteBuf buffer = channel.alloc().buffer(9);
buffer.writeByte('R');
Expand All @@ -90,7 +94,7 @@ public static ChannelFuture sendAuthenticationOK(Channel channel)
* @param rowCount : number of rows in the result set or number of rows affected by the DML
* statement
*/
static ChannelFuture sendCommandComplete(Channel channel, String query, long rowCount)
ChannelFuture sendCommandComplete(Channel channel, String query, long rowCount)
{
query = query.trim().split(" ", 2)[0].toUpperCase(Locale.ENGLISH);
String commandTag;
Expand Down Expand Up @@ -140,7 +144,7 @@ else if ("INSERT".equals(query))
* transaction block); 'T' if in a transaction block; or 'E' if in a failed transaction block
* (queries will be rejected until block is ended).
*/
static ChannelFuture sendReadyForQuery(Channel channel)
ChannelFuture sendReadyForQuery(Channel channel)
{
ByteBuf buffer = channel.alloc().buffer(6);
buffer.writeByte('Z');
Expand All @@ -167,7 +171,7 @@ static ChannelFuture sendReadyForQuery(Channel channel)
* session_authorization, - DateStyle, - IntervalStyle, - TimeZone, - integer_datetimes, -
* standard_conforming_string
*/
static void sendParameterStatus(Channel channel, final String name, final String value)
void sendParameterStatus(Channel channel, final String name, final String value)
{
byte[] nameBytes = name.getBytes(StandardCharsets.UTF_8);
byte[] valueBytes = value.getBytes(StandardCharsets.UTF_8);
Expand All @@ -186,7 +190,7 @@ static void sendParameterStatus(Channel channel, final String name, final String
}
}

static void sendAuthenticationError(Channel channel, String message)
void sendAuthenticationError(Channel channel, String message)
{
LOGGER.warn(message);
byte[] msg = (message != null ? message : "Unknown Auth Error").getBytes(StandardCharsets.UTF_8);
Expand All @@ -198,7 +202,7 @@ static void sendAuthenticationError(Channel channel, String message)
}


private static String buildErrorMessage(Throwable throwable)
private String buildErrorMessage(Throwable throwable)
{
TextMapSetter<Map> TEXT_MAP_SETTER = (map, key, value) -> Objects.requireNonNull(map).put(key, value);
Map<String, String> keys = new HashMap<>();
Expand All @@ -209,11 +213,13 @@ private static String buildErrorMessage(Throwable throwable)
return errorMessage.toString();
}

static ChannelFuture sendErrorResponse(Channel channel, Throwable throwable)
ChannelFuture sendErrorResponse(Channel channel, Throwable throwable)
{
String errorMessage = buildErrorMessage(throwable);
//wrap exception to add tracing if available
throwable = PostgresServerException.wrapException(throwable);
String errorMessage = errorMessageFormatter.format(throwable);
LOGGER.error(errorMessage, throwable);
final PGError error = new PGError(PGErrorStatus.INTERNAL_ERROR, errorMessage, throwable);
final PGError error = new PGError(PGErrorStatus.INTERNAL_ERROR, errorMessage, throwable);

ByteBuf buffer = channel.alloc().buffer();
buffer.writeByte('E');
Expand Down Expand Up @@ -280,14 +286,14 @@ static ChannelFuture sendErrorResponse(Channel channel, Throwable throwable)
* See https://www.postgresql.org/docs/9.2/static/protocol-error-fields.html for a list of error
* codes
*/
private static ChannelFuture sendErrorResponse(Channel channel,
String message,
byte[] msg,
byte[] severity,
byte[] lineNumber,
byte[] fileName,
byte[] methodName,
byte[] errorCode)
private ChannelFuture sendErrorResponse(Channel channel,
String message,
byte[] msg,
byte[] severity,
byte[] lineNumber,
byte[] fileName,
byte[] methodName,
byte[] errorCode)
{
int length = 4 +
1 + (severity.length + 1) +
Expand Down Expand Up @@ -348,8 +354,8 @@ private static ChannelFuture sendErrorResponse(Channel channel,
* above length.
*/
@SuppressWarnings({"unchecked", "rawtypes"})
static void sendDataRow(Channel channel, PostgresResultSet rs, List<PGType<?>> columnTypes,
FormatCodes.FormatCode[] formatCodes) throws Exception
void sendDataRow(Channel channel, PostgresResultSet rs, List<PGType<?>> columnTypes,
FormatCodes.FormatCode[] formatCodes) throws Exception
{
int length = 4 + 2;
assert columnTypes.size() == rs.getMetaData().getColumnCount()
Expand Down Expand Up @@ -393,7 +399,7 @@ static void sendDataRow(Channel channel, PostgresResultSet rs, List<PGType<?>> c

default:
buffer.release();
throw new AssertionError("Unrecognized formatCode: " + formatCode);
throw new PostgresServerException("Unrecognized formatCode: " + formatCode);
}
}
}
Expand All @@ -402,13 +408,13 @@ static void sendDataRow(Channel channel, PostgresResultSet rs, List<PGType<?>> c
channel.writeAndFlush(buffer);
}

static void writeCString(ByteBuf buffer, byte[] valBytes)
void writeCString(ByteBuf buffer, byte[] valBytes)
{
buffer.writeBytes(valBytes);
buffer.writeByte(0);
}

static void writeByteArray(ByteBuf buffer, byte[] valBytes)
void writeByteArray(ByteBuf buffer, byte[] valBytes)
{
buffer.writeBytes(valBytes);
}
Expand All @@ -433,7 +439,7 @@ static void writeByteArray(ByteBuf buffer, byte[] valBytes)
* @param channel The channel to write the parameter description to.
* @param parameters A {@link SortedSet} containing the parameters from index 1 upwards.
*/
static void sendParameterDescription(Channel channel, ParameterMetaData parameters) throws SQLException
void sendParameterDescription(Channel channel, ParameterMetaData parameters) throws SQLException
{
final int messageByteSize = 4 + 2 + parameters.getParameterCount() * 4;
ByteBuf buffer = channel.alloc().buffer(messageByteSize);
Expand Down Expand Up @@ -470,9 +476,9 @@ static void sendParameterDescription(Channel channel, ParameterMetaData paramete
* <p>
* See https://www.postgresql.org/docs/current/static/protocol-message-formats.html
*/
static void sendRowDescription(Channel channel,
PostgresResultSetMetaData resultSetMetaData,
FormatCodes.FormatCode[] formatCodes) throws Exception
void sendRowDescription(Channel channel,
PostgresResultSetMetaData resultSetMetaData,
FormatCodes.FormatCode[] formatCodes) throws Exception
{
int length = 4 + 2;
int columnSize = 4 + 2 + 4 + 2 + 4 + 2;
Expand Down Expand Up @@ -527,12 +533,12 @@ static void sendRowDescription(Channel channel,
/**
* ParseComplete | '1' | int32 len |
*/
static void sendParseComplete(Channel channel)
void sendParseComplete(Channel channel)
{
sendShortMsg(channel, '1', "sentParseComplete");
}

static void sendGssOutToken(Channel channel, byte[] outputToken)
void sendGssOutToken(Channel channel, byte[] outputToken)
{
int integerLength = 8;
int gssSuccessFlag = 8;
Expand All @@ -552,31 +558,31 @@ static void sendGssOutToken(Channel channel, byte[] outputToken)
/**
* BindComplete | '2' | int32 len |
*/
static void sendBindComplete(Channel channel)
void sendBindComplete(Channel channel)
{
sendShortMsg(channel, '2', "sentBindComplete");
}

/**
* EmptyQueryResponse | 'I' | int32 len |
*/
static void sendEmptyQueryResponse(Channel channel)
void sendEmptyQueryResponse(Channel channel)
{
sendShortMsg(channel, 'I', "sentEmptyQueryResponse");
}

/**
* NoData | 'n' | int32 len |
*/
static void sendNoData(Channel channel)
void sendNoData(Channel channel)
{
sendShortMsg(channel, 'n', "sentNoData");
}

/**
* Send a message that just contains the msgType and the msg length
*/
private static void sendShortMsg(Channel channel, char msgType, final String traceLogMsg)
private void sendShortMsg(Channel channel, char msgType, final String traceLogMsg)
{
ByteBuf buffer = channel.alloc().buffer(5);
buffer.writeByte(msgType);
Expand All @@ -589,15 +595,15 @@ private static void sendShortMsg(Channel channel, char msgType, final String tra
}
}

static void sendPortalSuspended(Channel channel)
void sendPortalSuspended(Channel channel)
{
sendShortMsg(channel, 's', "sentPortalSuspended");
}

/**
* CloseComplete | '3' | int32 len |
*/
static void sendCloseComplete(Channel channel)
void sendCloseComplete(Channel channel)
{
sendShortMsg(channel, '3', "sentCloseComplete");
}
Expand All @@ -613,7 +619,7 @@ static void sendCloseComplete(Channel channel)
*
* @param channel The channel to write to.
*/
static void sendAuthenticationCleartextPassword(Channel channel)
void sendAuthenticationCleartextPassword(Channel channel)
{
ByteBuf buffer = channel.alloc().buffer(9);
buffer.writeByte('R');
Expand All @@ -627,7 +633,7 @@ static void sendAuthenticationCleartextPassword(Channel channel)
}
}

static void sendAuthenticationKerberos(Channel channel)
void sendAuthenticationKerberos(Channel channel)
{
int integerLength = 8;
int authReqGss = 7;
Expand All @@ -651,7 +657,7 @@ static void sendAuthenticationKerberos(Channel channel)
/**
* CancelRequest | 'K' | int32 request code | int32 pid | int32 secret key |
*/
static void sendKeyData(Channel channel, int pid, int secretKey)
void sendKeyData(Channel channel, int pid, int secretKey)
{
ByteBuf buffer = channel.alloc().buffer(13);
buffer.writeByte('K');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,16 @@
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;

import java.net.InetSocketAddress;
import java.net.SocketAddress;

import org.finos.legend.engine.postgres.auth.AuthenticationProvider;
import org.finos.legend.engine.postgres.config.GSSConfig;
import org.finos.legend.engine.postgres.config.ServerConfig;
import org.finos.legend.engine.postgres.transport.Netty4OpenChannelsHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class PostgresServer
{

Expand All @@ -57,12 +56,15 @@ public class PostgresServer
private EventLoopGroup bossGroup;
private EventLoopGroup workerGroup;

public PostgresServer(ServerConfig serverConfig, SessionsFactory sessionsFactory, AuthenticationProvider authenticationProvider)
private final Messages messages;

public PostgresServer(ServerConfig serverConfig, SessionsFactory sessionsFactory, AuthenticationProvider authenticationProvider, Messages messages)
{
this.port = serverConfig.getPort();
this.sessionsFactory = sessionsFactory;
this.authenticationProvider = authenticationProvider;
this.gssConfig = serverConfig.getGss();
this.messages = messages;
}

public void run()
Expand All @@ -86,7 +88,7 @@ public void run()
protected void initChannel(SocketChannel ch)
{
PostgresWireProtocol postgresWireProtocol = new PostgresWireProtocol(sessionsFactory,
authenticationProvider, gssConfig, () -> null);
authenticationProvider, gssConfig, () -> null, messages);
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast("open_channels", openChannelsHandler);
pipeline.addLast("frame-decoder", postgresWireProtocol.decoder);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2023 Goldman Sachs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.finos.legend.engine.postgres;

import io.opentelemetry.context.Context;
import io.opentelemetry.context.propagation.TextMapSetter;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.finos.legend.engine.postgres.utils.OpenTelemetryUtil;

public class PostgresServerException extends RuntimeException
{
private static final TextMapSetter<Map> TEXT_MAP_SETTER = (map, key, value) -> Objects.requireNonNull(map).put(key, value);

private Map<String, String> tracingDetails = new HashMap<>();

public PostgresServerException(Throwable cause)
{
super(cause);
addTracingDetails();
}

public PostgresServerException(String message)
{
super(message);
addTracingDetails();
}

public PostgresServerException(String message, Throwable cause)
{
super(message, cause);
addTracingDetails();
}

public static PostgresServerException wrapException(Throwable e)
{
if (!(e instanceof PostgresServerException))
{
return new PostgresServerException(e);
}
return (PostgresServerException) e;

}

private void addTracingDetails()
{
OpenTelemetryUtil.getPropagators().inject(Context.current(), tracingDetails, TEXT_MAP_SETTER);
}

public Map<String, String> getTracingDetails()
{
return Collections.unmodifiableMap(tracingDetails);
}
}
Loading

0 comments on commit 7ce143d

Please sign in to comment.