Skip to content

Commit

Permalink
test: add test cases for template param (#38867)
Browse files Browse the repository at this point in the history
issue: #33419
- add case for Get
- add case for TemplateParam

Signed-off-by: ThreadDao <[email protected]>
  • Loading branch information
ThreadDao authored Jan 10, 2025
1 parent 22b8f6a commit 461c376
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 8 deletions.
6 changes: 6 additions & 0 deletions tests/go_client/base/milvus_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,9 @@ func (mc *MilvusClient) Query(ctx context.Context, option client.QueryOption, ca
resultSet, err := mc.mClient.Query(ctx, option, callOptions...)
return resultSet, err
}

// Get get from collection
func (mc *MilvusClient) Get(ctx context.Context, option client.QueryOption, callOptions ...grpc.CallOption) (client.ResultSet, error) {
resultSet, err := mc.mClient.Get(ctx, option, callOptions...)
return resultSet, err
}
2 changes: 1 addition & 1 deletion tests/go_client/testcases/helper/data_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ func GenColumnData(nb int, fieldType entity.FieldType, option GenDataOption) col
case entity.FieldTypeBool:
boolValues := make([]bool, 0, nb)
for i := start; i < start+nb; i++ {
boolValues = append(boolValues, i/2 == 0)
boolValues = append(boolValues, i%2 == 0)
}
return column.NewColumnBool(fieldName, boolValues)

Expand Down
6 changes: 3 additions & 3 deletions tests/go_client/testcases/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

func TestIndexVectorDefault(t *testing.T) {
t.Parallel()
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*2)
mc := createDefaultMilvusClient(ctx, t)

cp := hp.NewCreateCollectionParams(hp.Int64MultiVec)
Expand Down Expand Up @@ -51,7 +51,7 @@ func TestIndexVectorDefault(t *testing.T) {

func TestIndexVectorIP(t *testing.T) {
t.Parallel()
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*2)
mc := createDefaultMilvusClient(ctx, t)

cp := hp.NewCreateCollectionParams(hp.Int64MultiVec)
Expand Down Expand Up @@ -85,7 +85,7 @@ func TestIndexVectorIP(t *testing.T) {

func TestIndexVectorCosine(t *testing.T) {
t.Parallel()
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*2)
mc := createDefaultMilvusClient(ctx, t)

cp := hp.NewCreateCollectionParams(hp.Int64MultiVec)
Expand Down
181 changes: 181 additions & 0 deletions tests/go_client/testcases/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ func TestQueryDefault(t *testing.T) {
queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr))
common.CheckErr(t, err, true)
common.CheckQueryResult(t, queryRes.Fields, []column.Column{insertRes.IDs.Slice(0, 100)})

// query with limit
LimitRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithLimit(10))
common.CheckErr(t, err, true)
require.Equal(t, 10, LimitRes.ResultCount)
require.Equal(t, 10, LimitRes.GetColumn(common.DefaultInt64FieldName).Len())

// get ids -> same result with query
ids := hp.GenColumnData(100, entity.FieldTypeInt64, *hp.TNewDataOption().TWithFieldName(common.DefaultInt64FieldName))
getRes, errGet := mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids))
common.CheckErr(t, errGet, true)
common.CheckQueryResult(t, getRes.Fields, []column.Column{insertRes.IDs.Slice(0, 100)})
}

// test query with varchar field filter
Expand All @@ -55,6 +67,46 @@ func TestQueryVarcharPkDefault(t *testing.T) {
queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr))
common.CheckErr(t, err, true)
common.CheckQueryResult(t, queryRes.Fields, []column.Column{insertRes.IDs.Slice(0, 5)})

// get ids -> same result with query
varcharValues := []string{"0", "1", "2", "3", "4"}
ids := column.NewColumnVarChar(common.DefaultVarcharFieldName, varcharValues)
getRes, errGet := mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids))
common.CheckErr(t, errGet, true)
common.CheckQueryResult(t, getRes.Fields, []column.Column{insertRes.IDs.Slice(0, 5)})
}

// test get with invalid ids
func TestGetInvalid(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := createDefaultMilvusClient(ctx, t)

// create and insert
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption())
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption())

// flush -> index -> load
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))

// get ids with varchar ids -> error
varcharValues := []string{"0", "1", "2", "3", "4"}
ids := column.NewColumnVarChar(common.DefaultVarcharFieldName, varcharValues)
_, errGet := mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids))
common.CheckErr(t, errGet, false, "field varchar not exist: invalid parameter")

// get ids with varchar ids -> error
ids = column.NewColumnVarChar(common.DefaultInt64FieldName, varcharValues)
_, errGet = mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids))
common.CheckErr(t, errGet, false, "cannot parse expression: int64 in")

// get ids with non-pk column -> error for empty filter
t.Log("https://github.com/milvus-io/milvus/issues/38859")
values := []float32{0.0, 1.0}
ids2 := column.NewColumnFloat(common.DefaultInt64FieldName, values)
_, errGet = mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids2))
common.CheckErr(t, errGet, false, "empty expression should be used with limit")
}

// query from not existed collection name and partition name
Expand Down Expand Up @@ -635,3 +687,132 @@ func TestQueryOutputInvalidOutputFieldCount(t *testing.T) {
common.CheckErr(t, err, false, invalidCount.errMsg)
}
}

func TestQueryWithTemplateParam(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := createDefaultMilvusClient(ctx, t)

prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.AllFields),
hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption())
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))

// query
int64Values := make([]int64, 0, 1000)
for i := 10; i < 10+1000; i++ {
int64Values = append(int64Values, int64(i))
}
// default
queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s in {int64Values}", common.DefaultInt64FieldName)).WithTemplateParam("int64Values", int64Values))
common.CheckErr(t, err, true)
common.CheckQueryResult(t, queryRes.Fields, []column.Column{column.NewColumnInt64(common.DefaultInt64FieldName, int64Values)})

