diff --git a/README.md b/README.md index c55f69468..bdf2b53c0 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ See [Building And Artifacts](doc/12_building_and_artifacts.md) - [The Spark Shell](doc/13_spark_shell.md) - [DataFrames](doc/14_data_frames.md) - [Python](doc/15_python.md) + - [Partitioner](doc/16_partitioning.md) - [Frequently Asked Questions](doc/FAQ.md) - [Configuration Parameter Reference Table](doc/reference.md) diff --git a/doc/FAQ.md b/doc/FAQ.md index 2d9da0711..185289131 100644 --- a/doc/FAQ.md +++ b/doc/FAQ.md @@ -104,6 +104,12 @@ the rpc_address is set to. When troubleshooting Cassandra connections it is sometimes useful to set the rpc_address in the C* yaml file to `0.0.0.0` so any incoming connection will work. +### How does the connector evaluate number of Spark partitions? + +The Connector evaluates the number of Spark partitions by dividing table size estimate by +`input.split.size_in_mb` value. The resulting number of partitions in never smaller than +`1 + 2 * SparkContext.defaultParallelism`. + ### What does input.split.size_in_mb use to determine size? Input.split.size_in_mb uses a internal system table in C* ( >= 2.1.5) to determine the size diff --git a/doc/reference.md b/doc/reference.md index e8cadbf86..82e8c973a 100644 --- a/doc/reference.md +++ b/doc/reference.md @@ -203,7 +203,7 @@ OSS Cassandra this should never be used. input.split.size_in_mb 64 - Approx amount of data to be fetched into a Spark partition + Approx amount of data to be fetched into a Spark partition. Minimum number of resulting Spark partitions is 1 + 2 * SparkContext.defaultParallelism diff --git a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraJavaRDDSpec.scala b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraJavaRDDSpec.scala index 12c594fa3..62df97ad4 100644 --- a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraJavaRDDSpec.scala +++ b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraJavaRDDSpec.scala @@ -2,19 +2,16 @@ package com.datastax.spark.connector.rdd import java.io.IOException -import scala.concurrent.duration.Duration -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import com.datastax.spark.connector._ import com.datastax.spark.connector.cql.CassandraConnector -import com.datastax.spark.connector.embedded.SparkTemplate._ import com.datastax.spark.connector.embedded._ import com.datastax.spark.connector.japi.CassandraJavaUtil._ import com.datastax.spark.connector.japi.CassandraRow import com.datastax.spark.connector.types.TypeConverter import org.apache.commons.lang3.tuple import org.apache.spark.api.java.function.{Function => JFunction} - import scala.collection.JavaConversions._ class CassandraJavaRDDSpec extends SparkCassandraITFlatSpecBase { @@ -80,6 +77,13 @@ class CassandraJavaRDDSpec extends SparkCassandraITFlatSpecBase { session.execute(s"INSERT INTO $ks.wide_rows(key, group, value) VALUES (20, 20, '2020')") session.execute(s"INSERT INTO $ks.wide_rows(key, group, value) VALUES (20, 21, '2021')") session.execute(s"INSERT INTO $ks.wide_rows(key, group, value) VALUES (20, 22, '2022')") + }, + + Future { + session.execute(s"CREATE TABLE $ks.limit_test_table (key INT, value TEXT, PRIMARY KEY (key))") + for(i <- 0 to 30) { + session.execute(s"INSERT INTO $ks.limit_test_table (key, value) VALUES ($i, '$i')") + } } ) } @@ -413,9 +417,10 @@ class CassandraJavaRDDSpec extends SparkCassandraITFlatSpecBase { } it should "allow to set limit" in { - val rdd = javaFunctions(sc).cassandraTable(ks, "test_table").limit(1L) + val limit = 1 + val rdd = javaFunctions(sc).cassandraTable(ks, "limit_test_table").limit(limit.toLong) val result = rdd.collect() - result should have size 1 + result.size shouldBe <= (rdd.getNumPartitions * limit) } it should "allow to set ascending ordering" in { diff --git a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala index 60caecfa7..062b23e5e 100644 --- a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala +++ b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala @@ -357,11 +357,6 @@ class CassandraRDDSpec extends SparkCassandraITFlatSpecBase { result.head.getString("value") should startWith("000") } - it should "use a single partition per node for a tiny table" in { - val rdd = sc.cassandraTable(ks, "key_value") - rdd.partitions should have length conn.hosts.size - } - it should "support single partition where clauses" in { val someCass = sc .cassandraTable[KeyValue](ks, "key_value") diff --git a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraTableScanRDDSpec.scala b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraTableScanRDDSpec.scala new file mode 100644 index 000000000..9c5203968 --- /dev/null +++ b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraTableScanRDDSpec.scala @@ -0,0 +1,116 @@ +package com.datastax.spark.connector.rdd + +import org.apache.cassandra.tools.NodeProbe +import org.scalatest.Inspectors + +import com.datastax.spark.connector.SparkCassandraITFlatSpecBase +import com.datastax.spark.connector.cql.CassandraConnector +import com.datastax.spark.connector.embedded.{CassandraRunner, EmbeddedCassandra} +import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory + +class CassandraTableScanRDDSpec extends SparkCassandraITFlatSpecBase with Inspectors { + + useCassandraConfig(Seq("cassandra-default.yaml.template")) + useSparkConf(defaultConf) + + val conn = CassandraConnector(defaultConf) + val tokenFactory = TokenFactory.forSystemLocalPartitioner(conn) + val tableName = "data" + val noMinimalThreshold = Int.MinValue + + "CassandraTableScanRDD" should "favor user provided split count over minimal threshold" in { + val userProvidedSplitCount = 8 + val minimalSplitCountThreshold = 32 + val rddWith64MB = getCassandraTableScanRDD(splitSizeMB = 1, splitCount = Some(userProvidedSplitCount), + minimalSplitCountThreshold = minimalSplitCountThreshold) + + val partitions = rddWith64MB.getPartitions + + partitions.length should be(userProvidedSplitCount +- 1) + } + + it should "favor user provided split count over size-estimated partitions" in { + val userProvidedSplitCount = 8 + val rddWith64MB = getCassandraTableScanRDD(splitSizeMB = 1, splitCount = Some(userProvidedSplitCount), + minimalSplitCountThreshold = noMinimalThreshold) + + val partitions = rddWith64MB.getPartitions + + partitions.length should be(userProvidedSplitCount +- 1) + } + + it should "create size-estimated partitions with splitSize size" in { + val rddWith64MB = getCassandraTableScanRDD(splitSizeMB = 1, minimalSplitCountThreshold = noMinimalThreshold) + + val partitions = rddWith64MB.getPartitions + + // theoretically there should be 64 splits, but it is ok to be "a little" inaccurate + partitions.length should (be >= 16 and be <= 256) + } + + it should "create size-estimated partitions when above minimal threshold" in { + val minimalSplitCountThreshold = 2 + val rddWith64MB = getCassandraTableScanRDD(splitSizeMB = 1, minimalSplitCountThreshold = minimalSplitCountThreshold) + + val partitions = rddWith64MB.getPartitions + + // theoretically there should be 64 splits, but it is ok to be "a little" inaccurate + partitions.length should (be >= 16 and be <= 256) + } + + it should "create size-estimated partitions but not less than minimum partitions threshold" in { + val minimalSplitCountThreshold = 64 + val rddWith64MB = getCassandraTableScanRDD(splitSizeMB = 32, minimalSplitCountThreshold = minimalSplitCountThreshold) + + val partitions = rddWith64MB.getPartitions + + partitions.length should be >= minimalSplitCountThreshold + } + + it should "align index fields of partitions with their place in the array" in { + val minimalSplitCountThreshold = 64 + val rddWith64MB = getCassandraTableScanRDD(splitSizeMB = 32, minimalSplitCountThreshold = minimalSplitCountThreshold) + + val partitions = rddWith64MB.getPartitions + + forAll(partitions.zipWithIndex) { case (part, index) => part.index should be(index) } + } + + override def beforeAll(): Unit = { + conn.withSessionDo { session => + + session.execute(s"CREATE KEYSPACE IF NOT EXISTS $ks " + + s"WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }") + + session.execute(s"CREATE TABLE $ks.$tableName(key int primary key, value text)") + val st = session.prepare(s"INSERT INTO $ks.$tableName(key, value) VALUES(?, ?)") + // 1M rows x 64 bytes of payload = 64 MB of data + overhead + for (i <- (1 to 1000000).par) { + val key = i.asInstanceOf[AnyRef] + val value = "123456789.123456789.123456789.123456789.123456789.123456789." + session.execute(st.bind(key, value)) + } + } + for (host <- conn.hosts) { + val nodeProbe = new NodeProbe(host.getHostAddress, + EmbeddedCassandra.cassandraRunners(0).map(_.jmxPort).getOrElse(CassandraRunner.DefaultJmxPort)) + nodeProbe.forceKeyspaceFlush(ks, tableName) + } + + val timeout = CassandraRunner.SizeEstimatesUpdateIntervalInSeconds * 1000 * 5 + assert(DataSizeEstimates.waitForDataSizeEstimates(conn, ks, tableName, timeout), + s"Data size estimates not present after $timeout ms. Test cannot be finished.") + } + + private def getCassandraTableScanRDD( + splitSizeMB: Int, + splitCount: Option[Int] = None, + minimalSplitCountThreshold: Int): CassandraTableScanRDD[AnyRef] = { + val readConf = new ReadConf(splitSizeInMB = splitSizeMB, splitCount = splitCount) + + new CassandraTableScanRDD[AnyRef](sc, conn, ks, tableName, readConf = readConf) { + override def minimalSplitCount: Int = minimalSplitCountThreshold + } + } +} diff --git a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGeneratorSpec.scala b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGeneratorSpec.scala index 9501ffb1f..ca094f8c2 100644 --- a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGeneratorSpec.scala +++ b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGeneratorSpec.scala @@ -1,19 +1,16 @@ package com.datastax.spark.connector.rdd.partitioner -import org.apache.cassandra.tools.NodeProbe -import org.scalatest.{Inspectors, Matchers, FlatSpec} +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory import com.datastax.spark.connector.SparkCassandraITFlatSpecBase import com.datastax.spark.connector.cql.{Schema, CassandraConnector} -import com.datastax.spark.connector.embedded.{CassandraRunner, SparkTemplate, EmbeddedCassandra} -import com.datastax.spark.connector.rdd.CqlWhereClause -import com.datastax.spark.connector.testkit.SharedEmbeddedCassandra class CassandraPartitionGeneratorSpec - extends SparkCassandraITFlatSpecBase with Inspectors { + extends SparkCassandraITFlatSpecBase { useCassandraConfig(Seq("cassandra-default.yaml.template")) val conn = CassandraConnector(defaultConf) + implicit val tokenFactory = TokenFactory.forSystemLocalPartitioner(conn) conn.withSessionDo { session => createKeyspace(session) @@ -25,7 +22,7 @@ class CassandraPartitionGeneratorSpec // Should be improved in the future. private def testPartitionCount(numPartitions: Int, min: Int, max: Int): Unit = { val table = Schema.fromCassandra(conn, Some(ks), Some("empty")).tables.head - val partitioner = CassandraPartitionGenerator(conn, table, Some(numPartitions), 10000) + val partitioner = CassandraPartitionGenerator(conn, table, numPartitions) val partitions = partitioner.partitions partitions.length should be >= min partitions.length should be <= max @@ -39,44 +36,4 @@ class CassandraPartitionGeneratorSpec it should "create about 10000 partitions when splitCount == 10000" in { testPartitionCount(10000, 9000, 11000) } - - it should "create multiple partitions if the amount of data is big enough" in { - val tableName = "data" - conn.withSessionDo { session => - session.execute(s"CREATE TABLE $ks.$tableName(key int primary key, value text)") - val st = session.prepare(s"INSERT INTO $ks.$tableName(key, value) VALUES(?, ?)") - // 1M rows x 64 bytes of payload = 64 MB of data + overhead - for (i <- (1 to 1000000).par) { - val key = i.asInstanceOf[AnyRef] - val value = "123456789.123456789.123456789.123456789.123456789.123456789." - session.execute(st.bind(key, value)) - } - } - - for (host <- conn.hosts) { - val nodeProbe = new NodeProbe(host.getHostAddress, - EmbeddedCassandra.cassandraRunners(0).map(_.jmxPort).getOrElse(CassandraRunner.DefaultJmxPort)) - nodeProbe.forceKeyspaceFlush(ks, tableName) - } - - val timeout = CassandraRunner.SizeEstimatesUpdateIntervalInSeconds * 1000 * 5 - assert(DataSizeEstimates.waitForDataSizeEstimates(conn, ks, tableName, timeout), - s"Data size estimates not present after $timeout ms. Test cannot be finished.") - - val table = Schema.fromCassandra(conn, Some(ks), Some(tableName)).tables.head - val partitioner = CassandraPartitionGenerator(conn, table, splitCount = None, splitSize = 1000000) - val partitions = partitioner.partitions - - // theoretically there should be 64 splits, but it is ok to be "a little" inaccurate - partitions.length should be >= 16 - partitions.length should be <= 256 - } - - it should "align index fields of partitions with their place in the array" in { - val table = Schema.fromCassandra(conn, Some(ks), Some("data")).tables.head - val partitioner = CassandraPartitionGenerator(conn, table, splitCount = Some(1000), splitSize = 100) - val partToIndex = partitioner.partitions.zipWithIndex - forAll (partToIndex) { case (part, index) => part.index should be (index) } - } - } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraTableScanRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraTableScanRDD.scala index 57261d9da..16b45f039 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraTableScanRDD.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraTableScanRDD.scala @@ -4,7 +4,7 @@ import java.io.IOException import com.datastax.spark.connector._ import com.datastax.spark.connector.cql._ -import com.datastax.spark.connector.rdd.partitioner._ +import com.datastax.spark.connector.rdd.partitioner.{CassandraPartition, CassandraPartitionGenerator, CqlTokenRange, DataSizeEstimates, NodeAddresses, _} import com.datastax.spark.connector.rdd.partitioner.dht.{Token => ConnectorToken} import com.datastax.spark.connector.rdd.reader._ import com.datastax.spark.connector.types.ColumnType @@ -72,7 +72,8 @@ class CassandraTableScanRDD[R] private[connector]( val classTag: ClassTag[R], @transient val rowReaderFactory: RowReaderFactory[R]) extends CassandraRDD[R](sc, Seq.empty) - with CassandraTableRowReaderProvider[R] { + with CassandraTableRowReaderProvider[R] + with SplitSizeEstimator[R] { override type Self = CassandraTableScanRDD[R] @@ -218,9 +219,10 @@ class CassandraTableScanRDD[R] private[connector]( @transient lazy val partitionGenerator = { if (containsPartitionKey(where)) { - CassandraPartitionGenerator(connector, tableDef, Some(1), splitSize) + CassandraPartitionGenerator(connector, tableDef, 1) } else { - CassandraPartitionGenerator(connector, tableDef, splitCount, splitSize) + val reevaluatedSplitCount = splitCount.getOrElse(estimateSplitCount(splitSize)) + CassandraPartitionGenerator(connector, tableDef, reevaluatedSplitCount) } } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala index 66a1372a7..1cf93c682 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala @@ -32,7 +32,10 @@ object ReadConf { name = "spark.cassandra.input.split.size_in_mb", section = ReferenceSection, default = 64, - description = """Approx amount of data to be fetched into a Spark partition""") + description = + """Approx amount of data to be fetched into a Spark partition. Minimum number of resulting Spark + | partitions is 1 + 2 * SparkContext.defaultParallelism + |""".stripMargin.filter(_ >= ' ')) val FetchSizeInRowsParam = ConfigParameter[Int]( name = "spark.cassandra.input.fetch.size_in_rows", diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGenerator.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGenerator.scala index cd68ad2af..3a08c7705 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGenerator.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionGenerator.scala @@ -19,8 +19,7 @@ import com.datastax.spark.connector.writer.RowWriterFactory private[connector] class CassandraPartitionGenerator[V, T <: Token[V]]( connector: CassandraConnector, tableDef: TableDef, - splitCount: Option[Int], - splitSize: Long)( + splitCount: Int)( implicit tokenFactory: TokenFactory[V, T]) extends Logging{ @@ -30,23 +29,11 @@ private[connector] class CassandraPartitionGenerator[V, T <: Token[V]]( private val keyspaceName = tableDef.keyspaceName private val tableName = tableDef.tableName - private val totalDataSize: Long = { - // If we know both the splitCount and splitSize, we should pretend the total size of the data is - // their multiplication. TokenRangeSplitter will try to produce splits of desired size, and this way - // their number will be close to desired splitCount. Otherwise, if splitCount is not set, - // we just go to C* and read the estimated data size from an appropriate system table - splitCount match { - case Some(c) => c * splitSize - case None => new DataSizeEstimates(connector, keyspaceName, tableName).dataSizeInBytes - } - } - private def tokenRange(range: DriverTokenRange, metadata: Metadata): TokenRange = { val startToken = tokenFactory.tokenFromString(range.getStart.getValue.toString) val endToken = tokenFactory.tokenFromString(range.getEnd.getValue.toString) val replicas = metadata.getReplicas(Metadata.quote(keyspaceName), range).map(_.getAddress).toSet - val dataSize = (tokenFactory.ringFraction(startToken, endToken) * totalDataSize).toLong - new TokenRange(startToken, endToken, replicas, dataSize) + new TokenRange(startToken, endToken, replicas, tokenFactory) } private def describeRing: Seq[TokenRange] = { @@ -59,29 +46,19 @@ private[connector] class CassandraPartitionGenerator[V, T <: Token[V]]( * When we have a single Spark Partition use a single global range. This * will let us more easily deal with Partition Key equals and In clauses */ - if (splitCount == Some(1)) { - Seq(ranges(0).copy[V, T](tokenFactory.minToken, tokenFactory.maxToken)) + if (splitCount == 1) { + Seq(ranges.head.copy[V, T](tokenFactory.minToken, tokenFactory.minToken)) } else { ranges } } - private def splitsOf( - tokenRanges: Iterable[TokenRange], - splitter: TokenRangeSplitter[V, T]): Iterable[TokenRange] = { - - val parTokenRanges = tokenRanges.par - parTokenRanges.tasksupport = new ForkJoinTaskSupport(CassandraPartitionGenerator.pool) - (for (tokenRange <- parTokenRanges; - split <- splitter.split(tokenRange, splitSize)) yield split).seq - } - private def createTokenRangeSplitter: TokenRangeSplitter[V, T] = { tokenFactory.asInstanceOf[TokenFactory[_, _]] match { case TokenFactory.RandomPartitionerTokenFactory => - new RandomPartitionerTokenRangeSplitter(totalDataSize).asInstanceOf[TokenRangeSplitter[V, T]] + new RandomPartitionerTokenRangeSplitter().asInstanceOf[TokenRangeSplitter[V, T]] case TokenFactory.Murmur3TokenFactory => - new Murmur3PartitionerTokenRangeSplitter(totalDataSize).asInstanceOf[TokenRangeSplitter[V, T]] + new Murmur3PartitionerTokenRangeSplitter().asInstanceOf[TokenRangeSplitter[V, T]] case _ => throw new UnsupportedOperationException(s"Unsupported TokenFactory $tokenFactory") } @@ -93,18 +70,20 @@ private[connector] class CassandraPartitionGenerator[V, T <: Token[V]]( def partitions: Seq[CassandraPartition[V, T]] = { val tokenRanges = describeRing val endpointCount = tokenRanges.map(_.replicas).reduce(_ ++ _).size - val splitter = createTokenRangeSplitter - val splits = splitsOf(tokenRanges, splitter).toSeq val maxGroupSize = tokenRanges.size / endpointCount - val clusterer = new TokenRangeClusterer[V, T](splitSize, maxGroupSize) + + val splitter = createTokenRangeSplitter + val splits = splitter.split(tokenRanges, splitCount).toSeq + + val clusterer = new TokenRangeClusterer[V, T](splitCount, maxGroupSize) val tokenRangeGroups = clusterer.group(splits).toArray val partitions = for (group <- tokenRangeGroups) yield { val replicas = group.map(_.replicas).reduce(_ intersect _) - val rowCount = group.map(_.dataSize).sum + val rowCount = group.map(_.rangeSize).sum val cqlRanges = group.flatMap(rangeToCql) // partition index will be set later - CassandraPartition(0, replicas, cqlRanges, rowCount) + CassandraPartition(0, replicas, cqlRanges, rowCount.toLong) } // sort partitions and assign sequential numbers so that @@ -139,14 +118,6 @@ private[connector] class CassandraPartitionGenerator[V, T <: Token[V]]( } object CassandraPartitionGenerator { - /** Affects how many concurrent threads are used to fetch split information from cassandra nodes, in `getPartitions`. - * Does not affect how many Spark threads fetch data from Cassandra. */ - val MaxParallelism = 16 - - /** How many token rangesContaining to sample in order to estimate average number of rows per token */ - val TokenRangeSampleSize = 16 - - private val pool: ForkJoinPool = new ForkJoinPool(MaxParallelism) type V = t forSome { type t } type T = t forSome { type t <: Token[V] } @@ -158,17 +129,9 @@ object CassandraPartitionGenerator { def apply( conn: CassandraConnector, tableDef: TableDef, - splitCount: Option[Int], - splitSize: Int): CassandraPartitionGenerator[V, T] = { + splitCount: Int)( + implicit tokenFactory: TokenFactory[V, T]): CassandraPartitionGenerator[V, T] = { - val tokenFactory = getTokenFactory(conn) - new CassandraPartitionGenerator(conn, tableDef, splitCount, splitSize)(tokenFactory) - } - - def getTokenFactory(conn: CassandraConnector) : TokenFactory[V, T] = { - val partitionerName = conn.withSessionDo { session => - session.execute("SELECT partitioner FROM system.local").one().getString(0) - } - TokenFactory.forCassandraPartitioner(partitionerName) + new CassandraPartitionGenerator(conn, tableDef, splitCount)(tokenFactory) } } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitter.scala index 87a81fe0f..0417a11f0 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitter.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitter.scala @@ -1,34 +1,21 @@ package com.datastax.spark.connector.rdd.partitioner -import com.datastax.spark.connector.rdd.partitioner.dht.{LongToken, TokenFactory, TokenRange} +import com.datastax.spark.connector.rdd.partitioner.dht.LongToken -/** Fast token range splitter assuming that data are spread out evenly in the whole range. - * @param dataSize estimate of the size of the data in the whole ring */ -class Murmur3PartitionerTokenRangeSplitter(dataSize: Long) +/** Fast token range splitter assuming that data are spread out evenly in the whole range. */ +private[partitioner] class Murmur3PartitionerTokenRangeSplitter extends TokenRangeSplitter[Long, LongToken] { - private val tokenFactory = - TokenFactory.Murmur3TokenFactory + private type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[Long, LongToken] - private type TR = TokenRange[Long, LongToken] + override def split(tokenRange: TokenRange, splitSize: Int): Seq[TokenRange] = { + val rangeSize = tokenRange.rangeSize + val splitPointsCount = if (rangeSize < splitSize) rangeSize.toInt else splitSize + val splitPoints = (0 until splitPointsCount).map({ i => + new LongToken(tokenRange.start.value + (rangeSize * i / splitPointsCount).toLong) + }) :+ tokenRange.end - /** Splits the token range uniformly into sub-ranges. - * @param splitSize requested sub-split size, given in the same units as `dataSize` */ - def split(range: TR, splitSize: Long): Seq[TR] = { - val rangeSize = range.dataSize - val rangeTokenCount = tokenFactory.distance(range.start, range.end) - val n = math.max(1, math.round(rangeSize.toDouble / splitSize).toInt) - - val left = range.start.value - val right = range.end.value - val splitPoints = - (for (i <- 0 until n) yield left + (rangeTokenCount * i / n).toLong) :+ right - - for (Seq(l, r) <- splitPoints.sliding(2).toSeq) yield - new TokenRange[Long, LongToken]( - new LongToken(l), - new LongToken(r), - range.replicas, - rangeSize / n) + for (Seq(left, right) <- splitPoints.sliding(2).toSeq) yield + new TokenRange(left, right, tokenRange.replicas, tokenRange.tokenFactory) } } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitter.scala index b896f52e7..e35e557bf 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitter.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitter.scala @@ -1,40 +1,28 @@ package com.datastax.spark.connector.rdd.partitioner -import com.datastax.spark.connector.rdd.partitioner.dht.{BigIntToken, TokenFactory, TokenRange} +import com.datastax.spark.connector.rdd.partitioner.dht.BigIntToken - -/** Fast token range splitter assuming that data are spread out evenly in the whole range. - * @param dataSize estimate of the size of the data in the whole ring */ -class RandomPartitionerTokenRangeSplitter(dataSize: Long) +/** Fast token range splitter assuming that data are spread out evenly in the whole range. */ +private[partitioner] class RandomPartitionerTokenRangeSplitter extends TokenRangeSplitter[BigInt, BigIntToken] { - private val tokenFactory = - TokenFactory.RandomPartitionerTokenFactory + private type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[BigInt, BigIntToken] - private def wrap(token: BigInt): BigInt = { - val max = tokenFactory.maxToken.value + private def wrapWithMax(max: BigInt)(token: BigInt): BigInt = { if (token <= max) token else token - max } - private type TR = TokenRange[BigInt, BigIntToken] - - /** Splits the token range uniformly into sub-ranges. - * @param splitSize requested sub-split size, given in the same units as `dataSize` */ - def split(range: TR, splitSize: Long): Seq[TR] = { - val rangeSize = range.dataSize - val rangeTokenCount = tokenFactory.distance(range.start, range.end) - val n = math.max(1, math.round(rangeSize.toDouble / splitSize)).toInt + override def split(tokenRange: TokenRange, splitCount: Int): Seq[TokenRange] = { + val rangeSize = tokenRange.rangeSize + val wrap = wrapWithMax(tokenRange.tokenFactory.maxToken.value)(_) - val left = range.start.value - val right = range.end.value - val splitPoints = - (for (i <- 0 until n) yield wrap(left + (rangeTokenCount * i / n))) :+ right + val splitPointsCount = if (rangeSize < splitCount) rangeSize.toInt else splitCount + val splitPoints = (0 until splitPointsCount).map({ i => + val nextToken: BigInt = tokenRange.start.value + (rangeSize * i / splitPointsCount) + new BigIntToken(wrap(nextToken)) + }) :+ tokenRange.end - for (Seq(l, r) <- splitPoints.sliding(2).toSeq) yield - new TokenRange[BigInt, BigIntToken]( - new BigIntToken(l.bigInteger), - new BigIntToken(r.bigInteger), - range.replicas, - rangeSize / n) + for (Seq(left, right) <- splitPoints.sliding(2).toSeq) yield + new TokenRange(left, right, tokenRange.replicas, tokenRange.tokenFactory) } } \ No newline at end of file diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/SplitSizeEstimator.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/SplitSizeEstimator.scala new file mode 100644 index 000000000..e67d38daa --- /dev/null +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/SplitSizeEstimator.scala @@ -0,0 +1,25 @@ +package com.datastax.spark.connector.rdd.partitioner + +import com.datastax.spark.connector.rdd.CassandraRDD +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory + +private[rdd] trait SplitSizeEstimator[R] { + this: CassandraRDD[R] => + + @transient implicit lazy val tokenFactory = TokenFactory.forSystemLocalPartitioner(connector) + + private def estimateDataSize: Long = + new DataSizeEstimates(connector, keyspaceName, tableName).dataSizeInBytes + + private[rdd] def minimalSplitCount: Int = { + val coreCount = context.defaultParallelism + 1 + coreCount * 2 + } + + def estimateSplitCount(splitSize: Int): Int = { + require(splitSize > 0, "Split size must be greater than zero.") + val splitCountEstimate = estimateDataSize / splitSize + Math.max(splitCountEstimate.toInt, minimalSplitCount) + } + +} diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClusterer.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClusterer.scala index 7edcd219a..19fafaab7 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClusterer.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClusterer.scala @@ -5,20 +5,24 @@ import java.net.InetAddress import Ordering.Implicits._ import com.datastax.spark.connector.rdd.partitioner.dht.{Token, TokenRange} - import scala.annotation.tailrec -/** Divides a set of token rangesContaining into groups containing not more than `maxRowCountPerGroup` rows - * and not more than `maxGroupSize` token rangesContaining. Each group will form a single `CassandraPartition`. +import com.datastax.spark.connector.rdd.partitioner.TokenRangeClusterer.WholeRing + +/** Groups a set of token ranges into `groupCount` groups containing not more than `maxGroupSize` token + * ranges. + * Each group will form a single `CassandraRDDPartition`. * * The algorithm is as follows: - * 1. Sort token rangesContaining by endpoints lexicographically. - * 2. Take the highest possible number of token rangesContaining from the beginning of the list, - * such that their sum of rowCounts does not exceed `maxRowCountPerGroup` and they all contain at + * 1. Sort token ranges by endpoints lexicographically. + * 2. Take the highest possible number of token ranges from the beginning of the list, + * such that their sum of ringFraction does not exceed `ringFractionPerGroup` and they all contain at * least one common endpoint. If it is not possible, take at least one item. - * Those token rangesContaining will make a group. - * 3. Repeat the previous step until no more token rangesContaining left.*/ -class TokenRangeClusterer[V, T <: Token[V]](maxRowCountPerGroup: Long, maxGroupSize: Int = Int.MaxValue) { + * Those token ranges will make a group. + * 3. Repeat the previous step until no more token ranges left.*/ +class TokenRangeClusterer[V, T <: Token[V]](groupCount: Int, maxGroupSize: Int = Int.MaxValue) { + + private val ringFractionPerGroup = WholeRing / groupCount.toDouble private implicit object InetAddressOrdering extends Ordering[InetAddress] { override def compare(x: InetAddress, y: InetAddress) = @@ -27,22 +31,23 @@ class TokenRangeClusterer[V, T <: Token[V]](maxRowCountPerGroup: Long, maxGroupS @tailrec private def group(tokenRanges: Stream[TokenRange[V, T]], - result: Vector[Seq[TokenRange[V, T]]]): Iterable[Seq[TokenRange[V, T]]] = { + result: Vector[Seq[TokenRange[V, T]]], + ringFractionPerGroup: Double): Iterable[Seq[TokenRange[V, T]]] = { tokenRanges match { case Stream.Empty => result case head #:: rest => val firstEndpoint = head.replicas.min - val rowCounts = tokenRanges.map(_.dataSize) - val cumulativeRowCounts = rowCounts.scanLeft(0L)(_ + _).tail // drop first item always == 0 - val rowLimit = math.max(maxRowCountPerGroup, head.dataSize) // make sure first element will be always included + val ringFractions = tokenRanges.map(_.ringFraction) + val cumulativeRingFractions = ringFractions.scanLeft(.0)(_ + _).tail // drop first item always == 0 + val ringFractionLimit = math.max(ringFractionPerGroup, head.ringFraction) // make sure first element will be always included val cluster = tokenRanges .take(math.max(1, maxGroupSize)) - .zip(cumulativeRowCounts) - .takeWhile { case (tr, count) => count <= rowLimit && tr.replicas.min == firstEndpoint } + .zip(cumulativeRingFractions) + .takeWhile { case (tr, count) => count <= ringFractionLimit && tr.replicas.min == firstEndpoint } .map(_._1) .toVector val remainingTokenRanges = tokenRanges.drop(cluster.length) - group(remainingTokenRanges, result :+ cluster) + group(remainingTokenRanges, result :+ cluster, ringFractionPerGroup) } } @@ -53,7 +58,10 @@ class TokenRangeClusterer[V, T <: Token[V]](maxRowCountPerGroup: Long, maxGroupS // sort by endpoints lexicographically // this way ranges on the same host are grouped together val sortedRanges = tokenRanges.sortBy(_.replicas.toSeq.sorted) - group(sortedRanges.toStream, Vector.empty) + group(sortedRanges.toStream, Vector.empty, ringFractionPerGroup) } +} +object TokenRangeClusterer { + private val WholeRing = 1.0 } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeSplitter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeSplitter.scala index 069947d5c..280f723f3 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeSplitter.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeSplitter.scala @@ -1,16 +1,36 @@ package com.datastax.spark.connector.rdd.partitioner +import scala.collection.parallel.ForkJoinTaskSupport +import scala.concurrent.forkjoin.ForkJoinPool + +import com.datastax.spark.connector.rdd.partitioner.TokenRangeSplitter.WholeRing import com.datastax.spark.connector.rdd.partitioner.dht.{Token, TokenRange} -/** Splits a token range into smaller sub-ranges, + +/** Splits a token ranges into smaller sub-ranges, * each with the desired approximate number of rows. */ -trait TokenRangeSplitter[V, T <: Token[V]] { +private[partitioner] trait TokenRangeSplitter[V, T <: Token[V]] { - /** Splits given token range into n equal sub-ranges. */ - def split(range: TokenRange[V, T], splitSize: Long): Seq[TokenRange[V, T]] -} + def split(tokenRanges: Iterable[TokenRange[V, T]], splitCount: Int): Iterable[TokenRange[V, T]] = { + val ringFractionPerSplit = WholeRing / splitCount.toDouble + val parTokenRanges = tokenRanges.par + parTokenRanges.tasksupport = new ForkJoinTaskSupport(TokenRangeSplitter.pool) + parTokenRanges.flatMap(tokenRange => { + val splitCount = Math.rint(tokenRange.ringFraction / ringFractionPerSplit).toInt + split(tokenRange, math.max(1, splitCount)) + }).toList + } + /** Splits the token range uniformly into splitCount sub-ranges. */ + def split(tokenRange: TokenRange[V, T], splitCount: Int): Seq[TokenRange[V, T]] +} + +object TokenRangeSplitter { + private val MaxParallelism = 16 + private val WholeRing = 1.0 + private val pool = new ForkJoinPool(MaxParallelism) +} diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenFactory.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenFactory.scala index 10be6ac03..37135b0f3 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenFactory.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenFactory.scala @@ -1,5 +1,7 @@ package com.datastax.spark.connector.rdd.partitioner.dht +import com.datastax.spark.connector.cql.CassandraConnector + import scala.language.existentials import com.datastax.spark.connector.rdd.partitioner.MonotonicBucketing @@ -91,6 +93,13 @@ object TokenFactory { } partitioner.asInstanceOf[TokenFactory[V, T]] } + + def forSystemLocalPartitioner(connector: CassandraConnector): TokenFactory[V, T] = { + val partitionerClassName = connector.withSessionDo { session => + session.execute("SELECT partitioner FROM system.local").one().getString(0) + } + forCassandraPartitioner(partitionerClassName) + } } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRange.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRange.scala index d0cd73437..2b0fc3b48 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRange.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRange.scala @@ -3,7 +3,11 @@ package com.datastax.spark.connector.rdd.partitioner.dht import java.net.InetAddress case class TokenRange[V, T <: Token[V]] ( - start: T, end: T, replicas: Set[InetAddress], dataSize: Long) { + start: T, end: T, replicas: Set[InetAddress], tokenFactory: TokenFactory[V, T]) { + + private[partitioner] lazy val rangeSize = tokenFactory.distance(start, end) + + private[partitioner] lazy val ringFraction = tokenFactory.ringFraction(start, end) def isWrappedAround(implicit tf: TokenFactory[V, T]): Boolean = start >= end && end != tf.minToken @@ -18,8 +22,8 @@ case class TokenRange[V, T <: Token[V]] ( val minToken = tf.minToken if (isWrappedAround) Seq( - TokenRange(start, minToken, replicas, dataSize / 2), - TokenRange(minToken, end, replicas, dataSize / 2)) + TokenRange(start, minToken, replicas, tokenFactory), + TokenRange(minToken, end, replicas, tokenFactory)) else Seq(this) } diff --git a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala index 4c7f94788..6b220ead6 100644 --- a/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala +++ b/spark-cassandra-connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala @@ -15,6 +15,7 @@ import org.apache.spark.SparkConf import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf, Schema} import com.datastax.spark.connector.rdd.partitioner.CassandraPartitionGenerator._ import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.forSystemLocalPartitioner import com.datastax.spark.connector.rdd.{CassandraRDD, ReadConf} import com.datastax.spark.connector.types.{InetType, UUIDType, VarIntType} import com.datastax.spark.connector.util.Quote._ @@ -251,7 +252,7 @@ object CassandraSourceRelation { val tableSizeInBytes = tableSizeInBytesString match { case Some(size) => Option(size.toLong) case None => - val tokenFactory = getTokenFactory(cassandraConnector) + val tokenFactory = forSystemLocalPartitioner(cassandraConnector) val dataSizeInBytes = new DataSizeEstimates( cassandraConnector, diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/BucketingRangeIndexSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/BucketingRangeIndexSpec.scala index 8cffb881e..c192e44a7 100644 --- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/BucketingRangeIndexSpec.scala +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/BucketingRangeIndexSpec.scala @@ -123,7 +123,7 @@ class BucketingRangeIndexSpec extends FlatSpec with PropertyChecks with ShouldMa longTokenFactory.minToken, longTokenFactory.minToken, Set.empty, - 0) + Murmur3TokenFactory) "Murmur3Bucketing" should " map all tokens to a single wrapping range" in { @@ -146,7 +146,7 @@ class BucketingRangeIndexSpec extends FlatSpec with PropertyChecks with ShouldMa bigTokenFactory.minToken, bigTokenFactory.minToken, Set.empty, - 0) + RandomPartitionerTokenFactory) "RandomBucketing" should " map all tokens to a single wrapping range" in { diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterSpec.scala new file mode 100644 index 000000000..a6bffd576 --- /dev/null +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterSpec.scala @@ -0,0 +1,35 @@ +package com.datastax.spark.connector.rdd.partitioner + +import java.net.InetAddress + +import org.scalatest.{FlatSpec, Matchers} + +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory._ +import com.datastax.spark.connector.rdd.partitioner.dht.{LongToken, TokenRange} + +class Murmur3PartitionerTokenRangeSplitterSpec + extends FlatSpec + with SplitterBehaviors[Long, LongToken] + with Matchers { + + private val splitter = new Murmur3PartitionerTokenRangeSplitter + + "Murmur3PartitionerSplitter" should "split tokens" in testSplittingTokens(splitter) + + it should "split token sequences" in testSplittingTokenSequences(splitter) + + override def splitWholeRingIn(count: Int): Seq[TokenRange[Long, LongToken]] = { + val hugeTokensIncrement = totalTokenCount / count + (0 until count).map(i => + range(minToken.value + i * hugeTokensIncrement, minToken.value + (i + 1) * hugeTokensIncrement) + ) + } + + override def range(start: BigInt, end: BigInt): TokenRange[Long, LongToken] = + new TokenRange[Long, LongToken]( + LongToken(start.toLong), + LongToken(end.toLong), + Set(InetAddress.getLocalHost), + Murmur3TokenFactory) +} \ No newline at end of file diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterTest.scala deleted file mode 100644 index c034452fc..000000000 --- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/Murmur3PartitionerTokenRangeSplitterTest.scala +++ /dev/null @@ -1,84 +0,0 @@ -package com.datastax.spark.connector.rdd.partitioner - -import java.net.InetAddress - -import org.junit.Assert._ -import org.junit.Test - -import com.datastax.spark.connector.rdd.partitioner.dht.LongToken -import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory - -class Murmur3PartitionerTokenRangeSplitterTest { - - type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[Long, LongToken] - - private def assertNoHoles(tokenRanges: Seq[TokenRange]) { - for (Seq(range1, range2) <- tokenRanges.sliding(2)) - assertEquals(range1.end, range2.start) - } - - private def assertSimilarSize(tokenRanges: Seq[TokenRange]): Unit = { - val sizes = tokenRanges.map(r => Murmur3TokenFactory.distance(r.start, r.end)).toVector - val maxSize = sizes.max.toDouble - val minSize = sizes.min.toDouble - assertTrue(s"maxSize / minSize = ${maxSize / minSize} > 1.01", maxSize / minSize <= 1.01) - } - - @Test - def testSplit() { - val node = InetAddress.getLocalHost - val dataSize = 1000 - val splitter = new Murmur3PartitionerTokenRangeSplitter(dataSize) - val range = new TokenRange(LongToken(0), LongToken(0), Set(node), dataSize) - val out = splitter.split(range, 100) - - assertEquals(10, out.size) - assertEquals(0L, out.head.start.value) - assertEquals(0L, out.last.end.value) - assertTrue(out.forall(s => s.end.value != s.start.value)) - assertTrue(out.forall(_.replicas == Set(node))) - assertNoHoles(out) - assertSimilarSize(out) - } - - - @Test - def testNoSplit() { - val splitter = new Murmur3PartitionerTokenRangeSplitter(1000) - val range = new TokenRange(LongToken(0), new LongToken(100), Set.empty, 0) - val out = splitter.split(range, 500) - - // range is too small to contain 500 units - assertEquals(1, out.size) - assertEquals(0L, out.head.start.value) - assertEquals(100L, out.last.end.value) - } - - @Test - def testZeroRows() { - val dataSize = 0 - val splitter = new Murmur3PartitionerTokenRangeSplitter(dataSize) - val range = new TokenRange(LongToken(0), LongToken(100), Set.empty, dataSize) - val out = splitter.split(range, 500) - assertEquals(1, out.size) - assertEquals(0L, out.head.start.value) - assertEquals(100L, out.last.end.value) - } - - @Test - def testWrapAround() { - val dataSize = 2000 - val splitter = new Murmur3PartitionerTokenRangeSplitter(dataSize) - val start = Murmur3TokenFactory.maxToken.value - Long.MaxValue / 2 - val end = Murmur3TokenFactory.minToken.value + Long.MaxValue / 2 - val range = new TokenRange(LongToken(start), LongToken(end), Set.empty, dataSize / 2) - val splits = splitter.split(range, 100) - - // range is half of the ring; 2000 * 0.5 / 100 = 10 - assertEquals(10, splits.size) - assertEquals(start, splits.head.start.value) - assertEquals(end, splits.last.end.value) - assertNoHoles(splits) - assertSimilarSize(splits) - } -} diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterSpec.scala new file mode 100644 index 000000000..99af7667a --- /dev/null +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterSpec.scala @@ -0,0 +1,35 @@ +package com.datastax.spark.connector.rdd.partitioner + +import java.net.InetAddress + +import org.scalatest.{Matchers, _} + +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.RandomPartitionerTokenFactory +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.RandomPartitionerTokenFactory.{minToken, totalTokenCount} +import com.datastax.spark.connector.rdd.partitioner.dht.{BigIntToken, TokenRange} + +class RandomPartitionerTokenRangeSplitterSpec + extends FlatSpec + with SplitterBehaviors[BigInt, BigIntToken] + with Matchers { + + private val splitter = new RandomPartitionerTokenRangeSplitter + + "RandomPartitionerSplitter" should "split tokens" in testSplittingTokens(splitter) + + it should "split token sequences" in testSplittingTokenSequences(splitter) + + override def splitWholeRingIn(count: Int): Seq[TokenRange[BigInt, BigIntToken]] = { + val hugeTokensIncrement = totalTokenCount / count + (0 until count).map(i => + range(minToken.value + i * hugeTokensIncrement, minToken.value + (i + 1) * hugeTokensIncrement) + ) + } + + override def range(start: BigInt, end: BigInt): TokenRange[BigInt, BigIntToken] = + new TokenRange[BigInt, BigIntToken]( + BigIntToken(start), + BigIntToken(end), + Set(InetAddress.getLocalHost), + RandomPartitionerTokenFactory) +} diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterTest.scala deleted file mode 100644 index ab7a06922..000000000 --- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/RandomPartitionerTokenRangeSplitterTest.scala +++ /dev/null @@ -1,91 +0,0 @@ -package com.datastax.spark.connector.rdd.partitioner - -import java.net.InetAddress - -import org.junit.Assert._ -import org.junit.Test - -import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.RandomPartitionerTokenFactory -import com.datastax.spark.connector.rdd.partitioner.dht.{BigIntToken, TokenFactory} - -class RandomPartitionerTokenRangeSplitterTest { - - type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[BigInt, BigIntToken] - - private def assertNoHoles(tokenRanges: Seq[TokenRange]) { - for (Seq(range1, range2) <- tokenRanges.sliding(2)) - assertEquals(range1.end, range2.start) - } - - private def assertSimilarSize(tokenRanges: Seq[TokenRange]): Unit = { - val sizes = tokenRanges.map(r => RandomPartitionerTokenFactory.distance(r.start, r.end)).toVector - val maxSize = sizes.max.toDouble - val minSize = sizes.min.toDouble - assertTrue(s"maxSize / minSize = ${maxSize / minSize} > 1.01", maxSize / minSize <= 1.01) - } - - @Test - def testSplit() { - val dataSize = 1000 - val node = InetAddress.getLocalHost - val splitter = new RandomPartitionerTokenRangeSplitter(dataSize) - val rangeLeft = BigInt("0") - val rangeRight = BigInt("0") - val range = new TokenRange(BigIntToken(rangeLeft), BigIntToken(rangeRight), Set(node), dataSize) - val out = splitter.split(range, 100) - - assertEquals(10, out.size) - assertEquals(rangeLeft, out.head.start.value) - assertEquals(rangeRight, out.last.end.value) - assertTrue(out.forall(_.replicas == Set(node))) - assertNoHoles(out) - assertSimilarSize(out) - } - - @Test - def testNoSplit() { - val splitter = new RandomPartitionerTokenRangeSplitter(1000) - val rangeLeft = BigInt("0") - val rangeRight = BigInt("100") - val range = new TokenRange(BigIntToken(rangeLeft), BigIntToken(rangeRight), Set.empty, 0) - val out = splitter.split(range, 500) - - // range is too small to contain 500 rows - assertEquals(1, out.size) - assertEquals(rangeLeft, out.head.start.value) - assertEquals(rangeRight, out.last.end.value) - } - - @Test - def testZeroRows() { - val splitter = new RandomPartitionerTokenRangeSplitter(0) - val rangeLeft = BigInt("0") - val rangeRight = BigInt("100") - val range = new TokenRange(BigIntToken(rangeLeft), BigIntToken(rangeRight), Set.empty, 0) - val out = splitter.split(range, 500) - assertEquals(1, out.size) - assertEquals(rangeLeft, out.head.start.value) - assertEquals(rangeRight, out.last.end.value) - } - - @Test - def testWrapAround() { - val dataSize = 2000 - val splitter = new RandomPartitionerTokenRangeSplitter(dataSize) - val totalTokenCount = RandomPartitionerTokenFactory.totalTokenCount - val rangeLeft = RandomPartitionerTokenFactory.maxToken.value - totalTokenCount / 4 - val rangeRight = RandomPartitionerTokenFactory.minToken.value + totalTokenCount / 4 - val range = new TokenRange( - new BigIntToken(rangeLeft), - new BigIntToken(rangeRight), - Set.empty, - dataSize / 2) - val out = splitter.split(range, 100) - assertEquals(10, out.size) - assertEquals(rangeLeft, out.head.start.value) - assertEquals(rangeRight, out.last.end.value) - assertNoHoles(out) - assertSimilarSize(out) - } - -} diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/SplitterBehaviors.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/SplitterBehaviors.scala new file mode 100644 index 000000000..da2f95fa3 --- /dev/null +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/SplitterBehaviors.scala @@ -0,0 +1,97 @@ +package com.datastax.spark.connector.rdd.partitioner + +import org.scalatest.Matchers + +import com.datastax.spark.connector.rdd.partitioner.dht.{Token, TokenRange} + +private[partitioner] trait SplitterBehaviors[V, T <: Token[V]] { + this: Matchers => + + case class SplitResult(splitCount: Int, minSize: BigInt, maxSize: BigInt) + + def splitWholeRingIn(count: Int): Seq[TokenRange[V, T]] + + def range(start: BigInt, end: BigInt): TokenRange[V, T] + + def splittedIn(splitCount: Int): Int = splitCount + + def outputs(splits: Int, withSize: BigInt, sizeTolerance: BigInt = BigInt(0)): SplitResult = + SplitResult(splits, withSize - sizeTolerance, withSize + sizeTolerance) + + def testSplittingTokens(splitter: => TokenRangeSplitter[V, T]) { + val hugeRanges = splitWholeRingIn(10) + + val splitCases = Seq[(TokenRange[V, T], Int, SplitResult)]( + (range(start = 0, end = 1), splittedIn(1), outputs(splits = 1, withSize = 1)), + (range(start = 0, end = 10), splittedIn(1), outputs(splits = 1, withSize = 10)), + (range(start = 0, end = 1), splittedIn(10), outputs(splits = 1, withSize = 1)), + + (range(start = 0, end = 9), splittedIn(10), outputs(splits = 9, withSize = 1)), + (range(start = 0, end = 10), splittedIn(10), outputs(splits = 10, withSize = 1)), + (range(start = 0, end = 11), splittedIn(10), outputs(splits = 10, withSize = 1, sizeTolerance = 1)), + + (range(start = 10, end = 50), splittedIn(10), outputs(splits = 10, withSize = 4)), + (range(start = 0, end = 1000), splittedIn(3), outputs(splits = 3, withSize = 333, sizeTolerance = 1)), + + (hugeRanges.head, splittedIn(100), outputs(splits = 100, withSize = hugeRanges.head.rangeSize / 100, sizeTolerance = 4)), + (hugeRanges.last, splittedIn(100), outputs(splits = 100, withSize = hugeRanges.last.rangeSize / 100, sizeTolerance = 4)) + ) + + for ((range, splittedIn, expected) <- splitCases) { + + val splits = splitter.split(range, splittedIn) + + withClue(s"Splitting range (${range.start}, ${range.end}) in $splittedIn splits failed.") { + splits.size should be(expected.splitCount) + splits.head.start should be(range.start) + splits.last.end should be(range.end) + splits.foreach(_.replicas should be(range.replicas)) + splits.foreach(s => s.rangeSize should (be >= expected.minSize and be <= expected.maxSize)) + splits.map(_.rangeSize).sum should be(range.rangeSize) + splits.map(_.ringFraction).sum should be(range.ringFraction +- .000000001) + for (Seq(range1, range2) <- splits.sliding(2)) range1.end should be(range2.start) + } + } + } + + def testSplittingTokenSequences(splitter: TokenRangeSplitter[V, T]) { + val mediumRanges = splitWholeRingIn(100) + val wholeRingSize = mediumRanges.map(_.rangeSize).sum + + val splitCases = Seq[(Seq[TokenRange[V, T]], Int, SplitResult)]( + // we have 100 ranges, so 100 splits is minimum + (mediumRanges, splittedIn(3), outputs(splits = 100, withSize = wholeRingSize / 100)), + + (mediumRanges, splittedIn(100), outputs(splits = 100, withSize = wholeRingSize / 100)), + + (mediumRanges, splittedIn(101), outputs(splits = 100, withSize = wholeRingSize / 100)), + + (mediumRanges, splittedIn(149), outputs(splits = 100, withSize = wholeRingSize / 100)), + + (mediumRanges, splittedIn(150), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1)), + + (mediumRanges, splittedIn(151), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1)), + + (mediumRanges, splittedIn(199), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1)), + + (mediumRanges, splittedIn(200), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1)), + + (mediumRanges, splittedIn(201), outputs(splits = 200, withSize = wholeRingSize / 200, sizeTolerance = 1)) + ) + + for ((ranges, splittedIn, expected) <- splitCases) { + + val splits = splitter.split(ranges, splittedIn) + + withClue(s"Splitting ${ranges.size} ranges in $splittedIn splits failed.") { + splits.size should be(expected.splitCount) + splits.head.start should be(ranges.head.start) + splits.last.end should be(ranges.last.end) + splits.foreach(s => s.rangeSize should (be >= expected.minSize and be <= expected.maxSize)) + splits.map(_.rangeSize).sum should be(ranges.map(_.rangeSize).sum) + splits.map(_.ringFraction).sum should be(1.0 +- .000000001) + for (Seq(range1, range2) <- splits.sliding(2)) range1.end should be(range2.start) + } + } + } +} diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClustererTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClustererTest.scala index 3030b6670..71b7f5822 100644 --- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClustererTest.scala +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/TokenRangeClustererTest.scala @@ -2,22 +2,24 @@ package com.datastax.spark.connector.rdd.partitioner import java.net.InetAddress +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.Murmur3TokenFactory.{maxToken, minToken} +import com.datastax.spark.connector.rdd.partitioner.dht.{LongToken, TokenRange} import org.junit.Assert._ import org.junit.Test -import com.datastax.spark.connector.rdd.partitioner.dht.LongToken - class TokenRangeClustererTest { - type TokenRange = com.datastax.spark.connector.rdd.partitioner.dht.TokenRange[Long, LongToken] - val node1 = InetAddress.getByName("192.168.123.1") val node2 = InetAddress.getByName("192.168.123.2") val node3 = InetAddress.getByName("192.168.123.3") val node4 = InetAddress.getByName("192.168.123.4") val node5 = InetAddress.getByName("192.168.123.5") - private def token(x: Long) = new com.datastax.spark.connector.rdd.partitioner.dht.LongToken(x) + private def tokenRange(start: Long, end: Long, nodes: Set[InetAddress]): TokenRange[Long, LongToken] = + new TokenRange[Long, LongToken](new LongToken(start), new LongToken(end), nodes, Murmur3TokenFactory) + + private implicit def tokenToLong(token: LongToken): Long = token.value @Test def testEmpty() { @@ -28,9 +30,9 @@ class TokenRangeClustererTest { @Test def testTrivialClustering() { - val tr1 = new TokenRange(token(0), token(10), Set(node1), 5) - val tr2 = new TokenRange(token(10), token(20), Set(node1), 5) - val trc = new TokenRangeClusterer[Long, LongToken](10) + val tr1 = tokenRange(start = 0, end = 10, nodes = Set(node1)) + val tr2 = tokenRange(start = 10, end = 20, nodes = Set(node1)) + val trc = new TokenRangeClusterer[Long, LongToken](1) val groups = trc.group(Seq(tr1, tr2)) assertEquals(1, groups.size) assertEquals(Set(tr1, tr2), groups.head.toSet) @@ -38,12 +40,12 @@ class TokenRangeClustererTest { @Test def testSplitByHost() { - val tr1 = new TokenRange(token(0), token(10), Set(node1), 2) - val tr2 = new TokenRange(token(10), token(20), Set(node1), 2) - val tr3 = new TokenRange(token(20), token(30), Set(node2), 2) - val tr4 = new TokenRange(token(30), token(40), Set(node2), 2) + val tr1 = tokenRange(start = 0, end = 10, nodes = Set(node1)) + val tr2 = tokenRange(start = 10, end = 20, nodes = Set(node1)) + val tr3 = tokenRange(start = 20, end = 30, nodes = Set(node2)) + val tr4 = tokenRange(start = 30, end = 40, nodes = Set(node2)) - val trc = new TokenRangeClusterer[Long, LongToken](10) + val trc = new TokenRangeClusterer[Long, LongToken](1) val groups = trc.group(Seq(tr1, tr2, tr3, tr4)).map(_.toSet).toSet assertEquals(2, groups.size) assertTrue(groups.contains(Set(tr1, tr2))) @@ -52,36 +54,25 @@ class TokenRangeClustererTest { @Test def testSplitByCount() { - val tr1 = new TokenRange(token(0), token(10), Set(node1), 5) - val tr2 = new TokenRange(token(10), token(20), Set(node1), 5) - val tr3 = new TokenRange(token(20), token(30), Set(node1), 5) - val tr4 = new TokenRange(token(30), token(40), Set(node1), 5) + val tr1 = tokenRange(start = minToken, end = minToken / 2, Set(node1)) + val tr2 = tokenRange(start = minToken / 2, end = 0, Set(node1)) + val tr3 = tokenRange(start = 0, end = maxToken / 2, Set(node1)) + val tr4 = tokenRange(start = maxToken / 2, end = maxToken, Set(node1)) - val trc = new TokenRangeClusterer[Long, LongToken](10) + val trc = new TokenRangeClusterer[Long, LongToken](2) val groups = trc.group(Seq(tr1, tr2, tr3, tr4)).map(_.toSet).toSet assertEquals(2, groups.size) assertTrue(groups.contains(Set(tr1, tr2))) assertTrue(groups.contains(Set(tr3, tr4))) } - @Test - def testTooLargeRanges() { - val tr1 = new TokenRange(token(0), token(10), Set(node1), 100000) - val tr2 = new TokenRange(token(10), token(20), Set(node1), 100000) - val trc = new TokenRangeClusterer[Long, LongToken](10) - val groups = trc.group(Seq(tr1, tr2)).map(_.toSet).toSet - assertEquals(2, groups.size) - assertTrue(groups.contains(Set(tr1))) - assertTrue(groups.contains(Set(tr2))) - } - @Test def testMultipleEndpoints() { - val tr1 = new TokenRange(token(0), token(10), Set(node2, node1, node3), 1) - val tr2 = new TokenRange(token(10), token(20), Set(node1, node3, node4), 1) - val tr3 = new TokenRange(token(20), token(30), Set(node3, node1, node5), 1) - val tr4 = new TokenRange(token(30), token(40), Set(node3, node1, node4), 1) - val trc = new TokenRangeClusterer[Long, LongToken](10) + val tr1 = tokenRange(start = 0, end = 10, nodes = Set(node2, node1, node3)) + val tr2 = tokenRange(start = 10, end = 20, nodes = Set(node1, node3, node4)) + val tr3 = tokenRange(start = 20, end = 30, nodes = Set(node3, node1, node5)) + val tr4 = tokenRange(start = 30, end = 40, nodes = Set(node3, node1, node4)) + val trc = new TokenRangeClusterer[Long, LongToken](1) val groups = trc.group(Seq(tr1, tr2, tr3, tr4)) assertEquals(1, groups.size) assertEquals(4, groups.head.size) @@ -89,13 +80,12 @@ class TokenRangeClustererTest { } @Test - def testMaxClusterSize() { - val tr1 = new TokenRange(token(0), token(10), Set(node1, node2, node3), 1) - val tr2 = new TokenRange(token(10), token(20), Set(node1, node2, node3), 1) - val tr3 = new TokenRange(token(20), token(30), Set(node1, node2, node3), 1) - val trc = new TokenRangeClusterer[Long, LongToken](maxRowCountPerGroup = 10, maxGroupSize = 1) + def testMaxGroupSize() { + val tr1 = tokenRange(start = 0, end = 10, nodes = Set(node1, node2, node3)) + val tr2 = tokenRange(start = 10, end = 20, nodes = Set(node1, node2, node3)) + val tr3 = tokenRange(start = 20, end = 30, nodes = Set(node1, node2, node3)) + val trc = new TokenRangeClusterer[Long, LongToken](groupCount = 1, maxGroupSize = 1) val groups = trc.group(Seq(tr1, tr2, tr3)) assertEquals(3, groups.size) } - -} +} \ No newline at end of file diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRangeSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRangeSpec.scala index 1736153b2..1451823af 100644 --- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRangeSpec.scala +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/partitioner/dht/TokenRangeSpec.scala @@ -2,7 +2,8 @@ package com.datastax.spark.connector.rdd.partitioner.dht import org.scalatest.{FlatSpec, Matchers} -import com.datastax.driver.core.{TokenRange => DTokenRange, Token => DToken} +import com.datastax.driver.core.{Token => DToken, TokenRange => DTokenRange} +import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.{Murmur3TokenFactory, RandomPartitionerTokenFactory} class TokenRangeSpec extends FlatSpec with Matchers { @@ -11,7 +12,7 @@ class TokenRangeSpec extends FlatSpec with Matchers { "LongRanges " should " contain tokens with easy no wrapping bounds" in { - val lr = new LongRange(LongToken(-100), LongToken(10000), Set.empty, 0) + val lr = new LongRange(LongToken(-100), LongToken(10000), Set.empty, Murmur3TokenFactory) //Tokens Inside for (l <- 1 to 1000) { lr.contains(LongToken(l)) should be (true) @@ -24,7 +25,7 @@ class TokenRangeSpec extends FlatSpec with Matchers { } it should " contain tokens with wrapping bounds in" in { - val lr = new LongRange(LongToken(1000), LongToken(-1000), Set.empty, 0) + val lr = new LongRange(LongToken(1000), LongToken(-1000), Set.empty, Murmur3TokenFactory) //Tokens Inside for (l <- 30000 to 30500) { @@ -39,7 +40,7 @@ class TokenRangeSpec extends FlatSpec with Matchers { } "BigRanges " should " contain tokens with easy no wrapping bounds" in { - val lr = new BigRange(BigIntToken(-100), BigIntToken(10000), Set.empty, 0) + val lr = new BigRange(BigIntToken(-100), BigIntToken(10000), Set.empty, RandomPartitionerTokenFactory) //Tokens Inside for (l <- 1 to 1000) { lr.contains(BigIntToken(l)) should be (true) @@ -52,7 +53,7 @@ class TokenRangeSpec extends FlatSpec with Matchers { } it should " contain tokens with wrapping bounds in" in { - val lr = new BigRange(BigIntToken(1000), BigIntToken(100), Set.empty, 0) + val lr = new BigRange(BigIntToken(1000), BigIntToken(100), Set.empty, RandomPartitionerTokenFactory) //Tokens Inside for (l <- 0 to 50) {