Skip to content

Commit

Permalink
Append filterSql to source variant_to_person in cohort count (#1130)
Browse files Browse the repository at this point in the history
  • Loading branch information
dexamundsen authored Jan 16, 2025
1 parent a9e4c0c commit 2556ab4
Show file tree
Hide file tree
Showing 16 changed files with 65 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/generated/UNDERLAY_CONFIG.md
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ Pointer to SQL that returns entity id - rollup count (= number of related entity

True to skip copying the id-pairs SQL into a new index table, and use the source SQL directly.

Ignored if the [id pairs SQL](#szgroupitemsidpairssqlfile) is undefined.
When set filter conditions are directly applied on the id-pairs SQL to fetch cohort counts. Ignored if the [id pairs SQL](#szgroupitemsidpairssqlfile) is undefined.

*Default value:* `false`

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.apache.commons.text.StringSubstitutor;

public class BQTable extends SqlTable {
Expand All @@ -24,8 +25,14 @@ public BQTable(String sql) {

@Override
public String render() {
return render(null);
}

@Override
public String render(String appendToSql) {
String appendNotNull = Optional.ofNullable(appendToSql).orElse("");
if (isRawSql()) {
return "(" + sql + ")";
return "(" + sql + appendNotNull + ")";
} else {
String template = "`${projectId}.${datasetId}`.${tableName}";
Map<String, String> params =
Expand All @@ -34,7 +41,7 @@ public String render() {
.put("datasetId", datasetId)
.put("tableName", tableName)
.build();
return StringSubstitutor.replace(template, params);
return StringSubstitutor.replace(template, params) + appendNotNull;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public String buildSql(SqlParams sqlParams, String tableAlias) {
ancestorDescendantTable.getTablePointer(),
ancestorIdFilterSql,
null,
false,
sqlParams,
hierarchyHasAncestorFilter.getAncestorIds().toArray(new Literal[0]));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public String buildSql(SqlParams sqlParams, String tableAlias) {
childParentIndexTable.getTablePointer(),
parentIdFilterSql,
null,
false,
sqlParams);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ private String foreignKeyOnSelectEntity(SqlParams sqlParams, String tableAlias)
filterEntityTable.getTablePointer(),
inSelectFilterSql,
null,
relationshipFilter.getEntityGroup().isUseSourceIdPairsSql(),
sqlParams);
}
}
Expand Down Expand Up @@ -182,6 +183,7 @@ private String foreignKeyOnFilterEntity(SqlParams sqlParams, String tableAlias)
filterEntityTable.getTablePointer(),
inSelectFilterSql,
null,
relationshipFilter.getEntityGroup().isUseSourceIdPairsSql(),
sqlParams);
}

Expand Down Expand Up @@ -264,6 +266,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
idPairsTable.getEntityIdField(relationshipFilter.getSelectEntity().getName());
SqlField filterIdIntTable =
idPairsTable.getEntityIdField(relationshipFilter.getFilterEntity().getName());
boolean appendSqlToTable = relationshipFilter.getEntityGroup().isUseSourceIdPairsSql();

if (!relationshipFilter.hasSubFilter()
&& !relationshipFilter.hasGroupByFilter()
Expand Down Expand Up @@ -313,6 +316,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
idPairsTable.getTablePointer(),
subFilterSql.isEmpty() ? null : subFilterSql,
havingSql.isEmpty() ? null : havingSql,
appendSqlToTable,
sqlParams);
} else {
// id IN (SELECT selectId FROM
Expand Down Expand Up @@ -342,6 +346,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
filterEntityTable.getTablePointer(),
subFilterSql,
null,
appendSqlToTable,
sqlParams);
}

Expand All @@ -353,6 +358,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
idPairsTable.getTablePointer(),
filterIdInSelectSql.isEmpty() ? null : filterIdInSelectSql,
null,
appendSqlToTable,
sqlParams);
}

