Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Append filterSql to source variant_to_person in cohort count #1130

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/generated/UNDERLAY_CONFIG.md
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ Pointer to SQL that returns entity id - rollup count (= number of related entity

True to skip copying the id-pairs SQL into a new index table, and use the source SQL directly.

Ignored if the [id pairs SQL](#szgroupitemsidpairssqlfile) is undefined.
When set filter conditions are directly applied on the id-pairs SQL to fetch cohort counts. Ignored if the [id pairs SQL](#szgroupitemsidpairssqlfile) is undefined.

*Default value:* `false`

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.apache.commons.text.StringSubstitutor;

public class BQTable extends SqlTable {
Expand All @@ -24,8 +25,14 @@ public BQTable(String sql) {

@Override
public String render() {
return render(null);
}

@Override
public String render(String appendToSql) {
String appendNotNull = Optional.ofNullable(appendToSql).orElse("");
if (isRawSql()) {
return "(" + sql + ")";
return "(" + sql + appendNotNull + ")";
} else {
String template = "`${projectId}.${datasetId}`.${tableName}";
Map<String, String> params =
Expand All @@ -34,7 +41,7 @@ public String render() {
.put("datasetId", datasetId)
.put("tableName", tableName)
.build();
return StringSubstitutor.replace(template, params);
return StringSubstitutor.replace(template, params) + appendNotNull;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public String buildSql(SqlParams sqlParams, String tableAlias) {
ancestorDescendantTable.getTablePointer(),
ancestorIdFilterSql,
null,
false,
sqlParams,
hierarchyHasAncestorFilter.getAncestorIds().toArray(new Literal[0]));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public String buildSql(SqlParams sqlParams, String tableAlias) {
childParentIndexTable.getTablePointer(),
parentIdFilterSql,
null,
false,
sqlParams);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ private String foreignKeyOnSelectEntity(SqlParams sqlParams, String tableAlias)
filterEntityTable.getTablePointer(),
inSelectFilterSql,
null,
relationshipFilter.getEntityGroup().isUseSourceIdPairsSql(),
sqlParams);
}
}
Expand Down Expand Up @@ -182,6 +183,7 @@ private String foreignKeyOnFilterEntity(SqlParams sqlParams, String tableAlias)
filterEntityTable.getTablePointer(),
inSelectFilterSql,
null,
relationshipFilter.getEntityGroup().isUseSourceIdPairsSql(),
sqlParams);
}

Expand Down Expand Up @@ -264,6 +266,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
idPairsTable.getEntityIdField(relationshipFilter.getSelectEntity().getName());
SqlField filterIdIntTable =
idPairsTable.getEntityIdField(relationshipFilter.getFilterEntity().getName());
boolean appendSqlToTable = relationshipFilter.getEntityGroup().isUseSourceIdPairsSql();

if (!relationshipFilter.hasSubFilter()
&& !relationshipFilter.hasGroupByFilter()
Expand Down Expand Up @@ -313,6 +316,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
idPairsTable.getTablePointer(),
subFilterSql.isEmpty() ? null : subFilterSql,
havingSql.isEmpty() ? null : havingSql,
appendSqlToTable,
sqlParams);
} else {
// id IN (SELECT selectId FROM
Expand Down Expand Up @@ -342,6 +346,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
filterEntityTable.getTablePointer(),
subFilterSql,
null,
appendSqlToTable,
sqlParams);
}

Expand All @@ -353,6 +358,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
idPairsTable.getTablePointer(),
filterIdInSelectSql.isEmpty() ? null : filterIdInSelectSql,
null,
appendSqlToTable,
sqlParams);
}

Expand All @@ -374,6 +380,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
idPairsTable.getTablePointer(),
filterIdInSelectSql.isEmpty() ? null : filterIdInSelectSql,
havingSql,
appendSqlToTable,
sqlParams);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ public String getSql() {
}

public abstract String render();

public abstract String render(String append);
}
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,18 @@ default String inSelectFilterSql(
SqlTable table,
@Nullable String filterSql,
@Nullable String havingSql,
boolean appendSqlToTable,
SqlParams sqlParams,
Literal... unionAllLiterals) {
List<String> selectSqls = new ArrayList<>();
String sqlClause =
(filterSql != null ? " WHERE " + filterSql : "")
+ (havingSql != null ? ' ' + havingSql : "");
selectSqls.add(
"SELECT "
+ SqlQueryField.of(selectField).renderForSelect()
+ " FROM "
+ table.render()
+ (filterSql != null ? " WHERE " + filterSql : "")
+ (havingSql != null ? ' ' + havingSql : ""));
+ (appendSqlToTable ? table.render(sqlClause) : table.render() + sqlClause));
Arrays.stream(unionAllLiterals)
.forEach(literal -> selectSqls.add("SELECT @" + sqlParams.addParam("val", literal)));
return SqlQueryField.of(whereField).renderForWhere(tableAlias)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,12 @@ private static GroupItems fromConfigGroupItems(SZGroupItems szGroupItems, List<E
szGroupItems.foreignKeyAttributeItemsEntity == null
? null
: itemsEntity.getAttribute(szGroupItems.foreignKeyAttributeItemsEntity));
return new GroupItems(szGroupItems.name, groupEntity, itemsEntity, groupItemsRelationship);
return new GroupItems(
szGroupItems.name,
szGroupItems.useSourceIdPairsSql,
groupEntity,
itemsEntity,
groupItemsRelationship);
}

