Skip to content

Commit

Permalink
[core] fix FieldCountAgg should return the same data type as fieldTyp…
Browse files Browse the repository at this point in the history
…e. (#3924)
  • Loading branch information
LinMingQiang authored Aug 9, 2024
1 parent a1c7589 commit 3e61612
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,63 +36,65 @@ String name() {

@Override
public Object agg(Object accumulator, Object inputField) {
Object count;
if (accumulator == null || inputField == null) {
count = (accumulator == null ? (inputField == null ? 0 : 1) : accumulator);
} else {
// ordered by type root definition
switch (fieldType.getTypeRoot()) {
case TINYINT:
count = (byte) ((byte) accumulator + 1);
break;
case SMALLINT:
count = (short) ((short) accumulator + 1);
break;
case INTEGER:
count = (int) accumulator + 1;
break;
case BIGINT:
count = (long) accumulator + 1L;
break;
default:
String msg =
String.format(
"type %s not support in %s",
fieldType.getTypeRoot().toString(), this.getClass().getName());
throw new IllegalArgumentException(msg);
}

if (accumulator != null && inputField == null) {
return accumulator;
}
// ordered by type root definition
switch (fieldType.getTypeRoot()) {
case TINYINT:
return accumulator == null
? (inputField == null ? (byte) 0 : (byte) 1)
: (byte) ((byte) accumulator + 1);
case SMALLINT:
return accumulator == null
? (inputField == null ? (short) 0 : (short) 1)
: (short) ((short) accumulator + 1);
case INTEGER:
return accumulator == null ? (inputField == null ? 0 : 1) : (int) accumulator + 1;
case BIGINT:
return accumulator == null
? (inputField == null ? 0L : 1L)
: (long) accumulator + 1L;
default:
String msg =
String.format(
"type %s not support in %s",
fieldType.getTypeRoot().toString(), this.getClass().getName());
throw new IllegalArgumentException(msg);
}
return count;
}

@Override
public Object retract(Object accumulator, Object inputField) {
Object count;
if (accumulator == null || inputField == null) {
count = (accumulator == null ? (inputField == null ? 0 : -1) : accumulator);
} else {
// ordered by type root definition
switch (fieldType.getTypeRoot()) {
case TINYINT:
count = (byte) ((byte) accumulator - 1);
break;
case SMALLINT:
count = (short) ((short) accumulator - 1);
break;
case INTEGER:
count = (int) accumulator - 1;
break;
case BIGINT:
count = (long) accumulator - 1L;
break;
default:
String msg =
String.format(
"type %s not support in %s",
fieldType.getTypeRoot().toString(), this.getClass().getName());
throw new IllegalArgumentException(msg);
}

if (accumulator != null && inputField == null) {
return accumulator;
}

// ordered by type root definition
switch (fieldType.getTypeRoot()) {
case TINYINT:
return accumulator == null
? (inputField == null ? (byte) 0 : (byte) -1)
: (byte) ((byte) accumulator - 1);
case SMALLINT:
return accumulator == null
? (inputField == null ? (short) 0 : (short) -1)
: (short) ((short) accumulator - 1);
case INTEGER:
return accumulator == null ? (inputField == null ? 0 : -1) : (int) accumulator - 1;

case BIGINT:
return accumulator == null
? (inputField == null ? 0L : -1L)
: (long) accumulator - 1L;
default:
String msg =
String.format(
"type %s not support in %s",
fieldType.getTypeRoot().toString(), this.getClass().getName());
throw new IllegalArgumentException(msg);
}
return count;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,23 @@ public void testFieldCountIntAgg() {
assertThat(fieldCountAggInt.agg(3, 6)).isEqualTo(4);

FieldCountAgg fieldCountAggLong = new FieldCountAgg(new BigIntType());
assertThat(fieldCountAggLong.agg(null, null)).isEqualTo(0);
assertThat(fieldCountAggLong.agg(null, null)).isEqualTo(0L);
assertThat(fieldCountAggLong.agg((long) 1, null)).isEqualTo((long) 1);
assertThat(fieldCountAggLong.agg(null, (long) 15)).isEqualTo(1);
assertThat(fieldCountAggLong.agg(null, (long) 15)).isEqualTo(1L);
assertThat(fieldCountAggLong.agg((long) 1, 0)).isEqualTo((long) 2);
assertThat(fieldCountAggLong.agg((long) 3, (long) 6)).isEqualTo((long) 4);

FieldCountAgg fieldCountAggByte = new FieldCountAgg(new TinyIntType());
assertThat(fieldCountAggByte.agg(null, null)).isEqualTo(0);
assertThat(fieldCountAggByte.agg(null, null)).isEqualTo((byte) 0);
assertThat(fieldCountAggByte.agg((byte) 1, null)).isEqualTo((byte) 1);
assertThat(fieldCountAggByte.agg(null, (byte) 15)).isEqualTo(1);
assertThat(fieldCountAggByte.agg(null, (byte) 15)).isEqualTo((byte) 1);
assertThat(fieldCountAggByte.agg((byte) 1, 0)).isEqualTo((byte) 2);
assertThat(fieldCountAggByte.agg((byte) 3, (byte) 6)).isEqualTo((byte) 4);

FieldCountAgg fieldCountAggShort = new FieldCountAgg(new SmallIntType());
assertThat(fieldCountAggShort.agg(null, null)).isEqualTo(0);
assertThat(fieldCountAggShort.agg(null, null)).isEqualTo((short) 0);
assertThat(fieldCountAggShort.agg((short) 1, null)).isEqualTo((short) 1);
assertThat(fieldCountAggShort.agg(null, (short) 15)).isEqualTo(1);
assertThat(fieldCountAggShort.agg(null, (short) 15)).isEqualTo((short) 1);
assertThat(fieldCountAggShort.agg((short) 1, 0)).isEqualTo((short) 2);
assertThat(fieldCountAggShort.agg((short) 3, (short) 6)).isEqualTo((short) 4);
}
Expand Down

0 comments on commit 3e61612

Please sign in to comment.