Skip to content

Commit

Permalink
feat(java): support add columns via sql expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Dec 24, 2024
1 parent bccf35b commit c331840
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 54 deletions.
46 changes: 14 additions & 32 deletions java/core/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,26 +818,18 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeAddColumnsBySqlExpre
mut env: JNIEnv,
java_dataset: JObject,
sql_expressions: JObject, // SqlExpressions
read_columns: JObject, // Optional<List<String>>
batch_size: JObject, // Optional<Long>
) {
ok_or_throw_without_return!(
env,
inner_add_columns_by_sql_expressions(
&mut env,
java_dataset,
sql_expressions,
read_columns,
batch_size
)
inner_add_columns_by_sql_expressions(&mut env, java_dataset, sql_expressions, batch_size)
)
}

fn inner_add_columns_by_sql_expressions(
env: &mut JNIEnv,
java_dataset: JObject,
sql_expressions: JObject, // SqlExpressions
read_columns: JObject, // Optional<List<String>>
batch_size: JObject, // Optional<Long>
) -> Result<()> {
let sql_expressions_obj = env
Expand All @@ -847,32 +839,22 @@ fn inner_add_columns_by_sql_expressions(
let sql_expressions_obj_list = env.get_list(&sql_expressions_obj)?;
let mut expressions: Vec<(String, String)> = Vec::new();

for i in 0..sql_expressions_obj_list.size(env)? {
if let Ok(Some(item)) = sql_expressions_obj_list.get(env, i) {
let name = env
.call_method(&item, "getName", "()Ljava/lang/String;", &[])?
.l()?;
let value = env
.call_method(&item, "getExpression", "()Ljava/lang/String;", &[])?
.l()?;
let key_str: String = env.get_string(&JString::from(name))?.into();
let value_str: String = env.get_string(&JString::from(value))?.into();
expressions.push((key_str, value_str));
}
let mut iterator = sql_expressions_obj_list.iter(env)?;

while let Some(item) = iterator.next(env)? {
let name = env
.call_method(&item, "getName", "()Ljava/lang/String;", &[])?
.l()?;
let value = env
.call_method(&item, "getExpression", "()Ljava/lang/String;", &[])?
.l()?;
let key_str: String = env.get_string(&JString::from(name))?.into();
let value_str: String = env.get_string(&JString::from(value))?.into();
expressions.push((key_str, value_str));
}

let rust_transform = NewColumnTransform::SqlExpressions(expressions);

let read_cols = if env
.call_method(&read_columns, "isPresent", "()Z", &[])?
.z()?
{
let columns: Vec<String> = env.get_strings(&read_columns)?;
Some(columns.iter().map(|s| s.to_string()).collect())
} else {
None
};

let batch_size = if env.call_method(&batch_size, "isPresent", "()Z", &[])?.z()? {
let batch_size_value = env.get_long_opt(&batch_size)?;
match batch_size_value {
Expand All @@ -893,7 +875,7 @@ fn inner_add_columns_by_sql_expressions(
RT.block_on(
dataset_guard
.inner
.add_columns(rust_transform, read_cols, batch_size),
.add_columns(rust_transform, None, batch_size),
)?;
Ok(())
}
14 changes: 10 additions & 4 deletions java/core/src/main/java/com/lancedb/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,22 @@ public static native Dataset commitAppend(
*/
public static native void drop(String path, Map<String, String> storageOptions);

public void addColumns(
SqlExpressions sqlExpressions, Optional<List<String>> readColumns, Optional<Long> batchSize) {
/**
* Add columns to the dataset.
*
* @param sqlExpressions The SQL expressions to add columns
* @param batchSize The number of rows to read at a time from the source dataset when applying the
* transform.
*/
public void addColumns(SqlExpressions sqlExpressions, Optional<Long> batchSize) {
try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) {
Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
nativeAddColumnsBySqlExpressions(sqlExpressions, readColumns, batchSize);
nativeAddColumnsBySqlExpressions(sqlExpressions, batchSize);
}
}

private native void nativeAddColumnsBySqlExpressions(
SqlExpressions sqlExpressions, Optional<List<String>> readColumns, Optional<Long> batchSize);
SqlExpressions sqlExpressions, Optional<Long> batchSize);

/**
* Drop columns from the dataset.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
package com.lancedb.lance.schema;

import java.util.ArrayList;
import java.util.List;

public class SqlExpressions {

private List<SqlExpression> sqlExpressions;
private final List<SqlExpression> sqlExpressions;

public SqlExpressions() {}
private SqlExpressions(List<SqlExpression> sqlExpressions) {
this.sqlExpressions = sqlExpressions;
}

public List<SqlExpression> getSqlExpressions() {
return sqlExpressions;
}

public void setSqlExpressions(List<SqlExpression> sqlExpressions) {
this.sqlExpressions = sqlExpressions;
}

public static class SqlExpression {

private String name;
Expand All @@ -39,4 +38,25 @@ public void setExpression(String expression) {
this.expression = expression;
}
}

public static class Builder {

private final SqlExpressions sqlExpressions;

public Builder() {
this.sqlExpressions = new SqlExpressions(new ArrayList<>());
}

public Builder withExpression(String name, String expr) {
SqlExpression expression = new SqlExpression();
expression.setName(name);
expression.setExpression(expr);
this.sqlExpressions.getSqlExpressions().add(expression);
return this;
}

public SqlExpressions build() {
return this.sqlExpressions;
}
}
}
17 changes: 5 additions & 12 deletions java/core/src/test/java/com/lancedb/lance/DatasetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,9 @@ void testAddColumnBySqlExpressions() {
new TestUtils.SimpleTestDataset(allocator, datasetPath);
dataset = testDataset.createEmptyDataset();

SqlExpressions sqlExpressions = new SqlExpressions();
SqlExpressions.SqlExpression sqlExpression = new SqlExpressions.SqlExpression();
sqlExpression.setName("double_id");
sqlExpression.setExpression("id * 2");
sqlExpressions.setSqlExpressions(Collections.singletonList(sqlExpression));
dataset.addColumns(sqlExpressions, Optional.empty(), Optional.empty());
SqlExpressions sqlExpressions =
new SqlExpressions.Builder().withExpression("double_id", "id * 2").build();
dataset.addColumns(sqlExpressions, Optional.empty());

Schema changedSchema =
new Schema(
Expand All @@ -329,12 +326,8 @@ void testAddColumnBySqlExpressions() {
.map(Field::getName)
.collect(Collectors.toList()));

sqlExpressions = new SqlExpressions();
sqlExpression = new SqlExpressions.SqlExpression();
sqlExpression.setName("triple_id");
sqlExpression.setExpression("id * 3");
sqlExpressions.setSqlExpressions(Collections.singletonList(sqlExpression));
dataset.addColumns(sqlExpressions, Optional.empty(), Optional.empty());
sqlExpressions = new SqlExpressions.Builder().withExpression("triple_id", "id * 3").build();
dataset.addColumns(sqlExpressions, Optional.empty());
changedSchema =
new Schema(
Arrays.asList(
Expand Down

0 comments on commit c331840

Please sign in to comment.