Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNAP-2158] DataFrame update API. #941

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,58 @@ class SnappyTableMutableAPISuite extends SnappyFunSuite with Logging with Before
assert(resultdf.contains(Row(8, 8, 8)))
}

test("update with dataframe API col tables") {
val snc = new SnappySession(sc)
val rdd = sc.parallelize(data2, 2).map(s => Data(s(0), s(1), s(2)))
val df1 = snc.createDataFrame(rdd)
val rdd2 = sc.parallelize(data1, 2).map(s => DataDiffCol(s(0), s(1), s(2)))
val df2 = snc.createDataFrame(rdd2)

snc.createTable("col_table", "column",
df1.schema, Map("key_columns" -> "col2"))

df1.write.insertInto("col_table")
df2.write.update("col_table")

val resultdf = snc.table("col_table").collect()
assert(resultdf.length == 7)
assert(resultdf.contains(Row(88, 88, 88)))
}

test("update with dataframe API row tables") {
val snc = new SnappySession(sc)
val rdd = sc.parallelize(data2, 2).map(s => Data(s(0), s(1), s(2)))
val df1 = snc.createDataFrame(rdd)
val rdd2 = sc.parallelize(data1, 2).map(s => DataDiffCol(s(0), s(1), s(2)))
val df2 = snc.createDataFrame(rdd2)

snc.sql("create table row_table (col1 int, col2 int, col3 int, PRIMARY KEY (col2))")

df1.write.insertInto("row_table")
df2.write.update("row_table")

val resultdf = snc.table("row_table").collect()
assert(resultdf.length == 7)
assert(resultdf.contains(Row(88, 88, 88)))
}

test("Update row tables Key columns validation") {
val snc = new SnappySession(sc)
val rdd = sc.parallelize(data2, 2).map(s => Data(s(0), s(1), s(2)))
val df1 = snc.createDataFrame(rdd)
val rdd2 = sc.parallelize(data1, 2).map(s => DataDiffCol(s(0), s(1), s(2)))
val df2 = snc.createDataFrame(rdd2)

snc.createTable("row_table", "row",
df1.schema, Map.empty[String, String])

df1.write.insertInto("row_table")

intercept[AnalysisException]{
df2.write.update("row_table")
}
}

test("DeleteFrom Key columns validation") {
val snc = new SnappySession(sc)
val rdd = sc.parallelize(data2, 2).map(s => Data(s(0), s(1), s(2)))
Expand Down
14 changes: 10 additions & 4 deletions core/src/main/scala/io/snappydata/Literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,16 @@ object Property extends Enumeration {
s". Default is true.", Some(true), null, false)

val PutIntoInnerJoinCacheSize =
SQLVal[Long](s"${Constant.PROPERTY_PREFIX}cache.putIntoInnerJoinResultSize",
"The putInto inner join would be cached if the table is of size less " +
"than PutIntoInnerJoinCacheSize. Default value is 100 MB.", Some(100L * 1024 * 1024))

SQLVal[String](s"${Constant.PROPERTY_PREFIX}cache.putIntoInnerJoinResultSize",
"The putInto inner join would be cached if the result of " +
"join with incoming Dataset is of size less " +
"than PutIntoInnerJoinCacheSize. Default value is 100 MB.", Some("100m"))

val ForceCachePutIntoInnerJoin =
SQLVal[Boolean](s"${Constant.PROPERTY_PREFIX}cache.putIntoInnerJoinResult",
"if this property is set, The putInto inner join would be cached irrespective of size." +
"Primarily used in Streaming sources where correct stats can not be calculated." +
"", Some(false))

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,8 @@ class DataFrameWriterJavaFunctions(val dfWriter: DataFrameWriter[_]) {
def deleteFrom(tableName: String): Unit = {
new DataFrameWriterExtensions(dfWriter).deleteFrom(tableName)
}

def update(tableName: String): Unit = {
new DataFrameWriterExtensions(dfWriter).update(tableName)
}
}
45 changes: 35 additions & 10 deletions core/src/main/scala/org/apache/spark/sql/SnappyImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, SubqueryAlias}
import org.apache.spark.sql.internal.ColumnTableBulkOps
import org.apache.spark.sql.sources.{DeleteFromTable, PutIntoTable}
import org.apache.spark.sql.sources.{DeleteFromTable, BulkUpdate}
import org.apache.spark.{Partition, TaskContext}

