Skip to content

Commit

Permalink
feat(java): support add columns via sql expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Dec 24, 2024
1 parent d06488e commit bccf35b
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 1 deletion.
87 changes: 86 additions & 1 deletion java/core/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use jni::sys::{jboolean, jint};
use jni::{objects::JObject, JNIEnv};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::transaction::Operation;
use lance::dataset::{ColumnAlteration, Dataset, ReadParams, WriteParams};
use lance::dataset::{ColumnAlteration, Dataset, NewColumnTransform, ReadParams, WriteParams};
use lance::io::{ObjectStore, ObjectStoreParams};
use lance::table::format::Fragment;
use lance::table::format::Index;
Expand Down Expand Up @@ -812,3 +812,88 @@ fn inner_alter_columns(
RT.block_on(dataset_guard.inner.alter_columns(&column_alterations))?;
Ok(())
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeAddColumnsBySqlExpressions(
mut env: JNIEnv,
java_dataset: JObject,
sql_expressions: JObject, // SqlExpressions
read_columns: JObject, // Optional<List<String>>
batch_size: JObject, // Optional<Long>
) {
ok_or_throw_without_return!(
env,
inner_add_columns_by_sql_expressions(
&mut env,
java_dataset,
sql_expressions,
read_columns,
batch_size
)
)
}

fn inner_add_columns_by_sql_expressions(
env: &mut JNIEnv,
java_dataset: JObject,
sql_expressions: JObject, // SqlExpressions
read_columns: JObject, // Optional<List<String>>
batch_size: JObject, // Optional<Long>
) -> Result<()> {
let sql_expressions_obj = env
.get_field(sql_expressions, "sqlExpressions", "Ljava/util/List;")?
.l()?;

let sql_expressions_obj_list = env.get_list(&sql_expressions_obj)?;
let mut expressions: Vec<(String, String)> = Vec::new();

for i in 0..sql_expressions_obj_list.size(env)? {
if let Ok(Some(item)) = sql_expressions_obj_list.get(env, i) {
let name = env
.call_method(&item, "getName", "()Ljava/lang/String;", &[])?
.l()?;
let value = env
.call_method(&item, "getExpression", "()Ljava/lang/String;", &[])?
.l()?;
let key_str: String = env.get_string(&JString::from(name))?.into();
let value_str: String = env.get_string(&JString::from(value))?.into();
expressions.push((key_str, value_str));
}
}

let rust_transform = NewColumnTransform::SqlExpressions(expressions);

let read_cols = if env
.call_method(&read_columns, "isPresent", "()Z", &[])?
.z()?
{
let columns: Vec<String> = env.get_strings(&read_columns)?;
Some(columns.iter().map(|s| s.to_string()).collect())
} else {
None
};

let batch_size = if env.call_method(&batch_size, "isPresent", "()Z", &[])?.z()? {
let batch_size_value = env.get_long_opt(&batch_size)?;
match batch_size_value {
Some(value) => Some(
value
.try_into()
.map_err(|_| Error::input_error("Batch size conversion error".to_string()))?,
),
None => None,
}
} else {
None
};

let mut dataset_guard =
unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;

RT.block_on(
dataset_guard
.inner
.add_columns(rust_transform, read_cols, batch_size),
)?;
Ok(())
}
12 changes: 12 additions & 0 deletions java/core/src/main/java/com/lancedb/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.lancedb.lance.ipc.LanceScanner;
import com.lancedb.lance.ipc.ScanOptions;
import com.lancedb.lance.schema.ColumnAlteration;
import com.lancedb.lance.schema.SqlExpressions;

import org.apache.arrow.c.ArrowArrayStream;
import org.apache.arrow.c.ArrowSchema;
Expand Down Expand Up @@ -254,6 +255,17 @@ public static native Dataset commitAppend(
*/
public static native void drop(String path, Map<String, String> storageOptions);

public void addColumns(
SqlExpressions sqlExpressions, Optional<List<String>> readColumns, Optional<Long> batchSize) {
try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) {
Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
nativeAddColumnsBySqlExpressions(sqlExpressions, readColumns, batchSize);
}
}

private native void nativeAddColumnsBySqlExpressions(
SqlExpressions sqlExpressions, Optional<List<String>> readColumns, Optional<Long> batchSize);

/**
* Drop columns from the dataset.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.lancedb.lance.schema;

import java.util.List;

public class SqlExpressions {

private List<SqlExpression> sqlExpressions;

public SqlExpressions() {}

public List<SqlExpression> getSqlExpressions() {
return sqlExpressions;
}

public void setSqlExpressions(List<SqlExpression> sqlExpressions) {
this.sqlExpressions = sqlExpressions;
}

public static class SqlExpression {

private String name;
private String expression;

public SqlExpression() {}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getExpression() {
return expression;
}

public void setExpression(String expression) {
this.expression = expression;
}
}
}
56 changes: 56 additions & 0 deletions java/core/src/test/java/com/lancedb/lance/DatasetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package com.lancedb.lance;

import com.lancedb.lance.schema.ColumnAlteration;
import com.lancedb.lance.schema.SqlExpressions;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
Expand All @@ -29,6 +30,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -296,6 +298,60 @@ void testAlterColumns() {
}
}

@Test
void testAddColumnBySqlExpressions() {
String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
String datasetPath = tempDir.resolve(testMethodName).toString();
try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
TestUtils.SimpleTestDataset testDataset =
new TestUtils.SimpleTestDataset(allocator, datasetPath);
dataset = testDataset.createEmptyDataset();

SqlExpressions sqlExpressions = new SqlExpressions();
SqlExpressions.SqlExpression sqlExpression = new SqlExpressions.SqlExpression();
sqlExpression.setName("double_id");
sqlExpression.setExpression("id * 2");
sqlExpressions.setSqlExpressions(Collections.singletonList(sqlExpression));
dataset.addColumns(sqlExpressions, Optional.empty(), Optional.empty());

Schema changedSchema =
new Schema(
Arrays.asList(
Field.nullable("id", new ArrowType.Int(32, true)),
Field.nullable("name", new ArrowType.Utf8()),
Field.nullable("double_id", new ArrowType.Int(32, true))),
null);

assertEquals(changedSchema.getFields().size(), dataset.getSchema().getFields().size());
assertEquals(
changedSchema.getFields().stream().map(Field::getName).collect(Collectors.toList()),
dataset.getSchema().getFields().stream()
.map(Field::getName)
.collect(Collectors.toList()));

sqlExpressions = new SqlExpressions();
sqlExpression = new SqlExpressions.SqlExpression();
sqlExpression.setName("triple_id");
sqlExpression.setExpression("id * 3");
sqlExpressions.setSqlExpressions(Collections.singletonList(sqlExpression));
dataset.addColumns(sqlExpressions, Optional.empty(), Optional.empty());
changedSchema =
new Schema(
Arrays.asList(
Field.nullable("id", new ArrowType.Int(32, true)),
Field.nullable("name", new ArrowType.Utf8()),
Field.nullable("double_id", new ArrowType.Int(32, true)),
Field.nullable("triple_id", new ArrowType.Int(32, true))),
null);
assertEquals(changedSchema.getFields().size(), dataset.getSchema().getFields().size());
assertEquals(
changedSchema.getFields().stream().map(Field::getName).collect(Collectors.toList()),
dataset.getSchema().getFields().stream()
.map(Field::getName)
.collect(Collectors.toList()));
}
}

@Test
void testDropPath() {
String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
Expand Down

0 comments on commit bccf35b

Please sign in to comment.