Merge pull request #3629 from actiontech/fix/2029

fix: adjust the identification process of the ssl protocol package
This commit is contained in:
LUA
2023-03-16 09:57:10 +08:00
committed by GitHub
10 changed files with 330 additions and 212 deletions

View File

@@ -7,10 +7,11 @@ package com.actiontech.dble.backend.mysql.proto.handler.Impl;
import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandler;
import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResult;
import com.actiontech.dble.config.model.SystemConfig;
import com.actiontech.dble.net.connection.AbstractConnection;
import com.actiontech.dble.net.mysql.MySQLPacket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.actiontech.dble.net.ssl.GMSslWrapper;
import com.actiontech.dble.net.ssl.OpenSSLWrapper;
import com.actiontech.dble.util.exception.NotSslRecordException;
import javax.annotation.Nonnull;
import java.nio.ByteBuffer;
@@ -22,37 +23,159 @@ import static com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResult
* Created by szf on 2020/6/16.
*/
public class MySQLProtoHandlerImpl implements ProtoHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(MySQLProtoHandlerImpl.class);
private byte[] incompleteData = null;
/**
* the length of the ssl record header (in bytes)
*/
static final int SSL_RECORD_HEADER_LENGTH = 5;
/**
* change cipher spec
*/
static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
public MySQLProtoHandlerImpl() {
/**
* alert
*/
static final int SSL_CONTENT_TYPE_ALERT = 21;
/**
* handshake
*/
static final int SSL_CONTENT_TYPE_HANDSHAKE = 22;
/**
* application data
*/
static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
/**
* HeartBeat Extension
* <p>
* jdk not support. see:sun.security.ssl.ContentType
* <p>
* can cause <a href="https://en.wikipedia.org/wiki/Heartbleed">Heartbleed</a>
*/
@Deprecated
static final int SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT = 24;
/**
* Not enough data in buffer to parse the record length
*/
static final int NOT_ENOUGH_DATA = -1;
/**
* data is not encrypted
*/
static final int NOT_ENCRYPTED = -2;
/**
* GMSSL Protocol Version
*/
static final int GMSSL_PROTOCOL_VERSION = 0x101;
private final AbstractConnection connection;
private int protocol = OpenSSLWrapper.PROTOCOL;
public MySQLProtoHandlerImpl(AbstractConnection connection) {
this.connection = connection;
}
@Override
@Nonnull
public ProtoHandlerResult handle(ByteBuffer dataBuffer, int offset, boolean isSupportCompress) {
public ProtoHandlerResult handle(ByteBuffer dataBuffer, int offset, boolean isSupportCompress, boolean isContainSSLData) throws NotSslRecordException {
int position = dataBuffer.position();
int length = getPacketLength(dataBuffer, offset, isSupportCompress);
boolean isSSL = false;
//get length
int length;
if (isContainSSLData) {
if ((length = getSSLPacketLength(dataBuffer, offset)) != NOT_ENOUGH_DATA && length != NOT_ENCRYPTED) {
//client hello
isSSL = true;
} else {
//login
length = getNonSSLPacketLength(dataBuffer, offset, isSupportCompress);
}
} else {
length = getNonSSLPacketLength(dataBuffer, offset, isSupportCompress);
}
final ProtoHandlerResult.ProtoHandlerResultBuilder builder = ProtoHandlerResult.builder();
return getProtoHandlerResultBuilder(dataBuffer, offset, position, length, builder).build();
ProtoHandlerResult.ProtoHandlerResultBuilder resultBuilder = getProtoHandlerResultBuilder(dataBuffer, offset, position, length, builder, isSSL);
if (connection != null && resultBuilder.getCode().equals(SSL_PROTO_PACKET)) {
connection.initSSLContext(protocol);
}
return resultBuilder.build();
}
protected int getSSLPacketLength(ByteBuffer buffer, int offset) {
if (buffer.position() < offset + SSL_RECORD_HEADER_LENGTH) {
return NOT_ENOUGH_DATA;
}
int packetLength = 0;
// SSLv3 or TLS - Check ContentType
boolean tls;
switch (buffer.get(offset) & 0xff) {
case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SSL_CONTENT_TYPE_ALERT:
case SSL_CONTENT_TYPE_HANDSHAKE:
case SSL_CONTENT_TYPE_APPLICATION_DATA:
tls = true;
break;
default:
// SSLv2 or bad data
tls = false;
}
if (tls) {
// SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1 - Check ProtocolVersion
int majorVersion = buffer.get(offset + 1);
if (majorVersion == 3 || buffer.getShort(offset + 1) == GMSSL_PROTOCOL_VERSION) {
if (buffer.getShort(offset + 1) == GMSSL_PROTOCOL_VERSION) {
protocol = GMSslWrapper.PROTOCOL;
}
// SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1
packetLength = (buffer.getShort(offset + 3) & 0xffff) + SSL_RECORD_HEADER_LENGTH;
if (packetLength <= SSL_RECORD_HEADER_LENGTH) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
} else {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
}
if (!tls) {
// SSLv2 or bad data - Check the version
int headerLength = (buffer.get(offset) & 0x80) != 0 ? 2 : 3;
int majorVersion = buffer.get(offset + headerLength + 1);
if (majorVersion == 2 || majorVersion == 3) {
// SSLv2
packetLength = headerLength == 2 ? (buffer.getShort(offset) & 0x7FFF) + 2 : (buffer.getShort(offset) & 0x3FFF) + 3;
if (packetLength <= headerLength) {
return NOT_ENOUGH_DATA;
}
} else {
return NOT_ENCRYPTED;
}
}
return packetLength;
}
@Nonnull
public ProtoHandlerResult.ProtoHandlerResultBuilder handlerResultBuilder(ByteBuffer dataBuffer, int offset, boolean isSupportCompress) {
int position = dataBuffer.position();
int length = getPacketLength(dataBuffer, offset, isSupportCompress);
int length = getNonSSLPacketLength(dataBuffer, offset, isSupportCompress);
final ProtoHandlerResult.ProtoHandlerResultBuilder builder = ProtoHandlerResult.builder();
return getProtoHandlerResultBuilder(dataBuffer, offset, position, length, builder);
return getProtoHandlerResultBuilder(dataBuffer, offset, position, length, builder, false);
}
private ProtoHandlerResult.ProtoHandlerResultBuilder getProtoHandlerResultBuilder(ByteBuffer dataBuffer, int offset, int position, int length, ProtoHandlerResult.ProtoHandlerResultBuilder builder) {
private ProtoHandlerResult.ProtoHandlerResultBuilder getProtoHandlerResultBuilder(ByteBuffer dataBuffer, int offset, int position, int length,
ProtoHandlerResult.ProtoHandlerResultBuilder builder, boolean isSSL) {
if (length == -1) {
if (offset != 0) {
return builder.setCode(BUFFER_PACKET_UNCOMPLETE).setHasMorePacket(false).setOffset(offset);
} else if (!dataBuffer.hasRemaining()) {
throw new RuntimeException("invalid dataBuffer capacity ,too little buffer size " +
dataBuffer.capacity());
throw new RuntimeException("invalid dataBuffer capacity ,too little buffer size " + dataBuffer.capacity());
}
return builder.setCode(BUFFER_PACKET_UNCOMPLETE).setHasMorePacket(false).setOffset(offset);
}
@@ -65,7 +188,7 @@ public class MySQLProtoHandlerImpl implements ProtoHandler {
if (data == null) {
builder.setCode(PART_OF_BIG_PACKET);
} else {
builder.setCode(COMPLETE_PACKET);
builder.setCode(isSSL ? SSL_PROTO_PACKET : COMPLETE_PACKET);
}
// offset to next position
offset += length;
@@ -85,10 +208,7 @@ public class MySQLProtoHandlerImpl implements ProtoHandler {
// not read whole message package ,so check if buffer enough and
// compact dataBuffer
if (!dataBuffer.hasRemaining()) {
if (SystemConfig.getInstance().isSupportSSL() && SSLProtoHandler.isSSLPackage(dataBuffer, offset)) {
return builder.setCode(SSL_PROTO_PACKET).setHasMorePacket(false).setOffset(offset);
}
return builder.setCode(BUFFER_NOT_BIG_ENOUGH).setHasMorePacket(false).setOffset(offset).setPacketLength(length);
return builder.setCode(isSSL ? SSL_BUFFER_NOT_BIG_ENOUGH : BUFFER_NOT_BIG_ENOUGH).setHasMorePacket(false).setOffset(offset).setPacketLength(length);
} else {
return builder.setCode(BUFFER_PACKET_UNCOMPLETE).setHasMorePacket(false).setOffset(offset).setPacketLength(length);
}
@@ -96,7 +216,7 @@ public class MySQLProtoHandlerImpl implements ProtoHandler {
}
private int getPacketLength(ByteBuffer buffer, int offset, boolean isSupportCompress) {
private int getNonSSLPacketLength(ByteBuffer buffer, int offset, boolean isSupportCompress) {
int headerSize = MySQLPacket.PACKET_HEADER_SIZE;
if (isSupportCompress) {
headerSize = 7;
@@ -145,4 +265,8 @@ public class MySQLProtoHandlerImpl implements ProtoHandler {
System.arraycopy(data, 0, newData, incompleteData.length, data.length);
return newData;
}
public int getProtocol() {
return protocol;
}
}

View File

@@ -5,97 +5,33 @@
package com.actiontech.dble.backend.mysql.proto.handler.Impl;
import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandler;
import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResult;
import com.actiontech.dble.net.connection.FrontendConnection;
import com.actiontech.dble.net.ssl.OpenSSLWrapper;
import com.actiontech.dble.net.ssl.GMSslWrapper;
import com.actiontech.dble.net.connection.AbstractConnection;
import com.actiontech.dble.util.exception.NotSslRecordException;
import org.jetbrains.annotations.NotNull;
import java.nio.ByteBuffer;
import static com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResultCode.*;
public class SSLProtoHandler implements ProtoHandler {
public class SSLProtoHandler extends MySQLProtoHandlerImpl {
public static final int PACKET_HEADER_SIZE = 5;
/**
* GMSSL Protocol Version
*/
static final int GMSSL_PROTOCOL_VERSION = 0x101;
//1-ssl 2-gmssl
private int protocol = OpenSSLWrapper.PROTOCOL;
private boolean tls;
private FrontendConnection connection;
public SSLProtoHandler(FrontendConnection connection) {
this.connection = connection;
public SSLProtoHandler(AbstractConnection connection) {
super(connection);
}
@NotNull
@Override
public ProtoHandlerResult handle(ByteBuffer dataBuffer, int offset, boolean isSupportCompress) {
public ProtoHandlerResult handle(ByteBuffer dataBuffer, int offset, boolean isSupportCompress, boolean isContainSSLData) throws NotSslRecordException {
int position = dataBuffer.position();
int length = getPacketLength(dataBuffer, offset);
int length = getSSLPacketLength(dataBuffer, offset);
if (length == NOT_ENOUGH_DATA || length == NOT_ENCRYPTED) {
throw new NotSslRecordException("not an SSL/TLS record");
}
final ProtoHandlerResult.ProtoHandlerResultBuilder builder = ProtoHandlerResult.builder();
if (tls) {
connection.initSSLContext(protocol);
}
return getProtoHandlerResultBuilder(dataBuffer, offset, position, length, builder).build();
}
private int getPacketLength(ByteBuffer buffer, int offset) {
int packetLength = 0;
if (buffer.position() >= offset + PACKET_HEADER_SIZE) {
// SSLv3 or TLS - Check ContentType
tls = isSSLPackage(buffer, offset);
if (tls) {
// SSLv3 or TLS - Check ProtocolVersion
int majorVersion = buffer.get(offset + 1);
if (majorVersion == 3 || buffer.getShort(offset + 1) == GMSSL_PROTOCOL_VERSION) {
if (buffer.getShort(offset + 1) == GMSSL_PROTOCOL_VERSION) {
protocol = GMSslWrapper.PROTOCOL;
}
// SSLv3 or TLS
packetLength = buffer.getShort(offset + 3) & 0xffff;
packetLength += PACKET_HEADER_SIZE;
if (packetLength <= 5) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
} else {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
}
if (!tls) {
// SSLv2 or bad data - Check the version
boolean sslv2 = true;
int headerLength = (buffer.get(offset) & 0x80) != 0 ? 2 : 3;
int majorVersion = buffer.get(offset + headerLength + 1);
if (majorVersion == 2 || majorVersion == 3) {
// SSLv2
if (headerLength == 2) {
packetLength = (buffer.getShort(offset) & 0x7FFF) + 2;
} else {
packetLength = (buffer.getShort(offset) & 0x3FFF) + 3;
}
if (packetLength <= headerLength) {
sslv2 = false;
}
} else {
sslv2 = false;
}
if (!sslv2) {
return -1;
}
}
return packetLength;
}
return -1;
}
private ProtoHandlerResult.ProtoHandlerResultBuilder getProtoHandlerResultBuilder(ByteBuffer dataBuffer, int offset, int position, int length, ProtoHandlerResult.ProtoHandlerResultBuilder builder) {
if (length == -1) {
if (offset != 0) {
@@ -139,40 +75,4 @@ public class SSLProtoHandler implements ProtoHandler {
}
}
public static boolean isSSLPackage(ByteBuffer buffer, int offset) {
return checkSSLProto(buffer, offset);
}
private static boolean checkSSLProto(ByteBuffer buffer, int offset) {
if (buffer.position() >= offset + PACKET_HEADER_SIZE) {
int packageType = buffer.get(offset) & 0xff;
switch (packageType) {
case 20: // change_cipher_spec
case 21: // alert
case 22: // handshake
case 23: // application_data
int majorVersion = buffer.get(offset + 1);
if (majorVersion == 3) {
int minorVersion = buffer.get(offset + 2);
switch (minorVersion) {
case 0: // SSLv3
case 1: // TLS1.0
case 2: // TLS1.2
case 3: // TLS1.3
return true;
default:
return false;
}
}
return true;
default:
return false;
}
}
return false;
}
public int getProtocol() {
return protocol;
}
}

View File

@@ -5,6 +5,8 @@
package com.actiontech.dble.backend.mysql.proto.handler;
import com.actiontech.dble.util.exception.NotSslRecordException;
import java.nio.ByteBuffer;
/**
@@ -12,6 +14,6 @@ import java.nio.ByteBuffer;
*/
public interface ProtoHandler {
ProtoHandlerResult handle(ByteBuffer dataBuffer, int dataBufferOffset, boolean isSupportCompress);
ProtoHandlerResult handle(ByteBuffer dataBuffer, int dataBufferOffset, boolean isSupportCompress, boolean isContainSSLData) throws NotSslRecordException;
}

View File

@@ -79,6 +79,10 @@ public final class ProtoHandlerResult {
return this;
}
public ProtoHandlerResultCode getCode() {
return code;
}
public ProtoHandlerResultBuilder setPacketData(byte[] val) {
this.packetData = val;
return this;

View File

@@ -8,6 +8,7 @@ package com.actiontech.dble.net.connection;
import com.actiontech.dble.backend.mysql.proto.handler.Impl.MySQLProtoHandlerImpl;
import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandler;
import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResult;
import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResultCode;
import com.actiontech.dble.btrace.provider.IODelayProvider;
import com.actiontech.dble.buffer.BufferPoolRecord;
import com.actiontech.dble.buffer.BufferType;
@@ -89,6 +90,8 @@ public abstract class AbstractConnection implements Connection {
private final ConcurrentLinkedQueue<byte[]> compressUnfinishedDataQueue = new ConcurrentLinkedQueue<>();
protected final AtomicInteger writingSize = new AtomicInteger(0);
protected volatile Boolean requestSSL;
public AbstractConnection(NetworkChannel channel, SocketWR socketWR) {
this.channel = channel;
this.socketWR = socketWR;
@@ -96,7 +99,7 @@ public abstract class AbstractConnection implements Connection {
this.startupTime = TimeUtil.currentTimeMillis();
this.lastReadTime = startupTime;
this.lastWriteTime = startupTime;
this.proto = new MySQLProtoHandlerImpl();
this.proto = new MySQLProtoHandlerImpl(this);
FrontActiveRatioStat.getInstance().register(this, startupTime);
}
@@ -229,11 +232,21 @@ public abstract class AbstractConnection implements Connection {
service.handle(serviceTask);
}
public void handle(ByteBuffer dataBuffer) throws IOException {
protected void handle(ByteBuffer dataBuffer) throws IOException {
handle(dataBuffer, false);
}
/**
* need to handle the following scenarios
* 1. isSupportSSL == false
* 2. isSupportSSL == true and client is not using ssl
* 3. isSupportSSL == true and client is using ssl and login request & client hello protocol
*/
protected void handle(ByteBuffer dataBuffer, boolean isContainSSLData) throws IOException {
boolean hasRemaining = true;
int offset = 0;
while (hasRemaining) {
ProtoHandlerResult result = proto.handle(dataBuffer, offset, isSupportCompress);
ProtoHandlerResult result = proto.handle(dataBuffer, offset, isSupportCompress, isContainSSLData);
switch (result.getCode()) {
case PART_OF_BIG_PACKET:
if (!result.isHasMorePacket()) {
@@ -254,11 +267,20 @@ public abstract class AbstractConnection implements Connection {
dataBuffer.clear();
}
break;
case BUFFER_PACKET_UNCOMPLETE:
compactReadBuffer(dataBuffer, result.getOffset());
break;
case SSL_PROTO_PACKET:
compactReadBuffer(dataBuffer, offset);
if (!result.isHasMorePacket()) {
readReachEnd();
}
processSSLProto(result.getPacketData(), result.getCode());
if (!result.isHasMorePacket()) {
dataBuffer.clear();
}
break;
case BUFFER_PACKET_UNCOMPLETE:
compactReadBuffer(dataBuffer, result.getOffset(), false);
break;
case SSL_BUFFER_NOT_BIG_ENOUGH:
compactReadBuffer(dataBuffer, offset, true);
break;
case BUFFER_NOT_BIG_ENOUGH:
ensureFreeSpaceOfReadBuffer(dataBuffer, result.getOffset(), result.getPacketLength());
@@ -274,6 +296,18 @@ public abstract class AbstractConnection implements Connection {
}
}
protected void processSSLProto(byte[] packetData, ProtoHandlerResultCode code) {
AbstractService frontService = getService();
if (packetData != null) {
if (code == ProtoHandlerResultCode.SSL_PROTO_PACKET) {
pushServiceTask(new SSLProtoServerTask(packetData, frontService));
} else {
pushServiceTask(ServiceTaskFactory.getInstance(frontService).createForGracefulClose("ssl close", CloseType.READ));
}
}
}
public void processPacketData(ProtoHandlerResult result) {
byte[] packetData = result.getPacketData();
final AbstractService frontService = service;
@@ -323,7 +357,7 @@ public abstract class AbstractConnection implements Connection {
}
}
public void compactReadBuffer(ByteBuffer buffer, int offset) throws IOException {
public void compactReadBuffer(ByteBuffer buffer, int offset, boolean isSSL) throws IOException {
if (buffer == null) {
return;
}
@@ -348,7 +382,7 @@ public abstract class AbstractConnection implements Connection {
} else {
if (offset != 0) {
// compact bytebuffer only
compactReadBuffer(buffer, offset);
compactReadBuffer(buffer, offset, false);
} else {
throw new RuntimeException(" not enough space");
}
@@ -483,9 +517,10 @@ public abstract class AbstractConnection implements Connection {
}
int bufferSize;
WriteOutTask writeTask;
ByteBuffer newBuffer = null;
try {
if (isSupportCompress) {
ByteBuffer newBuffer = CompressUtil.compressMysqlPacket(buffer, this, compressUnfinishedDataQueue);
newBuffer = CompressUtil.compressMysqlPacket(buffer, this, compressUnfinishedDataQueue);
newBuffer = wrap(newBuffer);
writeTask = new WriteOutTask(newBuffer, false);
bufferSize = newBuffer.position();
@@ -496,6 +531,9 @@ public abstract class AbstractConnection implements Connection {
}
} catch (SSLException e) {
recycle(buffer);
if (newBuffer != null) {
recycle(newBuffer);
}
return;
}
@@ -706,6 +744,21 @@ public abstract class AbstractConnection implements Connection {
this.readBufferChunk = readBufferChunk;
}
public Boolean isRequestSSL() {
return requestSSL;
}
public void setRequestSSL(Boolean requestSSL) {
this.requestSSL = requestSSL;
}
/**
* ssl login request(non ssl)&client hello(ssl)
*/
protected boolean maybeUseSSL() {
return isRequestSSL() == null || isRequestSSL();
}
public ByteBuffer getBottomReadBuffer() {
return this.bottomReadBuffer;
}
@@ -753,4 +806,9 @@ public abstract class AbstractConnection implements Connection {
public void setBottomReadBuffer(ByteBuffer bottomReadBuffer) {
this.bottomReadBuffer = bottomReadBuffer;
}
public void initSSLContext(int protocol) {
}
}

View File

@@ -7,13 +7,13 @@ package com.actiontech.dble.net.connection;
import com.actiontech.dble.backend.mysql.proto.handler.Impl.SSLProtoHandler;
import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResult;
import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResultCode;
import com.actiontech.dble.btrace.provider.IODelayProvider;
import com.actiontech.dble.buffer.BufferType;
import com.actiontech.dble.config.model.SystemConfig;
import com.actiontech.dble.net.IOProcessor;
import com.actiontech.dble.net.SocketWR;
import com.actiontech.dble.net.service.*;
import com.actiontech.dble.net.service.AbstractService;
import com.actiontech.dble.net.service.AuthService;
import com.actiontech.dble.net.ssl.OpenSSLWrapper;
import com.actiontech.dble.net.ssl.SSLWrapperRegistry;
import com.actiontech.dble.services.BusinessService;
@@ -70,6 +70,29 @@ public class FrontendConnection extends AbstractConnection {
this.isSupportSSL = SystemConfig.getInstance().isSupportSSL();
}
@Override
protected void handle(ByteBuffer dataBuffer, boolean isContainSSLData) throws IOException {
if (this.isSupportSSL && isUseSSL()) {
//after ssl-client hello
handleSSLData(dataBuffer);
} else {
//ssl buffer -> bottomRead buffer
transferToReadBuffer(dataBuffer);
if (maybeUseSSL()) {
//ssl login request(non ssl)&client hello(ssl)
super.handle(getBottomReadBuffer(), true);
} else {
//no ssl
handleNonSSL(getBottomReadBuffer());
}
}
}
protected void handleNonSSL(ByteBuffer dataBuffer) throws IOException {
super.handle(dataBuffer, false);
}
@Override
public void initSSLContext(int protocol) {
if (sslHandler != null) {
return;
@@ -94,27 +117,16 @@ public class FrontendConnection extends AbstractConnection {
}
sslHandler.handShake(data);
} catch (SSLException e) {
LOGGER.error("SSL handshake failed, exception: {},", e);
LOGGER.warn("SSL handshake failed, exception: ", e);
close("SSL handshake failed");
} catch (IOException e) {
LOGGER.error("SSL initialization failed, exception: {},", e);
LOGGER.warn("SSL initialization failed, exception: ", e);
close("SSL initialization failed");
}
return;
}
@Override
public void handle(ByteBuffer dataBuffer) throws IOException {
if (isSupportSSL && isUseSSL()) {
handleSSLData(dataBuffer);
} else {
transferToReadBuffer(dataBuffer);
parentHandle(getBottomReadBuffer());
}
}
private void transferToReadBuffer(ByteBuffer dataBuffer) {
if (!isSupportSSL) return;
if (!isSupportSSL || !maybeUseSSL()) return;
dataBuffer.flip();
ByteBuffer readBuffer = findBottomReadBuffer();
int len = readBuffer.position() + dataBuffer.limit();
@@ -125,10 +137,6 @@ public class FrontendConnection extends AbstractConnection {
dataBuffer.clear();
}
public void parentHandle(ByteBuffer buffer) throws IOException {
super.handle(buffer);
}
public void handleSSLData(ByteBuffer dataBuffer) throws IOException {
if (dataBuffer == null) {
return;
@@ -137,7 +145,7 @@ public class FrontendConnection extends AbstractConnection {
SSLProtoHandler proto = new SSLProtoHandler(this);
boolean hasRemaining = true;
while (hasRemaining) {
ProtoHandlerResult result = proto.handle(dataBuffer, offset, false);
ProtoHandlerResult result = proto.handle(dataBuffer, offset, false, true);
switch (result.getCode()) {
case SSL_PROTO_PACKET:
case SSL_CLOSE_PACKET:
@@ -183,8 +191,7 @@ public class FrontendConnection extends AbstractConnection {
// received large message in recent 30 seconds
// then change to direct buffer for performance
ByteBuffer localReadBuffer = netReadBuffer;
if (localReadBuffer != null && !localReadBuffer.isDirect() &&
lastLargeMessageTime < lastReadTime - 30 * 1000L) { // used temp heap
if (localReadBuffer != null && !localReadBuffer.isDirect() && lastLargeMessageTime < lastReadTime - 30 * 1000L) { // used temp heap
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("change to direct con read buffer ,cur temp buf size :" + localReadBuffer.capacity());
}
@@ -198,22 +205,10 @@ public class FrontendConnection extends AbstractConnection {
}
}
private void processSSLProto(byte[] packetData, ProtoHandlerResultCode code) {
AbstractService frontService = getService();
if (packetData != null) {
if (code == ProtoHandlerResultCode.SSL_PROTO_PACKET) {
pushServiceTask(new SSLProtoServerTask(packetData, frontService));
} else {
pushServiceTask(ServiceTaskFactory.getInstance(frontService).createForGracefulClose("ssl close", CloseType.READ));
}
}
}
private void processSSLAppData(byte[] packetData) throws IOException {
if (packetData == null)
return;
if (packetData == null) return;
sslHandler.unwrapAppData(packetData);
parentHandle(getBottomReadBuffer());
handleNonSSL(getBottomReadBuffer());
}
public void processSSLPacketNotBigEnough(ByteBuffer buffer, int offset, final int pkgLength) {
@@ -240,32 +235,33 @@ public class FrontendConnection extends AbstractConnection {
@Override
public void close(String reason) {
if (isUseSSL())
sslHandler.close();
if (isUseSSL()) sslHandler.close();
super.close(reason);
}
@Override
public synchronized void recycleReadBuffer() {
if (netReadBuffer != null) {
this.recycle(netReadBuffer);
recycleNetReadBuffer();
super.recycleReadBuffer();
}
private void recycleNetReadBuffer() {
if (this.netReadBuffer != null) {
this.recycle(this.netReadBuffer);
this.netReadBuffer = null;
}
super.recycleReadBuffer();
}
@Override
public void startFlowControl(int currentWritingSize) {
if (!frontWriteFlowControlled && this.getService() instanceof BusinessService &&
currentWritingSize > FlowController.getFlowHighLevel()) {
if (!frontWriteFlowControlled && this.getService() instanceof BusinessService && currentWritingSize > FlowController.getFlowHighLevel()) {
((BusinessService) this.getService()).getSession().startFlowControl(currentWritingSize);
}
}
@Override
public void stopFlowControl(int currentWritingSize) {
if (this.getService() instanceof BusinessService &&
currentWritingSize <= FlowController.getFlowLowLevel()) {
if (this.getService() instanceof BusinessService && currentWritingSize <= FlowController.getFlowLowLevel()) {
((BusinessService) this.getService()).getSession().stopFlowControl(currentWritingSize);
}
}
@@ -273,10 +269,7 @@ public class FrontendConnection extends AbstractConnection {
@Override
public void cleanup(String reason) {
if (isCleanUp.compareAndSet(false, true)) {
if (netReadBuffer != null) {
this.recycle(netReadBuffer);
this.netReadBuffer = null;
}
recycleNetReadBuffer();
super.cleanup(reason);
AbstractService service = getService();
if (service instanceof FrontendService) {
@@ -293,17 +286,16 @@ public class FrontendConnection extends AbstractConnection {
@Override
public ByteBuffer wrap(ByteBuffer orgBuffer) throws SSLException {
if (!isUseSSL())
return orgBuffer;
if (!isUseSSL()) return orgBuffer;
return sslHandler.wrapAppData(orgBuffer);
}
@Override
public void compactReadBuffer(ByteBuffer dataBuffer, int offset) throws IOException {
public void compactReadBuffer(ByteBuffer dataBuffer, int offset, boolean isSSL) throws IOException {
if (dataBuffer == null) {
return;
}
if (isSupportSSL && SSLProtoHandler.isSSLPackage(dataBuffer, offset)) {
if (isSupportSSL && isSSL) {
dataBuffer.flip();
dataBuffer.position(offset);
int len = netReadBuffer.position() + (dataBuffer.limit() - dataBuffer.position());
@@ -349,19 +341,21 @@ public class FrontendConnection extends AbstractConnection {
@Override
public ByteBuffer findReadBuffer() {
if (isSupportSSL) {
if (isSupportSSL && maybeUseSSL()) {
if (this.netReadBuffer == null) {
netReadBuffer = allocate(processor.getBufferPool().getChunkSize(), generateBufferRecordBuilder().withType(BufferType.POOL));
}
return netReadBuffer;
} else {
//only recycle this read buffer
recycleNetReadBuffer();
return super.findReadBuffer();
}
}
@Override
ByteBuffer getReadBuffer() {
if (isSupportSSL) {
if (isSupportSSL && maybeUseSSL()) {
return netReadBuffer;
} else {
return super.getReadBuffer();
@@ -395,10 +389,7 @@ public class FrontendConnection extends AbstractConnection {
}
public String toString() {
return "FrontendConnection[id = " + id + " port = " + port + " host = " + host + " local_port = " +
localPort + " isManager = " + isManager() + " startupTime = " + startupTime + " skipCheck = " +
isSkipCheck() + " isFlowControl = " + isFrontWriteFlowControlled() + " onlyTcpConnect = " +
isOnlyFrontTcpConnected() + " ssl = " + (isUseSSL() ? sslName : "no") + "]";
return "FrontendConnection[id = " + id + " port = " + port + " host = " + host + " local_port = " + localPort + " isManager = " + isManager() + " startupTime = " + startupTime + " skipCheck = " + isSkipCheck() + " isFlowControl = " + isFrontWriteFlowControlled() + " onlyTcpConnect = " + isOnlyFrontTcpConnected() + " ssl = " + (isUseSSL() ? sslName : "no") + "]";
}
public String getSimple() {

View File

@@ -22,8 +22,8 @@ import java.nio.channels.SocketChannel;
public class SSLHandler {
protected static final Logger LOGGER = LoggerFactory.getLogger(SSLHandler.class);
private FrontendConnection con;
private NetworkChannel channel;
private final FrontendConnection con;
private final NetworkChannel channel;
private volatile ByteBuffer decryptOut;
@@ -54,7 +54,7 @@ public class SSLHandler {
/**
* receive and process the SSL handshake protocol initiated by the client
*/
private void unwrapNonAppData(byte[] data) throws SSLException {
private void unwrapNonAppData(byte[] data) {
ByteBuffer in = con.allocate(data.length);
in.put(data);
in.flip();
@@ -91,7 +91,7 @@ public class SSLHandler {
}
}
} catch (SSLException e) {
LOGGER.error("during the handshake, unwrap data exception: {}", e);
LOGGER.warn("during the handshake, unwrap data exception: ", e);
con.close("during the handshake, unwrap data fail");
} finally {
con.recycle(in);
@@ -124,7 +124,7 @@ public class SSLHandler {
}
}
} catch (SSLException e) {
LOGGER.error("during the interaction, unwrap data exception: {}", e);
LOGGER.warn("during the interaction, unwrap data exception: ", e);
con.close("during the interaction, unwrap data fail");
throw e;
} finally {
@@ -185,7 +185,7 @@ public class SSLHandler {
}
}
} catch (SSLException e) {
LOGGER.error("during the handshake, wrap data exception: {}", e);
LOGGER.warn("during the handshake, wrap data exception: ", e);
con.close("during the handshake, wrap data fail");
throw e;
}
@@ -218,9 +218,11 @@ public class SSLHandler {
return ByteBufferUtil.EMPTY_BYTE_BUFFER;
} catch (SSLException e) {
LOGGER.error("during the interaction, wrap data exception: {}", e);
LOGGER.warn("during the interaction, wrap data exception: ", e);
con.close("during the interaction, wrap data fail");
throw e;
} finally {
con.recycle(appBuffer);
}
}
@@ -270,7 +272,7 @@ public class SSLHandler {
con.recycle(decryptOut);
} catch (SSLException e) {
LOGGER.warn("SSL close failed, exception{}", e);
LOGGER.warn("SSL close failed, exception", e);
}
}

View File

@@ -170,7 +170,15 @@ public class MySQLFrontAuthService extends FrontendService implements AuthServic
private void handleAuthPacket(byte[] data) {
AuthPacket auth = new AuthPacket();
auth.read(data);
if (connection.isRequestSSL() == null) {
/*
++++ Only need to be based on the first CLIENT_SSL value ++++
+ Login request will be sent twice during ssl
+ 1. before the client hello and does not contain account password and other information
+ 2. encrypted after SSL authentication and contains account password and other information
*/
connection.setRequestSSL(auth.getIsSSLRequest());
}
if (auth.getIsSSLRequest())
return;

View File

@@ -27,7 +27,7 @@ public class LoadDataProtoHandlerImpl implements ProtoHandler {
}
@Override
public ProtoHandlerResult handle(ByteBuffer dataBuffer, int dataBufferOffset, boolean isSupportCompress) {
public ProtoHandlerResult handle(ByteBuffer dataBuffer, int dataBufferOffset, boolean isSupportCompress, boolean isContainSSLData) {
ProtoHandlerResult.ProtoHandlerResultBuilder resultBuilder = mySQLProtoHandler.handlerResultBuilder(dataBuffer, dataBufferOffset, isSupportCompress);
ProtoHandlerResult result = resultBuilder.build();
switch (result.getCode()) {

View File

@@ -0,0 +1,29 @@
/*
* Copyright (C) 2016-2023 ActionTech.
* License: http://www.gnu.org/licenses/gpl.html GPL version 2 or higher.
*/
package com.actiontech.dble.util.exception;
import javax.net.ssl.SSLException;
public class NotSslRecordException extends SSLException {
private static final long serialVersionUID = -4316784434770656841L;
public NotSslRecordException() {
super("");
}
public NotSslRecordException(String message) {
super(message);
}
public NotSslRecordException(Throwable cause) {
super(cause);
}
public NotSslRecordException(String message, Throwable cause) {
super(message, cause);
}
}