From b921dc183cca54adba0042cabb9042aa21c819a8 Mon Sep 17 00:00:00 2001 From: "Peter J. Martel" Date: Fri, 17 Dec 2021 01:54:28 -0500 Subject: [PATCH] Spark 3.2 (#9) * Bump spark version to 3.2.0 * Bump alchemy version for dependency change * Code updates for changes to Spark APIs --- VERSION | 2 +- .../spark/expressions/hll/HLLFunctions.scala | 31 +++++++++++++++---- .../org/apache/spark/sql/test/SQLHelper.scala | 2 +- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../spark/sql/test/TestSparkSession.scala | 2 +- build.sbt | 2 +- 6 files changed, 30 insertions(+), 11 deletions(-) diff --git a/VERSION b/VERSION index 9084fa2..26aaba0 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.1.0 +1.2.0 diff --git a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala index 7b58d0b..ff51cc3 100644 --- a/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala +++ b/alchemy/src/main/scala/com/swoop/alchemy/spark/expressions/hll/HLLFunctions.scala @@ -10,6 +10,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus.v import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, ExpressionDescription, Literal, UnaryExpression} +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ @@ -27,7 +28,7 @@ object HyperLogLogBase { def resolveImplementation(exp: String): Implementation = exp match { case null => resolveImplementation - case s => nameToImpl(s.toString) + case s => nameToImpl(s) } def resolveImplementation(implicit impl: Implementation = null): Implementation = @@ -60,7 +61,7 @@ object HyperLogLogBase { } } -trait HyperLogLogInit extends Expression with HyperLogLogBase { +trait HyperLogLogInit extends Expression with UnaryLike[Expression] with HyperLogLogBase { def relativeSD: Double // This formula for `p` came from org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus:93 @@ -141,7 +142,7 @@ trait HyperLogLogInitAgg extends NullableSketchAggregation with HyperLogLogInit } } -trait NullableSketchAggregation extends TypedImperativeAggregate[Option[Instance]] with HyperLogLogBase { +trait NullableSketchAggregation extends TypedImperativeAggregate[Option[Instance]] with HyperLogLogBase with UnaryLike[Expression] { override def createAggregationBuffer(): Option[Instance] = None @@ -159,8 +160,6 @@ trait NullableSketchAggregation extends TypedImperativeAggregate[Option[Instance def child: Expression - override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = child.nullable override def serialize(hll: Option[Instance]): Array[Byte] = @@ -214,6 +213,8 @@ case class HyperLogLogInitSimple( } override def prettyName: String = "hll_init" + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) } @@ -267,6 +268,8 @@ case class HyperLogLogInitSimpleAgg( copy(inputAggBufferOffset = newOffset) override def prettyName: String = "hll_init_agg" + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) } /** @@ -313,6 +316,8 @@ case class HyperLogLogInitCollection( override def prettyName: String = "hll_init_collection" + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) } @@ -367,6 +372,8 @@ case class HyperLogLogInitCollectionAgg( copy(inputAggBufferOffset = newOffset) override def prettyName: String = "hll_init_collection_agg" + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) } @@ -427,6 +434,8 @@ case class HyperLogLogMerge( copy(inputAggBufferOffset = newOffset) override def prettyName: String = "hll_merge" + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) } /** @@ -455,7 +464,7 @@ case class HyperLogLogRowMerge( assert(children.nonEmpty, s"function requires at least one argument") children }.last match { - case Literal(s: Any, StringType) => children.init + case Literal(_: Any, StringType) => children.init case _ => children }, children.last match { @@ -490,6 +499,9 @@ case class HyperLogLogRowMerge( } override def prettyName: String = "hll_row_merge" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) } /** @@ -527,6 +539,8 @@ case class HyperLogLogCardinality( } override def prettyName: String = "hll_cardinality" + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) } /** @@ -598,6 +612,9 @@ case class HyperLogLogIntersectionCardinality( } override def prettyName: String = "hll_intersect_cardinality" + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(left = newLeft, right = newRight) } @@ -648,6 +665,8 @@ case class HyperLogLogConvert( } override def prettyName: String = "hll_convert" + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) } object functions extends HLLFunctions { diff --git a/alchemy/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala b/alchemy/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala index e801254..0c4f5b9 100644 --- a/alchemy/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala +++ b/alchemy/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala @@ -23,7 +23,7 @@ trait SQLHelper { } } (keys, values).zipped.foreach { (k, v) => - if (SQLConf.staticConfKeys.contains(k)) { + if (SQLConf.isStaticConfigKey(k)) { throw new AnalysisException(s"Cannot modify the value of a static config: $k") } conf.setConfString(k, v) diff --git a/alchemy/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/alchemy/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d0c266f..7f1ac0a 100644 --- a/alchemy/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/alchemy/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -231,7 +231,7 @@ trait SQLTestUtilsBase // Blocking uncache table for tests protected def uncacheTable(tableName: String): Unit = { val tableIdent = spark.sessionState.sqlParser.parseTableIdentifier(tableName) - val cascade = !spark.sessionState.catalog.isTemporaryTable(tableIdent) + val cascade = !spark.sessionState.catalog.isTempView(tableIdent) spark.sharedState.cacheManager.uncacheQuery( spark, spark.table(tableName).logicalPlan, diff --git a/alchemy/src/test/scala/org/apache/spark/sql/test/TestSparkSession.scala b/alchemy/src/test/scala/org/apache/spark/sql/test/TestSparkSession.scala index ddf1c83..b1f668e 100644 --- a/alchemy/src/test/scala/org/apache/spark/sql/test/TestSparkSession.scala +++ b/alchemy/src/test/scala/org/apache/spark/sql/test/TestSparkSession.scala @@ -55,7 +55,7 @@ object TestSQLContext { private[sql] class TestSQLSessionStateBuilder( session: SparkSession, state: Option[SessionState]) - extends SessionStateBuilder(session, state, Map.empty[String, String]) with WithTestConf { + extends SessionStateBuilder(session, state) with WithTestConf { override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _) } diff --git a/build.sbt b/build.sbt index c652029..49b4248 100644 --- a/build.sbt +++ b/build.sbt @@ -6,7 +6,7 @@ ThisBuild / crossScalaVersions := Seq("2.12.11") ThisBuild / javacOptions ++= Seq("-source", "1.8", "-target", "1.8") -val sparkVersion = "3.1.2" +val sparkVersion = "3.2.0" lazy val scalaSettings = Seq( scalaVersion := "2.12.11",