Expand All @@ -374,6 +380,7 @@ private String intermediateTable(SqlParams sqlParams, String tableAlias) {
idPairsTable.getTablePointer(),
filterIdInSelectSql.isEmpty() ? null : filterIdInSelectSql,
havingSql,
appendSqlToTable,
sqlParams);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ public String getSql() {
}

public abstract String render();

public abstract String render(String append);
}
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,18 @@ default String inSelectFilterSql(
SqlTable table,
@Nullable String filterSql,
@Nullable String havingSql,
boolean appendSqlToTable,
SqlParams sqlParams,
Literal... unionAllLiterals) {
List<String> selectSqls = new ArrayList<>();
String sqlClause =
(filterSql != null ? " WHERE " + filterSql : "")
+ (havingSql != null ? ' ' + havingSql : "");
selectSqls.add(
"SELECT "
+ SqlQueryField.of(selectField).renderForSelect()
+ " FROM "
+ table.render()
+ (filterSql != null ? " WHERE " + filterSql : "")
+ (havingSql != null ? ' ' + havingSql : ""));
+ (appendSqlToTable ? table.render(sqlClause) : table.render() + sqlClause));
Arrays.stream(unionAllLiterals)
.forEach(literal -> selectSqls.add("SELECT @" + sqlParams.addParam("val", literal)));
return SqlQueryField.of(whereField).renderForWhere(tableAlias)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,12 @@ private static GroupItems fromConfigGroupItems(SZGroupItems szGroupItems, List<E
szGroupItems.foreignKeyAttributeItemsEntity == null
? null
: itemsEntity.getAttribute(szGroupItems.foreignKeyAttributeItemsEntity));
return new GroupItems(szGroupItems.name, groupEntity, itemsEntity, groupItemsRelationship);
return new GroupItems(
szGroupItems.name,
szGroupItems.useSourceIdPairsSql,
groupEntity,
itemsEntity,
groupItemsRelationship);
}

