diff --git a/README.md b/README.md index c55f69468..bdf2b53c0 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ See [Building And Artifacts](doc/12_building_and_artifacts.md) - [The Spark Shell](doc/13_spark_shell.md) - [DataFrames](doc/14_data_frames.md) - [Python](doc/15_python.md) + - [Partitioner](doc/16_partitioning.md) - [Frequently Asked Questions](doc/FAQ.md) - [Configuration Parameter Reference Table](doc/reference.md) diff --git a/doc/FAQ.md b/doc/FAQ.md index 2d9da0711..185289131 100644 --- a/doc/FAQ.md +++ b/doc/FAQ.md @@ -104,6 +104,12 @@ the rpc_address is set to. When troubleshooting Cassandra connections it is sometimes useful to set the rpc_address in the C* yaml file to `0.0.0.0` so any incoming connection will work. +### How does the connector evaluate number of Spark partitions? + +The Connector evaluates the number of Spark partitions by dividing table size estimate by +`input.split.size_in_mb` value. The resulting number of partitions in never smaller than +`1 + 2 * SparkContext.defaultParallelism`. + ### What does input.split.size_in_mb use to determine size? Input.split.size_in_mb uses a internal system table in C* ( >= 2.1.5) to determine the size diff --git a/doc/reference.md b/doc/reference.md index e8cadbf86..82e8c973a 100644 --- a/doc/reference.md +++ b/doc/reference.md @@ -203,7 +203,7 @@ OSS Cassandra this should never be used.
input.split.size_in_mb
1 + 2 * SparkContext.defaultParallelism
1 + 2 * SparkContext.defaultParallelism
+ |""".stripMargin.filter(_ >= ' '))
val FetchSizeInRowsParam = ConfigParameter[Int](
name = "spark.cassandra.input.fetch.size_in_rows",
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGenerator.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGenerator.scala
index cd68ad2af..3a08c7705 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGenerator.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGenerator.scala
@@ -19,8 +19,7 @@ import com.datastax.spark.connector.writer.RowWriterFactory
private[connector] class CassandraPartitionGenerator[V, T <: Token[V]](
connector: CassandraConnector,
tableDef: TableDef,
- splitCount: Option[Int],
- splitSize: Long)(
+ splitCount: Int)(
implicit
tokenFactory: TokenFactory[V, T]) extends Logging{
@@ -30,23 +29,11 @@ private[connector] class CassandraPartitionGenerator[V, T <: Token[V]](
private val keyspaceName = tableDef.keyspaceName
private val tableName = tableDef.tableName
- private val totalDataSize: Long = {
- // If we know both the splitCount and splitSize, we should pretend the total size of the data is
- // their multiplication. TokenRangeSplitter will try to produce splits of desired size, and this way
- // their number will be close to desired splitCount. Otherwise, if splitCount is not set,
- // we just go to C* and read the estimated data size from an appropriate system table
- splitCount match {
- case Some(c) => c * splitSize
- case None => new DataSizeEstimates(connector, keyspaceName, tableName).dataSizeInBytes
- }
- }
-
private def tokenRange(range: DriverTokenRange, metadata: Metadata): TokenRange = {
val startToken = tokenFactory.tokenFromString(range.getStart.getValue.toString)
val endToken = tokenFactory.tokenFromString(range.getEnd.getValue.toString)
val replicas = metadata.getReplicas(Metadata.quote(keyspaceName), range).map(_.getAddress).toSet
- val dataSize = (tokenFactory.ringFraction(startToken, endToken) * totalDataSize).toLong
- new TokenRange(startToken, endToken, replicas, dataSize)
+ new TokenRange(startToken, endToken, replicas, tokenFactory)
}
private def describeRing: Seq[TokenRange] = {
@@ -59,29 +46,19 @@ private[connector] class CassandraPartitionGenerator[V, T <: Token[V]](
* When we have a single Spark Partition use a single global range. This
* will let us more easily deal with Partition Key equals and In clauses
*/
- if (splitCount == Some(1)) {
- Seq(ranges(0).copy[V, T](tokenFactory.minToken, tokenFactory.maxToken))
+ if (splitCount == 1) {
+ Seq(ranges.head.copy[V, T](tokenFactory.minToken, tokenFactory.minToken))
} else {
ranges
}
}
- private def splitsOf(
- tokenRanges: Iterable[TokenRange],
- splitter: TokenRangeSplitter[V, T]): Iterable[TokenRange] = {
-
- val parTokenRanges = tokenRanges.par
- parTokenRanges.tasksupport = new ForkJoinTaskSupport(CassandraPartitionGenerator.pool)
- (for (tokenRange <- parTokenRanges;
- split <- splitter.split(tokenRange, splitSize)) yield split).seq
- }
-
private def createTokenRangeSplitter: TokenRangeSplitter[V, T] = {
tokenFactory.asInstanceOf[TokenFactory[_, _]] match {
case TokenFactory.RandomPartitionerTokenFactory =>
- new RandomPartitionerTokenRangeSplitter(totalDataSize).asInstanceOf[TokenRangeSplitter[V, T]]
+ new RandomPartitionerTokenRangeSplitter().asInstanceOf[TokenRangeSplitter[V, T]]
case TokenFactory.Murmur3TokenFactory =>
- new Murmur3PartitionerTokenRangeSplitter(totalDataSize).asInstanceOf[TokenRangeSplitter[V, T]]
+ new Murmur3PartitionerTokenRangeSplitter().asInstanceOf[TokenRangeSplitter[V, T]]
case _ =>
throw new UnsupportedOperationException(s"Unsupported TokenFactory $tokenFactory")
}
@@ -93,18 +70,20 @@ private[connector] class CassandraPartitionGenerator[V, T <: Token[V]](
def partitions: Seq[CassandraPartition[V, T]] = {
val tokenRanges = describeRing
val endpointCount = tokenRanges.map(_.replicas).reduce(_ ++ _).size
- val splitter = createTokenRangeSplitter
- val splits = splitsOf(tokenRanges, splitter).toSeq
val maxGroupSize = tokenRanges.size / endpointCount
- val clusterer = new TokenRangeClusterer[V, T](splitSize, maxGroupSize)
+
+ val splitter = createTokenRangeSplitter
+ val splits = splitter.split(tokenRanges, splitCount).toSeq
+
+ val clusterer = new TokenRangeClusterer[V, T](splitCount, maxGroupSize)
val tokenRangeGroups = clusterer.group(splits).toArray
val partitions = for (group <- tokenRangeGroups) yield {
val replicas = group.map(_.replicas).reduce(_ intersect _)
- val rowCount = group.map(_.dataSize).sum
+ val rowCount = group.map(_.rangeSize).sum
val cqlRanges = group.flatMap(rangeToCql)
// partition index will be set later
- CassandraPartition(0, replicas, cqlRanges, rowCount)
+ CassandraPartition(0, replicas, cqlRanges, rowCount.toLong)
}
// sort partitions and assign sequential numbers so that
@@ -139,14 +118,6 @@ private[connector] class CassandraPartitionGenerator[V, T <: Token[V]](
}
object CassandraPartitionGenerator {
- /** Affects how many concurrent threads are used to fetch split information from cassandra nodes, in `getPartitions`.
- * Does not affect how many Spark threads fetch data from Cassandra. */
- val MaxParallelism = 16
-
- /** How many token rangesContaining to sample in order to estimate average number of rows per token */
- val TokenRangeSampleSize = 16
-
- private val pool: ForkJoinPool = new ForkJoinPool(MaxParallelism)
type V = t forSome { type t }
type T = t forSome { type t <: Token[V] }
@@ -158,17 +129,9 @@ object CassandraPartitionGenerator {
def apply(
conn: CassandraConnector,
tableDef: TableDef,
- splitCount: Option[Int],
- splitSize: Int): CassandraPartitionGenerator[V, T] = {
+ splitCount: Int)(
+ implicit tokenFactory: TokenFactory[V, T]): CassandraPartitionGenerator[V, T] = {
- val tokenFactory = getTokenFactory(conn)
- new CassandraPartitionGenerator(conn, tableDef, splitCount, splitSize)(tokenFactory)
- }
-
- def getTokenFactory(conn: CassandraConnector) : TokenFactory[V, T] = {
- val partitionerName = conn.withSessionDo { session =>
- session.execute("SELECT partitioner FROM system.local").one().getString(0)
- }
- TokenFactory.forCassandraPartitioner(partitionerName)
+ new CassandraPartitionGenerator(conn, tableDef, splitCount)(tokenFactory)
}
}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitter.scala
index 87a81fe0f..0417a11f0 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitter.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitter.scala
@@ -1,34 +1,21 @@
package com.datastax.spark.connector.rdd.partitioner
-import com.datastax.spark.connector.rdd.partitioner.dht.{LongToken, TokenFactory, TokenRange}
+import com.datastax.spark.connector.rdd.partitioner.dht.LongToken
-/** Fast token range splitter assuming that data are spread out evenly in the whole range.
- * @param dataSize estimate of the size of the data in the whole ring */
-class Murmur3PartitionerTokenRangeSplitter(dataSize: Long)
+/** Fast token range splitter assuming that data are spread out evenly in the whole range. */
+private[partitioner] class Murmur3PartitionerTokenRangeSplitter
extends TokenRangeSplitter[Long, LongToken] {
- private val tokenFactory =
- TokenFactory.Murmur3TokenFactory
+ private type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[Long, LongToken]
- private type TR = TokenRange[Long, LongToken]
+ override def split(tokenRange: TokenRange, splitSize: Int): Seq[TokenRange] = {
+ val rangeSize = tokenRange.rangeSize
+ val splitPointsCount = if (rangeSize < splitSize) rangeSize.toInt else splitSize
+ val splitPoints = (0 until splitPointsCount).map({ i =>
+ new LongToken(tokenRange.start.value + (rangeSize * i / splitPointsCount).toLong)
+ }) :+ tokenRange.end
- /** Splits the token range uniformly into sub-ranges.
- * @param splitSize requested sub-split size, given in the same units as `dataSize` */
- def split(range: TR, splitSize: Long): Seq[TR] = {
- val rangeSize = range.dataSize
- val rangeTokenCount = tokenFactory.distance(range.start, range.end)
- val n = math.max(1, math.round(rangeSize.toDouble / splitSize).toInt)
-
- val left = range.start.value
- val right = range.end.value
- val splitPoints =
- (for (i <- 0 until n) yield left + (rangeTokenCount * i / n).toLong) :+ right
-
- for (Seq(l, r) <- splitPoints.sliding(2).toSeq) yield
- new TokenRange[Long, LongToken](
- new LongToken(l),
- new LongToken(r),
- range.replicas,
- rangeSize / n)
+ for (Seq(left, right) <- splitPoints.sliding(2).toSeq) yield
+ new TokenRange(left, right, tokenRange.replicas, tokenRange.tokenFactory)
}
}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitter.scala
index b896f52e7..e35e557bf 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitter.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitter.scala
@@ -1,40 +1,28 @@
package com.datastax.spark.connector.rdd.partitioner
-import com.datastax.spark.connector.rdd.partitioner.dht.{BigIntToken, TokenFactory, TokenRange}
+import com.datastax.spark.connector.rdd.partitioner.dht.BigIntToken
-
-/** Fast token range splitter assuming that data are spread out evenly in the whole range.
- * @param dataSize estimate of the size of the data in the whole ring */
-class RandomPartitionerTokenRangeSplitter(dataSize: Long)
+/** Fast token range splitter assuming that data are spread out evenly in the whole range. */
+private[partitioner] class RandomPartitionerTokenRangeSplitter
extends TokenRangeSplitter[BigInt, BigIntToken] {
- private val tokenFactory =
- TokenFactory.RandomPartitionerTokenFactory
+ private type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[BigInt, BigIntToken]
- private def wrap(token: BigInt): BigInt = {
- val max = tokenFactory.maxToken.value
+ private def wrapWithMax(max: BigInt)(token: BigInt): BigInt = {
if (token <= max) token else token - max
}
- private type TR = TokenRange[BigInt, BigIntToken]
-
- /** Splits the token range uniformly into sub-ranges.
- * @param splitSize requested sub-split size, given in the same units as `dataSize` */
- def split(range: TR, splitSize: Long): Seq[TR] = {
- val rangeSize = range.dataSize
- val rangeTokenCount = tokenFactory.distance(range.start, range.end)
- val n = math.max(1, math.round(rangeSize.toDouble / splitSize)).toInt
+ override def split(tokenRange: TokenRange, splitCount: Int): Seq[TokenRange] = {
+ val rangeSize = tokenRange.rangeSize
+ val wrap = wrapWithMax(tokenRange.tokenFactory.maxToken.value)(_)
- val left = range.start.value
- val right = range.end.value
- val splitPoints =
- (for (i <- 0 until n) yield wrap(left + (rangeTokenCount * i / n))) :+ right
+ val splitPointsCount = if (rangeSize < splitCount) rangeSize.toInt else splitCount
+ val splitPoints = (0 until splitPointsCount).map({ i =>
+ val nextToken: BigInt = tokenRange.start.value + (rangeSize * i / splitPointsCount)
+ new BigIntToken(wrap(nextToken))
+ }) :+ tokenRange.end
- for (Seq(l, r) <- splitPoints.sliding(2).toSeq) yield
- new TokenRange[BigInt, BigIntToken](
- new BigIntToken(l.bigInteger),
- new BigIntToken(r.bigInteger),
- range.replicas,
- rangeSize / n)
+ for (Seq(left, right) <- splitPoints.sliding(2).toSeq) yield
+ new TokenRange(left, right, tokenRange.replicas, tokenRange.tokenFactory)
}
}
\ No newline at end of file
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/SplitSizeEstimator.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/SplitSizeEstimator.scala
new file mode 100644
index 000000000..e67d38daa
--- /dev/null
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/SplitSizeEstimator.scala
@@ -0,0 +1,25 @@
+package com.datastax.spark.connector.rdd.partitioner
+
+import com.datastax.spark.connector.rdd.CassandraRDD
+import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory
+
+private[rdd] trait SplitSizeEstimator[R] {
+ this: CassandraRDD[R] =>
+
+ @transient implicit lazy val tokenFactory = TokenFactory.forSystemLocalPartitioner(connector)
+
+ private def estimateDataSize: Long =
+ new DataSizeEstimates(connector, keyspaceName, tableName).dataSizeInBytes
+
+ private[rdd] def minimalSplitCount: Int = {
+ val coreCount = context.defaultParallelism
+ 1 + coreCount * 2
+ }
+
+ def estimateSplitCount(splitSize: Int): Int = {
+ require(splitSize > 0, "Split size must be greater than zero.")
+ val splitCountEstimate = estimateDataSize / splitSize
+ Math.max(splitCountEstimate.toInt, minimalSplitCount)
+ }
+
+}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClusterer.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClusterer.scala
index 7edcd219a..19fafaab7 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClusterer.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClusterer.scala
@@ -5,20 +5,24 @@ import java.net.InetAddress
import Ordering.Implicits._
import com.datastax.spark.connector.rdd.partitioner.dht.{Token, TokenRange}
-
import scala.annotation.tailrec
-/** Divides a set of token rangesContaining into groups containing not more than `maxRowCountPerGroup` rows
- * and not more than `maxGroupSize` token rangesContaining. Each group will form a single `CassandraPartition`.
+import com.datastax.spark.connector.rdd.partitioner.TokenRangeClusterer.WholeRing
+
+/** Groups a set of token ranges into `groupCount` groups containing not more than `maxGroupSize` token
+ * ranges.
+ * Each group will form a single `CassandraRDDPartition`.
*
* The algorithm is as follows:
- * 1. Sort token rangesContaining by endpoints lexicographically.
- * 2. Take the highest possible number of token rangesContaining from the beginning of the list,
- * such that their sum of rowCounts does not exceed `maxRowCountPerGroup` and they all contain at
+ * 1. Sort token ranges by endpoints lexicographically.
+ * 2. Take the highest possible number of token ranges from the beginning of the list,
+ * such that their sum of ringFraction does not exceed `ringFractionPerGroup` and they all contain at
* least one common endpoint. If it is not possible, take at least one item.
- * Those token rangesContaining will make a group.
- * 3. Repeat the previous step until no more token rangesContaining left.*/
-class TokenRangeClusterer[V, T <: Token[V]](maxRowCountPerGroup: Long, maxGroupSize: Int = Int.MaxValue) {
+ * Those token ranges will make a group.
+ * 3. Repeat the previous step until no more token ranges left.*/
+class TokenRangeClusterer[V, T <: Token[V]](groupCount: Int, maxGroupSize: Int = Int.MaxValue) {
+
+ private val ringFractionPerGroup = WholeRing / groupCount.toDouble
private implicit object InetAddressOrdering extends Ordering[InetAddress] {
override def compare(x: InetAddress, y: InetAddress) =
@@ -27,22 +31,23 @@ class TokenRangeClusterer[V, T <: Token[V]](maxRowCountPerGroup: Long, maxGroupS
@tailrec
private def group(tokenRanges: Stream[TokenRange[V, T]],
- result: Vector[Seq[TokenRange[V, T]]]): Iterable[Seq[TokenRange[V, T]]] = {
+ result: Vector[Seq[TokenRange[V, T]]],
+ ringFractionPerGroup: Double): Iterable[Seq[TokenRange[V, T]]] = {
tokenRanges match {
case Stream.Empty => result
case head #:: rest =>
val firstEndpoint = head.replicas.min
- val rowCounts = tokenRanges.map(_.dataSize)
- val cumulativeRowCounts = rowCounts.scanLeft(0L)(_ + _).tail // drop first item always == 0
- val rowLimit = math.max(maxRowCountPerGroup, head.dataSize) // make sure first element will be always included
+ val ringFractions = tokenRanges.map(_.ringFraction)
+ val cumulativeRingFractions = ringFractions.scanLeft(.0)(_ + _).tail // drop first item always == 0
+ val ringFractionLimit = math.max(ringFractionPerGroup, head.ringFraction) // make sure first element will be always included
val cluster = tokenRanges
.take(math.max(1, maxGroupSize))
- .zip(cumulativeRowCounts)
- .takeWhile { case (tr, count) => count <= rowLimit && tr.replicas.min == firstEndpoint }
+ .zip(cumulativeRingFractions)
+ .takeWhile { case (tr, count) => count <= ringFractionLimit && tr.replicas.min == firstEndpoint }
.map(_._1)
.toVector
val remainingTokenRanges = tokenRanges.drop(cluster.length)
- group(remainingTokenRanges, result :+ cluster)
+ group(remainingTokenRanges, result :+ cluster, ringFractionPerGroup)
}
}
@@ -53,7 +58,10 @@ class TokenRangeClusterer[V, T <: Token[V]](maxRowCountPerGroup: Long, maxGroupS
// sort by endpoints lexicographically
// this way ranges on the same host are grouped together
val sortedRanges = tokenRanges.sortBy(_.replicas.toSeq.sorted)
- group(sortedRanges.toStream, Vector.empty)
+ group(sortedRanges.toStream, Vector.empty, ringFractionPerGroup)
}
+}
+object TokenRangeClusterer {
+ private val WholeRing = 1.0
}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeSplitter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeSplitter.scala
index 069947d5c..280f723f3 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeSplitter.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeSplitter.scala
@@ -1,16 +1,36 @@
package com.datastax.spark.connector.rdd.partitioner
+import scala.collection.parallel.ForkJoinTaskSupport
+import scala.concurrent.forkjoin.ForkJoinPool
+
+import com.datastax.spark.connector.rdd.partitioner.TokenRangeSplitter.WholeRing
import com.datastax.spark.connector.rdd.partitioner.dht.{Token, TokenRange}
-/** Splits a token range into smaller sub-ranges,
+
+/** Splits a token ranges into smaller sub-ranges,
* each with the desired approximate number of rows. */
-trait TokenRangeSplitter[V, T <: Token[V]] {
+private[partitioner] trait TokenRangeSplitter[V, T <: Token[V]] {
- /** Splits given token range into n equal sub-ranges. */
- def split(range: TokenRange[V, T], splitSize: Long): Seq[TokenRange[V, T]]
-}
+ def split(tokenRanges: Iterable[TokenRange[V, T]], splitCount: Int): Iterable[TokenRange[V, T]] = {
+ val ringFractionPerSplit = WholeRing / splitCount.toDouble
+ val parTokenRanges = tokenRanges.par
+ parTokenRanges.tasksupport = new ForkJoinTaskSupport(TokenRangeSplitter.pool)
+ parTokenRanges.flatMap(tokenRange => {
+ val splitCount = Math.rint(tokenRange.ringFraction / ringFractionPerSplit).toInt
+ split(tokenRange, math.max(1, splitCount))
+ }).toList
+ }
+ /** Splits the token range uniformly into splitCount sub-ranges. */
+ def split(tokenRange: TokenRange[V, T], splitCount: Int): Seq[TokenRange[V, T]]
+}
+
+object TokenRangeSplitter {
+ private val MaxParallelism = 16
+ private val WholeRing = 1.0
+ private val pool = new ForkJoinPool(MaxParallelism)
+}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenFactory.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenFactory.scala
index 10be6ac03..37135b0f3 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenFactory.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenFactory.scala
@@ -1,5 +1,7 @@
package com.datastax.spark.connector.rdd.partitioner.dht
+import com.datastax.spark.connector.cql.CassandraConnector
+
import scala.language.existentials
import com.datastax.spark.connector.rdd.partitioner.MonotonicBucketing
@@ -91,6 +93,13 @@ object TokenFactory {
}
partitioner.asInstanceOf[TokenFactory[V, T]]
}
+
+ def forSystemLocalPartitioner(connector: CassandraConnector): TokenFactory[V, T] = {
+ val partitionerClassName = connector.withSessionDo { session =>
+ session.execute("SELECT partitioner FROM system.local").one().getString(0)
+ }
+ forCassandraPartitioner(partitionerClassName)
+ }
}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRange.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRange.scala
index d0cd73437..2b0fc3b48 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRange.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRange.scala
@@ -3,7 +3,11 @@ package com.datastax.spark.connector.rdd.partitioner.dht
import java.net.InetAddress
case class TokenRange[V, T <: Token[V]] (
- start: T, end: T, replicas: Set[InetAddress], dataSize: Long) {
+ start: T, end: T, replicas: Set[InetAddress], tokenFactory: TokenFactory[V, T]) {
+
+ private[partitioner] lazy val rangeSize = tokenFactory.distance(start, end)
+
+ private[partitioner] lazy val ringFraction = tokenFactory.ringFraction(start, end)
def isWrappedAround(implicit tf: TokenFactory[V, T]): Boolean =
start >= end && end != tf.minToken
@@ -18,8 +22,8 @@ case class TokenRange[V, T <: Token[V]] (
val minToken = tf.minToken
if (isWrappedAround)
Seq(
- TokenRange(start, minToken, replicas, dataSize / 2),
- TokenRange(minToken, end, replicas, dataSize / 2))
+ TokenRange(start, minToken, replicas, tokenFactory),
+ TokenRange(minToken, end, replicas, tokenFactory))
else
Seq(this)
}
diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala
index 4c7f94788..6b220ead6 100644
--- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala
+++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala
@@ -15,6 +15,7 @@ import org.apache.spark.SparkConf
import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf, Schema}
import com.datastax.spark.connector.rdd.partitioner.CassandraPartitionGenerator._
import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates
+import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.forSystemLocalPartitioner
import com.datastax.spark.connector.rdd.{CassandraRDD, ReadConf}
import com.datastax.spark.connector.types.{InetType, UUIDType, VarIntType}
import com.datastax.spark.connector.util.Quote._
@@ -251,7 +252,7 @@ object CassandraSourceRelation {
val tableSizeInBytes = tableSizeInBytesString match {
case Some(size) => Option(size.toLong)
case None =>
- val tokenFactory = getTokenFactory(cassandraConnector)
+ val tokenFactory = forSystemLocalPartitioner(cassandraConnector)
val dataSizeInBytes =
new DataSizeEstimates(
cassandraConnector,
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/BucketingRangeIndexSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/BucketingRangeIndexSpec.scala
index 8cffb881e..c192e44a7 100644
--- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/BucketingRangeIndexSpec.scala
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/BucketingRangeIndexSpec.scala
@@ -123,7 +123,7 @@ class BucketingRangeIndexSpec extends FlatSpec with PropertyChecks with ShouldMa
longTokenFactory.minToken,
longTokenFactory.minToken,
Set.empty,
- 0)
+ Murmur3TokenFactory)
"Murmur3Bucketing" should " map all tokens to a single wrapping range" in {
@@ -146,7 +146,7 @@ class BucketingRangeIndexSpec extends FlatSpec with PropertyChecks with ShouldMa
bigTokenFactory.minToken,
bigTokenFactory.minToken,
Set.empty,
- 0)
+ RandomPartitionerTokenFactory)
"RandomBucketing" should " map all tokens to a single wrapping range" in {
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterSpec.scala
new file mode 100644
index 000000000..a6bffd576
--- /dev/null
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterSpec.scala
@@ -0,0 +1,35 @@
+package com.datastax.spark.connector.rdd.partitioner
+
+import java.net.InetAddress
+
+import org.scalatest.{FlatSpec, Matchers}
+
+import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory
+import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory._
+import com.datastax.spark.connector.rdd.partitioner.dht.{LongToken, TokenRange}
+
+class Murmur3PartitionerTokenRangeSplitterSpec
+ extends FlatSpec
+ with SplitterBehaviors[Long, LongToken]
+ with Matchers {
+
+ private val splitter = new Murmur3PartitionerTokenRangeSplitter
+
+ "Murmur3PartitionerSplitter" should "split tokens" in testSplittingTokens(splitter)
+
+ it should "split token sequences" in testSplittingTokenSequences(splitter)
+
+ override def splitWholeRingIn(count: Int): Seq[TokenRange[Long, LongToken]] = {
+ val hugeTokensIncrement = totalTokenCount / count
+ (0 until count).map(i =>
+ range(minToken.value + i * hugeTokensIncrement, minToken.value + (i + 1) * hugeTokensIncrement)
+ )
+ }
+
+ override def range(start: BigInt, end: BigInt): TokenRange[Long, LongToken] =
+ new TokenRange[Long, LongToken](
+ LongToken(start.toLong),
+ LongToken(end.toLong),
+ Set(InetAddress.getLocalHost),
+ Murmur3TokenFactory)
+}
\ No newline at end of file
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterTest.scala
deleted file mode 100644
index c034452fc..000000000
--- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterTest.scala
+++ /dev/null
@@ -1,84 +0,0 @@
-package com.datastax.spark.connector.rdd.partitioner
-
-import java.net.InetAddress
-
-import org.junit.Assert._
-import org.junit.Test
-
-import com.datastax.spark.connector.rdd.partitioner.dht.LongToken
-import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory
-
-class Murmur3PartitionerTokenRangeSplitterTest {
-
- type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[Long, LongToken]
-
- private def assertNoHoles(tokenRanges: Seq[TokenRange]) {
- for (Seq(range1, range2) <- tokenRanges.sliding(2))
- assertEquals(range1.end, range2.start)
- }
-
- private def assertSimilarSize(tokenRanges: Seq[TokenRange]): Unit = {
- val sizes = tokenRanges.map(r => Murmur3TokenFactory.distance(r.start, r.end)).toVector
- val maxSize = sizes.max.toDouble
- val minSize = sizes.min.toDouble
- assertTrue(s"maxSize / minSize = ${maxSize / minSize} > 1.01", maxSize / minSize <= 1.01)
- }
-
- @Test
- def testSplit() {
- val node = InetAddress.getLocalHost
- val dataSize = 1000
- val splitter = new Murmur3PartitionerTokenRangeSplitter(dataSize)
- val range = new TokenRange(LongToken(0), LongToken(0), Set(node), dataSize)
- val out = splitter.split(range, 100)
-
- assertEquals(10, out.size)
- assertEquals(0L, out.head.start.value)
- assertEquals(0L, out.last.end.value)
- assertTrue(out.forall(s => s.end.value != s.start.value))
- assertTrue(out.forall(_.replicas == Set(node)))
- assertNoHoles(out)
- assertSimilarSize(out)
- }
-
-
- @Test
- def testNoSplit() {
- val splitter = new Murmur3PartitionerTokenRangeSplitter(1000)
- val range = new TokenRange(LongToken(0), new LongToken(100), Set.empty, 0)
- val out = splitter.split(range, 500)
-
- // range is too small to contain 500 units
- assertEquals(1, out.size)
- assertEquals(0L, out.head.start.value)
- assertEquals(100L, out.last.end.value)
- }
-
- @Test
- def testZeroRows() {
- val dataSize = 0
- val splitter = new Murmur3PartitionerTokenRangeSplitter(dataSize)
- val range = new TokenRange(LongToken(0), LongToken(100), Set.empty, dataSize)
- val out = splitter.split(range, 500)
- assertEquals(1, out.size)
- assertEquals(0L, out.head.start.value)
- assertEquals(100L, out.last.end.value)
- }
-
- @Test
- def testWrapAround() {
- val dataSize = 2000
- val splitter = new Murmur3PartitionerTokenRangeSplitter(dataSize)
- val start = Murmur3TokenFactory.maxToken.value - Long.MaxValue / 2
- val end = Murmur3TokenFactory.minToken.value + Long.MaxValue / 2
- val range = new TokenRange(LongToken(start), LongToken(end), Set.empty, dataSize / 2)
- val splits = splitter.split(range, 100)
-
- // range is half of the ring; 2000 * 0.5 / 100 = 10
- assertEquals(10, splits.size)
- assertEquals(start, splits.head.start.value)
- assertEquals(end, splits.last.end.value)
- assertNoHoles(splits)
- assertSimilarSize(splits)
- }
-}
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterSpec.scala
new file mode 100644
index 000000000..99af7667a
--- /dev/null
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterSpec.scala
@@ -0,0 +1,35 @@
+package com.datastax.spark.connector.rdd.partitioner
+
+import java.net.InetAddress
+
+import org.scalatest.{Matchers, _}
+
+import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.RandomPartitionerTokenFactory
+import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.RandomPartitionerTokenFactory.{minToken, totalTokenCount}
+import com.datastax.spark.connector.rdd.partitioner.dht.{BigIntToken, TokenRange}
+
+class RandomPartitionerTokenRangeSplitterSpec
+ extends FlatSpec
+ with SplitterBehaviors[BigInt, BigIntToken]
+ with Matchers {
+
+ private val splitter = new RandomPartitionerTokenRangeSplitter
+
+ "RandomPartitionerSplitter" should "split tokens" in testSplittingTokens(splitter)
+
+ it should "split token sequences" in testSplittingTokenSequences(splitter)
+
+ override def splitWholeRingIn(count: Int): Seq[TokenRange[BigInt, BigIntToken]] = {
+ val hugeTokensIncrement = totalTokenCount / count
+ (0 until count).map(i =>
+ range(minToken.value + i * hugeTokensIncrement, minToken.value + (i + 1) * hugeTokensIncrement)
+ )
+ }
+
+ override def range(start: BigInt, end: BigInt): TokenRange[BigInt, BigIntToken] =
+ new TokenRange[BigInt, BigIntToken](
+ BigIntToken(start),
+ BigIntToken(end),
+ Set(InetAddress.getLocalHost),
+ RandomPartitionerTokenFactory)
+}
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterTest.scala
deleted file mode 100644
index ab7a06922..000000000
--- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterTest.scala
+++ /dev/null
@@ -1,91 +0,0 @@
-package com.datastax.spark.connector.rdd.partitioner
-
-import java.net.InetAddress
-
-import org.junit.Assert._
-import org.junit.Test
-
-import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.RandomPartitionerTokenFactory
-import com.datastax.spark.connector.rdd.partitioner.dht.{BigIntToken, TokenFactory}
-
-class RandomPartitionerTokenRangeSplitterTest {
-
- type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[BigInt, BigIntToken]
-
- private def assertNoHoles(tokenRanges: Seq[TokenRange]) {
- for (Seq(range1, range2) <- tokenRanges.sliding(2))
- assertEquals(range1.end, range2.start)
- }
-
- private def assertSimilarSize(tokenRanges: Seq[TokenRange]): Unit = {
- val sizes = tokenRanges.map(r => RandomPartitionerTokenFactory.distance(r.start, r.end)).toVector
- val maxSize = sizes.max.toDouble
- val minSize = sizes.min.toDouble
- assertTrue(s"maxSize / minSize = ${maxSize / minSize} > 1.01", maxSize / minSize <= 1.01)
- }
-
- @Test
- def testSplit() {
- val dataSize = 1000
- val node = InetAddress.getLocalHost
- val splitter = new RandomPartitionerTokenRangeSplitter(dataSize)
- val rangeLeft = BigInt("0")
- val rangeRight = BigInt("0")
- val range = new TokenRange(BigIntToken(rangeLeft), BigIntToken(rangeRight), Set(node), dataSize)
- val out = splitter.split(range, 100)
-
- assertEquals(10, out.size)
- assertEquals(rangeLeft, out.head.start.value)
- assertEquals(rangeRight, out.last.end.value)
- assertTrue(out.forall(_.replicas == Set(node)))
- assertNoHoles(out)
- assertSimilarSize(out)
- }
-
- @Test
- def testNoSplit() {
- val splitter = new RandomPartitionerTokenRangeSplitter(1000)
- val rangeLeft = BigInt("0")
- val rangeRight = BigInt("100")
- val range = new TokenRange(BigIntToken(rangeLeft), BigIntToken(rangeRight), Set.empty, 0)
- val out = splitter.split(range, 500)
-
- // range is too small to contain 500 rows
- assertEquals(1, out.size)
- assertEquals(rangeLeft, out.head.start.value)
- assertEquals(rangeRight, out.last.end.value)
- }
-
- @Test
- def testZeroRows() {
- val splitter = new RandomPartitionerTokenRangeSplitter(0)
- val rangeLeft = BigInt("0")
- val rangeRight = BigInt("100")
- val range = new TokenRange(BigIntToken(rangeLeft), BigIntToken(rangeRight), Set.empty, 0)
- val out = splitter.split(range, 500)
- assertEquals(1, out.size)
- assertEquals(rangeLeft, out.head.start.value)
- assertEquals(rangeRight, out.last.end.value)
- }
-
- @Test
- def testWrapAround() {
- val dataSize = 2000
- val splitter = new RandomPartitionerTokenRangeSplitter(dataSize)
- val totalTokenCount = RandomPartitionerTokenFactory.totalTokenCount
- val rangeLeft = RandomPartitionerTokenFactory.maxToken.value - totalTokenCount / 4
- val rangeRight = RandomPartitionerTokenFactory.minToken.value + totalTokenCount / 4
- val range = new TokenRange(
- new BigIntToken(rangeLeft),
- new BigIntToken(rangeRight),
- Set.empty,
- dataSize / 2)
- val out = splitter.split(range, 100)
- assertEquals(10, out.size)
- assertEquals(rangeLeft, out.head.start.value)
- assertEquals(rangeRight, out.last.end.value)
- assertNoHoles(out)
- assertSimilarSize(out)
- }
-
-}
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/SplitterBehaviors.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/SplitterBehaviors.scala
new file mode 100644
index 000000000..da2f95fa3
--- /dev/null
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/SplitterBehaviors.scala
@@ -0,0 +1,97 @@
+package com.datastax.spark.connector.rdd.partitioner
+
+import org.scalatest.Matchers
+
+import com.datastax.spark.connector.rdd.partitioner.dht.{Token, TokenRange}
+
+private[partitioner] trait SplitterBehaviors[V, T <: Token[V]] {
+ this: Matchers =>
+
+ case class SplitResult(splitCount: Int, minSize: BigInt, maxSize: BigInt)
+
+ def splitWholeRingIn(count: Int): Seq[TokenRange[V, T]]
+
+ def range(start: BigInt, end: BigInt): TokenRange[V, T]
+
+ def splittedIn(splitCount: Int): Int = splitCount
+
+ def outputs(splits: Int, withSize: BigInt, sizeTolerance: BigInt = BigInt(0)): SplitResult =
+ SplitResult(splits, withSize - sizeTolerance, withSize + sizeTolerance)
+
+ def testSplittingTokens(splitter: => TokenRangeSplitter[V, T]) {
+ val hugeRanges = splitWholeRingIn(10)
+
+ val splitCases = Seq[(TokenRange[V, T], Int, SplitResult)](
+ (range(start = 0, end = 1), splittedIn(1), outputs(splits = 1, withSize = 1)),
+ (range(start = 0, end = 10), splittedIn(1), outputs(splits = 1, withSize = 10)),
+ (range(start = 0, end = 1), splittedIn(10), outputs(splits = 1, withSize = 1)),
+
+ (range(start = 0, end = 9), splittedIn(10), outputs(splits = 9, withSize = 1)),
+ (range(start = 0, end = 10), splittedIn(10), outputs(splits = 10, withSize = 1)),
+ (range(start = 0, end = 11), splittedIn(10), outputs(splits = 10, withSize = 1, sizeTolerance = 1)),
+
+ (range(start = 10, end = 50), splittedIn(10), outputs(splits = 10, withSize = 4)),
+ (range(start = 0, end = 1000), splittedIn(3), outputs(splits = 3, withSize = 333, sizeTolerance = 1)),
+
+ (hugeRanges.head, splittedIn(100), outputs(splits = 100, withSize = hugeRanges.head.rangeSize / 100, sizeTolerance = 4)),
+ (hugeRanges.last, splittedIn(100), outputs(splits = 100, withSize = hugeRanges.last.rangeSize / 100, sizeTolerance = 4))
+ )
+
+ for ((range, splittedIn, expected) <- splitCases) {
+
+ val splits = splitter.split(range, splittedIn)
+
+ withClue(s"Splitting range (${range.start}, ${range.end}) in $splittedIn splits failed.") {
+ splits.size should be(expected.splitCount)
+ splits.head.start should be(range.start)
+ splits.last.end should be(range.end)
+ splits.foreach(_.replicas should be(range.replicas))
+ splits.foreach(s => s.rangeSize should (be >= expected.minSize and be <= expected.maxSize))
+ splits.map(_.rangeSize).sum should be(range.rangeSize)
+ splits.map(_.ringFraction).sum should be(range.ringFraction +- .000000001)
+ for (Seq(range1, range2) <- splits.sliding(2)) range1.end should be(range2.start)
+ }
+ }
+ }
+
+ def testSplittingTokenSequences(splitter: TokenRangeSplitter[V, T]) {
+ val mediumRanges = splitWholeRingIn(100)
+ val wholeRingSize = mediumRanges.map(_.rangeSize).sum
+
+ val splitCases = Seq[(Seq[TokenRange[V, T]], Int, SplitResult)](
+ // we have 100 ranges, so 100 splits is minimum
+ (mediumRanges, splittedIn(3), outputs(splits = 100, withSize = wholeRingSize / 100)),
+
+ (mediumRanges, splittedIn(100), outputs(splits = 100, withSize = wholeRingSize / 100)),
+
+ (mediumRanges, splittedIn(101), outputs(splits = 100, withSize = wholeRingSize / 100)),
+
+ (mediumRanges, splittedIn(149), outputs(splits = 100, withSize = wholeRingSize / 100)),
+
+ (mediumRanges, splittedIn(150), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1)),
+
+ (mediumRanges, splittedIn(151), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1)),
+
+ (mediumRanges, splittedIn(199), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1)),
+
+ (mediumRanges, splittedIn(200), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1)),
+
+ (mediumRanges, splittedIn(201), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1))
+ )
+
+ for ((ranges, splittedIn, expected) <- splitCases) {
+
+ val splits = splitter.split(ranges, splittedIn)
+
+ withClue(s"Splitting ${ranges.size} ranges in $splittedIn splits failed.") {
+ splits.size should be(expected.splitCount)
+ splits.head.start should be(ranges.head.start)
+ splits.last.end should be(ranges.last.end)
+ splits.foreach(s => s.rangeSize should (be >= expected.minSize and be <= expected.maxSize))
+ splits.map(_.rangeSize).sum should be(ranges.map(_.rangeSize).sum)
+ splits.map(_.ringFraction).sum should be(1.0 +- .000000001)
+ for (Seq(range1, range2) <- splits.sliding(2)) range1.end should be(range2.start)
+ }
+ }
+ }
+}
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClustererTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClustererTest.scala
index 3030b6670..71b7f5822 100644
--- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClustererTest.scala
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClustererTest.scala
@@ -2,22 +2,24 @@ package com.datastax.spark.connector.rdd.partitioner
import java.net.InetAddress
+import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory
+import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory.{maxToken, minToken}
+import com.datastax.spark.connector.rdd.partitioner.dht.{LongToken, TokenRange}
import org.junit.Assert._
import org.junit.Test
-import com.datastax.spark.connector.rdd.partitioner.dht.LongToken
-
class TokenRangeClustererTest {
- type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[Long, LongToken]
-
val node1 = InetAddress.getByName("192.168.123.1")
val node2 = InetAddress.getByName("192.168.123.2")
val node3 = InetAddress.getByName("192.168.123.3")
val node4 = InetAddress.getByName("192.168.123.4")
val node5 = InetAddress.getByName("192.168.123.5")
- private def token(x: Long) = new com.datastax.spark.connector.rdd.partitioner.dht.LongToken(x)
+ private def tokenRange(start: Long, end: Long, nodes: Set[InetAddress]): TokenRange[Long, LongToken] =
+ new TokenRange[Long, LongToken](new LongToken(start), new LongToken(end), nodes, Murmur3TokenFactory)
+
+ private implicit def tokenToLong(token: LongToken): Long = token.value
@Test
def testEmpty() {
@@ -28,9 +30,9 @@ class TokenRangeClustererTest {
@Test
def testTrivialClustering() {
- val tr1 = new TokenRange(token(0), token(10), Set(node1), 5)
- val tr2 = new TokenRange(token(10), token(20), Set(node1), 5)
- val trc = new TokenRangeClusterer[Long, LongToken](10)
+ val tr1 = tokenRange(start = 0, end = 10, nodes = Set(node1))
+ val tr2 = tokenRange(start = 10, end = 20, nodes = Set(node1))
+ val trc = new TokenRangeClusterer[Long, LongToken](1)
val groups = trc.group(Seq(tr1, tr2))
assertEquals(1, groups.size)
assertEquals(Set(tr1, tr2), groups.head.toSet)
@@ -38,12 +40,12 @@ class TokenRangeClustererTest {
@Test
def testSplitByHost() {
- val tr1 = new TokenRange(token(0), token(10), Set(node1), 2)
- val tr2 = new TokenRange(token(10), token(20), Set(node1), 2)
- val tr3 = new TokenRange(token(20), token(30), Set(node2), 2)
- val tr4 = new TokenRange(token(30), token(40), Set(node2), 2)
+ val tr1 = tokenRange(start = 0, end = 10, nodes = Set(node1))
+ val tr2 = tokenRange(start = 10, end = 20, nodes = Set(node1))
+ val tr3 = tokenRange(start = 20, end = 30, nodes = Set(node2))
+ val tr4 = tokenRange(start = 30, end = 40, nodes = Set(node2))
- val trc = new TokenRangeClusterer[Long, LongToken](10)
+ val trc = new TokenRangeClusterer[Long, LongToken](1)
val groups = trc.group(Seq(tr1, tr2, tr3, tr4)).map(_.toSet).toSet
assertEquals(2, groups.size)
assertTrue(groups.contains(Set(tr1, tr2)))
@@ -52,36 +54,25 @@ class TokenRangeClustererTest {
@Test
def testSplitByCount() {
- val tr1 = new TokenRange(token(0), token(10), Set(node1), 5)
- val tr2 = new TokenRange(token(10), token(20), Set(node1), 5)
- val tr3 = new TokenRange(token(20), token(30), Set(node1), 5)
- val tr4 = new TokenRange(token(30), token(40), Set(node1), 5)
+ val tr1 = tokenRange(start = minToken, end = minToken / 2, Set(node1))
+ val tr2 = tokenRange(start = minToken / 2, end = 0, Set(node1))
+ val tr3 = tokenRange(start = 0, end = maxToken / 2, Set(node1))
+ val tr4 = tokenRange(start = maxToken / 2, end = maxToken, Set(node1))
- val trc = new TokenRangeClusterer[Long, LongToken](10)
+ val trc = new TokenRangeClusterer[Long, LongToken](2)
val groups = trc.group(Seq(tr1, tr2, tr3, tr4)).map(_.toSet).toSet
assertEquals(2, groups.size)
assertTrue(groups.contains(Set(tr1, tr2)))
assertTrue(groups.contains(Set(tr3, tr4)))
}
- @Test
- def testTooLargeRanges() {
- val tr1 = new TokenRange(token(0), token(10), Set(node1), 100000)
- val tr2 = new TokenRange(token(10), token(20), Set(node1), 100000)
- val trc = new TokenRangeClusterer[Long, LongToken](10)
- val groups = trc.group(Seq(tr1, tr2)).map(_.toSet).toSet
- assertEquals(2, groups.size)
- assertTrue(groups.contains(Set(tr1)))
- assertTrue(groups.contains(Set(tr2)))
- }
-
@Test
def testMultipleEndpoints() {
- val tr1 = new TokenRange(token(0), token(10), Set(node2, node1, node3), 1)
- val tr2 = new TokenRange(token(10), token(20), Set(node1, node3, node4), 1)
- val tr3 = new TokenRange(token(20), token(30), Set(node3, node1, node5), 1)
- val tr4 = new TokenRange(token(30), token(40), Set(node3, node1, node4), 1)
- val trc = new TokenRangeClusterer[Long, LongToken](10)
+ val tr1 = tokenRange(start = 0, end = 10, nodes = Set(node2, node1, node3))
+ val tr2 = tokenRange(start = 10, end = 20, nodes = Set(node1, node3, node4))
+ val tr3 = tokenRange(start = 20, end = 30, nodes = Set(node3, node1, node5))
+ val tr4 = tokenRange(start = 30, end = 40, nodes = Set(node3, node1, node4))
+ val trc = new TokenRangeClusterer[Long, LongToken](1)
val groups = trc.group(Seq(tr1, tr2, tr3, tr4))
assertEquals(1, groups.size)
assertEquals(4, groups.head.size)
@@ -89,13 +80,12 @@ class TokenRangeClustererTest {
}
@Test
- def testMaxClusterSize() {
- val tr1 = new TokenRange(token(0), token(10), Set(node1, node2, node3), 1)
- val tr2 = new TokenRange(token(10), token(20), Set(node1, node2, node3), 1)
- val tr3 = new TokenRange(token(20), token(30), Set(node1, node2, node3), 1)
- val trc = new TokenRangeClusterer[Long, LongToken](maxRowCountPerGroup = 10, maxGroupSize = 1)
+ def testMaxGroupSize() {
+ val tr1 = tokenRange(start = 0, end = 10, nodes = Set(node1, node2, node3))
+ val tr2 = tokenRange(start = 10, end = 20, nodes = Set(node1, node2, node3))
+ val tr3 = tokenRange(start = 20, end = 30, nodes = Set(node1, node2, node3))
+ val trc = new TokenRangeClusterer[Long, LongToken](groupCount = 1, maxGroupSize = 1)
val groups = trc.group(Seq(tr1, tr2, tr3))
assertEquals(3, groups.size)
}
-
-}
+}
\ No newline at end of file
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRangeSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRangeSpec.scala
index 1736153b2..1451823af 100644
--- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRangeSpec.scala
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRangeSpec.scala
@@ -2,7 +2,8 @@ package com.datastax.spark.connector.rdd.partitioner.dht
import org.scalatest.{FlatSpec, Matchers}
-import com.datastax.driver.core.{TokenRange => DTokenRange, Token => DToken}
+import com.datastax.driver.core.{Token => DToken, TokenRange => DTokenRange}
+import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.{Murmur3TokenFactory, RandomPartitionerTokenFactory}
class TokenRangeSpec extends FlatSpec with Matchers {
@@ -11,7 +12,7 @@ class TokenRangeSpec extends FlatSpec with Matchers {
"LongRanges " should " contain tokens with easy no wrapping bounds" in {
- val lr = new LongRange(LongToken(-100), LongToken(10000), Set.empty, 0)
+ val lr = new LongRange(LongToken(-100), LongToken(10000), Set.empty, Murmur3TokenFactory)
//Tokens Inside
for (l <- 1 to 1000) {
lr.contains(LongToken(l)) should be (true)
@@ -24,7 +25,7 @@ class TokenRangeSpec extends FlatSpec with Matchers {
}
it should " contain tokens with wrapping bounds in" in {
- val lr = new LongRange(LongToken(1000), LongToken(-1000), Set.empty, 0)
+ val lr = new LongRange(LongToken(1000), LongToken(-1000), Set.empty, Murmur3TokenFactory)
//Tokens Inside
for (l <- 30000 to 30500) {
@@ -39,7 +40,7 @@ class TokenRangeSpec extends FlatSpec with Matchers {
}
"BigRanges " should " contain tokens with easy no wrapping bounds" in {
- val lr = new BigRange(BigIntToken(-100), BigIntToken(10000), Set.empty, 0)
+ val lr = new BigRange(BigIntToken(-100), BigIntToken(10000), Set.empty, RandomPartitionerTokenFactory)
//Tokens Inside
for (l <- 1 to 1000) {
lr.contains(BigIntToken(l)) should be (true)
@@ -52,7 +53,7 @@ class TokenRangeSpec extends FlatSpec with Matchers {
}
it should " contain tokens with wrapping bounds in" in {
- val lr = new BigRange(BigIntToken(1000), BigIntToken(100), Set.empty, 0)
+ val lr = new BigRange(BigIntToken(1000), BigIntToken(100), Set.empty, RandomPartitionerTokenFactory)
//Tokens Inside
for (l <- 0 to 50) {