diff --git a/cluster/src/test/scala/org/apache/spark/sql/execution/SnappyTableMutableAPISuite.scala b/cluster/src/test/scala/org/apache/spark/sql/execution/SnappyTableMutableAPISuite.scala index cd6e627ef5..a5c31d5cc1 100644 --- a/cluster/src/test/scala/org/apache/spark/sql/execution/SnappyTableMutableAPISuite.scala +++ b/cluster/src/test/scala/org/apache/spark/sql/execution/SnappyTableMutableAPISuite.scala @@ -585,6 +585,58 @@ class SnappyTableMutableAPISuite extends SnappyFunSuite with Logging with Before assert(resultdf.contains(Row(8, 8, 8))) } + test("update with dataframe API col tables") { + val snc = new SnappySession(sc) + val rdd = sc.parallelize(data2, 2).map(s => Data(s(0), s(1), s(2))) + val df1 = snc.createDataFrame(rdd) + val rdd2 = sc.parallelize(data1, 2).map(s => DataDiffCol(s(0), s(1), s(2))) + val df2 = snc.createDataFrame(rdd2) + + snc.createTable("col_table", "column", + df1.schema, Map("key_columns" -> "col2")) + + df1.write.insertInto("col_table") + df2.write.update("col_table") + + val resultdf = snc.table("col_table").collect() + assert(resultdf.length == 7) + assert(resultdf.contains(Row(88, 88, 88))) + } + + test("update with dataframe API row tables") { + val snc = new SnappySession(sc) + val rdd = sc.parallelize(data2, 2).map(s => Data(s(0), s(1), s(2))) + val df1 = snc.createDataFrame(rdd) + val rdd2 = sc.parallelize(data1, 2).map(s => DataDiffCol(s(0), s(1), s(2))) + val df2 = snc.createDataFrame(rdd2) + + snc.sql("create table row_table (col1 int, col2 int, col3 int, PRIMARY KEY (col2))") + + df1.write.insertInto("row_table") + df2.write.update("row_table") + + val resultdf = snc.table("row_table").collect() + assert(resultdf.length == 7) + assert(resultdf.contains(Row(88, 88, 88))) + } + + test("Update row tables Key columns validation") { + val snc = new SnappySession(sc) + val rdd = sc.parallelize(data2, 2).map(s => Data(s(0), s(1), s(2))) + val df1 = snc.createDataFrame(rdd) + val rdd2 = sc.parallelize(data1, 2).map(s => DataDiffCol(s(0), s(1), s(2))) + val df2 = snc.createDataFrame(rdd2) + + snc.createTable("row_table", "row", + df1.schema, Map.empty[String, String]) + + df1.write.insertInto("row_table") + + intercept[AnalysisException]{ + df2.write.update("row_table") + } + } + test("DeleteFrom Key columns validation") { val snc = new SnappySession(sc) val rdd = sc.parallelize(data2, 2).map(s => Data(s(0), s(1), s(2))) diff --git a/core/src/main/scala/io/snappydata/Literals.scala b/core/src/main/scala/io/snappydata/Literals.scala index 612ae7be15..01301235b7 100644 --- a/core/src/main/scala/io/snappydata/Literals.scala +++ b/core/src/main/scala/io/snappydata/Literals.scala @@ -355,10 +355,16 @@ object Property extends Enumeration { s". Default is true.", Some(true), null, false) val PutIntoInnerJoinCacheSize = - SQLVal[Long](s"${Constant.PROPERTY_PREFIX}cache.putIntoInnerJoinResultSize", - "The putInto inner join would be cached if the table is of size less " + - "than PutIntoInnerJoinCacheSize. Default value is 100 MB.", Some(100L * 1024 * 1024)) - + SQLVal[String](s"${Constant.PROPERTY_PREFIX}cache.putIntoInnerJoinResultSize", + "The putInto inner join would be cached if the result of " + + "join with incoming Dataset is of size less " + + "than PutIntoInnerJoinCacheSize. Default value is 100 MB.", Some("100m")) + + val ForceCachePutIntoInnerJoin = + SQLVal[Boolean](s"${Constant.PROPERTY_PREFIX}cache.putIntoInnerJoinResult", + "if this property is set, The putInto inner join would be cached irrespective of size." + + "Primarily used in Streaming sources where correct stats can not be calculated." + + "", Some(false)) } diff --git a/core/src/main/scala/org/apache/spark/sql/DataFrameWriterJavaFunctions.scala b/core/src/main/scala/org/apache/spark/sql/DataFrameWriterJavaFunctions.scala index debfc47d69..80f6d58608 100644 --- a/core/src/main/scala/org/apache/spark/sql/DataFrameWriterJavaFunctions.scala +++ b/core/src/main/scala/org/apache/spark/sql/DataFrameWriterJavaFunctions.scala @@ -35,4 +35,8 @@ class DataFrameWriterJavaFunctions(val dfWriter: DataFrameWriter[_]) { def deleteFrom(tableName: String): Unit = { new DataFrameWriterExtensions(dfWriter).deleteFrom(tableName) } + + def update(tableName: String): Unit = { + new DataFrameWriterExtensions(dfWriter).update(tableName) + } } diff --git a/core/src/main/scala/org/apache/spark/sql/SnappyImplicits.scala b/core/src/main/scala/org/apache/spark/sql/SnappyImplicits.scala index d3d7da4e76..f27aa2148f 100644 --- a/core/src/main/scala/org/apache/spark/sql/SnappyImplicits.scala +++ b/core/src/main/scala/org/apache/spark/sql/SnappyImplicits.scala @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, SubqueryAlias} import org.apache.spark.sql.internal.ColumnTableBulkOps -import org.apache.spark.sql.sources.{DeleteFromTable, PutIntoTable} +import org.apache.spark.sql.sources.{DeleteFromTable, BulkUpdate} import org.apache.spark.{Partition, TaskContext} /** @@ -165,12 +165,12 @@ object snappy extends Serializable { extends Serializable { /** - * "Puts" the content of the [[DataFrame]] to the specified table. It - * requires that the schema of the [[DataFrame]] is the same as the schema - * of the table. If some rows are already present then they are updated. - * - * This ignores all SaveMode. - */ + * "Puts" the content of the [[DataFrame]] to the specified table. It + * requires that the schema of the [[DataFrame]] is the same as the schema + * of the table. If some rows are already present then they are updated. + * + * This ignores all SaveMode. + */ def putInto(tableName: String): Unit = { val df: DataFrame = dfField.get(writer).asInstanceOf[DataFrame] val session = df.sparkSession match { @@ -191,13 +191,13 @@ object snappy extends Serializable { }.getOrElse(df.logicalPlan) try { - df.sparkSession.sessionState.executePlan(PutIntoTable(UnresolvedRelation( - session.sessionState.catalog.newQualifiedTableName(tableName)), input)) + df.sparkSession.sessionState.executePlan(BulkUpdate(UnresolvedRelation( + session.sessionState.catalog.newQualifiedTableName(tableName)), input, isPutInto = true)) .executedPlan.executeCollect() } finally { df.sparkSession.asInstanceOf[SnappySession]. getContextObject[LogicalPlan](ColumnTableBulkOps.CACHED_PUTINTO_UPDATE_PLAN). - map { cachedPlan => + foreach { cachedPlan => df.sparkSession. sharedState.cacheManager.uncacheQuery(df.sparkSession, cachedPlan, true) } @@ -228,7 +228,32 @@ object snappy extends Serializable { .executedPlan.executeCollect() } + def update(tableName: String): Unit = { + val df: DataFrame = dfField.get(writer).asInstanceOf[DataFrame] + val session = df.sparkSession match { + case sc: SnappySession => sc + case _ => sys.error("Expected a SnappyContext for putInto operation") + } + val normalizedParCols = parColsMethod.invoke(writer) + .asInstanceOf[Option[Seq[String]]] + // A partitioned relation's schema can be different from the input + // logicalPlan, since partition columns are all moved after data columns. + // We Project to adjust the ordering. + // TODO: this belongs to the analyzer. + val input = normalizedParCols.map { parCols => + val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { + attr => parCols.contains(attr.name) + } + Project(inputDataCols ++ inputPartCols, df.logicalPlan) + }.getOrElse(df.logicalPlan) + + df.sparkSession.sessionState.executePlan(BulkUpdate(UnresolvedRelation( + session.sessionState.catalog.newQualifiedTableName(tableName)), input, isPutInto = false)) + .executedPlan.executeCollect() + + } } + } private[sql] case class SnappyDataFrameOperations(session: SnappySession, diff --git a/core/src/main/scala/org/apache/spark/sql/SnappyParser.scala b/core/src/main/scala/org/apache/spark/sql/SnappyParser.scala index 3af243c3c0..c51e978171 100644 --- a/core/src/main/scala/org/apache/spark/sql/SnappyParser.scala +++ b/core/src/main/scala/org/apache/spark/sql/SnappyParser.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.collection.Utils -import org.apache.spark.sql.sources.{Delete, Insert, PutIntoTable, Update} +import org.apache.spark.sql.sources.{Delete, Insert, BulkUpdate, Update} import org.apache.spark.sql.streaming.WindowLogicalPlan import org.apache.spark.sql.types._ import org.apache.spark.sql.{SnappyParserConsts => Consts} @@ -991,7 +991,8 @@ class SnappyParser(session: SnappySession) extends SnappyDDLParser(session) { } protected final def put: Rule1[LogicalPlan] = rule { - PUT ~ INTO ~ TABLE.? ~ relationFactor ~ subSelectQuery ~> PutIntoTable + PUT ~ INTO ~ TABLE.? ~ relationFactor ~ subSelectQuery ~> + ((t: LogicalPlan, s: LogicalPlan) => BulkUpdate(t, s, true)) } protected final def update: Rule1[LogicalPlan] = rule { diff --git a/core/src/main/scala/org/apache/spark/sql/execution/columnar/impl/ColumnFormatRelation.scala b/core/src/main/scala/org/apache/spark/sql/execution/columnar/impl/ColumnFormatRelation.scala index 605724d3d6..cbeaeb8a47 100644 --- a/core/src/main/scala/org/apache/spark/sql/execution/columnar/impl/ColumnFormatRelation.scala +++ b/core/src/main/scala/org/apache/spark/sql/execution/columnar/impl/ColumnFormatRelation.scala @@ -517,7 +517,7 @@ class ColumnFormatRelation( _externalStore, _partitioningColumns, _context) - with ParentRelation with DependentRelation with BulkPutRelation { + with ParentRelation with DependentRelation with BulkUpdateRelation { val tableOptions = new CaseInsensitiveMutableHashMap(_origOptions) override def withKeyColumns(relation: LogicalRelation, @@ -670,11 +670,11 @@ class ColumnFormatRelation( * otherwise it gets inserted into the table represented by this relation. * The result of SparkPlan execution should be a count of number of rows put. */ - override def getPutPlan(insertPlan: SparkPlan, updatePlan: SparkPlan): SparkPlan = { + override def getUpdatePlan(insertPlan: SparkPlan, updatePlan: SparkPlan): SparkPlan = { ColumnPutIntoExec(insertPlan, updatePlan) } - override def getPutKeys(): Option[Seq[String]] = { + override def getUpdateKeys(): Option[Seq[String]] = { val keys = _origOptions.get(ExternalStoreUtils.KEY_COLUMNS) keys match { case Some(x) => Some(x.split(",").map(s => s.trim).toSeq) diff --git a/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelationUtil.scala b/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelationUtil.scala new file mode 100644 index 0000000000..501959b756 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelationUtil.scala @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2017 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, DriverManager} + +import org.apache.spark.Partition +import org.apache.spark.jdbc.{ConnectionConf, ConnectionConfBuilder, ConnectionUtil} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.JDBC_DRIVER_CLASS +import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Dataset, Row, SnappySession} + +/** + * This object is used to get an optimized JDBC RDD, which uses pooled connections. + */ +object JDBCRelationUtil { + + def jdbcDF(sparkSession: SnappySession, + parameters: Map[String, String], + table: String, + schema: StructType, + requiredColumns: Array[String], + parts: Array[Partition], + conf : ConnectionConf, + filters: Array[Filter]): Dataset[Row] = { + val url = parameters("url") + val options = new JDBCOptions(url, table, parameters) + + val dialect = JdbcDialects.get(url) + val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) + val rdd = new JDBCRDD( + sparkSession.sparkContext, + createConnectionFactory(url, conf), + schema, + quotedColumns, + filters, + parts, + url, + options).asInstanceOf[RDD[InternalRow]] + + sparkSession.internalCreateDataFrame(rdd, schema) + } + + def createConnectionFactory(url: String, conf: ConnectionConf): () => Connection = { + () => { + ConnectionUtil.getPooledConnection(url, conf) + } + } + + def buildConf(snappySession: SnappySession, + parameters: Map[String, String]): ConnectionConf = { + val url = parameters("url") + val driverClass = { + val userSpecifiedDriverClass = parameters.get(JDBC_DRIVER_CLASS) + userSpecifiedDriverClass.foreach(DriverRegistry.register) + + // Performing this part of the logic on the driver guards against the corner-case where the + // driver returned for a URL is different on the driver and executors due to classpath + // differences. + userSpecifiedDriverClass.getOrElse { + DriverManager.getDriver(url).getClass.getCanonicalName + } + } + new ConnectionConfBuilder(snappySession) + .setPoolProvider(parameters.getOrElse("poolImpl", "hikari")) + .setPoolConf("maximumPoolSize", parameters.getOrElse("maximumPoolSize", "10")) + .setPoolConf("minimumIdle", parameters.getOrElse("minimumIdle", "5")) + .setDriver(driverClass) + .setConf("user", parameters.getOrElse("user", "")) + .setConf("password", parameters.getOrElse("password", "")) + .setURL(url) + .build() + } + + def schema(table: String, parameters: Map[String, String]): StructType = { + val url = parameters("url") + val options = new JDBCOptions(url, table, parameters) + JDBCRDD.resolveTable(options) + } +} diff --git a/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoinExec.scala b/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoinExec.scala index 2d673d4ca6..f9e92eccac 100644 --- a/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoinExec.scala +++ b/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoinExec.scala @@ -476,6 +476,10 @@ case class HashJoinExec(leftKeys: Seq[Expression], val produced = streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) val beforeMap = ctx.freshName("beforeMap") + val skipScan = skipProcessForEmptyMap() + val skipCondition = if (skipScan) { + s"if($hashMapTerm.size() == 0) return;" + } else "" s""" boolean $keyIsUniqueTerm = true; @@ -491,6 +495,7 @@ case class HashJoinExec(leftKeys: Seq[Expression], final $entryClass[] $mapDataTerm = ($entryClass[])$hashMapTerm.data(); long $numRowsTerm = 0L; try { + $skipCondition ${session.evaluateFinallyCode(ctx, produced)} } finally { $numOutputRows.${metricAdd(numRowsTerm)}; @@ -498,6 +503,13 @@ case class HashJoinExec(leftKeys: Seq[Expression], """ } + private def skipProcessForEmptyMap() : Boolean = { + joinType match { + case Inner | LeftSemi => true + case _ => false + } + } + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { // variable that holds if relation is unique to optimize iteration diff --git a/core/src/main/scala/org/apache/spark/sql/internal/ColumnTableBulkOps.scala b/core/src/main/scala/org/apache/spark/sql/internal/ColumnTableBulkOps.scala index e688da1db2..d97fb25bba 100644 --- a/core/src/main/scala/org/apache/spark/sql/internal/ColumnTableBulkOps.scala +++ b/core/src/main/scala/org/apache/spark/sql/internal/ColumnTableBulkOps.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.internal import io.snappydata.Property +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, EqualTo, Expression} import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, OverwriteOptions, Project} import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti} import org.apache.spark.sql.collection.Utils +import org.apache.spark.sql.execution.columnar.ExternalStoreUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, LongType} @@ -33,13 +35,13 @@ import org.apache.spark.sql.{AnalysisException, Dataset, Row, SnappyContext, Sna * This class takes the logical plans from SnappyParser * and converts it into another plan. */ -object ColumnTableBulkOps { +object ColumnTableBulkOps extends Logging { val CACHED_PUTINTO_UPDATE_PLAN = "cached_putinto_logical_plan" - def validateOp(originalPlan: PutIntoTable) { + def validateOp(originalPlan: BulkUpdate) { originalPlan match { - case PutIntoTable(LogicalRelation(t: BulkPutRelation, _, _), query) => + case BulkUpdate(LogicalRelation(t: BulkUpdateRelation, _, _), query, _) => val srcRelations = query.collect { case LogicalRelation(src: BaseRelation, _, _) => src } @@ -53,49 +55,73 @@ object ColumnTableBulkOps { } } - def transformPutPlan(sparkSession: SparkSession, originalPlan: PutIntoTable): LogicalPlan = { + def transformUpdatePlan(sparkSession: SparkSession, originalPlan: BulkUpdate): LogicalPlan = { validateOp(originalPlan) val table = originalPlan.table + val isPutInto = originalPlan.isPutInto val subQuery = originalPlan.child var transFormedPlan: LogicalPlan = originalPlan table.collectFirst { - case LogicalRelation(mutable: BulkPutRelation, _, _) => { - val putKeys = mutable.getPutKeys() - if (!putKeys.isDefined) { + case LogicalRelation(mutable: BulkUpdateRelation, _, _) => { + val keys = mutable.getUpdateKeys() + if (!keys.isDefined) { throw new AnalysisException( - s"PutInto in a column table requires key column(s) but got empty string") + s"Bulk update/putInto on a column table requires key column(s) but got empty string") } - val condition = prepareCondition(sparkSession, table, subQuery, putKeys.get) + val condition = prepareCondition(sparkSession, table, subQuery, keys.get) val keyColumns = getKeyColumns(table) val updateSubQuery = Join(table, subQuery, Inner, condition) val updateColumns = table.output.filterNot(a => keyColumns.contains(a.name)) - val cacheSize = Property.PutIntoInnerJoinCacheSize - .getOption(sparkSession.sparkContext.conf) match { - case Some(size) => size.toInt - case None => Property.PutIntoInnerJoinCacheSize.defaultValue.get - } - if (updateSubQuery.statistics.sizeInBytes <= cacheSize) { - sparkSession.sharedState.cacheManager. - cacheQuery(new Dataset(sparkSession, - updateSubQuery, RowEncoder(updateSubQuery.schema))) - sparkSession.asInstanceOf[SnappySession]. - addContextObject(CACHED_PUTINTO_UPDATE_PLAN, updateSubQuery) + if (isPutInto) { + val cacheSize = ExternalStoreUtils.sizeAsBytes( + Property.PutIntoInnerJoinCacheSize.get(sparkSession.sqlContext.conf), + Property.PutIntoInnerJoinCacheSize.name, -1, Long.MaxValue) + val forceCache = Property.ForceCachePutIntoInnerJoin.get(sparkSession.sqlContext.conf) + if (updateSubQuery.statistics.sizeInBytes <= cacheSize || forceCache) { + sparkSession.sharedState.cacheManager. + cacheQuery(new Dataset(sparkSession, + updateSubQuery, RowEncoder(updateSubQuery.schema))) + sparkSession.asInstanceOf[SnappySession]. + addContextObject(CACHED_PUTINTO_UPDATE_PLAN, updateSubQuery) + } + + val notExists = Join(subQuery, updateSubQuery, LeftAnti, condition) + val insertPlan = new Insert(table, Map.empty[String, + Option[String]], Project(subQuery.output, notExists), + OverwriteOptions(false), ifNotExists = false) + + val updateExpressions = + notExists.output.filterNot(a => keyColumns.contains(a.name)) + val updatePlan = Update(table, updateSubQuery, Seq.empty, + updateColumns, updateExpressions) + + transFormedPlan = PutIntoColumnTable(table, insertPlan, updatePlan) + } else { + val updateExpressions = + subQuery.output.filterNot(a => keyColumns.contains(a.name)) + transFormedPlan = Update(table, updateSubQuery, Seq.empty, + updateColumns, updateExpressions) } - val notExists = Join(subQuery, updateSubQuery, LeftAnti, condition) - val insertPlan = new Insert(table, Map.empty[String, - Option[String]], Project(subQuery.output, notExists), - OverwriteOptions(false), ifNotExists = false) + } + case LogicalRelation(mutable: RowPutRelation, _, _) if !isPutInto => + // For row tables get the actual key columns + val keyColumns = getKeyColumns(table) + if (keyColumns.isEmpty) { + throw new AnalysisException( + s"Empty key columns for update/delete on $mutable") + } + val condition = prepareCondition(sparkSession, table, subQuery, keyColumns) + val updateSubQuery = Join(table, subQuery, Inner, condition) + val updateColumns = table.output.filterNot(a => keyColumns.contains(a.name)) - val updateExpressions = notExists.output.filterNot(a => keyColumns.contains(a.name)) - val updatePlan = Update(table, updateSubQuery, Seq.empty, + val updateExpressions = + subQuery.output.filterNot(a => keyColumns.contains(a.name)) + transFormedPlan = Update(table, updateSubQuery, Seq.empty, updateColumns, updateExpressions) - - transFormedPlan = PutIntoColumnTable(table, insertPlan, updatePlan) - } case _ => // Do nothing, original putInto plan is enough } transFormedPlan @@ -146,13 +172,13 @@ object ColumnTableBulkOps { var transFormedPlan: LogicalPlan = originalPlan table.collectFirst { - case LogicalRelation(mutable: BulkPutRelation, _, _) => { - val putKeys = mutable.getPutKeys() - if (!putKeys.isDefined) { + case LogicalRelation(mutable: BulkUpdateRelation, _, _) => { + val keys = mutable.getUpdateKeys() + if (!keys.isDefined) { throw new AnalysisException( s"DeleteFrom in a column table requires key column(s) but got empty string") } - val condition = prepareCondition(sparkSession, table, subQuery, putKeys.get) + val condition = prepareCondition(sparkSession, table, subQuery, keys.get) val exists = Join(subQuery, table, Inner, condition) transFormedPlan = Delete(table, exists, Seq.empty[Attribute]) } diff --git a/core/src/main/scala/org/apache/spark/sql/internal/SnappySessionState.scala b/core/src/main/scala/org/apache/spark/sql/internal/SnappySessionState.scala index 669bed7852..a04e53204f 100644 --- a/core/src/main/scala/org/apache/spark/sql/internal/SnappySessionState.scala +++ b/core/src/main/scala/org/apache/spark/sql/internal/SnappySessionState.scala @@ -237,7 +237,7 @@ class SnappySessionState(snappySession: SnappySession) } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i@PutIntoTable(u: UnresolvedRelation, _) => + case i@BulkUpdate(u: UnresolvedRelation, _, _) => i.copy(table = EliminateSubqueryAliases(getTable(u))) case d@DMLExternalTable(_, u: UnresolvedRelation, _) => d.copy(query = EliminateSubqueryAliases(getTable(u))) @@ -355,8 +355,8 @@ class SnappySessionState(snappySession: SnappySession) } case d@DeleteFromTable(_, child) if child.resolved => ColumnTableBulkOps.transformDeletePlan(sparkSession, d) - case p@PutIntoTable(_, child) if child.resolved => - ColumnTableBulkOps.transformPutPlan(sparkSession, p) + case p@BulkUpdate(_, child, _) if child.resolved => + ColumnTableBulkOps.transformUpdatePlan(sparkSession, p) } private def analyzeQuery(query: LogicalPlan): LogicalPlan = { @@ -839,7 +839,7 @@ private[sql] final class PreprocessTableInsertOrPut(conf: SQLConf) // Need to eliminate subqueries here. Unlike InsertIntoTable whose // subqueries have already been eliminated by special check in // ResolveRelations, no such special rule has been added for PUT - case p@PutIntoTable(table, child) if table.resolved && child.resolved => + case p@BulkUpdate(table, child, _) if table.resolved && child.resolved => EliminateSubqueryAliases(table) match { case l@LogicalRelation(ir: RowInsertableRelation, _, _) => // First, make sure the data to be inserted have the same number of @@ -993,12 +993,12 @@ private[sql] final class PreprocessTableInsertOrPut(conf: SQLConf) if (newChildOutput == child.output) { plan match { - case p: PutIntoTable => p.copy(table = newRelation).asInstanceOf[T] + case p: BulkUpdate => p.copy(table = newRelation).asInstanceOf[T] case d: DeleteFromTable => d.copy(table = newRelation).asInstanceOf[T] case _: InsertIntoTable => plan } } else plan match { - case p: PutIntoTable => p.copy(table = newRelation, + case p: BulkUpdate => p.copy(table = newRelation, child = Project(newChildOutput, child)).asInstanceOf[T] case d: DeleteFromTable => d.copy(table = newRelation, child = Project(newChildOutput, child)).asInstanceOf[T] @@ -1036,7 +1036,7 @@ private[sql] case object PrePutCheck extends (LogicalPlan => Unit) { def apply(plan: LogicalPlan): Unit = { plan.foreach { - case PutIntoTable(LogicalRelation(t: RowPutRelation, _, _), query) => + case BulkUpdate(LogicalRelation(t: RowPutRelation, _, _), query, _) => // Get all input data source relations of the query. val srcRelations = query.collect { case LogicalRelation(src: BaseRelation, _, _) => src @@ -1047,7 +1047,7 @@ private[sql] case object PrePutCheck extends (LogicalPlan => Unit) { } else { // OK } - case PutIntoTable(table, _) => + case BulkUpdate(table, _, _) => throw Utils.analysisException(s"$table does not allow puts.") case _ => // OK } diff --git a/core/src/main/scala/org/apache/spark/sql/sources/StoreStrategy.scala b/core/src/main/scala/org/apache/spark/sql/sources/StoreStrategy.scala index 249c857f70..fc4a0cee37 100644 --- a/core/src/main/scala/org/apache/spark/sql/sources/StoreStrategy.scala +++ b/core/src/main/scala/org/apache/spark/sql/sources/StoreStrategy.scala @@ -106,11 +106,13 @@ object StoreStrategy extends Strategy { case d@DMLExternalTable(_, storeRelation: LogicalRelation, insertCommand) => ExecutedCommandExec(ExternalTableDMLCmd(storeRelation, insertCommand, d.output)) :: Nil - case PutIntoTable(l@LogicalRelation(p: RowPutRelation, _, _), query) => + // Handle only putInto here. + // For row tables bulk update will be handled bu Update node + case BulkUpdate(l@LogicalRelation(p: RowPutRelation, _, _), query, isPutInto) if isPutInto => ExecutePlan(p.getPutPlan(l, planLater(query))) :: Nil - case PutIntoColumnTable(l@LogicalRelation(p: BulkPutRelation, _, _), left, right) => - ExecutePlan(p.getPutPlan(planLater(left), planLater(right))) :: Nil + case PutIntoColumnTable(l@LogicalRelation(p: BulkUpdateRelation, _, _), left, right) => + ExecutePlan(p.getUpdatePlan(planLater(left), planLater(right))) :: Nil case Update(l@LogicalRelation(u: MutableRelation, _, _), child, keyColumns, updateColumns, updateExpressions) => @@ -147,7 +149,8 @@ case class ExternalTableDMLCmd( override lazy val output: Seq[Attribute] = childOutput } -case class PutIntoTable(table: LogicalPlan, child: LogicalPlan) +case class BulkUpdate(table: LogicalPlan, child: LogicalPlan, + isPutInto: Boolean = false) extends LogicalPlan with TableMutationPlan { override def children: Seq[LogicalPlan] = table :: child :: Nil diff --git a/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index acb710eb4d..1149f38580 100644 --- a/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -70,16 +70,16 @@ trait RowPutRelation extends DestroyRelation { def getPutPlan(relation: LogicalRelation, child: SparkPlan): SparkPlan } -trait BulkPutRelation extends DestroyRelation { +trait BulkUpdateRelation extends DestroyRelation { - def getPutKeys() : Option[Seq[String]] + def getUpdateKeys() : Option[Seq[String]] /** * Get a spark plan for puts. If the row is already present, it gets updated * otherwise it gets inserted into the table represented by this relation. * The result of SparkPlan execution should be a count of number of rows put. */ - def getPutPlan(insertPlan: SparkPlan, updatePlan: SparkPlan): SparkPlan + def getUpdatePlan(insertPlan: SparkPlan, updatePlan: SparkPlan): SparkPlan } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/sql/sources/jdbcExtensions.scala b/core/src/main/scala/org/apache/spark/sql/sources/jdbcExtensions.scala index efffc325db..09e97cef03 100644 --- a/core/src/main/scala/org/apache/spark/sql/sources/jdbcExtensions.scala +++ b/core/src/main/scala/org/apache/spark/sql/sources/jdbcExtensions.scala @@ -342,9 +342,9 @@ object JdbcExtendedUtils extends Logging { val ds = session.internalCreateDataFrame(session.sparkContext.parallelize( rows.map(encoder.toRow)), schema) val plan = if (putInto) { - PutIntoTable( + BulkUpdate( table = UnresolvedRelation(tableIdent), - child = ds.logicalPlan) + child = ds.logicalPlan, isPutInto = true) } else { new Insert( table = UnresolvedRelation(tableIdent),