private static CriteriaOccurrence fromConfigCriteriaOccurrence(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public CriteriaOccurrence(
Relationship primaryCriteriaRelationship,
Map<String, Set<String>> occurrenceAttributesWithInstanceLevelDisplayHints,
Map<String, Set<String>> occurrenceAttributesWithRollupInstanceLevelDisplayHints) {
super(name);
super(name, false);
this.criteriaEntity = criteriaEntity;
this.occurrenceEntities = ImmutableList.copyOf(occurrenceEntities);
this.primaryEntity = primaryEntity;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@ public enum Type {
}

private final String name;
private final boolean useSourceIdPairsSql;

protected EntityGroup(String name) {
protected EntityGroup(String name, boolean useSourceIdPairsSql) {
this.name = name;
this.useSourceIdPairsSql = useSourceIdPairsSql;
}

public String getName() {
return name;
}

public boolean isUseSourceIdPairsSql() {
return useSourceIdPairsSql;
}

public abstract Type getType();

public abstract boolean includesEntity(String name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ public class GroupItems extends EntityGroup {
private final Relationship groupItemsRelationship;

public GroupItems(
String name, Entity groupEntity, Entity itemsEntity, Relationship groupItemsRelationship) {
super(name);
String name,
boolean useSourceIdPairsSql,
Entity groupEntity,
Entity itemsEntity,
Relationship groupItemsRelationship) {
super(name, useSourceIdPairsSql);
this.groupEntity = groupEntity;
this.itemsEntity = itemsEntity;
this.groupItemsRelationship = groupItemsRelationship;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class SZGroupItems {
name = "SZGroupItems.useSourceIdPairsSql",
markdown =
"True to skip copying the id-pairs SQL into a new index table, and use the source SQL directly.\n\n"
+ "When set filter conditions are directly applied on the id-pairs SQL to fetch cohort counts. "
+ "Ignored if the [id pairs SQL](${SZGroupItems.idPairsSqlFile}) is undefined.",
optional = true,
defaultValue = "false")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"itemsEntity": "person",
"idPairsSqlFile": "idPairs.sql",
"useSourceIdPairsSql": true,
"groupEntityIdFieldName": "variant_id",
"groupEntityIdFieldName": "vid",
"itemsEntityIdFieldName": "flattened_person_id",
"rollupCountsSql": {
"sqlFile": "rollupCounts.sql",
"entityIdFieldName": "variant_id",
"entityIdFieldName": "vid",
"rollupCountFieldName": "num_persons"
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT DISTINCT vid AS variant_id, flattened_person_id
SELECT DISTINCT vid, flattened_person_id
FROM `${omopDataset}.variant_to_person`
CROSS JOIN UNNEST(person_ids) AS flattened_person_id
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT vid AS variant_id, ARRAY_LENGTH(person_ids) AS num_persons
SELECT vid, ARRAY_LENGTH(person_ids) AS num_persons
/* Wrap variant_to_person table in a SELECT DISTINCT because there is a duplicate row in the test data. */
FROM (SELECT DISTINCT vid, person_ids FROM `${omopDataset}.variant_to_person` WHERE REGEXP_CONTAINS(vid, r"{indexIdRegex}"))
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void inSelectFilter() {
apiTranslator.binaryFilterSql(joinField, BinaryOperator.EQUALS, val, null, sqlParams);
String sql =
apiTranslator.inSelectFilterSql(
field, tableAlias, joinField, joinTable, joinFilterSql, null, sqlParams);
field, tableAlias, joinField, joinTable, joinFilterSql, null, false, sqlParams);
assertEquals(
"tableAlias.columnName IN (SELECT joinColumnName FROM `projectId.datasetId`.joinTableName WHERE joinColumnName = @val0)",
sql);
Expand All @@ -186,7 +186,16 @@ void inSelectFilter() {
Literal val1 = Literal.forInt64(25L);
sql =
apiTranslator.inSelectFilterSql(
field, tableAlias, joinField, joinTable, joinFilterSql, null, sqlParams, val0, val1);
field,
tableAlias,
joinField,
joinTable,
joinFilterSql,
null,
false,
sqlParams,
val0,
val1);
assertEquals(
"tableAlias.columnName IN (SELECT joinColumnName FROM `projectId.datasetId`.joinTableName WHERE joinColumnName = @val0 UNION ALL SELECT @val1 UNION ALL SELECT @val2)",
sql);
Expand All @@ -197,7 +206,7 @@ void inSelectFilter() {
val = Literal.forInt64(38L);
sql =
apiTranslator.inSelectFilterSql(
field, tableAlias, joinField, joinTable, null, null, sqlParams, val);
field, tableAlias, joinField, joinTable, null, null, false, sqlParams, val);
assertEquals(
"tableAlias.columnName IN (SELECT joinColumnName FROM `projectId.datasetId`.joinTableName UNION ALL SELECT @val0)",
sql);
Expand All @@ -213,6 +222,7 @@ void inSelectFilter() {
joinTable,
null,
"GROUP BY joinColumnName HAVING COUNT(*) > 1",
false,
sqlParams);
assertEquals(
"tableAlias.columnName IN (SELECT joinColumnName FROM `projectId.datasetId`.joinTableName GROUP BY joinColumnName HAVING COUNT(*) > 1)",
Expand All @@ -237,7 +247,7 @@ void booleanAndOrFilter() {
apiTranslator.binaryFilterSql(joinField, BinaryOperator.EQUALS, val2, null, sqlParams);
String filterSql2 =
apiTranslator.inSelectFilterSql(
field, tableAlias, joinField, joinTable, joinFilterSql, null, sqlParams);
field, tableAlias, joinField, joinTable, joinFilterSql, null, false, sqlParams);

String sql =
apiTranslator.booleanAndOrFilterSql(
Expand Down

0 comments on commit 2556ab4

Please sign in to comment.