/**
Expand Down Expand Up @@ -165,12 +165,12 @@ object snappy extends Serializable {
extends Serializable {

/**
* "Puts" the content of the [[DataFrame]] to the specified table. It
* requires that the schema of the [[DataFrame]] is the same as the schema
* of the table. If some rows are already present then they are updated.
*
* This ignores all SaveMode.
*/
* "Puts" the content of the [[DataFrame]] to the specified table. It
* requires that the schema of the [[DataFrame]] is the same as the schema
* of the table. If some rows are already present then they are updated.
*
* This ignores all SaveMode.
*/
def putInto(tableName: String): Unit = {
val df: DataFrame = dfField.get(writer).asInstanceOf[DataFrame]
val session = df.sparkSession match {
Expand All @@ -191,13 +191,13 @@ object snappy extends Serializable {
}.getOrElse(df.logicalPlan)

try {
df.sparkSession.sessionState.executePlan(PutIntoTable(UnresolvedRelation(
session.sessionState.catalog.newQualifiedTableName(tableName)), input))
df.sparkSession.sessionState.executePlan(BulkUpdate(UnresolvedRelation(
session.sessionState.catalog.newQualifiedTableName(tableName)), input, isPutInto = true))
.executedPlan.executeCollect()
} finally {
df.sparkSession.asInstanceOf[SnappySession].
getContextObject[LogicalPlan](ColumnTableBulkOps.CACHED_PUTINTO_UPDATE_PLAN).
map { cachedPlan =>
foreach { cachedPlan =>
df.sparkSession.
sharedState.cacheManager.uncacheQuery(df.sparkSession, cachedPlan, true)
}
Expand Down Expand Up @@ -228,7 +228,32 @@ object snappy extends Serializable {
.executedPlan.executeCollect()
}

def update(tableName: String): Unit = {
val df: DataFrame = dfField.get(writer).asInstanceOf[DataFrame]
val session = df.sparkSession match {
case sc: SnappySession => sc
case _ => sys.error("Expected a SnappyContext for putInto operation")
}
val normalizedParCols = parColsMethod.invoke(writer)
.asInstanceOf[Option[Seq[String]]]
// A partitioned relation's schema can be different from the input
// logicalPlan, since partition columns are all moved after data columns.
// We Project to adjust the ordering.
// TODO: this belongs to the analyzer.
val input = normalizedParCols.map { parCols =>
val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition {
attr => parCols.contains(attr.name)
}
Project(inputDataCols ++ inputPartCols, df.logicalPlan)
}.getOrElse(df.logicalPlan)

df.sparkSession.sessionState.executePlan(BulkUpdate(UnresolvedRelation(
session.sessionState.catalog.newQualifiedTableName(tableName)), input, isPutInto = false))
.executedPlan.executeCollect()

}
}

}

private[sql] case class SnappyDataFrameOperations(session: SnappySession,
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/sql/SnappyParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.collection.Utils
import org.apache.spark.sql.sources.{Delete, Insert, PutIntoTable, Update}
import org.apache.spark.sql.sources.{Delete, Insert, BulkUpdate, Update}
import org.apache.spark.sql.streaming.WindowLogicalPlan
import org.apache.spark.sql.types._
import org.apache.spark.sql.{SnappyParserConsts => Consts}
Expand Down Expand Up @@ -991,7 +991,8 @@ class SnappyParser(session: SnappySession) extends SnappyDDLParser(session) {
}

protected final def put: Rule1[LogicalPlan] = rule {
PUT ~ INTO ~ TABLE.? ~ relationFactor ~ subSelectQuery ~> PutIntoTable
PUT ~ INTO ~ TABLE.? ~ relationFactor ~ subSelectQuery ~>
((t: LogicalPlan, s: LogicalPlan) => BulkUpdate(t, s, true))
}

protected final def update: Rule1[LogicalPlan] = rule {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ class ColumnFormatRelation(
_externalStore,
_partitioningColumns,
_context)
with ParentRelation with DependentRelation with BulkPutRelation {
with ParentRelation with DependentRelation with BulkUpdateRelation {
val tableOptions = new CaseInsensitiveMutableHashMap(_origOptions)

override def withKeyColumns(relation: LogicalRelation,
Expand Down Expand Up @@ -670,11 +670,11 @@ class ColumnFormatRelation(
* otherwise it gets inserted into the table represented by this relation.
* The result of SparkPlan execution should be a count of number of rows put.
*/
override def getPutPlan(insertPlan: SparkPlan, updatePlan: SparkPlan): SparkPlan = {
override def getUpdatePlan(insertPlan: SparkPlan, updatePlan: SparkPlan): SparkPlan = {
ColumnPutIntoExec(insertPlan, updatePlan)
}

override def getPutKeys(): Option[Seq[String]] = {
override def getUpdateKeys(): Option[Seq[String]] = {
val keys = _origOptions.get(ExternalStoreUtils.KEY_COLUMNS)
keys match {
case Some(x) => Some(x.split(",").map(s => s.trim).toSeq)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright (c) 2017 SnappyData, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You
* may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License. See accompanying
* LICENSE file.
*/

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, DriverManager}

import org.apache.spark.Partition
import org.apache.spark.jdbc.{ConnectionConf, ConnectionConfBuilder, ConnectionUtil}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.JDBC_DRIVER_CLASS
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Dataset, Row, SnappySession}

/**
* This object is used to get an optimized JDBC RDD, which uses pooled connections.
*/
object JDBCRelationUtil {

def jdbcDF(sparkSession: SnappySession,
parameters: Map[String, String],
table: String,
schema: StructType,
requiredColumns: Array[String],
parts: Array[Partition],
conf : ConnectionConf,
filters: Array[Filter]): Dataset[Row] = {
val url = parameters("url")
val options = new JDBCOptions(url, table, parameters)

val dialect = JdbcDialects.get(url)
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
val rdd = new JDBCRDD(
sparkSession.sparkContext,
createConnectionFactory(url, conf),
schema,
quotedColumns,
filters,
parts,
url,
options).asInstanceOf[RDD[InternalRow]]

sparkSession.internalCreateDataFrame(rdd, schema)
}

def createConnectionFactory(url: String, conf: ConnectionConf): () => Connection = {
() => {
ConnectionUtil.getPooledConnection(url, conf)
}
}

def buildConf(snappySession: SnappySession,
parameters: Map[String, String]): ConnectionConf = {
val url = parameters("url")
val driverClass = {
val userSpecifiedDriverClass = parameters.get(JDBC_DRIVER_CLASS)
userSpecifiedDriverClass.foreach(DriverRegistry.register)

// Performing this part of the logic on the driver guards against the corner-case where the
// driver returned for a URL is different on the driver and executors due to classpath
// differences.
userSpecifiedDriverClass.getOrElse {
DriverManager.getDriver(url).getClass.getCanonicalName
}
}
new ConnectionConfBuilder(snappySession)
.setPoolProvider(parameters.getOrElse("poolImpl", "hikari"))
.setPoolConf("maximumPoolSize", parameters.getOrElse("maximumPoolSize", "10"))
.setPoolConf("minimumIdle", parameters.getOrElse("minimumIdle", "5"))
.setDriver(driverClass)
.setConf("user", parameters.getOrElse("user", ""))
.setConf("password", parameters.getOrElse("password", ""))
.setURL(url)
.build()
}

def schema(table: String, parameters: Map[String, String]): StructType = {
val url = parameters("url")
val options = new JDBCOptions(url, table, parameters)
JDBCRDD.resolveTable(options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,10 @@ case class HashJoinExec(leftKeys: Seq[Expression],
val produced = streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)

val beforeMap = ctx.freshName("beforeMap")
val skipScan = skipProcessForEmptyMap()
val skipCondition = if (skipScan) {
s"if($hashMapTerm.size() == 0) return;"
} else ""

s"""
boolean $keyIsUniqueTerm = true;
Expand All @@ -491,13 +495,21 @@ case class HashJoinExec(leftKeys: Seq[Expression],
final $entryClass[] $mapDataTerm = ($entryClass[])$hashMapTerm.data();
long $numRowsTerm = 0L;
try {
$skipCondition
${session.evaluateFinallyCode(ctx, produced)}
} finally {
$numOutputRows.${metricAdd(numRowsTerm)};
}
"""
}

private def skipProcessForEmptyMap() : Boolean = {
joinType match {
case Inner | LeftSemi => true
case _ => false
}
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode],
row: ExprCode): String = {
// variable that holds if relation is unique to optimize iteration
Expand Down
Loading