diff --git a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala index cea343535..4b3fa79d9 100644 --- a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala +++ b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala @@ -192,6 +192,16 @@ class RDDSpec extends SparkCassandraITFlatSpecBase { checkArrayCassandraRow(result) } + it should "be deterministically repartitionable" in { + val source = sc.parallelize(keys).map(Tuple1(_)) + val repartRDDs = (1 to 10).map(_ => + source + .repartitionByCassandraReplica(ks, tableName, 10) + .mapPartitionsWithIndex((index, it) => it.map((_, index)))) + val first = repartRDDs(1).collect + repartRDDs.foreach( rdd => rdd.collect should be(first)) + } + "A case-class RDD specifying partition keys" should "be retrievable from Cassandra" in { val source = sc.parallelize(keys).map(x => new KVRow(x)) val someCass = source.joinWithCassandraTable(ks, tableName) diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala index 5b5d3abd8..b59dc6007 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala @@ -188,12 +188,22 @@ class RDDFunctions[T](rdd: RDD[T]) extends WritableToCassandra[T] with Serializa partitionKeyMapper: ColumnSelector)( implicit connector: CassandraConnector, - currentType: ClassTag[T]): CassandraPartitionedRDD[T] = { + currentType: ClassTag[T], + rwf: RowWriterFactory[T]): CassandraPartitionedRDD[T] = { + + val partitioner = new ReplicaPartitioner[T]( + tableName, + keyspaceName, + partitionsPerHost, + partitionKeyMapper, + connector) + + val repart = rdd + .map((_,None)) + .partitionBy(partitioner) + .mapPartitions(_.map(_._1), preservesPartitioning = true) - val part = new ReplicaPartitioner(partitionsPerHost, connector) - val repart = rdd.keyByCassandraReplica(replicaLocator).partitionBy(part) - val output = repart.mapPartitions(_.map(_._2), preservesPartitioning = true) - new CassandraPartitionedRDD[T](output, keyspaceName, tableName) + new CassandraPartitionedRDD[T](repart, keyspaceName, tableName) } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionedRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionedRDD.scala index 592ae5150..8231c1f56 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionedRDD.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/CassandraPartitionedRDD.scala @@ -24,11 +24,13 @@ class CassandraPartitionedRDD[T]( @transient override val partitioner: Option[Partitioner] = prev.partitioner - private val replicaPartitioner: ReplicaPartitioner = + private val replicaPartitioner: ReplicaPartitioner[_] = partitioner match { - case Some(rp: ReplicaPartitioner) => rp - case _ => throw new IllegalArgumentException("CassandraPartitionedRDD hasn't been " + - "partitioned by ReplicaPartitioner. Unable to do any work with data locality.") + case Some(rp: ReplicaPartitioner[_]) => rp + case other => throw new IllegalArgumentException( + s"""CassandraPartitionedRDD hasn't been + |partitioned by ReplicaPartitioner. Unable to do any work with data locality. + |Found: $other""".stripMargin) } private lazy val nodeAddresses = new NodeAddresses(replicaPartitioner.connector) diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/ReplicaPartitioner.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/ReplicaPartitioner.scala index ed1688e92..47c8edf59 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/ReplicaPartitioner.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/partitioner/ReplicaPartitioner.scala @@ -2,9 +2,14 @@ package com.datastax.spark.connector.rdd.partitioner import java.net.InetAddress -import com.datastax.spark.connector.cql.CassandraConnector +import com.datastax.spark.connector.ColumnSelector +import com.datastax.spark.connector.cql.{CassandraConnector, Schema} +import com.datastax.spark.connector.writer.RowWriterFactory import org.apache.spark.{Partition, Partitioner} +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + case class ReplicaPartition(index: Int, endpoints: Set[InetAddress]) extends EndpointPartition @@ -13,21 +18,41 @@ case class ReplicaPartition(index: Int, endpoints: Set[InetAddress]) extends End * Hosts . It will group keys which share a common IP address into partitionsPerReplicaSet Partitions. * @param partitionsPerReplicaSet The number of Spark Partitions to make Per Unique Endpoint */ -class ReplicaPartitioner(partitionsPerReplicaSet: Int, val connector: CassandraConnector) extends Partitioner { - /* TODO We Need JAVA-312 to get sets of replicas instead of single endpoints. Once we have that we'll be able to - build a map of Set[ip,ip,...] => Index before looking at our data and give the all options for the preferred location - for a partition*/ +class ReplicaPartitioner[T]( + table: String, + keyspace: String, + partitionsPerReplicaSet: Int, + partitionKeyMapper: ColumnSelector, + val connector: CassandraConnector)( +implicit + currentType: ClassTag[T], + @transient rwf: RowWriterFactory[T]) extends Partitioner { + + val tableDef = Schema.tableFromCassandra(connector, keyspace, table) + val rowWriter = implicitly[RowWriterFactory[T]].rowWriter( + tableDef, + partitionKeyMapper.selectFrom(tableDef) + ) + + @transient lazy private val tokenGenerator = new TokenGenerator[T](connector, tableDef, rowWriter) + @transient lazy private val metadata = connector.withClusterDo(_.getMetadata) + @transient lazy private val protocolVersion = connector + .withClusterDo(_.getConfiguration.getProtocolOptions.getProtocolVersion) + @transient lazy private val clazz = implicitly[ClassTag[T]].runtimeClass + private val hosts = connector.hosts.toVector + private val hostSet = connector.hosts private val numHosts = hosts.size - private val partitionIndexes = (0 until partitionsPerReplicaSet * numHosts).grouped(partitionsPerReplicaSet).toList + private val partitionIndexes = (0 until partitionsPerReplicaSet * numHosts) + .grouped(partitionsPerReplicaSet) + .toList + private val hostMap = (hosts zip partitionIndexes).toMap // Ip1 -> (0,1,2,..), Ip2 -> (11,12,13...) private val indexMap = for ((ip, partitions) <- hostMap; partition <- partitions) yield (partition, ip) // 0->IP1, 1-> IP1, ... - private val rand = new java.util.Random() - private def randomHost: InetAddress = - hosts(rand.nextInt(numHosts)) + private def randomHost(index: Int): InetAddress = hosts(index % hosts.length) /** * Given a set of endpoints, pick a random endpoint, and then a random partition owned by that @@ -38,14 +63,20 @@ class ReplicaPartitioner(partitionsPerReplicaSet: Int, val connector: CassandraC */ override def getPartition(key: Any): Int = { key match { - case key: Set[_] if key.size > 0 && key.forall(_.isInstanceOf[InetAddress]) => + case key: T if clazz.isInstance(key) => //Only use ReplicaEndpoints in the connected DC - val replicaSetInDC = (hosts.toSet & key.asInstanceOf[Set[InetAddress]]).toVector + val token = tokenGenerator.getTokenFor(key) + val tokenHash = Math.abs(token.hashCode()) + val replicas = metadata + .getReplicas(keyspace, token.serialize(protocolVersion)) + .map(_.getBroadcastAddress) + + val replicaSetInDC = (hostSet & replicas).toVector if (replicaSetInDC.nonEmpty) { - val endpoint = replicaSetInDC(rand.nextInt(replicaSetInDC.size)) - hostMap(endpoint)(rand.nextInt(partitionsPerReplicaSet)) + val endpoint = replicaSetInDC(tokenHash % replicaSetInDC.size) + hostMap(endpoint)(tokenHash % partitionsPerReplicaSet) } else { - hostMap(randomHost)(rand.nextInt(partitionsPerReplicaSet)) + hostMap(randomHost(tokenHash))(tokenHash % partitionsPerReplicaSet) } case _ => throw new IllegalArgumentException( "ReplicaPartitioner can only determine the partition of a tuple whose key is a non-empty Set[InetAddress]. " + diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/streaming/DStreamFunctions.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/streaming/DStreamFunctions.scala index 820ebe277..ca441b28f 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/streaming/DStreamFunctions.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/streaming/DStreamFunctions.scala @@ -1,11 +1,11 @@ package com.datastax.spark.connector.streaming import com.datastax.spark.connector._ -import com.datastax.spark.connector.cql.{CassandraConnectorConf, CassandraConnector} +import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf} +import com.datastax.spark.connector.rdd.partitioner.ReplicaPartitioner import com.datastax.spark.connector.rdd.{EmptyCassandraRDD, ValidRDDType} import com.datastax.spark.connector.rdd.reader.RowReaderFactory import com.datastax.spark.connector.writer._ - import org.apache.spark._ import org.apache.spark.SparkContext import org.apache.spark.streaming.Duration @@ -68,9 +68,14 @@ class DStreamFunctions[T](dstream: DStream[T]) currentType: ClassTag[T], rwf: RowWriterFactory[T]): DStream[T] = { - val replicaLocator = ReplicaLocator[T](connector, keyspaceName, tableName, partitionKeyMapper) - dstream.transform(rdd => - rdd.repartitionByCassandraReplica(replicaLocator, keyspaceName, tableName, partitionsPerHost, partitionKeyMapper)) + val partitioner = new ReplicaPartitioner[T]( + tableName, + keyspaceName, + partitionsPerHost, + partitionKeyMapper, + connector) + + dstream.transform(rdd => rdd.map((_, None)).partitionBy(partitioner).map(_._1)) } /**