Skip to content

Commit

Permalink
[SPARK-45378][CORE] Add convertToNettyForSsl to ManagedBuffer
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

As the title suggests. In addition to that API, add a config to the `TransportConf` to configure the default block size if desired.

### Why are the changes needed?

Netty's SSL support does not support zero-copy transfers. In order to support SSL using Netty we need to add another API to the `ManagedBuffer` which lets buffers return a different data type.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

CI. This will have tests added later - it's tested as part of #42685 from which this is split out.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #43166 from hasnain-db/spark-tls-buffers.

Authored-by: Hasnain Lakhani <[email protected]>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
  • Loading branch information
hasnain-db authored and Mridul Muralidharan committed Oct 3, 2023
1 parent e53abbb commit b01dce2
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import com.google.common.io.ByteStreams;
import io.netty.channel.DefaultFileRegion;
import io.netty.handler.stream.ChunkedStream;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;

Expand Down Expand Up @@ -137,6 +138,12 @@ public Object convertToNetty() throws IOException {
}
}

@Override
public Object convertToNettyForSsl() throws IOException {
// Cannot use zero-copy with HTTPS
return new ChunkedStream(createInputStream(), conf.sslShuffleChunkSize());
}

public File getFile() { return file; }

public long getOffset() { return offset; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,18 @@ public abstract class ManagedBuffer {
* the caller will be responsible for releasing this new reference.
*/
public abstract Object convertToNetty() throws IOException;

/**
* Convert the buffer into a Netty object, used to write the data out with SSL encryption,
* which cannot use {@link io.netty.channel.FileRegion}.
* The return value is either a {@link io.netty.buffer.ByteBuf},
* a {@link io.netty.handler.stream.ChunkedStream}, or a {@link java.io.InputStream}.
*
* If this method returns a ByteBuf, then that buffer's reference count will be incremented and
* the caller will be responsible for releasing this new reference.
*
* Once `kernel.ssl.sendfile` and OpenSSL's `ssl_sendfile` are more widely adopted (and supported
* in Netty), we can potentially deprecate these APIs and just use `convertToNetty`.
*/
public abstract Object convertToNettyForSsl() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public Object convertToNetty() throws IOException {
return buf.duplicate().retain();
}

@Override
public Object convertToNettyForSsl() throws IOException {
return buf.duplicate().retain();
}

@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ public Object convertToNetty() throws IOException {
return Unpooled.wrappedBuffer(buf);
}

@Override
public Object convertToNettyForSsl() throws IOException {
return Unpooled.wrappedBuffer(buf);
}

@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ public boolean saslServerAlwaysEncrypt() {
return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
}

/**
* When Secure (SSL/TLS) Shuffle is enabled, the Chunk size to use for shuffling files.
*/
public int sslShuffleChunkSize() {
return Ints.checkedCast(JavaUtils.byteStringAsBytes(
conf.get("spark.network.ssl.maxEncryptedBlockSize", "64k")));
}

/**
* Flag indicating whether to share the pooled ByteBuf allocators between the different Netty
* channels. If enabled then only two pooled ByteBuf allocators are created: one where caching
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public Object convertToNetty() throws IOException {
return underlying.convertToNetty();
}

@Override
public Object convertToNettyForSsl() throws IOException {
return underlying.convertToNettyForSsl();
}

@Override
public int hashCode() {
return underlying.hashCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ private[spark] trait BlockData {
*/
def toNetty(): Object

/**
* Returns a Netty-friendly wrapper for the block's data.
*
* Please see `ManagedBuffer.convertToNettyForSsl()` for more details.
*/
def toNettyForSsl(): Object

def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer

def toByteBuffer(): ByteBuffer
Expand All @@ -103,6 +110,8 @@ private[spark] class ByteBufferBlockData(

override def toNetty(): Object = buffer.toNetty

override def toNettyForSsl(): AnyRef = buffer.toNettyForSsl

override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
buffer.copy(allocator)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ private[storage] class BlockManagerManagedBuffer(

override def convertToNetty(): Object = data.toNetty()

override def convertToNettyForSsl(): Object = data.toNettyForSsl()

override def retain(): ManagedBuffer = {
refCount.incrementAndGet()
val locked = blockInfoManager.lockForReading(blockId, blocking = false)
Expand Down
13 changes: 13 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/DiskStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,14 @@ private class DiskBlockData(
*/
override def toNetty(): AnyRef = new DefaultFileRegion(file, 0, size)

/**
* Returns a Netty-friendly wrapper for the block's data.
*
* Please see `ManagedBuffer.convertToNettyForSsl()` for more details.
*/
override def toNettyForSsl(): AnyRef =
toChunkedByteBuffer(ByteBuffer.allocate).toNettyForSsl

override def toChunkedByteBuffer(allocator: (Int) => ByteBuffer): ChunkedByteBuffer = {
Utils.tryWithResource(open()) { channel =>
var remaining = blockSize
Expand Down Expand Up @@ -234,6 +242,9 @@ private[spark] class EncryptedBlockData(

override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize)

override def toNettyForSsl(): AnyRef =
toChunkedByteBuffer(ByteBuffer.allocate).toNettyForSsl

override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
val source = open()
try {
Expand Down Expand Up @@ -297,6 +308,8 @@ private[spark] class EncryptedManagedBuffer(

override def convertToNetty(): AnyRef = blockData.toNetty()

override def convertToNettyForSsl(): AnyRef = blockData.toNettyForSsl()

override def createInputStream(): InputStream = blockData.toInputStream()

override def retain(): ManagedBuffer = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.channels.WritableByteChannel

import com.google.common.io.ByteStreams
import com.google.common.primitives.UnsignedBytes
import io.netty.handler.stream.ChunkedStream
import org.apache.commons.io.IOUtils

import org.apache.spark.SparkEnv
Expand Down Expand Up @@ -131,6 +132,14 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) extends Ex
new ChunkedByteBufferFileRegion(this, bufferWriteChunkSize)
}

/**
* Wrap this in a ChunkedStream which allows us to provide the data in a manner
* compatible with SSL encryption
*/
def toNettyForSsl: ChunkedStream = {
new ChunkedStream(toInputStream(), bufferWriteChunkSize)
}

/**
* Copy this buffer into a new byte array.
*
Expand Down Expand Up @@ -284,6 +293,17 @@ private[spark] class ChunkedByteBufferInputStream(
}
}

override def available(): Int = {
if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) {
currentChunk = chunks.next()
}
if (currentChunk != null && currentChunk.hasRemaining) {
currentChunk.remaining
} else {
0
}
}

override def read(): Int = {
if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) {
currentChunk = chunks.next()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits {
override def release(): ManagedBuffer = this

override def convertToNetty(): AnyRef = null

override def convertToNettyForSsl(): AnyRef = null
}
listener.onBlockFetchSuccess("block-id-unused", badBuffer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed
override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer()
override def createInputStream(): InputStream = underlyingBuffer.createInputStream()
override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty()
override def convertToNettyForSsl(): AnyRef = underlyingBuffer.convertToNettyForSsl()

override def retain(): ManagedBuffer = {
callsToRetain += 1
Expand Down

0 comments on commit b01dce2

Please sign in to comment.