diff --git a/src/main/scala/com/lucidworks/spark/util/SolrQuerySupport.scala b/src/main/scala/com/lucidworks/spark/util/SolrQuerySupport.scala index 0dc193a9..c3e497cc 100644 --- a/src/main/scala/com/lucidworks/spark/util/SolrQuerySupport.scala +++ b/src/main/scala/com/lucidworks/spark/util/SolrQuerySupport.scala @@ -69,9 +69,9 @@ object SolrQuerySupport extends Logging { "solr.StrField" -> DataTypes.StringType, "solr.TextField" -> DataTypes.StringType, "solr.BoolField" -> DataTypes.BooleanType, - "solr.TrieIntField" -> DataTypes.IntegerType, + "solr.TrieIntField" -> DataTypes.LongType, "solr.TrieLongField" -> DataTypes.LongType, - "solr.TrieFloatField" -> DataTypes.FloatType, + "solr.TrieFloatField" -> DataTypes.DoubleType, "solr.TrieDoubleField" -> DataTypes.DoubleType, "solr.TrieDateField" -> DataTypes.TimestampType, "solr.BinaryField" -> DataTypes.BinaryType diff --git a/src/main/scala/com/lucidworks/spark/util/SolrRelationUtil.scala b/src/main/scala/com/lucidworks/spark/util/SolrRelationUtil.scala index 0b6eebb1..dfb0bb67 100644 --- a/src/main/scala/com/lucidworks/spark/util/SolrRelationUtil.scala +++ b/src/main/scala/com/lucidworks/spark/util/SolrRelationUtil.scala @@ -271,6 +271,8 @@ object SolrRelationUtil extends Logging { if (fieldValues != null) { val iterableValues = fieldValues.iterator().map { case d: Date => new Timestamp(d.getTime) + case i: java.lang.Integer => new java.lang.Long(i.longValue()) + case f: java.lang.Float => new java.lang.Double(f.doubleValue()) case a => a } values.add(iterableValues.toArray) @@ -283,9 +285,13 @@ object SolrRelationUtil extends Logging { fieldValue match { case f: String => values.add(f) case f: Date => values.add(new Timestamp(f.getTime)) + case i: java.lang.Integer => values.add(new java.lang.Long(i.longValue())) + case f: java.lang.Float => values.add(new java.lang.Double(f.doubleValue())) case f: java.util.ArrayList[_] => val jlist = f.iterator.map { case d: Date => new Timestamp(d.getTime) + case i: java.lang.Integer => new java.lang.Long(i.longValue()) + case f: java.lang.Float => new java.lang.Double(f.doubleValue()) case v: Any => v } val arr = jlist.toArray @@ -295,6 +301,8 @@ object SolrRelationUtil extends Logging { case f: Iterable[_] => val iterableValues = f.iterator.map { case d: Date => new Timestamp(d.getTime) + case i: java.lang.Integer => new java.lang.Long(i.longValue()) + case f: java.lang.Float => new java.lang.Double(f.doubleValue()) case v: Any => v } val arr = iterableValues.toArray diff --git a/src/test/java/com/lucidworks/spark/SolrRelationTest.java b/src/test/java/com/lucidworks/spark/SolrRelationTest.java index 4ba3e20b..f7c1453c 100644 --- a/src/test/java/com/lucidworks/spark/SolrRelationTest.java +++ b/src/test/java/com/lucidworks/spark/SolrRelationTest.java @@ -390,7 +390,7 @@ protected void validateSchema(DataFrame df) { if (fieldName.equals("id") || fieldName.endsWith("_s")) { assertEquals("Field '" + fieldName + "' should be a string but has type '" + type + "' instead!", "string", type.typeName()); } else if (fieldName.endsWith("_i")) { - assertEquals("Field '" + fieldName + "' should be an integer but has type '" + type + "' instead!", "integer", type.typeName()); + assertEquals("Field '" + fieldName + "' should be an integer but has type '" + type + "' instead!", "long", type.typeName()); } else if (fieldName.endsWith("_ss")) { assertEquals("Field '"+fieldName+"' should be an array but has '"+type+"' instead!", "array", type.typeName()); ArrayType arrayType = (ArrayType)type; @@ -400,7 +400,7 @@ protected void validateSchema(DataFrame df) { assertEquals("Field '"+fieldName+"' should be an array but has '"+type+"' instead!", "array", type.typeName()); ArrayType arrayType = (ArrayType)type; assertEquals("Field '"+fieldName+"' should have an integer element type but has '"+arrayType.elementType()+ - "' instead!", "integer", arrayType.elementType().typeName()); + "' instead!", "long", arrayType.elementType().typeName()); } } } diff --git a/src/test/java/com/lucidworks/spark/SolrSqlTest.java b/src/test/java/com/lucidworks/spark/SolrSqlTest.java index 35a62e51..3a650a75 100644 --- a/src/test/java/com/lucidworks/spark/SolrSqlTest.java +++ b/src/test/java/com/lucidworks/spark/SolrSqlTest.java @@ -60,7 +60,7 @@ public void testSQLQueries() throws Exception { assert fieldNames.length == 19 + 1 + 1; // extra fields are id and _version_ Assert.assertEquals(schema.apply("ts").dataType().typeName(), DataTypes.TimestampType.typeName()); - Assert.assertEquals(schema.apply("sessionId").dataType().typeName(), DataTypes.IntegerType.typeName()); + Assert.assertEquals(schema.apply("sessionId").dataType().typeName(), DataTypes.LongType.typeName()); Assert.assertEquals(schema.apply("length").dataType().typeName(), DataTypes.DoubleType.typeName()); Assert.assertEquals(schema.apply("song").dataType().typeName(), DataTypes.StringType.typeName()); diff --git a/src/test/resources/eventsim/fields_schema.json b/src/test/resources/eventsim/fields_schema.json index 8af17de3..b4f179fb 100644 --- a/src/test/resources/eventsim/fields_schema.json +++ b/src/test/resources/eventsim/fields_schema.json @@ -3,7 +3,7 @@ "name": "userId", "type": "string", "indexed": "true", "stored": "true", "docValues": "true" }, { - "name": "sessionId", "type": "int", "indexed": "true", "stored": "true" + "name": "sessionId", "type": "tint", "indexed": "true", "stored": "true" }, { "name": "page", "type": "string", "indexed": "true", "stored": "true" @@ -15,7 +15,7 @@ "name": "method", "type": "string", "indexed": "true", "stored": "true" }, { - "name": "status", "type": "int", "indexed": "true", "stored": "true" + "name": "status", "type": "int", "indexed": "true", "stored": "true", "docValues": "true" }, { "name": "level", "type": "string", "indexed": "true", "stored": "true" diff --git a/src/test/scala/com/lucidworks/spark/EventsimTestSuite.scala b/src/test/scala/com/lucidworks/spark/EventsimTestSuite.scala index 90c1e397..ab42c55e 100644 --- a/src/test/scala/com/lucidworks/spark/EventsimTestSuite.scala +++ b/src/test/scala/com/lucidworks/spark/EventsimTestSuite.scala @@ -152,6 +152,32 @@ class EventsimTestSuite extends EventsimBuilder { assert(timeQueryDF.count() == 21) } + test("Streaming query with int field") { + val df: DataFrame = sqlContext.read.format("solr") + .option("zkHost", zkHost) + .option("collection", collectionName) + .option(USE_EXPORT_HANDLER, "true") + .option(ARBITRARY_PARAMS_STRING, "fl=status,length&sort=userId desc") // The test will fail without the fl param here + .load() + df.registerTempTable("events") + + val queryDF = sqlContext.sql("SELECT count(distinct status), avg(length) FROM events") + val values = queryDF.collect() + } + + test("Non streaming query with int field") { + val df: DataFrame = sqlContext.read.format("solr") + .option("zkHost", zkHost) + .option("collection", collectionName) + .option(ARBITRARY_PARAMS_STRING, "fl=status,length&sort=id desc") // The test will fail without the fl param here + .load() + df.registerTempTable("events") + + val queryDF = sqlContext.sql("SELECT count(distinct status), avg(length) FROM events") + val values = queryDF.collect() + assert(values(0)(0) == 3) + } + def testCommons(solrRDD: SolrRDD): Unit = { val sparkCount = solrRDD.count()