private static CriteriaOccurrence fromConfigCriteriaOccurrence(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public CriteriaOccurrence(
Relationship primaryCriteriaRelationship,
Map<String, Set<String>> occurrenceAttributesWithInstanceLevelDisplayHints,
Map<String, Set<String>> occurrenceAttributesWithRollupInstanceLevelDisplayHints) {
super(name);
super(name, false);
this.criteriaEntity = criteriaEntity;
this.occurrenceEntities = ImmutableList.copyOf(occurrenceEntities);
this.primaryEntity = primaryEntity;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@ public enum Type {
}

private final String name;
private final boolean useSourceIdPairsSql;

protected EntityGroup(String name) {
protected EntityGroup(String name, boolean useSourceIdPairsSql) {
this.name = name;
this.useSourceIdPairsSql = useSourceIdPairsSql;
}

public String getName() {
return name;
}

public boolean isUseSourceIdPairsSql() {
return useSourceIdPairsSql;
}

public abstract Type getType();

public abstract boolean includesEntity(String name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ public class GroupItems extends EntityGroup {
private final Relationship groupItemsRelationship;

public GroupItems(
String name, Entity groupEntity, Entity itemsEntity, Relationship groupItemsRelationship) {
super(name);
String name,
boolean useSourceIdPairsSql,
Entity groupEntity,
Entity itemsEntity,
Relationship groupItemsRelationship) {
super(name, useSourceIdPairsSql);
this.groupEntity = groupEntity;
this.itemsEntity = itemsEntity;
this.groupItemsRelationship = groupItemsRelationship;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class SZGroupItems {
name = "SZGroupItems.useSourceIdPairsSql",
markdown =
"True to skip copying the id-pairs SQL into a new index table, and use the source SQL directly.\n\n"
+ "When set filter conditions are directly applied on the id-pairs SQL to fetch cohort counts. "
+ "Ignored if the [id pairs SQL](${SZGroupItems.idPairsSqlFile}) is undefined.",
optional = true,
defaultValue = "false")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"itemsEntity": "person",
"idPairsSqlFile": "idPairs.sql",
"useSourceIdPairsSql": true,
"groupEntityIdFieldName": "variant_id",
"groupEntityIdFieldName": "vid",
"itemsEntityIdFieldName": "flattened_person_id",
"rollupCountsSql": {
"sqlFile": "rollupCounts.sql",
"entityIdFieldName": "variant_id",
"entityIdFieldName": "vid",
"rollupCountFieldName": "num_persons"
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT DISTINCT vid AS variant_id, flattened_person_id
SELECT DISTINCT vid, flattened_person_id
FROM `${omopDataset}.variant_to_person`
CROSS JOIN UNNEST(person_ids) AS flattened_person_id
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT vid AS variant_id, ARRAY_LENGTH(person_ids) AS num_persons
SELECT vid, ARRAY_LENGTH(person_ids) AS num_persons
/* Wrap variant_to_person table in a SELECT DISTINCT because there is a duplicate row in the test data. */
FROM (SELECT DISTINCT vid, person_ids FROM `${omopDataset}.variant_to_person` WHERE REGEXP_CONTAINS(vid, r"{indexIdRegex}"))
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void inSelectFilter() {
apiTranslator.binaryFilterSql(joinField, BinaryOperator.EQUALS, val, null, sqlParams);
String sql =
apiTranslator.inSelectFilterSql(
field, tableAlias, joinField, joinTable, joinFilterSql, null, sqlParams);
field, tableAlias, joinField, joinTable, joinFilterSql, null, false, sqlParams);
assertEquals(
"tableAlias.columnName IN (SELECT joinColumnName FROM `projectId.datasetId`.joinTableName WHERE joinColumnName = @val0)",
sql);
Expand All @@ -186,7 +186,16 @@ void inSelectFilter() {
Literal val1 = Literal.forInt64(25L);
sql =
apiTranslator.inSelectFilterSql(
field, tableAlias, joinField, joinTable, joinFilterSql, null, sqlParams, val0, val1);
field,
tableAlias,
joinField,
joinTable,
joinFilterSql,
null,
false,
sqlParams,
val0,
val1);
assertEquals(
"tableAlias.columnName IN (SELECT joinColumnName FROM `projectId.datasetId`.joinTableName WHERE joinColumnName = @val0 UNION ALL SELECT @val1 UNION ALL SELECT @val2)",
sql);
Expand All @@ -197,7 +206,7 @@ void inSelectFilter() {
val = Literal.forInt64(38L);
sql =
apiTranslator.inSelectFilterSql(
field, tableAlias, joinField, joinTable, null, null, sqlParams, val);
field, tableAlias, joinField, joinTable, null, null, false, sqlParams, val);
assertEquals(
"tableAlias.columnName IN (SELECT joinColumnName FROM `projectId.datasetId`.joinTableName UNION ALL SELECT @val0)",
sql);
Expand All @@ -213,6 +222,7 @@ void inSelectFilter() {
joinTable,
null,
"GROUP BY joinColumnName HAVING COUNT(*) > 1",
false,
sqlParams);
assertEquals(
"tableAlias.columnName IN (SELECT joinColumnName FROM `projectId.datasetId`.joinTableName GROUP BY joinColumnName HAVING COUNT(*) > 1)",
Expand All @@ -237,7 +247,7 @@ void booleanAndOrFilter() {
apiTranslator.binaryFilterSql(joinField, BinaryOperator.EQUALS, val2, null, sqlParams);
String filterSql2 =
apiTranslator.inSelectFilterSql(
field, tableAlias, joinField, joinTable, joinFilterSql, null, sqlParams);
field, tableAlias, joinField, joinTable, joinFilterSql, null, false, sqlParams);

String sql =
apiTranslator.booleanAndOrFilterSql(
Expand Down
Loading