Skip to content

Commit

Permalink
Merge pull request datastax#983 from datastax/SPARKC-383
Browse files Browse the repository at this point in the history
SPARKC-383: cache all common CassandraRow data in a single CassandraR…
  • Loading branch information
pkolaczk committed Jun 3, 2016
2 parents 8fc12ba + 8ea88dd commit 7b0147b
Show file tree
Hide file tree
Showing 20 changed files with 257 additions and 157 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package com.datastax.spark.connector.writer

import scala.concurrent.Future

import com.datastax.spark.connector.cql.{CassandraConnector, Schema}
import com.datastax.spark.connector.{CassandraRow, CassandraRowMetadata, SparkCassandraITFlatSpecBase}
import org.apache.cassandra.dht.IPartitioner

import com.datastax.spark.connector.cql.{CassandraConnector, Schema}
import com.datastax.spark.connector.embedded.SparkTemplate._
import com.datastax.spark.connector.{CassandraRow, SparkCassandraITFlatSpecBase}
import scala.concurrent.Future

class RoutingKeyGeneratorSpec extends SparkCassandraITFlatSpecBase {

Expand Down Expand Up @@ -40,7 +38,7 @@ class RoutingKeyGeneratorSpec extends SparkCassandraITFlatSpecBase {
session.execute(bStmt)
val row = session.execute(s"""SELECT TOKEN(id) FROM $ks.one_key WHERE id = 1""").one()

val readTokenStr = CassandraRow.fromJavaDriverRow(row, Array("token(id)")).getString(0)
val readTokenStr = CassandraRow.fromJavaDriverRow(row, CassandraRowMetadata.fromColumnNames(IndexedSeq("token(id)"))).getString(0)

val rk = rkg.apply(bStmt)
val rkToken = cp.getToken(rk)
Expand All @@ -61,7 +59,7 @@ class RoutingKeyGeneratorSpec extends SparkCassandraITFlatSpecBase {
session.execute(bStmt)
val row = session.execute(s"""SELECT TOKEN(id, id2) FROM $ks.two_keys WHERE id = 1 AND id2 = 'one'""").one()

val readTokenStr = CassandraRow.fromJavaDriverRow(row, Array("token(id,id2)")).getString(0)
val readTokenStr = CassandraRow.fromJavaDriverRow(row, CassandraRowMetadata.fromColumnNames(IndexedSeq(("token(id,id2)")))).getString(0)

val rk = rkg.apply(bStmt)
val rkToken = cp.getToken(rk)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.datastax.spark.connector.japi;

import com.datastax.driver.core.Row;
import com.datastax.spark.connector.CassandraRowMetadata;
import com.datastax.spark.connector.ColumnRef;
import com.datastax.spark.connector.cql.TableDef;
import com.datastax.spark.connector.rdd.reader.RowReader;
Expand Down Expand Up @@ -31,10 +32,10 @@ private JavaRowReader() {
}

@Override
public CassandraRow read(Row row, String[] columnNames) {
assert row.getColumnDefinitions().size() == columnNames.length :
public CassandraRow read(Row row, CassandraRowMetadata metaData) {
assert row.getColumnDefinitions().size() == metaData.columnNames().size() :
"Number of columns in a row must match the number of columns in the table metadata";
return CassandraRow$.MODULE$.fromJavaDriverRow(row, columnNames);
return CassandraRow$.MODULE$.fromJavaDriverRow(row, metaData);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.datastax.spark.connector

import com.datastax.driver.core.Row
import com.datastax.driver.core.{CodecRegistry, ResultSet, Row, TypeCodec}

/** Represents a single row fetched from Cassandra.
* Offers getters to read individual fields by column name or column index.
Expand All @@ -17,25 +17,25 @@ import com.datastax.driver.core.Row
*
* Recommended getters for Cassandra types:
*
* - `ascii`: `getString`, `getStringOption`
* - `bigint`: `getLong`, `getLongOption`
* - `blob`: `getBytes`, `getBytesOption`
* - `boolean`: `getBool`, `getBoolOption`
* - `counter`: `getLong`, `getLongOption`
* - `decimal`: `getDecimal`, `getDecimalOption`
* - `double`: `getDouble`, `getDoubleOption`
* - `float`: `getFloat`, `getFloatOption`
* - `inet`: `getInet`, `getInetOption`
* - `int`: `getInt`, `getIntOption`
* - `text`: `getString`, `getStringOption`
* - `timestamp`: `getDate`, `getDateOption`
* - `timeuuid`: `getUUID`, `getUUIDOption`
* - `uuid`: `getUUID`, `getUUIDOption`
* - `varchar`: `getString`, `getStringOption`
* - `varint`: `getVarInt`, `getVarIntOption`
* - `list`: `getList[T]`
* - `set`: `getSet[T]`
* - `map`: `getMap[K, V]`
* - `ascii`: `getString`, `getStringOption`
* - `bigint`: `getLong`, `getLongOption`
* - `blob`: `getBytes`, `getBytesOption`
* - `boolean`: `getBool`, `getBoolOption`
* - `counter`: `getLong`, `getLongOption`
* - `decimal`: `getDecimal`, `getDecimalOption`
* - `double`: `getDouble`, `getDoubleOption`
* - `float`: `getFloat`, `getFloatOption`
* - `inet`: `getInet`, `getInetOption`
* - `int`: `getInt`, `getIntOption`
* - `text`: `getString`, `getStringOption`
* - `timestamp`: `getDate`, `getDateOption`
* - `timeuuid`: `getUUID`, `getUUIDOption`
* - `uuid`: `getUUID`, `getUUIDOption`
* - `varchar`: `getString`, `getStringOption`
* - `varint`: `getVarInt`, `getVarIntOption`
* - `list`: `getList[T]`
* - `set`: `getSet[T]`
* - `map`: `getMap[K, V]`
*
* Collection getters `getList`, `getSet` and `getMap` require to explicitly pass an appropriate item type:
* {{{
Expand All @@ -46,17 +46,17 @@ import com.datastax.driver.core.Row
*
* Generic `get` allows to automatically convert collections to other collection types.
* Supported containers:
* - `scala.collection.immutable.List`
* - `scala.collection.immutable.Set`
* - `scala.collection.immutable.TreeSet`
* - `scala.collection.immutable.Vector`
* - `scala.collection.immutable.Map`
* - `scala.collection.immutable.TreeMap`
* - `scala.collection.Iterable`
* - `scala.collection.IndexedSeq`
* - `java.util.ArrayList`
* - `java.util.HashSet`
* - `java.util.HashMap`
* - `scala.collection.immutable.List`
* - `scala.collection.immutable.Set`
* - `scala.collection.immutable.TreeSet`
* - `scala.collection.immutable.Vector`
* - `scala.collection.immutable.Map`
* - `scala.collection.immutable.TreeMap`
* - `scala.collection.Iterable`
* - `scala.collection.IndexedSeq`
* - `java.util.ArrayList`
* - `java.util.HashSet`
* - `java.util.HashMap`
*
* Example:
* {{{
Expand All @@ -68,16 +68,83 @@ import com.datastax.driver.core.Row
*
*
* Timestamps can be converted to other Date types by using generic `get`. Supported date types:
* - java.util.Date
* - java.sql.Date
* - org.joda.time.DateTime
* - java.util.Date
* - java.sql.Date
* - org.joda.time.DateTime
*/
final class CassandraRow(val columnNames: IndexedSeq[String], val columnValues: IndexedSeq[AnyRef])
final class CassandraRow(val metaData: CassandraRowMetadata, val columnValues: IndexedSeq[AnyRef])
extends ScalaGettableData with Serializable {

/**
* The constructor is for testing and backward compatibility only.
* Use default constructor with shared metadata for memory saving and performance.
*
* @param columnNames
* @param columnValues
*/
@deprecated("Use default constructor", "1.6.0")
def this(columnNames: IndexedSeq[String], columnValues: IndexedSeq[AnyRef]) =
this(CassandraRowMetadata.fromColumnNames(columnNames), columnValues)

override def toString = "CassandraRow" + dataAsString
}

/**
* All CassandraRows shared data
*
* @param columnNames row column names
* @param resultSetColumnNames column names from java driver row result set, without connector aliases.
* @param codecs cached java driver codecs to avoid registry lookups
*
*/
case class CassandraRowMetadata(columnNames: IndexedSeq[String],
resultSetColumnNames: Option[IndexedSeq[String]] = None,
// transient because codecs are not serializable and used only at Row parsing
// not and option as deserialized fileld will be null not None
@transient private[connector] val codecs: IndexedSeq[TypeCodec[AnyRef]] = null) {
@transient
lazy val namesToIndex: Map[String, Int] = columnNames.zipWithIndex.toMap.withDefaultValue(-1)
@transient
lazy val indexOfCqlColumnOrThrow = unaliasedColumnNames.zipWithIndex.toMap.withDefault { name =>
throw new ColumnNotFoundException(
s"Column not found: $name. " +
s"Available columns are: ${columnNames.mkString("[", ", ", "]")}")
}

@transient
lazy val indexOfOrThrow = namesToIndex.withDefault { name =>
throw new ColumnNotFoundException(
s"Column not found: $name. " +
s"Available columns are: ${columnNames.mkString("[", ", ", "]")}")
}

def unaliasedColumnNames = resultSetColumnNames.getOrElse(columnNames)
}

object CassandraRowMetadata {

def fromResultSet(columnNames: IndexedSeq[String], rs: ResultSet) = {
import scala.collection.JavaConversions._
val columnDefs = rs.getColumnDefinitions.asList().toList
val rsColumnNames = columnDefs.map(_.getName)
val codecs = columnDefs.map(col => CodecRegistry.DEFAULT_INSTANCE.codecFor(col.getType))
.asInstanceOf[List[TypeCodec[AnyRef]]]
CassandraRowMetadata(columnNames, Some(rsColumnNames.toIndexedSeq), codecs.toIndexedSeq)
}

/**
* create metadata object without codecs. Should be used for testing only
*
* @param columnNames
* @return
*/
def fromColumnNames(columnNames: IndexedSeq[String]): CassandraRowMetadata =
CassandraRowMetadata(columnNames, None)

def fromColumnNames(columnNames: Seq[String]): CassandraRowMetadata =
fromColumnNames(columnNames.toIndexedSeq)
}

object CassandraRow {

/** Deserializes first n columns from the given `Row` and returns them as
Expand All @@ -86,18 +153,40 @@ object CassandraRow {
* the newly created `CassandraRow`, but it is not used to fetch data from
* the input `Row` in order to improve performance. Fetching column values by name is much
* slower than fetching by index. */
def fromJavaDriverRow(row: Row, columnNames: Array[String]): CassandraRow = {
val data = new Array[Object](columnNames.length)
for (i <- columnNames.indices)
data(i) = GettableData.get(row, i)
new CassandraRow(columnNames, data)
def fromJavaDriverRow(row: Row, metaData: CassandraRowMetadata): CassandraRow = {
new CassandraRow(metaData, CassandraRow.dataFromJavaDriverRow(row, metaData))
}

def dataFromJavaDriverRow(row: Row, metaData: CassandraRowMetadata): Array[Object] = {
val length = metaData.columnNames.length
var i = 0
val data = new Array[Object](length)

// Here we use a mutable while loop for performance reasons, scala for loops are
// converted into range.foreach() and the JVM is unable to inline the foreach closure.
// 'match' is replaced with 'if' for the same reason.
// It is also out of the loop for performance.
if (metaData.codecs == null) {
//that should not happen in production, but just in case
while (i < length) {
data(i) = GettableData.get(row, i)
i += 1
}
}
else {
while (i < length) {
data(i) = GettableData.get(row, i, metaData.codecs(i))
i += 1
}
}
data
}

/** Creates a CassandraRow object from a map with keys denoting column names and
* values denoting column values. */
def fromMap(map: Map[String, Any]): CassandraRow = {
val (columnNames, values) = map.unzip
new CassandraRow(columnNames.toIndexedSeq, values.map(_.asInstanceOf[AnyRef]).toIndexedSeq)
new CassandraRow(CassandraRowMetadata.fromColumnNames(columnNames.toIndexedSeq), values.map(_.asInstanceOf[AnyRef]).toIndexedSeq)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,50 @@ import java.nio.ByteBuffer

import scala.collection.JavaConversions._

import com.datastax.driver.core.{LocalDate, Row, TupleValue => DriverTupleValue, UDTValue => DriverUDTValue}
import com.datastax.driver.core.{LocalDate, Row, TypeCodec, TupleValue => DriverTupleValue, UDTValue => DriverUDTValue}
import com.datastax.spark.connector.types.TypeConverter.StringConverter
import com.datastax.spark.connector.util.ByteBufferUtil

trait GettableData extends GettableByIndexData {

def columnNames: IndexedSeq[String]
def metaData: CassandraRowMetadata

@transient
private[connector] lazy val _indexOf =
columnNames.zipWithIndex.toMap.withDefaultValue(-1)

@transient
private[connector] lazy val _indexOfOrThrow = _indexOf.withDefault { name =>
throw new ColumnNotFoundException(
s"Column not found: $name. " +
s"Available columns are: ${columnNames.mkString("[", ", ", "]")}")
}

/** Returns a column value by index without applying any conversion.
/** Returns a column value by aliased name without applying any conversion.
* The underlying type is the same as the type returned by the low-level Cassandra driver,
* is implementation defined and may change in the future.
* Cassandra nulls are returned as Scala nulls. */
def getRaw(name: String): AnyRef = columnValues(_indexOfOrThrow(name))
def getRaw(name: String): AnyRef = columnValues(metaData.indexOfOrThrow(name))

/**
* Returns a column value by cql Name
* @param name
* @return
*/
def getRawCql(name: String): AnyRef = columnValues(metaData.indexOfCqlColumnOrThrow(name))


/** Returns true if column value is Cassandra null */
def isNullAt(name: String): Boolean = {
columnValues(_indexOfOrThrow(name)) == null
columnValues(metaData.indexOfOrThrow(name)) == null
}

/** Returns index of column with given name or -1 if column not found */
def indexOf(name: String): Int =
_indexOf(name)
metaData.namesToIndex(name)

/** Returns the name of the i-th column. */
def nameOf(index: Int): String =
columnNames(index)
metaData.columnNames(index)

/** Returns true if column with given name is defined and has an
* entry in the underlying value array, i.e. was requested in the result set.
* For columns having null value, returns true. */
def contains(name: String): Boolean =
_indexOf(name) != -1
metaData.namesToIndex(name) != -1

/** Displays the content in human readable form, including the names and values of the columns */
override def dataAsString =
columnNames
metaData.columnNames
.zip(columnValues)
.map(kv => kv._1 + ": " + StringConverter.convert(kv._2))
.mkString("{", ", ", "}")
Expand All @@ -59,13 +56,13 @@ trait GettableData extends GettableByIndexData {

override def equals(o: Any) = o match {
case o: GettableData if
this.columnNames == o.columnNames &&
this.metaData == o.metaData &&
this.columnValues == o.columnValues => true
case _ => false
}

override def hashCode =
columnNames.hashCode * 31 + columnValues.hashCode
metaData.hashCode * 31 + columnValues.hashCode
}

object GettableData {
Expand Down Expand Up @@ -96,12 +93,29 @@ object GettableData {
null
}


/** Deserializes given field from the DataStax Java Driver `Row` into appropriate Java type by using predefined codec
* If the field is null, returns null (not Scala Option). */
def get(row: Row, index: Int, codec: TypeCodec[AnyRef]): AnyRef = {
val data = row.get(index, codec)
if (data != null)
convert(data)
else
null
}

def get(row: Row, name: String): AnyRef = {
val index = row.getColumnDefinitions.getIndexOf(name)
require(index >= 0, s"Column not found in Java driver Row: $name")
get(row, index)
}

def get(row: Row, name: String, codec: TypeCodec[AnyRef]): AnyRef = {
val index = row.getColumnDefinitions.getIndexOf(name)
require(index >= 0, s"Column not found in Java driver Row: $name")
get(row, index, codec)
}

def get(value: DriverUDTValue, name: String): AnyRef = {
val quotedName = "\"" + name + "\""
val data = value.getObject(quotedName)
Expand Down
Loading

0 comments on commit 7b0147b

Please sign in to comment.