Skip to content

Commit

Permalink
Merge pull request datastax#987 from datastax/SPARKC-278
Browse files Browse the repository at this point in the history
SPARKC-278: Make Repartition by Cassandra Replica Deterministic
  • Loading branch information
pkolaczk committed Jun 7, 2016
2 parents 4b74042 + 0c4151a commit fd0e59e
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ class RDDSpec extends SparkCassandraITFlatSpecBase {
checkArrayCassandraRow(result)
}

it should "be deterministically repartitionable" in {
val source = sc.parallelize(keys).map(Tuple1(_))
val repartRDDs = (1 to 10).map(_ =>
source
.repartitionByCassandraReplica(ks, tableName, 10)
.mapPartitionsWithIndex((index, it) => it.map((_, index))))
val first = repartRDDs(1).collect
repartRDDs.foreach( rdd => rdd.collect should be(first))
}

"A case-class RDD specifying partition keys" should "be retrievable from Cassandra" in {
val source = sc.parallelize(keys).map(x => new KVRow(x))
val someCass = source.joinWithCassandraTable(ks, tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,22 @@ class RDDFunctions[T](rdd: RDD[T]) extends WritableToCassandra[T] with Serializa
partitionKeyMapper: ColumnSelector)(
implicit
connector: CassandraConnector,
currentType: ClassTag[T]): CassandraPartitionedRDD[T] = {
currentType: ClassTag[T],
rwf: RowWriterFactory[T]): CassandraPartitionedRDD[T] = {

val partitioner = new ReplicaPartitioner[T](
tableName,
keyspaceName,
partitionsPerHost,
partitionKeyMapper,
connector)

val repart = rdd
.map((_,None))
.partitionBy(partitioner)
.mapPartitions(_.map(_._1), preservesPartitioning = true)

val part = new ReplicaPartitioner(partitionsPerHost, connector)
val repart = rdd.keyByCassandraReplica(replicaLocator).partitionBy(part)
val output = repart.mapPartitions(_.map(_._2), preservesPartitioning = true)
new CassandraPartitionedRDD[T](output, keyspaceName, tableName)
new CassandraPartitionedRDD[T](repart, keyspaceName, tableName)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ class CassandraPartitionedRDD[T](
@transient
override val partitioner: Option[Partitioner] = prev.partitioner

private val replicaPartitioner: ReplicaPartitioner =
private val replicaPartitioner: ReplicaPartitioner[_] =
partitioner match {
case Some(rp: ReplicaPartitioner) => rp
case _ => throw new IllegalArgumentException("CassandraPartitionedRDD hasn't been " +
"partitioned by ReplicaPartitioner. Unable to do any work with data locality.")
case Some(rp: ReplicaPartitioner[_]) => rp
case other => throw new IllegalArgumentException(
s"""CassandraPartitionedRDD hasn't been
|partitioned by ReplicaPartitioner. Unable to do any work with data locality.
|Found: $other""".stripMargin)
}

private lazy val nodeAddresses = new NodeAddresses(replicaPartitioner.connector)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ package com.datastax.spark.connector.rdd.partitioner

import java.net.InetAddress

import com.datastax.spark.connector.cql.CassandraConnector
import com.datastax.spark.connector.ColumnSelector
import com.datastax.spark.connector.cql.{CassandraConnector, Schema}
import com.datastax.spark.connector.writer.RowWriterFactory
import org.apache.spark.{Partition, Partitioner}

import scala.reflect.ClassTag
import scala.collection.JavaConversions._


case class ReplicaPartition(index: Int, endpoints: Set[InetAddress]) extends EndpointPartition

Expand All @@ -13,21 +18,41 @@ case class ReplicaPartition(index: Int, endpoints: Set[InetAddress]) extends End
* Hosts . It will group keys which share a common IP address into partitionsPerReplicaSet Partitions.
* @param partitionsPerReplicaSet The number of Spark Partitions to make Per Unique Endpoint
*/
class ReplicaPartitioner(partitionsPerReplicaSet: Int, val connector: CassandraConnector) extends Partitioner {
/* TODO We Need JAVA-312 to get sets of replicas instead of single endpoints. Once we have that we'll be able to
build a map of Set[ip,ip,...] => Index before looking at our data and give the all options for the preferred location
for a partition*/
class ReplicaPartitioner[T](
table: String,
keyspace: String,
partitionsPerReplicaSet: Int,
partitionKeyMapper: ColumnSelector,
val connector: CassandraConnector)(
implicit
currentType: ClassTag[T],
@transient rwf: RowWriterFactory[T]) extends Partitioner {

val tableDef = Schema.tableFromCassandra(connector, keyspace, table)
val rowWriter = implicitly[RowWriterFactory[T]].rowWriter(
tableDef,
partitionKeyMapper.selectFrom(tableDef)
)

@transient lazy private val tokenGenerator = new TokenGenerator[T](connector, tableDef, rowWriter)
@transient lazy private val metadata = connector.withClusterDo(_.getMetadata)
@transient lazy private val protocolVersion = connector
.withClusterDo(_.getConfiguration.getProtocolOptions.getProtocolVersion)
@transient lazy private val clazz = implicitly[ClassTag[T]].runtimeClass

private val hosts = connector.hosts.toVector
private val hostSet = connector.hosts
private val numHosts = hosts.size
private val partitionIndexes = (0 until partitionsPerReplicaSet * numHosts).grouped(partitionsPerReplicaSet).toList
private val partitionIndexes = (0 until partitionsPerReplicaSet * numHosts)
.grouped(partitionsPerReplicaSet)
.toList

private val hostMap = (hosts zip partitionIndexes).toMap
// Ip1 -> (0,1,2,..), Ip2 -> (11,12,13...)
private val indexMap = for ((ip, partitions) <- hostMap; partition <- partitions) yield (partition, ip)
// 0->IP1, 1-> IP1, ...
private val rand = new java.util.Random()

private def randomHost: InetAddress =
hosts(rand.nextInt(numHosts))
private def randomHost(index: Int): InetAddress = hosts(index % hosts.length)

/**
* Given a set of endpoints, pick a random endpoint, and then a random partition owned by that
Expand All @@ -38,14 +63,20 @@ class ReplicaPartitioner(partitionsPerReplicaSet: Int, val connector: CassandraC
*/
override def getPartition(key: Any): Int = {
key match {
case key: Set[_] if key.size > 0 && key.forall(_.isInstanceOf[InetAddress]) =>
case key: T if clazz.isInstance(key) =>
//Only use ReplicaEndpoints in the connected DC
val replicaSetInDC = (hosts.toSet & key.asInstanceOf[Set[InetAddress]]).toVector
val token = tokenGenerator.getTokenFor(key)
val tokenHash = Math.abs(token.hashCode())
val replicas = metadata
.getReplicas(keyspace, token.serialize(protocolVersion))
.map(_.getBroadcastAddress)

val replicaSetInDC = (hostSet & replicas).toVector
if (replicaSetInDC.nonEmpty) {
val endpoint = replicaSetInDC(rand.nextInt(replicaSetInDC.size))
hostMap(endpoint)(rand.nextInt(partitionsPerReplicaSet))
val endpoint = replicaSetInDC(tokenHash % replicaSetInDC.size)
hostMap(endpoint)(tokenHash % partitionsPerReplicaSet)
} else {
hostMap(randomHost)(rand.nextInt(partitionsPerReplicaSet))
hostMap(randomHost(tokenHash))(tokenHash % partitionsPerReplicaSet)
}
case _ => throw new IllegalArgumentException(
"ReplicaPartitioner can only determine the partition of a tuple whose key is a non-empty Set[InetAddress]. " +
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package com.datastax.spark.connector.streaming

import com.datastax.spark.connector._
import com.datastax.spark.connector.cql.{CassandraConnectorConf, CassandraConnector}
import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf}
import com.datastax.spark.connector.rdd.partitioner.ReplicaPartitioner
import com.datastax.spark.connector.rdd.{EmptyCassandraRDD, ValidRDDType}
import com.datastax.spark.connector.rdd.reader.RowReaderFactory
import com.datastax.spark.connector.writer._

import org.apache.spark._
import org.apache.spark.SparkContext
import org.apache.spark.streaming.Duration
Expand Down Expand Up @@ -68,9 +68,14 @@ class DStreamFunctions[T](dstream: DStream[T])
currentType: ClassTag[T],
rwf: RowWriterFactory[T]): DStream[T] = {

val replicaLocator = ReplicaLocator[T](connector, keyspaceName, tableName, partitionKeyMapper)
dstream.transform(rdd =>
rdd.repartitionByCassandraReplica(replicaLocator, keyspaceName, tableName, partitionsPerHost, partitionKeyMapper))
val partitioner = new ReplicaPartitioner[T](
tableName,
keyspaceName,
partitionsPerHost,
partitionKeyMapper,
connector)

dstream.transform(rdd => rdd.map((_, None)).partitionBy(partitioner).map(_._1))
}

/**
Expand Down

0 comments on commit fd0e59e

Please sign in to comment.