// cover keys
res, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 < {k2}").WithTemplateParam("k2", 10).WithTemplateParam("k2", 5))
common.CheckErr(t, err, true)
require.Equal(t, 5, res.ResultCount)

// array contains
anyValues := []int64{0.0, 100.0, 10000.0}
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("json_contains_any (%s, {any_values})", common.DefaultFloatArrayField)).WithTemplateParam("any_values", anyValues).
WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 101, count)

// dynamic
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter("dynamicNumber % 2 == {v}").WithTemplateParam("v", 0).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 1500, count)

// json['bool']
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s['bool'] == {v}", common.DefaultJSONFieldName)).
WithTemplateParam("v", false).
WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 1500/2, count)

// bool
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s == {v}", common.DefaultBoolFieldName)).
WithTemplateParam("v", true).
WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, common.DefaultNb/2, count)

// and {expr: fmt.Sprintf("%s >= 1000 && %s < 2000", common.DefaultInt64FieldName, common.DefaultInt64FieldName), count: 1000},
res, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s >= {k1} && %s < {k2}", common.DefaultInt64FieldName, common.DefaultInt64FieldName)).
WithTemplateParam("v", 0).WithTemplateParam("k1", 1000).
WithTemplateParam("k2", 2000))
common.CheckErr(t, err, true)
require.EqualValues(t, 1000, res.ResultCount)
}

func TestQueryWithTemplateParamInvalid(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := createDefaultMilvusClient(ctx, t)

prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec),
hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption())
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))

// query with invalid template
// expr := "varchar like 'a%' "
_, err2 := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("varchar like {key1}").WithTemplateParam("key1", "'a%'"))
common.CheckErr(t, err2, false, "mismatched input '{' expecting StringLiteral")

// no template param
_, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 in {key1}"))
common.CheckErr(t, err, false, "the value of expression template variable name {key1} is not found")

// template param with empty expr
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("").WithTemplateParam("a", 12))
common.CheckErr(t, err, false, "empty expression should be used with limit")

// *** template param with field name key -> error ***
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("{field} < 10").WithTemplateParam("field", "int64"))
common.CheckErr(t, err, false, "cannot parse expression")
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("{field} < {v}").WithTemplateParam("field", "int64").WithTemplateParam("v", 10))
common.CheckErr(t, err, false, "placeholder was not supported between two constants with operator")
// exists x
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("exists {x}").WithTemplateParam("x", "json"))
common.CheckErr(t, err, false, "exists operations are only supported on single fields now")

// compare two fields
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("{f1} > {f2}").WithTemplateParam("f1", "f1").WithTemplateParam("f2", "f2"))
common.CheckErr(t, err, false, "placeholder was not supported between two constants with operator")

// expr key != template key
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 in {key1}").WithTemplateParam("key2", []int64{0, 1, 2}))
common.CheckErr(t, err, false, "the value of expression template variable name {key1} is not found")

// template missing some keys
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("{k1} < int64 < {k2}").WithTemplateParam("k1", 10))
common.CheckErr(t, err, false, "the upper value of expression template variable name {k2} is not found")

// template value type is valid
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 < {k1}").WithTemplateParam("k1", []int64{0, 1, 3}))
common.CheckErr(t, err, false, "cannot cast value to Int64")

_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 < {k1}").WithTemplateParam("k1", "10"))
common.CheckErr(t, err, false, "cannot cast value to Int64")

// invalid expr
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("{name} == 'O'Reilly'").WithTemplateParam("name", common.DefaultVarcharFieldName))
common.CheckErr(t, err, false, "cannot parse expression")

// invalid expr
_, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("{_123} > 10").WithTemplateParam("_123", common.DefaultInt64FieldName))
common.CheckErr(t, err, false, "cannot parse expression")
}
19 changes: 15 additions & 4 deletions tests/go_client/testcases/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,17 +610,28 @@ func TestSearchExpr(t *testing.T) {
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))

type mExprExpected struct {
expr string
ids []int64
expr string
ids []int64
value any
}

vectors := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
for _, _mExpr := range []mExprExpected{
{expr: fmt.Sprintf("%s < 10", common.DefaultInt64FieldName), ids: []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}},
{expr: fmt.Sprintf("%s in [10, 100]", common.DefaultInt64FieldName), ids: []int64{10, 100}},
} {
resSearch, errSearch := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, vectors).WithConsistencyLevel(entity.ClStrong).
WithFilter(_mExpr.expr))
resSearch, errSearch := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, vectors).WithConsistencyLevel(entity.ClStrong).WithFilter(_mExpr.expr))
common.CheckErr(t, errSearch, true)
for _, res := range resSearch {
require.ElementsMatch(t, _mExpr.ids, res.IDs.(*column.ColumnInt64).Data())
}
}
// search with template param
for _, _mExpr := range []mExprExpected{
{expr: fmt.Sprintf("%s < {v}", common.DefaultInt64FieldName), ids: []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, value: 10},
{expr: fmt.Sprintf("%s in {v}", common.DefaultInt64FieldName), ids: []int64{10, 100}, value: []int64{10, 100}},
} {
resSearch, errSearch := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, vectors).WithFilter(_mExpr.expr).WithTemplateParam("v", _mExpr.value))
common.CheckErr(t, errSearch, true)
for _, res := range resSearch {
require.ElementsMatch(t, _mExpr.ids, res.IDs.(*column.ColumnInt64).Data())
Expand Down

0 comments on commit 461c376

Please sign in to comment.