diff --git a/core/src/main/scala/org/broadinstitute/dsde/rawls/Boot.scala b/core/src/main/scala/org/broadinstitute/dsde/rawls/Boot.scala index 9050b2052f..8660c48225 100644 --- a/core/src/main/scala/org/broadinstitute/dsde/rawls/Boot.scala +++ b/core/src/main/scala/org/broadinstitute/dsde/rawls/Boot.scala @@ -482,7 +482,8 @@ object Boot extends IOApp with LazyLogging { val spendReportingServiceConfig = SpendReportingServiceConfig( gcsConfig.getString("billingExportTableName"), gcsConfig.getString("billingExportTimePartitionColumn"), - gcsConfig.getConfig("spendReporting").getInt("maxDateRange") + gcsConfig.getConfig("spendReporting").getInt("maxDateRange"), + metricsPrefix, ) val spendReportingServiceConstructor: RawlsRequestContext => SpendReportingService = diff --git a/core/src/main/scala/org/broadinstitute/dsde/rawls/config/SpendReportingServiceConfig.scala b/core/src/main/scala/org/broadinstitute/dsde/rawls/config/SpendReportingServiceConfig.scala index b526cbba16..48fea06da7 100644 --- a/core/src/main/scala/org/broadinstitute/dsde/rawls/config/SpendReportingServiceConfig.scala +++ b/core/src/main/scala/org/broadinstitute/dsde/rawls/config/SpendReportingServiceConfig.scala @@ -1,6 +1,8 @@ package org.broadinstitute.dsde.rawls.config -final case class SpendReportingServiceConfig(defaultTableName: String, - defaultTimePartitionColumn: String, - maxDateRange: Int +final case class SpendReportingServiceConfig( + defaultTableName: String, + defaultTimePartitionColumn: String, + maxDateRange: Int, + workbenchMetricBaseName: String ) diff --git a/core/src/main/scala/org/broadinstitute/dsde/rawls/spendreporting/SpendReportingService.scala b/core/src/main/scala/org/broadinstitute/dsde/rawls/spendreporting/SpendReportingService.scala index 00080af621..0ced5eb054 100644 --- a/core/src/main/scala/org/broadinstitute/dsde/rawls/spendreporting/SpendReportingService.scala +++ b/core/src/main/scala/org/broadinstitute/dsde/rawls/spendreporting/SpendReportingService.scala @@ -1,7 +1,6 @@ package org.broadinstitute.dsde.rawls.spendreporting import java.util.Currency - import akka.http.scaladsl.model.StatusCodes import cats.effect.IO import cats.effect.unsafe.implicits.global @@ -9,9 +8,9 @@ import com.google.cloud.bigquery.{Option => _, _} import com.typesafe.scalalogging.LazyLogging import org.broadinstitute.dsde.rawls.config.SpendReportingServiceConfig import org.broadinstitute.dsde.rawls.dataaccess.{SamDAO, SlickDataSource} +import org.broadinstitute.dsde.rawls.metrics.RawlsInstrumented import org.broadinstitute.dsde.rawls.model.SpendReportingAggregationKeys.SpendReportingAggregationKey -import org.broadinstitute.dsde.rawls.model.TerraSpendCategories.TerraSpendCategory -import org.broadinstitute.dsde.rawls.model._ +import org.broadinstitute.dsde.rawls.model.{SpendReportingAggregationKeyWithSub, _} import org.broadinstitute.dsde.workbench.google2.GoogleBigQueryService import org.broadinstitute.dsde.rawls.{RawlsException, RawlsExceptionWithErrorReport} import org.broadinstitute.dsde.workbench.model.google.GoogleProject @@ -23,377 +22,271 @@ import scala.jdk.CollectionConverters._ import scala.math.BigDecimal.RoundingMode object SpendReportingService { - def constructor(dataSource: SlickDataSource, - bigQueryService: cats.effect.Resource[IO, GoogleBigQueryService[IO]], - samDAO: SamDAO, - spendReportingServiceConfig: SpendReportingServiceConfig + def constructor( + dataSource: SlickDataSource, + bigQueryService: cats.effect.Resource[IO, GoogleBigQueryService[IO]], + samDAO: SamDAO, + spendReportingServiceConfig: SpendReportingServiceConfig )(ctx: RawlsRequestContext)(implicit executionContext: ExecutionContext): SpendReportingService = new SpendReportingService(ctx, dataSource, bigQueryService, samDAO, spendReportingServiceConfig) -} -class SpendReportingService(ctx: RawlsRequestContext, - dataSource: SlickDataSource, - bigQueryService: cats.effect.Resource[IO, GoogleBigQueryService[IO]], - samDAO: SamDAO, - spendReportingServiceConfig: SpendReportingServiceConfig -)(implicit val executionContext: ExecutionContext) - extends LazyLogging { - private def requireProjectAction[T](projectName: RawlsBillingProjectName, action: SamResourceAction)( - op: => Future[T] - ): Future[T] = - samDAO.userHasAction(SamResourceTypeNames.billingProject, projectName.value, action, ctx.userInfo).flatMap { - case true => op - case false => - Future.failed( - new RawlsExceptionWithErrorReport( - errorReport = ErrorReport( - StatusCodes.Forbidden, - s"${ctx.userInfo.userEmail.value} cannot perform ${action.value} on project ${projectName.value}" - ) - ) + def extractSpendReportingResults( + allRows: List[FieldValueList], + start: DateTime, + end: DateTime, + workspaceProjectsToNames: Map[String, WorkspaceName], + aggregations: Set[SpendReportingAggregationKeyWithSub] + ): SpendReportingResults = { + + val currency = allRows.map(_.get("currency").getStringValue).distinct match { + case head :: List() => Currency.getInstance(head) + case head :: tail => + throw RawlsExceptionWithErrorReport( + StatusCodes.BadGateway, // todo: Probably the wrong status code + s"Inconsistent currencies found while aggregating spend data: $head and ${tail.head} cannot be combined" ) + case List() => throw RawlsExceptionWithErrorReport(StatusCodes.NotFound, "No currencies found for spend data") + } - private def requireAlphaUser[T]()(op: => Future[T]): Future[T] = - samDAO - .userHasAction(SamResourceTypeNames.managedGroup, - "Alpha_Spend_Report_Users", - SamResourceAction("use"), - ctx.userInfo - ) - .flatMap { - case true => op - case false => - Future.failed( - new RawlsExceptionWithErrorReport(ErrorReport(StatusCodes.Forbidden, "This API is not live yet.")) + def sum(rows: List[FieldValueList], field: String): String = rows + .map(row => BigDecimal(row.get(field).getDoubleValue)) + .sum + .setScale(currency.getDefaultFractionDigits, RoundingMode.HALF_EVEN) + .toString() + + type SubKey = Option[SpendReportingAggregationKey] + + def byDate(rows: List[FieldValueList], subKey: SubKey): List[SpendReportingForDateRange] = rows + .groupBy(row => DateTime.parse(row.get("date").getStringValue)) + .map { case (startTime, rowsForStartTime) => + SpendReportingForDateRange( + sum(rowsForStartTime, "cost"), + sum(rowsForStartTime, "credits"), + currency.getCurrencyCode, + Option(startTime), + endTime = Option(startTime.plusDays(1).minusMillis(1)), + subAggregation = subKey.map(key => aggregate(rowsForStartTime, SpendReportingAggregationKeyWithSub(key))) + ) + } + .toList + + def byWorkspace(rows: List[FieldValueList], subKey: SubKey): List[SpendReportingForDateRange] = rows + .groupBy(row => row.get("googleProjectId").getStringValue) + .map { case (googleProjectId, projectRows) => + val workspaceName = workspaceProjectsToNames.getOrElse( + googleProjectId, + throw RawlsExceptionWithErrorReport( + StatusCodes.BadGateway, + s"unexpected project ${googleProjectId} returned by BigQuery" ) + ) + SpendReportingForDateRange( + sum(projectRows, "cost"), + sum(projectRows, "credits"), + currency.getCurrencyCode, + workspace = Option(workspaceName), + googleProjectId = Option(GoogleProject(googleProjectId)), + subAggregation = subKey.map(key => aggregate(projectRows, SpendReportingAggregationKeyWithSub(key))) + ) + } + .toList + + def byCategory(rows: List[FieldValueList], subKey: SubKey): List[SpendReportingForDateRange] = rows + .groupBy(row => TerraSpendCategories.categorize(row.get("service").getStringValue)) + .map { case (category, categoryRows) => + SpendReportingForDateRange( + sum(categoryRows, "cost"), + sum(categoryRows, "credits"), + currency.getCurrencyCode, + category = Option(category), + subAggregation = subKey.map(key => aggregate(categoryRows, SpendReportingAggregationKeyWithSub(key))) + ) } + .toList - def extractSpendReportingResults(rows: List[FieldValueList], - startTime: DateTime, - endTime: DateTime, - workspaceProjectsToNames: Map[GoogleProject, WorkspaceName], - aggregationKeys: Set[SpendReportingAggregationKeyWithSub] - ): SpendReportingResults = { - val currency = getCurrency(rows) - val spendAggregations = aggregationKeys.map { aggregationKey => - extractSpendAggregation(rows, currency, aggregationKey, workspaceProjectsToNames) + def aggregate(rows: List[FieldValueList], key: SpendReportingAggregationKeyWithSub): SpendReportingAggregation = { + val spend = key match { + case SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Category, sub) => byCategory(rows, sub) + case SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace, sub) => byWorkspace(rows, sub) + case SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Daily, sub) => byDate(rows, sub) + } + SpendReportingAggregation(key.key, spend) } - val spendSummary = extractSpendSummary(rows, currency, startTime, endTime) - SpendReportingResults(spendAggregations.toList, spendSummary) + val summary = SpendReportingForDateRange( + sum(allRows, "cost"), + sum(allRows, "credits"), + currency.getCurrencyCode, + Option(start), + Option(end) + ) + SpendReportingResults(aggregations.map(aggregate(allRows, _)).toList, summary) } +} + +class SpendReportingService( + ctx: RawlsRequestContext, + dataSource: SlickDataSource, + bigQueryService: cats.effect.Resource[IO, GoogleBigQueryService[IO]], + samDAO: SamDAO, + spendReportingServiceConfig: SpendReportingServiceConfig +)(implicit val executionContext: ExecutionContext) + extends LazyLogging + with RawlsInstrumented { + /** - * Ensure that BigQuery results only include one type of currency and return that currency. + * Base name for all metrics. This will be prepended to all generated metric names. + * Example: dev.firecloud.rawls */ - private def getCurrency(rows: List[FieldValueList]): Currency = { - val currencies = rows.map(_.get("currency").getStringValue) + override val workbenchMetricBaseName: String = spendReportingServiceConfig.workbenchMetricBaseName - Currency.getInstance(currencies.reduce { (x, y) => - if (x.equals(y)) { - x - } else { - throw new RawlsExceptionWithErrorReport( - ErrorReport(StatusCodes.BadGateway, - s"Inconsistent currencies found while aggregating spend data: $x and $y cannot be combined" - ) + private def requireProjectAction[T](projectName: RawlsBillingProjectName, action: SamResourceAction)( + op: => Future[T] + ): Future[T] = + samDAO.userHasAction(SamResourceTypeNames.billingProject, projectName.value, action, ctx.userInfo).flatMap { + case true => op + case false => + throw RawlsExceptionWithErrorReport( + StatusCodes.Forbidden, + s"${ctx.userInfo.userEmail.value} cannot perform ${action.value} on project ${projectName.value}" ) - } - }) - } - - private def extractSpendAggregation(rows: List[FieldValueList], - currency: Currency, - aggregationKey: SpendReportingAggregationKeyWithSub, - workspaceProjectsToNames: Map[GoogleProject, WorkspaceName] = Map.empty - ): SpendReportingAggregation = - aggregationKey match { - case SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Category, subAggregationKey) => - extractCategorySpendAggregation(rows, currency, subAggregationKey, workspaceProjectsToNames) - case SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace, subAggregationKey) => - extractWorkspaceSpendAggregation(rows, currency, subAggregationKey, workspaceProjectsToNames) - case SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Daily, subAggregationKey) => - extractDailySpendAggregation(rows, currency, subAggregationKey, workspaceProjectsToNames) } - private def extractSpendSummary(rows: List[FieldValueList], - currency: Currency, - startTime: DateTime, - endTime: DateTime - ): SpendReportingForDateRange = { - val (cost, credits) = sumCostsAndCredits(rows, currency) - - SpendReportingForDateRange( - cost.toString(), - credits.toString(), - currency.getCurrencyCode, - Option(startTime), - Option(endTime) + def requireAlphaUser[T]()(op: => Future[T]): Future[T] = samDAO + .userHasAction( + SamResourceTypeNames.managedGroup, + "Alpha_Spend_Report_Users", + SamResourceAction("use"), + ctx.userInfo ) - } + .flatMap { + case true => op + case false => + throw RawlsExceptionWithErrorReport(StatusCodes.Forbidden, "This API is not live yet.") + } - private def sumCostsAndCredits(rows: List[FieldValueList], currency: Currency): (BigDecimal, BigDecimal) = - ( - rows - .map(row => BigDecimal(row.get("cost").getDoubleValue)) - .sum - .setScale(currency.getDefaultFractionDigits, RoundingMode.HALF_EVEN), - rows - .map(row => BigDecimal(row.get("credits").getDoubleValue)) - .sum - .setScale(currency.getDefaultFractionDigits, RoundingMode.HALF_EVEN) - ) + private def dateTimeToISODateString(dt: DateTime): String = dt.toString(ISODateTimeFormat.date()) - private def extractWorkspaceSpendAggregation(rows: List[FieldValueList], - currency: Currency, - subAggregationKey: Option[SpendReportingAggregationKey] = None, - workspaceProjectsToNames: Map[GoogleProject, WorkspaceName] - ): SpendReportingAggregation = { - val spendByGoogleProjectId: Map[GoogleProject, List[FieldValueList]] = - rows.groupBy(row => GoogleProject(row.get("googleProjectId").getStringValue)) - val workspaceSpend = spendByGoogleProjectId.map { case (googleProjectId, rowsForGoogleProjectId) => - val (cost, credits) = sumCostsAndCredits(rowsForGoogleProjectId, currency) - val subAggregation = subAggregationKey.map { key => - extractSpendAggregation(rowsForGoogleProjectId, - currency, - aggregationKey = SpendReportingAggregationKeyWithSub(key), - workspaceProjectsToNames = workspaceProjectsToNames - ) - } - val workspaceName = workspaceProjectsToNames.getOrElse( - googleProjectId, - throw new RawlsExceptionWithErrorReport( - ErrorReport(StatusCodes.BadGateway, s"unexpected project ${googleProjectId.value} returned by BigQuery") - ) + def getSpendExportConfiguration(project: RawlsBillingProjectName): Future[BillingProjectSpendExport] = dataSource + .inTransaction(_.rawlsBillingProjectQuery.getBillingProjectSpendConfiguration(project)) + .recover { case _: RawlsException => + throw RawlsExceptionWithErrorReport( + StatusCodes.BadRequest, + s"billing account not found on billing project ${project.value}" ) - SpendReportingForDateRange( - cost.toString(), - credits.toString(), - currency.getCurrencyCode, - workspace = Option(workspaceName), - googleProjectId = Option(googleProjectId), - subAggregation = subAggregation + } + .map { + _.getOrElse( + throw RawlsExceptionWithErrorReport(StatusCodes.NotFound, s"billing project ${project.value} not found") ) - }.toList - - SpendReportingAggregation( - SpendReportingAggregationKeys.Workspace, - workspaceSpend - ) - } + } - private def extractCategorySpendAggregation(rows: List[FieldValueList], - currency: Currency, - subAggregationKey: Option[SpendReportingAggregationKey] = None, - workspaceProjectsToNames: Map[GoogleProject, WorkspaceName] = Map.empty - ): SpendReportingAggregation = { - val spendByCategory: Map[TerraSpendCategory, List[FieldValueList]] = - rows.groupBy(row => TerraSpendCategories.categorize(row.get("service").getStringValue)) - val categorySpend = spendByCategory.map { case (category, rowsForCategory) => - val (cost, credits) = sumCostsAndCredits(rowsForCategory, currency) - val subAggregation = subAggregationKey.map { key => - extractSpendAggregation(rowsForCategory, - currency, - aggregationKey = SpendReportingAggregationKeyWithSub(key), - workspaceProjectsToNames = workspaceProjectsToNames - ) - } - SpendReportingForDateRange( - cost.toString, - credits.toString, - currency.getCurrencyCode, - category = Option(category), - subAggregation = subAggregation - ) - }.toList + def getWorkspaceGoogleProjects(projectName: RawlsBillingProjectName): Future[Map[String, WorkspaceName]] = + dataSource.inTransaction(_.workspaceQuery.listWithBillingProject(projectName)).map { + _.collect { + case w if w.workspaceVersion == WorkspaceVersions.V2 => w.googleProjectId.value -> w.toWorkspaceName + }.toMap + } - SpendReportingAggregation( - SpendReportingAggregationKeys.Category, - categorySpend + def validateReportParameters(startDate: DateTime, endDate: DateTime): Unit = if (startDate.isAfter(endDate)) { + throw RawlsExceptionWithErrorReport( + StatusCodes.BadRequest, + s"start date ${dateTimeToISODateString(startDate)} must be before end date ${dateTimeToISODateString(endDate)}" ) - } - - private def extractDailySpendAggregation(rows: List[FieldValueList], - currency: Currency, - subAggregationKey: Option[SpendReportingAggregationKey] = None, - workspaceProjectsToNames: Map[GoogleProject, WorkspaceName] = Map.empty - ): SpendReportingAggregation = { - val spendByStartTime: Map[DateTime, List[FieldValueList]] = - rows.groupBy(row => DateTime.parse(row.get("date").getStringValue)) - val dailySpend = spendByStartTime.map { case (startTime, rowsForStartTime) => - val (cost, credits) = sumCostsAndCredits(rowsForStartTime, currency) - val subAggregation = subAggregationKey.map { key => - extractSpendAggregation(rowsForStartTime, - currency, - aggregationKey = SpendReportingAggregationKeyWithSub(key), - workspaceProjectsToNames = workspaceProjectsToNames - ) - } - SpendReportingForDateRange( - cost.toString, - credits.toString, - currency.getCurrencyCode, - Option(startTime), - endTime = Option(startTime.plusDays(1).minusMillis(1)), - subAggregation = subAggregation - ) - }.toList - - SpendReportingAggregation( - SpendReportingAggregationKeys.Daily, - dailySpend + } else if (Days.daysBetween(startDate, endDate).getDays > spendReportingServiceConfig.maxDateRange) { + throw RawlsExceptionWithErrorReport( + StatusCodes.BadRequest, + s"provided dates exceed maximum report date range of ${spendReportingServiceConfig.maxDateRange} days" ) } - private def dateTimeToISODateString(dt: DateTime): String = dt.toString(ISODateTimeFormat.date()) - - private def getSpendExportConfiguration( - billingProjectName: RawlsBillingProjectName - ): Future[BillingProjectSpendExport] = - dataSource - .inTransaction { dataAccess => - dataAccess.rawlsBillingProjectQuery.getBillingProjectSpendConfiguration(billingProjectName) - } - .recover { case _: RawlsException => - throw new RawlsExceptionWithErrorReport( - ErrorReport(StatusCodes.BadRequest, - s"billing account not found on billing project ${billingProjectName.value}" - ) - ) - } - .map( - _.getOrElse( - throw new RawlsExceptionWithErrorReport( - ErrorReport(StatusCodes.NotFound, s"billing project ${billingProjectName.value} not found") - ) - ) - ) - - private def getWorkspaceGoogleProjects( - billingProjectName: RawlsBillingProjectName - ): Future[Map[GoogleProject, WorkspaceName]] = - dataSource - .inTransaction { dataAccess => - dataAccess.workspaceQuery.listWithBillingProject(billingProjectName) - } - .map { workspaces => - workspaces.collect { - case workspace if workspace.workspaceVersion == WorkspaceVersions.V2 => - GoogleProject(workspace.googleProjectId.value) -> workspace.toWorkspaceName - }.toMap - } - - private def validateReportParameters(startDate: DateTime, endDate: DateTime): Unit = - if (startDate.isAfter(endDate)) { - throw new RawlsExceptionWithErrorReport( - ErrorReport( - StatusCodes.BadRequest, - s"start date ${dateTimeToISODateString(startDate)} must be before end date ${dateTimeToISODateString(endDate)}" - ) - ) - } else if (Days.daysBetween(startDate, endDate).getDays > spendReportingServiceConfig.maxDateRange) { - throw new RawlsExceptionWithErrorReport( - ErrorReport( - StatusCodes.BadRequest, - s"provided dates exceed maximum report date range of ${spendReportingServiceConfig.maxDateRange} days" - ) - ) + def getQuery(aggregations: Set[SpendReportingAggregationKeyWithSub], config: BillingProjectSpendExport): String = { + // Unbox potentially many SpendReportingAggregationKeyWithSubs for query, + // all of which have optional subAggregationKeys and convert to Set[SpendReportingAggregationKey] + val queryKeys = aggregations.flatMap(a => Set(Option(a.key), a.subAggregationKey).flatten) + val tableName = config.spendExportTable.getOrElse(spendReportingServiceConfig.defaultTableName) + val timePartitionColumn: String = { + val isBroadTable = tableName == spendReportingServiceConfig.defaultTableName + // The Broad table uses a view with a different column name. + if (isBroadTable) spendReportingServiceConfig.defaultTimePartitionColumn else "_PARTITIONTIME" } + s""" + | SELECT + | SUM(cost) as cost, + | SUM(IFNULL((SELECT SUM(c.amount) FROM UNNEST(credits) c), 0)) as credits, + | currency ${queryKeys.map(_.bigQueryAliasClause()).mkString} + | FROM `$tableName` + | WHERE billing_account_id = @billingAccountId + | AND $timePartitionColumn BETWEEN @startDate AND @endDate + | AND project.id in UNNEST(@projects) + | GROUP BY currency ${queryKeys.map(_.bigQueryGroupByClause()).mkString} + |""".stripMargin + .replace("REPLACE_TIME_PARTITION_COLUMN", timePartitionColumn) + } - private def stringQueryParameterValue(parameterValue: String): QueryParameterValue = - QueryParameterValue - .newBuilder() - .setType(StandardSQLTypeName.STRING) - .setValue(parameterValue) - .build() - - private def stringArrayQueryParameterValue(parameterValues: List[String]): QueryParameterValue = { - val queryParameterArrayValues = parameterValues.map { parameterValue => + def setUpQuery( + query: String, + exportConf: BillingProjectSpendExport, + start: DateTime, + end: DateTime, + projectNames: Map[String, WorkspaceName] + ): QueryJobConfiguration = { + def queryParam(value: String): QueryParameterValue = + QueryParameterValue.newBuilder().setType(StandardSQLTypeName.STRING).setValue(value).build() + + val projectNamesParam: QueryParameterValue = QueryParameterValue .newBuilder() - .setType(StandardSQLTypeName.STRING) - .setValue(parameterValue) + .setType(StandardSQLTypeName.ARRAY) + .setArrayType(StandardSQLTypeName.STRING) + .setArrayValues(projectNames.keySet.map(name => queryParam(name)).toList.asJava) .build() - }.asJava - QueryParameterValue - .newBuilder() - .setType(StandardSQLTypeName.ARRAY) - .setArrayType(StandardSQLTypeName.STRING) - .setArrayValues(queryParameterArrayValues) + QueryJobConfiguration + .newBuilder(query) + .addNamedParameter("billingAccountId", queryParam(exportConf.billingAccountId.withoutPrefix())) + .addNamedParameter("startDate", queryParam(dateTimeToISODateString(start))) + .addNamedParameter("endDate", queryParam(dateTimeToISODateString(end))) + .addNamedParameter("projects", projectNamesParam) .build() } - def getQuery(aggregationKeys: Set[SpendReportingAggregationKey], - tableName: String, - customTimePartitionColumn: Option[String] - ): String = { - // The Broad table uses a view with a different column name. - val timePartitionColumn = customTimePartitionColumn.getOrElse("_PARTITIONTIME") - val queryClause = s""" - | SELECT - | SUM(cost) as cost, - | SUM(IFNULL((SELECT SUM(c.amount) FROM UNNEST(credits) c), 0)) as credits, - | currency ${aggregationKeys.map(_.bigQueryAliasClause()).mkString} - | FROM `$tableName` - | WHERE billing_account_id = @billingAccountId - | AND $timePartitionColumn BETWEEN @startDate AND @endDate - | AND project.id in UNNEST(@projects) - | GROUP BY currency ${aggregationKeys.map(_.bigQueryGroupByClause()).mkString} - |""".stripMargin - queryClause.replace("REPLACE_TIME_PARTITION_COLUMN", timePartitionColumn) - } - - def getSpendForBillingProject(billingProjectName: RawlsBillingProjectName, - startDate: DateTime, - endDate: DateTime, - aggregationKeyParameters: Set[SpendReportingAggregationKeyWithSub] = Set.empty + def getSpendForBillingProject( + projectName: RawlsBillingProjectName, + startDate: DateTime, + endDate: DateTime, + aggregationKeys: Set[SpendReportingAggregationKeyWithSub] = Set.empty ): Future[SpendReportingResults] = { validateReportParameters(startDate, endDate) requireAlphaUser() { - requireProjectAction(billingProjectName, SamBillingProjectActions.readSpendReport) { + requireProjectAction(projectName, SamBillingProjectActions.readSpendReport) { for { - spendExportConf <- getSpendExportConfiguration(billingProjectName) - workspaceProjectsToNames <- getWorkspaceGoogleProjects(billingProjectName) + spendExportConf <- getSpendExportConfiguration(projectName) + workspaceProjectsToNames <- getWorkspaceGoogleProjects(projectName) - // Unbox potentially many SpendReportingAggregationKeyWithSubs, all of which have optional subAggregationKeys and convert to Set[SpendReportingAggregationKey] - aggregationKeys = aggregationKeyParameters.flatMap(maybeKeys => - Set(Option(maybeKeys.key), maybeKeys.subAggregationKey).flatten - ) - - spendReportTableName = spendExportConf.spendExportTable.getOrElse( - spendReportingServiceConfig.defaultTableName - ) - isBroadTable = spendReportTableName == spendReportingServiceConfig.defaultTableName - timePartitionColumn = if (isBroadTable) Some(spendReportingServiceConfig.defaultTimePartitionColumn) else None - query = getQuery(aggregationKeys, spendReportTableName, timePartitionColumn) - - queryJobConfiguration = QueryJobConfiguration - .newBuilder(query) - .addNamedParameter("billingAccountId", - stringQueryParameterValue(spendExportConf.billingAccountId.withoutPrefix()) - ) - .addNamedParameter("startDate", stringQueryParameterValue(dateTimeToISODateString(startDate))) - .addNamedParameter("endDate", stringQueryParameterValue(dateTimeToISODateString(endDate))) - .addNamedParameter("projects", - stringArrayQueryParameterValue(workspaceProjectsToNames.keySet.map(_.value).toList) - ) - .build() + query = getQuery(aggregationKeys, spendExportConf) + queryJobConfiguration = setUpQuery(query, spendExportConf, startDate, endDate, workspaceProjectsToNames) result <- bigQueryService.use(_.query(queryJobConfiguration)).unsafeToFuture() } yield result.getValues.asScala.toList match { case Nil => - throw new RawlsExceptionWithErrorReport( - ErrorReport( - StatusCodes.NotFound, - s"no spend data found for billing project ${billingProjectName.value} between dates ${dateTimeToISODateString(startDate)} and ${dateTimeToISODateString(endDate)}" - ) + throw RawlsExceptionWithErrorReport( + StatusCodes.NotFound, + s"no spend data found for billing project ${projectName.value} between dates ${dateTimeToISODateString(startDate)} and ${dateTimeToISODateString(endDate)}" ) case rows => - extractSpendReportingResults(rows, startDate, endDate, workspaceProjectsToNames, aggregationKeyParameters) + SpendReportingService.extractSpendReportingResults( + rows, + startDate, + endDate, + workspaceProjectsToNames, + aggregationKeys + ) } } } } + } diff --git a/core/src/test/scala/org/broadinstitute/dsde/rawls/spendreporting/SpendReportingServiceSpec.scala b/core/src/test/scala/org/broadinstitute/dsde/rawls/spendreporting/SpendReportingServiceSpec.scala index a42fc10c33..4a80b32cb0 100644 --- a/core/src/test/scala/org/broadinstitute/dsde/rawls/spendreporting/SpendReportingServiceSpec.scala +++ b/core/src/test/scala/org/broadinstitute/dsde/rawls/spendreporting/SpendReportingServiceSpec.scala @@ -1,21 +1,24 @@ package org.broadinstitute.dsde.rawls.spendreporting import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.model.headers.OAuth2BearerToken +import cats.effect.{IO, Resource} import com.google.cloud.PageImpl import com.google.cloud.bigquery.{Option => _, _} import org.broadinstitute.dsde.rawls.config.SpendReportingServiceConfig -import org.broadinstitute.dsde.rawls.dataaccess.slick.TestDriverComponent -import org.broadinstitute.dsde.rawls.dataaccess.{MockBigQueryServiceFactory, SamDAO, SlickDataSource} +import org.broadinstitute.dsde.rawls.dataaccess.{SamDAO, SlickDataSource} +import org.broadinstitute.dsde.rawls.model.Attributable.AttributeMap import org.broadinstitute.dsde.rawls.model._ import org.broadinstitute.dsde.rawls.util.MockitoTestUtils -import org.broadinstitute.dsde.rawls.{model, RawlsExceptionWithErrorReport} +import org.broadinstitute.dsde.rawls.{model, RawlsException, RawlsExceptionWithErrorReport, TestExecutionContext} +import org.broadinstitute.dsde.workbench.google2.GoogleBigQueryService import org.broadinstitute.dsde.workbench.model.google.GoogleProject import org.joda.time.DateTime +import org.mockito.ArgumentMatchers import org.mockito.ArgumentMatchers.{any, eq => mockitoEq} -import org.mockito.Mockito.{when, RETURNS_SMART_NULLS} +import org.mockito.Mockito.{doReturn, spy, when, RETURNS_SMART_NULLS} import org.scalatest.flatspec.AnyFlatSpecLike import org.scalatest.matchers.should.Matchers -import org.scalatestplus.mockito.MockitoSugar import java.util.UUID import scala.concurrent.duration.Duration @@ -23,13 +26,69 @@ import scala.concurrent.{Await, Future} import scala.jdk.CollectionConverters._ import scala.math.BigDecimal.RoundingMode -class SpendReportingServiceSpec - extends AnyFlatSpecLike - with TestDriverComponent - with MockitoSugar - with Matchers - with MockitoTestUtils { - object SpendReportingTestData { +class SpendReportingServiceSpec extends AnyFlatSpecLike with Matchers with MockitoTestUtils { + + implicit val executionContext: TestExecutionContext = TestExecutionContext.testExecutionContext + + val userInfo: UserInfo = UserInfo(RawlsUserEmail("owner-access"), + OAuth2BearerToken("token"), + 123, + RawlsUserSubjectId("123456789876543212345") + ) + val wsName: WorkspaceName = WorkspaceName("myNamespace", "myWorkspace") + + val billingAccountName: RawlsBillingAccountName = RawlsBillingAccountName("fakeBillingAcct") + + val billingProject: RawlsBillingProject = RawlsBillingProject(RawlsBillingProjectName(wsName.namespace), + CreationStatuses.Ready, + Option(billingAccountName), + None + ) + + val testContext: RawlsRequestContext = RawlsRequestContext(userInfo) + object TestData { + val workspaceGoogleProject1 = "project1" + val workspaceGoogleProject2 = "project2" + val workspace1: Workspace = workspace("workspace1", GoogleProjectId(workspaceGoogleProject1)) + val workspace2: Workspace = workspace("workspace2", GoogleProjectId(workspaceGoogleProject2)) + + val googleProjectsToWorkspaceNames: Map[String, WorkspaceName] = Map( + workspaceGoogleProject1 -> workspace1.toWorkspaceName, + workspaceGoogleProject2 -> workspace2.toWorkspaceName + ) + + def workspace( + name: String, + googleProjectId: GoogleProjectId, + version: WorkspaceVersions.WorkspaceVersion = WorkspaceVersions.V2, + namespace: String = RawlsBillingProjectName(wsName.namespace).value, + workspaceId: String = UUID.randomUUID().toString, + bucketName: String = "bucketName", + workflowCollectionName: Option[String] = None, + attributes: AttributeMap = Map.empty, + googleProjectNumber: Option[GoogleProjectNumber] = None, + currentBillingAccountOnGoogleProject: Option[RawlsBillingAccountName] = None, + billingAccountErrorMessage: Option[String] = None + ): Workspace = model.Workspace( + namespace, + name, + workspaceId, + bucketName, + workflowCollectionName, + DateTime.now, + DateTime.now, + "creator", + attributes, + isLocked = false, + workspaceVersion = version, + googleProjectId, + googleProjectNumber, + currentBillingAccountOnGoogleProject, + billingAccountErrorMessage, + None, + WorkspaceType.RawlsWorkspace + ) + object Daily { val firstRowCost = 2.4 val secondRowCost = 0.10111 @@ -58,47 +117,6 @@ class SpendReportingServiceSpec } object Workspace { - val workspaceGoogleProject1 = "project1" - val workspaceGoogleProject2 = "project2" - val workspace1: Workspace = model.Workspace( - testData.billingProject.projectName.value, - "workspace1", - UUID.randomUUID().toString, - "bucketName", - None, - DateTime.now, - DateTime.now, - "creator", - Map.empty, - isLocked = false, - WorkspaceVersions.V2, - GoogleProjectId(workspaceGoogleProject1), - None, - None, - None, - None, - WorkspaceType.RawlsWorkspace - ) - val workspace2: Workspace = model.Workspace( - testData.billingProject.projectName.value, - "workspace2", - UUID.randomUUID().toString, - "bucketName", - None, - DateTime.now, - DateTime.now, - "creator", - Map.empty, - isLocked = false, - WorkspaceVersions.V2, - GoogleProjectId(workspaceGoogleProject2), - None, - None, - None, - None, - WorkspaceType.RawlsWorkspace - ) - val firstRowCost = 100.582 val secondRowCost = 0.10111 val firstRowCostRounded: BigDecimal = BigDecimal(firstRowCost).setScale(2, RoundingMode.HALF_EVEN) @@ -127,11 +145,11 @@ class SpendReportingServiceSpec val otherRowCost = 204.1025 val computeRowCost = 50.20 val storageRowCost = 2.5 - val otherRowCostRounded: BigDecimal = BigDecimal(otherRowCost).setScale(2, RoundingMode.HALF_EVEN) - val computeRowCostRounded: BigDecimal = BigDecimal(computeRowCost).setScale(2, RoundingMode.HALF_EVEN) - val storageRowCostRounded: BigDecimal = BigDecimal(storageRowCost).setScale(2, RoundingMode.HALF_EVEN) - val totalCostRounded: BigDecimal = - BigDecimal(otherRowCost + computeRowCost + storageRowCost).setScale(2, RoundingMode.HALF_EVEN) + val otherRowCostRounded: String = BigDecimal(otherRowCost).setScale(2, RoundingMode.HALF_EVEN).toString + val computeRowCostRounded: String = BigDecimal(computeRowCost).setScale(2, RoundingMode.HALF_EVEN).toString + val storageRowCostRounded: String = BigDecimal(storageRowCost).setScale(2, RoundingMode.HALF_EVEN).toString + val totalCostRounded: String = + BigDecimal(otherRowCost + computeRowCost + storageRowCost).setScale(2, RoundingMode.HALF_EVEN).toString val table: List[Map[String, String]] = List( Map( @@ -157,46 +175,6 @@ class SpendReportingServiceSpec } object SubAggregation { - val workspaceGoogleProject1 = "project1" - val workspaceGoogleProject2 = "project2" - val workspace1: Workspace = model.Workspace( - testData.billingProject.projectName.value, - "workspace1", - UUID.randomUUID().toString, - "bucketName", - None, - DateTime.now, - DateTime.now, - "creator", - Map.empty, - isLocked = false, - WorkspaceVersions.V2, - GoogleProjectId(workspaceGoogleProject1), - None, - None, - None, - None, - WorkspaceType.RawlsWorkspace - ) - val workspace2: Workspace = model.Workspace( - testData.billingProject.projectName.value, - "workspace2", - UUID.randomUUID().toString, - "bucketName", - None, - DateTime.now, - DateTime.now, - "creator", - Map.empty, - isLocked = false, - WorkspaceVersions.V2, - GoogleProjectId(workspaceGoogleProject2), - None, - None, - None, - None, - WorkspaceType.RawlsWorkspace - ) val workspace1OtherRowCost = 204.1025 val workspace1ComputeRowCost = 50.20 @@ -212,19 +190,19 @@ class SpendReportingServiceSpec val workspace2OtherRowCostRounded: BigDecimal = BigDecimal(workspace2OtherRowCost).setScale(2, RoundingMode.HALF_EVEN) - val otherTotalCostRounded: BigDecimal = - BigDecimal(workspace1OtherRowCost + workspace2OtherRowCost).setScale(2, RoundingMode.HALF_EVEN) - val storageTotalCostRounded: BigDecimal = workspace2StorageRowCostRounded - val computeTotalCostRounded: BigDecimal = workspace1ComputeRowCostRounded + val otherTotalCostRounded: String = + BigDecimal(workspace1OtherRowCost + workspace2OtherRowCost).setScale(2, RoundingMode.HALF_EVEN).toString() + val storageTotalCostRounded: String = workspace2StorageRowCostRounded.toString + val computeTotalCostRounded: String = workspace1ComputeRowCostRounded.toString val workspace1TotalCostRounded: BigDecimal = BigDecimal(workspace1OtherRowCost + workspace1ComputeRowCost).setScale(2, RoundingMode.HALF_EVEN) val workspace2TotalCostRounded: BigDecimal = BigDecimal(workspace2StorageRowCost + workspace2OtherRowCost).setScale(2, RoundingMode.HALF_EVEN) - val totalCostRounded: BigDecimal = BigDecimal( + val totalCostRounded: String = BigDecimal( workspace1OtherRowCost + workspace1ComputeRowCost + workspace2StorageRowCost + workspace2OtherRowCost - ).setScale(2, RoundingMode.HALF_EVEN) + ).setScale(2, RoundingMode.HALF_EVEN).toString() val table: List[Map[String, String]] = List( Map( @@ -261,131 +239,57 @@ class SpendReportingServiceSpec } } - def createTableResult(values: List[Map[String, String]]): TableResult = { - val rawSchemas: List[Set[String]] = values.map(_.keySet) - val rawFields: List[String] = if (rawSchemas.nonEmpty) { - rawSchemas.reduce { (x, y) => - if (x.equals(y)) { - x - } else { - fail(s"inconsistent schema found when comparing rows $x and $y") - } - }.toList - } else { - List.empty - } - val fields: List[Field] = rawFields.map { field => - Field.of(field, StandardSQLTypeName.STRING) - } - val schema: Schema = Schema.of(fields: _*) - - val fieldValues: List[List[FieldValue]] = values.map { row => - row.values.toList.map { value => + def createTableResult(data: List[Map[String, String]]): TableResult = { + val fields = data.flatMap(_.keySet).distinct.map(field => Field.of(field, StandardSQLTypeName.STRING)) + val values: List[FieldValueList] = data.map { row => + val rowValues = row.values.toList.map { value => FieldValue.of(FieldValue.Attribute.PRIMITIVE, value) } + FieldValueList.of(rowValues.asJava, fields: _*) } - val fieldValueLists: List[FieldValueList] = fieldValues.map { row => - FieldValueList.of(row.asJava, fields: _*) - } - val page: PageImpl[FieldValueList] = new PageImpl[FieldValueList](null, null, fieldValueLists.asJava) - - new TableResult(schema, fieldValueLists.length, page) + val page: PageImpl[FieldValueList] = new PageImpl[FieldValueList](null, null, values.asJava) + new TableResult(Schema.of(fields: _*), values.length, page) } val defaultServiceProject: GoogleProject = GoogleProject("project") val spendReportingServiceConfig: SpendReportingServiceConfig = SpendReportingServiceConfig( "fakeTable", "fakeTimePartitionColumn", - 90 + 90, + "test.rawls" ) - // Create Spend Reporting Service with Sam and BQ DAOs that mock happy-path responses and return SpendReportingTestData.Workspace.tableResult. Override Sam and BQ responses as needed - def createSpendReportingService( - dataSource: SlickDataSource, - samDAO: SamDAO = mock[SamDAO](RETURNS_SMART_NULLS), - tableResult: TableResult = SpendReportingTestData.Workspace.tableResult - ): SpendReportingService = { - when( - samDAO.userHasAction(SamResourceTypeNames.managedGroup, - "Alpha_Spend_Report_Users", - SamResourceAction("use"), - userInfo - ) + "SpendReportingService.extractSpendReportingResults" should "break down results from Google by day" in { + val reportingResults = SpendReportingService.extractSpendReportingResults( + TestData.Daily.tableResult.getValues.asScala.toList, + DateTime.now().minusDays(1), + DateTime.now(), + Map(), + Set(SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Daily)) ) - .thenReturn(Future.successful(true)) - when( - samDAO.userHasAction(mockitoEq(SamResourceTypeNames.billingProject), - any[String], - mockitoEq(SamBillingProjectActions.readSpendReport), - mockitoEq(userInfo) - ) - ) - .thenReturn(Future.successful(true)) - val mockServiceFactory = MockBigQueryServiceFactory.ioFactory(Right(tableResult)) - - new SpendReportingService(testContext, - dataSource, - mockServiceFactory.getServiceFromJson("json", defaultServiceProject), - samDAO, - spendReportingServiceConfig - ) - } - - "SpendReportingService" should "break down results from Google by day" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val service = createSpendReportingService(dataSource, tableResult = SpendReportingTestData.Daily.tableResult) - - val reportingResults = Await.result( - service.getSpendForBillingProject( - testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now(), - Set(SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Daily)) - ), - Duration.Inf - ) - reportingResults.spendSummary.cost shouldBe SpendReportingTestData.Daily.totalCostRounded.toString - val dailyAggregation = - reportingResults.spendDetails.headOption.getOrElse(fail("daily results not parsed correctly")) - dailyAggregation.aggregationKey shouldBe SpendReportingAggregationKeys.Daily - - dailyAggregation.spendData.map { spendForDay => - if ( - spendForDay.startTime - .getOrElse(fail("daily results not parsed correctly")) - .toLocalDate - .equals(SpendReportingTestData.Daily.firstRowDate.toLocalDate) - ) { - spendForDay.cost shouldBe SpendReportingTestData.Daily.firstRowCostRounded.toString - } else if ( - spendForDay.startTime - .getOrElse(fail("daily results not parsed correctly")) - .toLocalDate - .equals(SpendReportingTestData.Daily.secondRowDate.toLocalDate) - ) { - spendForDay.cost shouldBe SpendReportingTestData.Daily.secondRowCostRounded.toString - } else { - fail(s"unexpected day found in spend results - $spendForDay") - } + reportingResults.spendSummary.cost shouldBe TestData.Daily.totalCostRounded.toString + reportingResults.spendDetails.head.aggregationKey shouldBe SpendReportingAggregationKeys.Daily + reportingResults.spendDetails.head.spendData.foreach { spendForDay => + spendForDay.startTime match { + case Some(date) if date.toLocalDate.equals(TestData.Daily.firstRowDate.toLocalDate) => + spendForDay.cost shouldBe TestData.Daily.firstRowCostRounded.toString + case Some(date) if date.toLocalDate.equals(TestData.Daily.secondRowDate.toLocalDate) => + spendForDay.cost shouldBe TestData.Daily.secondRowCostRounded.toString + case _ => fail(s"unexpected day found in spend results - $spendForDay") } + } } - it should "break down results from Google by workspace" in withDefaultTestDatabase { dataSource: SlickDataSource => - val service = createSpendReportingService(dataSource, tableResult = SpendReportingTestData.Workspace.tableResult) - - runAndWait(dataSource.dataAccess.workspaceQuery.createOrUpdate(SpendReportingTestData.Workspace.workspace1)) - runAndWait(dataSource.dataAccess.workspaceQuery.createOrUpdate(SpendReportingTestData.Workspace.workspace2)) - - val reportingResults = Await.result( - service.getSpendForBillingProject( - testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now(), - Set(SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace)) - ), - Duration.Inf + it should "break down results from Google by workspace" in { + val reportingResults = SpendReportingService.extractSpendReportingResults( + TestData.Workspace.tableResult.getValues.asScala.toList, + DateTime.now().minusDays(1), + DateTime.now(), + TestData.googleProjectsToWorkspaceNames, + Set(SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace)) ) - reportingResults.spendSummary.cost shouldBe SpendReportingTestData.Workspace.totalCostRounded.toString + + reportingResults.spendSummary.cost shouldBe TestData.Workspace.totalCostRounded.toString val workspaceAggregation = reportingResults.spendDetails.headOption.getOrElse(fail("workspace results not parsed correctly")) @@ -395,10 +299,10 @@ class SpendReportingServiceSpec val workspaceGoogleProject = spendForWorkspace.googleProjectId.getOrElse(fail("workspace results not parsed correctly")).value - if (workspaceGoogleProject.equals(SpendReportingTestData.Workspace.workspace1.googleProjectId.value)) { - spendForWorkspace.cost shouldBe SpendReportingTestData.Workspace.firstRowCostRounded.toString - } else if (workspaceGoogleProject.equals(SpendReportingTestData.Workspace.workspace2.googleProjectId.value)) { - spendForWorkspace.cost shouldBe SpendReportingTestData.Workspace.secondRowCostRounded.toString + if (workspaceGoogleProject.equals(TestData.workspace1.googleProjectId.value)) { + spendForWorkspace.cost shouldBe TestData.Workspace.firstRowCostRounded.toString + } else if (workspaceGoogleProject.equals(TestData.workspace2.googleProjectId.value)) { + spendForWorkspace.cost shouldBe TestData.Workspace.secondRowCostRounded.toString } else { fail(s"unexpected workspace found in spend results - $spendForWorkspace") } @@ -411,101 +315,73 @@ class SpendReportingServiceSpec .workspace shouldBe defined } - it should "break down results from Google by Terra spend category" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val service = createSpendReportingService(dataSource, tableResult = SpendReportingTestData.Category.tableResult) - - val reportingResults = Await.result( - service.getSpendForBillingProject( - testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now(), - Set(SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Category)) - ), - Duration.Inf - ) - reportingResults.spendSummary.cost shouldBe SpendReportingTestData.Category.totalCostRounded.toString - val categoryAggregation = - reportingResults.spendDetails.headOption.getOrElse(fail("workspace results not parsed correctly")) - - categoryAggregation.aggregationKey shouldBe SpendReportingAggregationKeys.Category + it should "break down results from Google by Terra spend category" in { + val reportingResults = SpendReportingService.extractSpendReportingResults( + TestData.Category.tableResult.getValues.asScala.toList, + DateTime.now().minusDays(1), + DateTime.now(), + Map(), + Set(SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Category)) + ) - verifyCategoryAggregation( - categoryAggregation, - expectedComputeCost = SpendReportingTestData.Category.computeRowCostRounded, - expectedStorageCost = SpendReportingTestData.Category.storageRowCostRounded, - expectedOtherCost = SpendReportingTestData.Category.otherRowCostRounded - ) + reportingResults.spendSummary.cost shouldBe TestData.Category.totalCostRounded + val categoryAggregation = reportingResults.spendDetails.headOption.get + categoryAggregation.aggregationKey shouldBe SpendReportingAggregationKeys.Category + verifyCategoryAggregation( + categoryAggregation, + expectedCompute = TestData.Category.computeRowCostRounded, + expectedStorage = TestData.Category.storageRowCostRounded, + expectedOther = TestData.Category.otherRowCostRounded + ) } - it should "return summary data only if aggregation key is omitted" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val service = createSpendReportingService(dataSource, tableResult = SpendReportingTestData.Workspace.tableResult) - - val reportingResults = Await.result(service.getSpendForBillingProject(testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now(), - Set.empty - ), - Duration.Inf - ) - reportingResults.spendSummary.cost shouldBe SpendReportingTestData.Workspace.totalCostRounded.toString - reportingResults.spendDetails shouldBe empty + it should "should return only summary data if aggregation keys are omitted" in { + val reportingResults = SpendReportingService.extractSpendReportingResults( + TestData.Workspace.tableResult.getValues.asScala.toList, + DateTime.now().minusDays(1), + DateTime.now(), + Map(), + Set.empty + ) + reportingResults.spendSummary.cost shouldBe TestData.Workspace.totalCostRounded.toString + reportingResults.spendDetails shouldBe empty } - it should "support sub-aggregations" in withDefaultTestDatabase { dataSource: SlickDataSource => - val service = - createSpendReportingService(dataSource, tableResult = SpendReportingTestData.SubAggregation.tableResult) - - runAndWait(dataSource.dataAccess.workspaceQuery.createOrUpdate(SpendReportingTestData.Workspace.workspace1)) - runAndWait(dataSource.dataAccess.workspaceQuery.createOrUpdate(SpendReportingTestData.Workspace.workspace2)) - - val reportingResults = Await.result( - service.getSpendForBillingProject( - testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now(), - Set( - SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace, - Option(SpendReportingAggregationKeys.Category) - ) + it should "support sub-aggregations" in { + val reportingResults = SpendReportingService.extractSpendReportingResults( + TestData.SubAggregation.tableResult.getValues.asScala.toList, + DateTime.now().minusDays(1), + DateTime.now(), + TestData.googleProjectsToWorkspaceNames, + Set( + SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace, + Option(SpendReportingAggregationKeys.Category) ) - ), - Duration.Inf + ) ) - val topLevelAggregation = - reportingResults.spendDetails.headOption.getOrElse(fail("spend results not parsed correctly")) - withClue("total cost was incorrect") { - reportingResults.spendSummary.cost shouldBe SpendReportingTestData.SubAggregation.totalCostRounded.toString - } - verifyWorkspaceCategorySubAggregation(topLevelAggregation) + reportingResults.spendSummary.cost shouldBe TestData.SubAggregation.totalCostRounded + verifyWorkspaceCategorySubAggregation(reportingResults.spendDetails.headOption.get) } - it should "support multiple aggregations" in withDefaultTestDatabase { dataSource: SlickDataSource => - val service = - createSpendReportingService(dataSource, tableResult = SpendReportingTestData.SubAggregation.tableResult) - - runAndWait(dataSource.dataAccess.workspaceQuery.createOrUpdate(SpendReportingTestData.Workspace.workspace1)) - runAndWait(dataSource.dataAccess.workspaceQuery.createOrUpdate(SpendReportingTestData.Workspace.workspace2)) - - val reportingResults = Await.result( - service.getSpendForBillingProject( - testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now(), - Set( - SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace, - Option(SpendReportingAggregationKeys.Category) - ), - SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Category) - ) + it should "support multiple aggregations" in { + val reportingResults = SpendReportingService.extractSpendReportingResults( + TestData.SubAggregation.tableResult.getValues.asScala.toList, + DateTime.now().minusDays(1), + DateTime.now(), + Map( + TestData.workspace1.googleProjectId.value -> TestData.workspace1.toWorkspaceName, + TestData.workspace2.googleProjectId.value -> TestData.workspace2.toWorkspaceName ), - Duration.Inf + Set( + SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace, + Option(SpendReportingAggregationKeys.Category) + ), + SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Category) + ) ) - withClue("total cost was incorrect") { - reportingResults.spendSummary.cost shouldBe SpendReportingTestData.SubAggregation.totalCostRounded.toString - } + + reportingResults.spendSummary.cost shouldBe TestData.SubAggregation.totalCostRounded reportingResults.spendDetails.map { case workspaceAggregation @ SpendReportingAggregation(SpendReportingAggregationKeys.Workspace, _) => @@ -513,86 +389,66 @@ class SpendReportingServiceSpec case categoryAggregation @ SpendReportingAggregation(SpendReportingAggregationKeys.Category, _) => verifyCategoryAggregation( categoryAggregation, - expectedComputeCost = SpendReportingTestData.SubAggregation.computeTotalCostRounded, - expectedStorageCost = SpendReportingTestData.SubAggregation.storageTotalCostRounded, - expectedOtherCost = SpendReportingTestData.SubAggregation.otherTotalCostRounded + expectedCompute = TestData.SubAggregation.computeTotalCostRounded, + expectedStorage = TestData.SubAggregation.storageTotalCostRounded, + expectedOther = TestData.SubAggregation.otherTotalCostRounded ) case _ => fail("unexpected aggregation key found") } } - private def verifyCategoryAggregation(categoryAggregation: SpendReportingAggregation, - expectedComputeCost: BigDecimal, - expectedStorageCost: BigDecimal, - expectedOtherCost: BigDecimal - ) = - categoryAggregation.spendData.map { spendDataForCategory => + def verifyCategoryAggregation( + aggregation: SpendReportingAggregation, + expectedCompute: String, + expectedStorage: String, + expectedOther: String + ): Unit = + aggregation.spendData.foreach { spendDataForCategory => val category = spendDataForCategory.category.getOrElse(fail("results not parsed correctly")) withClue(s"total $category cost was incorrect") { if (category.equals(TerraSpendCategories.Compute)) { - spendDataForCategory.cost shouldBe expectedComputeCost.toString + spendDataForCategory.cost shouldBe expectedCompute } else if (category.equals(TerraSpendCategories.Storage)) { - spendDataForCategory.cost shouldBe expectedStorageCost.toString + spendDataForCategory.cost shouldBe expectedStorage } else if (category.equals(TerraSpendCategories.Other)) { - spendDataForCategory.cost shouldBe expectedOtherCost.toString + spendDataForCategory.cost shouldBe expectedOther } else { fail(s"unexpected category found in spend results - $spendDataForCategory") } } } - private def verifyWorkspaceCategorySubAggregation(topLevelAggregation: SpendReportingAggregation) = { + def verifyWorkspaceCategorySubAggregation(topLevelAggregation: SpendReportingAggregation): Unit = { topLevelAggregation.aggregationKey shouldBe SpendReportingAggregationKeys.Workspace - topLevelAggregation.spendData.map { spendData => - val workspaceGoogleProject = spendData.googleProjectId.getOrElse(fail("spend results not parsed correctly")).value - val subAggregation = spendData.subAggregation.getOrElse(fail("spend results not parsed correctly")) + topLevelAggregation.spendData.foreach { spendData => + val workspaceGoogleProject = spendData.googleProjectId.get.value + val subAggregation = spendData.subAggregation.get subAggregation.aggregationKey shouldBe SpendReportingAggregationKeys.Category - if (workspaceGoogleProject.equals(SpendReportingTestData.SubAggregation.workspace1.googleProjectId.value)) { - withClue(s"cost for ${SpendReportingTestData.SubAggregation.workspace1.toWorkspaceName} was incorrect") { - spendData.cost shouldBe SpendReportingTestData.SubAggregation.workspace1TotalCostRounded.toString - } - subAggregation.spendData.map { subAggregatedSpendData => - if (subAggregatedSpendData.category == Option(TerraSpendCategories.Compute)) { - withClue( - s"${subAggregatedSpendData.category} category cost for ${SpendReportingTestData.SubAggregation.workspace1.toWorkspaceName} was incorrect" - ) { - subAggregatedSpendData.cost shouldBe SpendReportingTestData.SubAggregation.workspace1ComputeRowCostRounded.toString - } - } else if (subAggregatedSpendData.category == Option(TerraSpendCategories.Other)) { - withClue( - s"${subAggregatedSpendData.category} category cost for ${SpendReportingTestData.SubAggregation.workspace1.toWorkspaceName} was incorrect" - ) { - subAggregatedSpendData.cost shouldBe SpendReportingTestData.SubAggregation.workspace1OtherRowCostRounded.toString - } - } else { - fail(s"unexpected category found in spend results - $subAggregatedSpendData") + if (workspaceGoogleProject.equals(TestData.workspace1.googleProjectId.value)) { + spendData.cost shouldBe TestData.SubAggregation.workspace1TotalCostRounded.toString + subAggregation.spendData.foreach { data => + data.category match { + case Some(TerraSpendCategories.Compute) => + data.cost shouldBe TestData.SubAggregation.workspace1ComputeRowCostRounded.toString + case Some(TerraSpendCategories.Other) => + data.cost shouldBe TestData.SubAggregation.workspace1OtherRowCostRounded.toString + case _ => fail(s"unexpected category found in spend results - $data") } } - } else if ( - workspaceGoogleProject.equals(SpendReportingTestData.SubAggregation.workspace2.googleProjectId.value) - ) { - withClue(s"cost for ${SpendReportingTestData.SubAggregation.workspace2.toWorkspaceName} was incorrect") { - spendData.cost shouldBe SpendReportingTestData.SubAggregation.workspace2TotalCostRounded.toString - } - subAggregation.spendData.map { subAggregatedSpendData => - if (subAggregatedSpendData.category == Option(TerraSpendCategories.Storage)) { - withClue( - s"${subAggregatedSpendData.category} category cost for ${SpendReportingTestData.SubAggregation.workspace2.toWorkspaceName} was incorrect" - ) { - subAggregatedSpendData.cost shouldBe SpendReportingTestData.SubAggregation.workspace2StorageRowCostRounded.toString - } - } else if (subAggregatedSpendData.category == Option(TerraSpendCategories.Other)) { - withClue( - s"${subAggregatedSpendData.category} category cost for ${SpendReportingTestData.SubAggregation.workspace2.toWorkspaceName} was incorrect" - ) { - subAggregatedSpendData.cost shouldBe SpendReportingTestData.SubAggregation.workspace2OtherRowCostRounded.toString - } - } else { - fail(s"unexpected category found in spend results - $subAggregatedSpendData") + } else if (workspaceGoogleProject.equals(TestData.workspace2.googleProjectId.value)) { + spendData.cost shouldBe TestData.SubAggregation.workspace2TotalCostRounded.toString + subAggregation.spendData.foreach { data => + data.category match { + case Some(TerraSpendCategories.Storage) => + data.cost shouldBe TestData.SubAggregation.workspace2StorageRowCostRounded.toString + case Some(TerraSpendCategories.Other) => + data.cost shouldBe TestData.SubAggregation.workspace2OtherRowCostRounded.toString + case Some(_) => fail(s"unexpected category found in spend results - $data") + case None => fail(s"no category found in spend results - $data") } } } else { @@ -601,231 +457,330 @@ class SpendReportingServiceSpec } } - it should "throw an exception when BQ returns zero rows" in withDefaultTestDatabase { dataSource: SlickDataSource => - val emptyTableResult = createTableResult(List[Map[String, String]]()) - val service = createSpendReportingService(dataSource, tableResult = emptyTableResult) + it should "throw an exception if the query result contains multiple kinds of currencies" in { + val table = createTableResult( + List( + Map("cost" -> "0.10111", "credits" -> "0.0", "currency" -> "CAD", "date" -> DateTime.now().toString), + Map("cost" -> "0.10111", "credits" -> "0.0", "currency" -> "USD", "date" -> DateTime.now().toString) + ) + ).getValues.asScala.toList val e = intercept[RawlsExceptionWithErrorReport] { - Await.result(service.getSpendForBillingProject(testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now() - ), - Duration.Inf + SpendReportingService.extractSpendReportingResults( + table, + DateTime.now().minusDays(1), + DateTime.now(), + Map(), + Set(SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Daily)) ) } - e.errorReport.statusCode shouldBe Option(StatusCodes.NotFound) + e.errorReport.statusCode shouldBe Option(StatusCodes.BadGateway) } - it should "throw an exception when user does not have read_spend_report" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val samDAO = mock[SamDAO](RETURNS_SMART_NULLS) - val service = createSpendReportingService(dataSource, samDAO = samDAO) + "getSpendForBillingProject" should "throw an exception when BQ returns zero rows" in { + val samDAO = mock[SamDAO] + when(samDAO.userHasAction(any(), any(), any(), any())).thenReturn(Future.successful(true)) + + val bigQueryService = mock[GoogleBigQueryService[IO]](RETURNS_SMART_NULLS) + when(bigQueryService.query(any(), any[BigQuery.JobOption]())) + .thenReturn(IO(createTableResult(List[Map[String, String]]()))) + val service = spy( + new SpendReportingService( + testContext, + mock[SlickDataSource], + Resource.pure[IO, GoogleBigQueryService[IO]](bigQueryService), + samDAO, + spendReportingServiceConfig + ) + ) + val billingProject = BillingProjectSpendExport(RawlsBillingProjectName(""), RawlsBillingAccountName(""), None) + doReturn(Future.successful(billingProject)).when(service).getSpendExportConfiguration(any()) + doReturn(Future.successful(TestData.googleProjectsToWorkspaceNames)).when(service).getWorkspaceGoogleProjects(any()) - when( - samDAO.userHasAction(mockitoEq(SamResourceTypeNames.billingProject), - any[String], - mockitoEq(SamBillingProjectActions.readSpendReport), - mockitoEq(userInfo) - ) + val e = intercept[RawlsExceptionWithErrorReport] { + Await.result( + service.getSpendForBillingProject( + RawlsBillingProjectName(""), + DateTime.now().minusDays(1), + DateTime.now() + ), + Duration.Inf ) - .thenReturn(Future.successful(false)) - - val e = intercept[RawlsExceptionWithErrorReport] { - Await.result(service.getSpendForBillingProject(testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now() - ), - Duration.Inf - ) - } - e.errorReport.statusCode shouldBe Option(StatusCodes.Forbidden) + } + e.errorReport.statusCode shouldBe Option(StatusCodes.NotFound) } - it should "throw an exception when user is not in alpha group" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val samDAO = mock[SamDAO](RETURNS_SMART_NULLS) - val service = createSpendReportingService(dataSource, samDAO = samDAO) + it should "throw an exception when user does not have read_spend_report" in { + val samDAO = mock[SamDAO](RETURNS_SMART_NULLS) + when(samDAO.userHasAction(any(), any(), any(), any())).thenReturn(Future.successful(true)) + when( + samDAO.userHasAction( + mockitoEq(SamResourceTypeNames.billingProject), + any(), + mockitoEq(SamBillingProjectActions.readSpendReport), + mockitoEq(userInfo) + ) + ).thenReturn(Future.successful(false)) + val service = new SpendReportingService( + testContext, + mock[SlickDataSource], + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + samDAO, + spendReportingServiceConfig + ) - when( - samDAO.userHasAction(SamResourceTypeNames.managedGroup, - "Alpha_Spend_Report_Users", - SamResourceAction("use"), - userInfo - ) + val e = intercept[RawlsExceptionWithErrorReport] { + Await.result( + service.getSpendForBillingProject( + billingProject.projectName, + DateTime.now().minusDays(1), + DateTime.now() + ), + Duration.Inf ) - .thenReturn(Future.successful(false)) - - val e = intercept[RawlsExceptionWithErrorReport] { - Await.result(service.getSpendForBillingProject(testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now() - ), - Duration.Inf - ) - } - e.errorReport.statusCode shouldBe Option(StatusCodes.Forbidden) + } + e.errorReport.statusCode shouldBe Option(StatusCodes.Forbidden) } - it should "throw an exception when start date is after end date" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val service = createSpendReportingService(dataSource) + it should "throw an exception when user is not in alpha group" in { + val samDAO = mock[SamDAO] + when(samDAO.userHasAction(any(), any(), any(), any())).thenReturn(Future.successful(true)) + when( + samDAO.userHasAction( + ArgumentMatchers.eq(SamResourceTypeNames.managedGroup), + ArgumentMatchers.eq("Alpha_Spend_Report_Users"), + ArgumentMatchers.eq(SamResourceAction("use")), + ArgumentMatchers.eq(testContext.userInfo) + ) + ).thenReturn(Future.successful(false)) + + val service = new SpendReportingService( + testContext, + mock[SlickDataSource], + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + samDAO, + spendReportingServiceConfig + ) - val e = intercept[RawlsExceptionWithErrorReport] { - Await.result(service.getSpendForBillingProject(testData.billingProject.projectName, - startDate = DateTime.now(), - endDate = DateTime.now().minusDays(1) - ), - Duration.Inf - ) - } - e.errorReport.statusCode shouldBe Option(StatusCodes.BadRequest) + val e = intercept[RawlsExceptionWithErrorReport] { + Await.result( + service.getSpendForBillingProject(billingProject.projectName, DateTime.now().minusDays(1), DateTime.now()), + Duration.Inf + ) + fail("action was run without an exception being thrown") + } + e.errorReport.statusCode shouldBe Option(StatusCodes.Forbidden) } - it should s"throw an exception when date range is larger than ${spendReportingServiceConfig.maxDateRange} days" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val service = createSpendReportingService(dataSource) - - val e = intercept[RawlsExceptionWithErrorReport] { - Await.result( - service.getSpendForBillingProject(testData.billingProject.projectName, - startDate = - DateTime.now().minusDays(spendReportingServiceConfig.maxDateRange + 1), - endDate = DateTime.now() - ), - Duration.Inf - ) - } - e.errorReport.statusCode shouldBe Option(StatusCodes.BadRequest) - } + it should "throw an exception if the billing project cannot be found" in { + val samDAO = mock[SamDAO] + when(samDAO.userHasAction(any(), any(), any(), any())).thenReturn(Future.successful(true)) + val dataSource = mock[SlickDataSource] + when(dataSource.inTransaction[Option[BillingProjectSpendExport]](any(), any())).thenReturn(Future.successful(None)) + val service = new SpendReportingService( + testContext, + dataSource, + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + samDAO, + spendReportingServiceConfig + ) - it should "throw an exception if the billing project cannot be found" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val service = createSpendReportingService(dataSource) + val e = intercept[RawlsExceptionWithErrorReport] { + Await.result(service.getSpendExportConfiguration(RawlsBillingProjectName("fakeProject")), Duration.Inf) + } - val e = intercept[RawlsExceptionWithErrorReport] { - Await.result(service.getSpendForBillingProject(RawlsBillingProjectName("fakeProject"), - DateTime.now().minusDays(1), - DateTime.now() - ), - Duration.Inf - ) - } - e.errorReport.statusCode shouldBe Option(StatusCodes.NotFound) + e.errorReport.statusCode shouldBe Option(StatusCodes.NotFound) + e.errorReport.message shouldBe s"billing project fakeProject not found" } - it should "throw an exception if the billing project does not have a linked billing account" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val service = createSpendReportingService(dataSource) - val projectName = RawlsBillingProjectName("fakeProject") - runAndWait( - dataSource.dataAccess.rawlsBillingProjectQuery.create( - RawlsBillingProject(projectName, CreationStatuses.Ready, billingAccount = None, None) - ) - ) + it should "throw an exception if the billing project does not have a linked billing account" in { + val samDAO = mock[SamDAO] + when(samDAO.userHasAction(any(), any(), any(), any())).thenReturn(Future.successful(true)) + val dataSource = mock[SlickDataSource] + when(dataSource.inTransaction[Option[BillingProjectSpendExport]](any(), any())) + .thenReturn(Future.failed(new RawlsException())) + val service = new SpendReportingService( + testContext, + dataSource, + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + samDAO, + spendReportingServiceConfig + ) + val projectName = RawlsBillingProjectName("fakeProject") - val e = intercept[RawlsExceptionWithErrorReport] { - Await.result(service.getSpendForBillingProject(projectName, DateTime.now().minusDays(1), DateTime.now()), - Duration.Inf - ) - } - e.errorReport.statusCode shouldBe Option(StatusCodes.BadRequest) + val e = intercept[RawlsExceptionWithErrorReport] { + Await.result(service.getSpendExportConfiguration(projectName), Duration.Inf) + } + + e.errorReport.statusCode shouldBe Option(StatusCodes.BadRequest) + e.errorReport.message shouldBe s"billing account not found on billing project ${projectName.value}" } - it should "throw an exception if BigQuery returns multiple kinds of currencies" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val cadRow = Map( - "cost" -> "0.10111", - "credits" -> "0.0", - "currency" -> "CAD", - "date" -> DateTime.now().toString + it should "throw an exception if BigQuery results include an unexpected Google project" in { + val samDAO = mock[SamDAO] + when(samDAO.userHasAction(any(), any(), any(), any())).thenReturn(Future.successful(true)) + val badRow = Map( + "cost" -> "0.10111", + "credits" -> "0.0", + "currency" -> "USD", + "googleProjectId" -> "fakeProject" + ) + + val badTable = createTableResult(badRow :: TestData.Workspace.table) + + val bigQueryService = mock[GoogleBigQueryService[IO]](RETURNS_SMART_NULLS) + when(bigQueryService.query(any(), any[BigQuery.JobOption]())).thenReturn(IO(badTable)) + val service = spy( + new SpendReportingService( + testContext, + mock[SlickDataSource], + Resource.pure[IO, GoogleBigQueryService[IO]](bigQueryService), + samDAO, + spendReportingServiceConfig ) + ) + val billingProject = BillingProjectSpendExport(RawlsBillingProjectName(""), RawlsBillingAccountName(""), None) + doReturn(Future.successful(billingProject)).when(service).getSpendExportConfiguration(any()) + doReturn(Future.successful(TestData.googleProjectsToWorkspaceNames)).when(service).getWorkspaceGoogleProjects(any()) - val internationalTable = createTableResult(cadRow :: SpendReportingTestData.Daily.table) + val e = intercept[RawlsExceptionWithErrorReport] { + Await.result( + service.getSpendForBillingProject( + RawlsBillingProjectName(""), + DateTime.now().minusDays(1), + DateTime.now(), + Set(SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace)) + ), + Duration.Inf + ) + } + e.errorReport.statusCode shouldBe Option(StatusCodes.BadGateway) + } - val service = createSpendReportingService(dataSource, tableResult = internationalTable) + "validateReportParameters" should "not throw an exception when validating max start and end date range" in { + val service = new SpendReportingService( + testContext, + mock[SlickDataSource], + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + mock[SamDAO], + spendReportingServiceConfig + ) + val startDate = DateTime.now().minusDays(spendReportingServiceConfig.maxDateRange) + val endDate = DateTime.now() + service.validateReportParameters(startDate, endDate) + } - val e = intercept[RawlsExceptionWithErrorReport] { - Await.result(service.getSpendForBillingProject(testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now() - ), - Duration.Inf - ) - } - e.errorReport.statusCode shouldBe Option(StatusCodes.BadGateway) + it should "throw an exception when start date is after end date" in { + val service = new SpendReportingService( + testContext, + mock[SlickDataSource], + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + mock[SamDAO], + spendReportingServiceConfig + ) + val startDate = DateTime.now() + val endDate = DateTime.now().minusDays(1) + val e = intercept[RawlsExceptionWithErrorReport](service.validateReportParameters(startDate, endDate)) + e.errorReport.statusCode shouldBe Option(StatusCodes.BadRequest) } - it should "throw an exception if BigQuery results include an unexpected Google project" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val badRow = Map( - "cost" -> "0.10111", - "credits" -> "0.0", - "currency" -> "USD", - "googleProjectId" -> "fakeProject" - ) + it should "throw an exception when date range is larger than the max date range" in { + val service = new SpendReportingService( + testContext, + mock[SlickDataSource], + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + mock[SamDAO], + spendReportingServiceConfig + ) + val startDate = DateTime.now().minusDays(spendReportingServiceConfig.maxDateRange + 1) + val endDate = DateTime.now() + val e = intercept[RawlsExceptionWithErrorReport](service.validateReportParameters(startDate, endDate)) + e.errorReport.statusCode shouldBe Option(StatusCodes.BadRequest) + } - val badTable = createTableResult(badRow :: SpendReportingTestData.Workspace.table) - - val service = createSpendReportingService(dataSource, tableResult = badTable) - runAndWait(dataSource.dataAccess.workspaceQuery.createOrUpdate(SpendReportingTestData.Workspace.workspace1)) - runAndWait(dataSource.dataAccess.workspaceQuery.createOrUpdate(SpendReportingTestData.Workspace.workspace2)) - - val e = intercept[RawlsExceptionWithErrorReport] { - Await.result( - service.getSpendForBillingProject( - testData.billingProject.projectName, - DateTime.now().minusDays(1), - DateTime.now(), - Set(SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace)) - ), - Duration.Inf - ) - } - e.errorReport.statusCode shouldBe Option(StatusCodes.BadGateway) + "getQuery" should "use the constant _PARTITIONTIME as the time partition name for non-broad tables" in { + val expectedQuery = + s""" + | SELECT + | SUM(cost) as cost, + | SUM(IFNULL((SELECT SUM(c.amount) FROM UNNEST(credits) c), 0)) as credits, + | currency , project.id as googleProjectId, DATE(_PARTITIONTIME) as date + | FROM `NonBroadTable` + | WHERE billing_account_id = @billingAccountId + | AND _PARTITIONTIME BETWEEN @startDate AND @endDate + | AND project.id in UNNEST(@projects) + | GROUP BY currency , googleProjectId, date + |""".stripMargin + + val service = new SpendReportingService( + testContext, + mock[SlickDataSource], + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + mock[SamDAO], + spendReportingServiceConfig + ) + val result = service.getQuery( + Set( + SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace), + SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Daily) + ), + BillingProjectSpendExport(RawlsBillingProjectName(""), RawlsBillingAccountName(""), Some("NonBroadTable")) + ) + result shouldBe expectedQuery } - it should "use the custom time partition column name if specified" in withDefaultTestDatabase { - dataSource: SlickDataSource => - val expectedNoCustom = - s""" - | SELECT - | SUM(cost) as cost, - | SUM(IFNULL((SELECT SUM(c.amount) FROM UNNEST(credits) c), 0)) as credits, - | currency , project.id as googleProjectId, DATE(_PARTITIONTIME) as date - | FROM `fakeTable` - | WHERE billing_account_id = @billingAccountId - | AND _PARTITIONTIME BETWEEN @startDate AND @endDate - | AND project.id in UNNEST(@projects) - | GROUP BY currency , googleProjectId, date - |""".stripMargin - assertResult(expectedNoCustom) { - val service = createSpendReportingService(dataSource) - service.getQuery( - Set(SpendReportingAggregationKeys.Workspace, SpendReportingAggregationKeys.Daily), - "fakeTable", - None - ) - } + it should "use the configured default time partition name for the broad table" in { + val expectedQuery = + s""" + | SELECT + | SUM(cost) as cost, + | SUM(IFNULL((SELECT SUM(c.amount) FROM UNNEST(credits) c), 0)) as credits, + | currency , project.id as googleProjectId, DATE(fakeTimePartitionColumn) as date + | FROM `fakeTable` + | WHERE billing_account_id = @billingAccountId + | AND fakeTimePartitionColumn BETWEEN @startDate AND @endDate + | AND project.id in UNNEST(@projects) + | GROUP BY currency , googleProjectId, date + |""".stripMargin + + val service = new SpendReportingService( + testContext, + mock[SlickDataSource], + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + mock[SamDAO], + spendReportingServiceConfig + ) + val result = service.getQuery( + Set( + SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Workspace), + SpendReportingAggregationKeyWithSub(SpendReportingAggregationKeys.Daily) + ), + BillingProjectSpendExport(RawlsBillingProjectName(""), RawlsBillingAccountName(""), None) + ) + result shouldBe expectedQuery + } - val expectedCustom = - s""" - | SELECT - | SUM(cost) as cost, - | SUM(IFNULL((SELECT SUM(c.amount) FROM UNNEST(credits) c), 0)) as credits, - | currency , project.id as googleProjectId, DATE(custom_time_partition) as date - | FROM `fakeTable` - | WHERE billing_account_id = @billingAccountId - | AND custom_time_partition BETWEEN @startDate AND @endDate - | AND project.id in UNNEST(@projects) - | GROUP BY currency , googleProjectId, date - |""".stripMargin - assertResult(expectedCustom) { - val service = createSpendReportingService(dataSource) - service.getQuery( - Set(SpendReportingAggregationKeys.Workspace, SpendReportingAggregationKeys.Daily), - "fakeTable", - Some("custom_time_partition") - ) - } + "getWorkspaceGoogleProjects" should "only map v2 workspaces to project names" in { + val v1Workspace = TestData.workspace("v1name", GoogleProjectId("v1ProjectId"), WorkspaceVersions.V1) + val v2Workspace = TestData.workspace("v2name", GoogleProjectId("v2ProjectId"), WorkspaceVersions.V2) + val workspaces = Seq( + v1Workspace, + v2Workspace + ) + + val dataSource = mock[SlickDataSource] + when(dataSource.inTransaction[Seq[Workspace]](any(), any())).thenReturn(Future.successful(workspaces)) + val service = new SpendReportingService( + testContext, + dataSource, + Resource.pure[IO, GoogleBigQueryService[IO]](mock[GoogleBigQueryService[IO]]), + mock[SamDAO], + spendReportingServiceConfig + ) + + val result = Await.result(service.getWorkspaceGoogleProjects(RawlsBillingProjectName("")), Duration.Inf) + + result shouldBe Map("v2ProjectId" -> v2Workspace.toWorkspaceName) } + } diff --git a/core/src/test/scala/org/broadinstitute/dsde/rawls/webservice/ApiServiceSpec.scala b/core/src/test/scala/org/broadinstitute/dsde/rawls/webservice/ApiServiceSpec.scala index 925ccd8637..7c436ca816 100644 --- a/core/src/test/scala/org/broadinstitute/dsde/rawls/webservice/ApiServiceSpec.scala +++ b/core/src/test/scala/org/broadinstitute/dsde/rawls/webservice/ApiServiceSpec.scala @@ -232,7 +232,7 @@ trait ApiServiceSpec ) _ val spendReportingBigQueryService = bigQueryServiceFactory.getServiceFromJson("json", GoogleProject("test-project")) - val spendReportingServiceConfig = SpendReportingServiceConfig("fakeTableName", "fakeTimePartitionColumn", 90) + val spendReportingServiceConfig = SpendReportingServiceConfig("fakeTableName", "fakeTimePartitionColumn", 90, "test.metrics") override val spendReportingConstructor = SpendReportingService.constructor( slickDataSource, spendReportingBigQueryService, diff --git a/model/src/main/scala/org/broadinstitute/dsde/rawls/RawlsException.scala b/model/src/main/scala/org/broadinstitute/dsde/rawls/RawlsException.scala index eecd02b663..cf6247be62 100644 --- a/model/src/main/scala/org/broadinstitute/dsde/rawls/RawlsException.scala +++ b/model/src/main/scala/org/broadinstitute/dsde/rawls/RawlsException.scala @@ -1,11 +1,25 @@ package org.broadinstitute.dsde.rawls -import org.broadinstitute.dsde.rawls.model.ErrorReport +import akka.http.scaladsl.model.StatusCode +import org.broadinstitute.dsde.rawls.model.{ErrorReport, ErrorReportSource} class RawlsException(message: String = null, cause: Throwable = null) extends Exception(message, cause) class RawlsExceptionWithErrorReport(val errorReport: ErrorReport) extends RawlsException(errorReport.toString) +object RawlsExceptionWithErrorReport { + def apply(errorReport: ErrorReport): RawlsExceptionWithErrorReport = new RawlsExceptionWithErrorReport(errorReport) + + def apply(message: String)(implicit source: ErrorReportSource): RawlsExceptionWithErrorReport = + RawlsExceptionWithErrorReport(ErrorReport(message)) + + def apply(message: String, cause: ErrorReport)(implicit source: ErrorReportSource): RawlsExceptionWithErrorReport = + RawlsExceptionWithErrorReport(ErrorReport(message, cause)) + + def apply(status: StatusCode, message: String)(implicit source: ErrorReportSource): RawlsExceptionWithErrorReport = + RawlsExceptionWithErrorReport(ErrorReport(status, message)) +} + /** * An exception where retrying will not help. *