Skip to content

Commit

Permalink
[WOR-1794] Enforce submission cost cap (#3034)
Browse files Browse the repository at this point in the history
  • Loading branch information
marctalbott authored Sep 18, 2024
1 parent 54010c7 commit 231e0cf
Show file tree
Hide file tree
Showing 17 changed files with 295 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,5 @@
<include file="changesets/20240418_workspace_monitor_args.xml" relativeToChangelogFile="true"/>
<include file="changesets/20240723_workspace_settings.xml" relativeToChangelogFile="true"/>
<include file="changesets/20240820_submission_cost_cap_threshold.xml" relativeToChangelogFile="true"/>
<include file="changesets/20240830_workflow_cost.xml" relativeToChangelogFile="true"/>
</databaseChangeLog>
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<databaseChangeLog logicalFilePath="dummy" xmlns="http://www.liquibase.org/xml/ns/dbchangelog" xmlns:ext="http://www.liquibase.org/xml/ns/dbchangelog-ext" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog-ext http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-ext.xsd http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.4.xsd">
<changeSet id="add_workflow_cost" author="mtalbott" logicalFilePath="dummy">
<addColumn tableName="WORKFLOW">
<column name="cost" type="NUMBER(10,2)" />
</addColumn>
</changeSet>
</databaseChangeLog>
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ trait ExecutionServiceCluster extends ErrorReportable {

def abort(workflowRec: WorkflowRecord, userInfo: UserInfo): Future[Try[ExecutionServiceStatus]]

def getCost(workflowRec: WorkflowRecord, userInfo: UserInfo): Future[WorkflowCostBreakdown]

def version: Future[ExecutionServiceVersion]

// ====================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ trait ExecutionServiceDAO extends ErrorReportable {
def abort(id: String, userInfo: UserInfo): Future[Try[ExecutionServiceStatus]]
def getLabels(id: String, userInfo: UserInfo): Future[ExecutionServiceLabelResponse]
def patchLabels(id: String, userInfo: UserInfo, labels: Map[String, String]): Future[ExecutionServiceLabelResponse]
def getCost(id: String, userInfo: UserInfo): Future[WorkflowCostBreakdown]

// get the version of the execution service itself
def version(): Future[ExecutionServiceVersion]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ class HttpExecutionServiceDAO(executionServiceURL: String, override val workbenc
retry(when5xx)(() => pipeline[ExecutionServiceLabelResponse](userInfo) apply Patch(url, labels))
}

override def getCost(id: String, userInfo: UserInfo): Future[WorkflowCostBreakdown] = {
val url = executionServiceURL + s"/api/workflows/v1/$id/cost"
retry(when5xx)(() => pipeline[WorkflowCostBreakdown](userInfo) apply Get(url))
}

override def version(): Future[ExecutionServiceVersion] = {
val url = executionServiceURL + s"/engine/v1/version"
retry(when5xx)(() => httpClientUtils.executeRequestUnmarshalResponse[ExecutionServiceVersion](http, Get(url)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class ShardedHttpExecutionServiceCluster(readMembers: Set[ClusterMember],
// the abort operation is special, it needs to go to a specific cromwell
getMember(workflowRec).dao.abort(workflowRec.externalId.get, userInfo)

def getCost(workflowRec: WorkflowRecord, userInfo: UserInfo): Future[WorkflowCostBreakdown] =
getMember(workflowRec).dao.getCost(workflowRec.externalId.get, userInfo)

def version: Future[ExecutionServiceVersion] =
getRandomReadMember.dao.version

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,18 @@ trait SubmissionComponent {
})
)

def listActiveSubmissionIdsWithWorkspace(limit: FiniteDuration): ReadAction[Seq[(UUID, WorkspaceName)]] = {
def listActiveSubmissionIdsWithWorkspaceAndCostCapThreshold(
limit: FiniteDuration
): ReadAction[Seq[(UUID, WorkspaceName, Option[BigDecimal])]] = {
// Exclude submissions from monitoring if they are ancient/stuck [WX-820]
val cutoffTime = new Timestamp(DateTime.now().minusDays(limit.toDays.toInt).getMillis)
val query = findActiveSubmissionsAfterTime(cutoffTime) join workspaceQuery on (_.workspaceId === _.id)
val result = query.map { case (sub, ws) => (sub.id, ws.namespace, ws.name) }.result
result.map(rows => rows.map { case (subId, wsNs, wsName) => (subId, WorkspaceName(wsNs, wsName)) })
val result = query.map { case (sub, ws) => (sub.id, ws.namespace, ws.name, sub.costCapThreshold) }.result
result.map(rows =>
rows.map { case (subId, wsNs, wsName, costCapThreshold) =>
(subId, WorkspaceName(wsNs, wsName), costCapThreshold)
}
)
}

def getSubmissionMethodConfigId(workspaceContext: Workspace, submissionId: UUID): ReadAction[Option[Long]] =
Expand Down Expand Up @@ -568,7 +574,7 @@ trait SubmissionComponent {
)

implicit val getWorkflowMessagesListResult: GetResult[WorkflowMessagesListResult] = GetResult { r =>
val workflowRec = WorkflowRecord(r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<)
val workflowRec = WorkflowRecord(r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<)
val rootEntityTypeOption: Option[String] = r.<<

val messageOption: Option[String] = r.<<
Expand All @@ -588,7 +594,7 @@ trait SubmissionComponent {
}

def action(submissionId: UUID) =
sql"""select w.ID, w.EXTERNAL_ID, w.SUBMISSION_ID, w.STATUS, w.STATUS_LAST_CHANGED, w.ENTITY_ID, w.record_version, w.EXEC_SERVICE_KEY, w.EXTERNAL_ENTITY_ID,
sql"""select w.ID, w.EXTERNAL_ID, w.SUBMISSION_ID, w.STATUS, w.STATUS_LAST_CHANGED, w.ENTITY_ID, w.record_version, w.EXEC_SERVICE_KEY, w.EXTERNAL_ENTITY_ID, w.COST,
s.ROOT_ENTITY_TYPE,
m.MESSAGE,
e.name, e.entity_type, e.workspace_id, e.record_version, e.deleted, e.deleted_date
Expand All @@ -607,7 +613,7 @@ trait SubmissionComponent {
)

implicit val getWorkflowInputResolutionListResult: GetResult[WorkflowInputResolutionListResult] = GetResult { r =>
val workflowRec = WorkflowRecord(r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<)
val workflowRec = WorkflowRecord(r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<, r.<<)
val (submissionValidation, attribute) = r.nextLongOption() match {
case Some(submissionValidationId) =>
(Option(SubmissionValidationRecord(submissionValidationId, workflowRec.id, r.<<, r.<<)),
Expand Down Expand Up @@ -635,7 +641,7 @@ trait SubmissionComponent {
}

def action(submissionId: UUID) =
sql"""select w.ID, w.EXTERNAL_ID, w.SUBMISSION_ID, w.STATUS, w.STATUS_LAST_CHANGED, w.ENTITY_ID, w.record_version, w.EXEC_SERVICE_KEY, w.EXTERNAL_ENTITY_ID,
sql"""select w.ID, w.EXTERNAL_ID, w.SUBMISSION_ID, w.STATUS, w.STATUS_LAST_CHANGED, w.ENTITY_ID, w.record_version, w.EXEC_SERVICE_KEY, w.EXTERNAL_ENTITY_ID, w.COST,
sv.id, sv.ERROR_TEXT, sv.INPUT_NAME,
sa.id, sa.namespace, sa.name, sa.value_string, sa.value_number, sa.value_boolean, sa.value_json, sa.value_entity_ref, sa.list_index, sa.list_length, sa.deleted, sa.deleted_date
from WORKFLOW w
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ case class WorkflowRecord(id: Long,
workflowEntityId: Option[Long],
recordVersion: Long,
executionServiceKey: Option[String],
externalEntityId: Option[String]
externalEntityId: Option[String],
cost: Option[BigDecimal]
)

case class WorkflowMessageRecord(workflowId: Long, message: String)
Expand All @@ -54,6 +55,7 @@ trait WorkflowComponent {
def version = column[Long]("record_version")
def executionServiceKey = column[Option[String]]("EXEC_SERVICE_KEY")
def externalEntityId = column[Option[String]]("EXTERNAL_ENTITY_ID")
def cost = column[Option[BigDecimal]]("COST")

def * = (id,
externalId,
Expand All @@ -63,7 +65,8 @@ trait WorkflowComponent {
workflowEntityId,
version,
executionServiceKey,
externalEntityId
externalEntityId,
cost
) <> (WorkflowRecord.tupled, WorkflowRecord.unapply)

def submission = foreignKey("FK_WF_SUB", submissionId, submissionQuery)(_.id)
Expand Down Expand Up @@ -272,6 +275,21 @@ trait WorkflowComponent {
): ReadWriteAction[Int] =
batchUpdateStatus(Seq(workflow), newStatus)

def updateStatusAndCost(workflow: WorkflowRecord, newStatus: WorkflowStatus, newCost: BigDecimal)(implicit
wfStatusCounter: WorkflowStatus => Option[Counter]
): ReadWriteAction[Int] =
UpdateWorkflowStatusAndCostRawSql.actionForWorkflowRecs(Seq(workflow), newStatus, newCost) map { rows =>
if (rows.head != 1) {
throw new RawlsConcurrentModificationException(
s"could not update workflow because its record version(s) has changed"
)
}
rows.head
} map { result =>
wfStatusCounter(newStatus).foreach(_ += result)
result
}

// input: old workflow records, and the status that we want to apply to all of them
def batchUpdateStatus(workflows: Seq[WorkflowRecord], newStatus: WorkflowStatus)(implicit
wfStatusCounter: WorkflowStatus => Option[Counter]
Expand Down Expand Up @@ -614,7 +632,8 @@ trait WorkflowComponent {
entityId,
0,
None,
externalEntityId
externalEntityId,
workflow.cost.map(BigDecimal(_))
)

private def unmarshalWorkflow(workflowRec: WorkflowRecord,
Expand All @@ -628,7 +647,8 @@ trait WorkflowComponent {
new DateTime(workflowRec.statusLastChangedDate.getTime),
entity,
inputResolutions,
messages
messages,
workflowRec.cost.map(_.floatValue)
)

private def marshalInputResolution(value: SubmissionValidationValue, workflowId: Long): SubmissionValidationRecord =
Expand Down Expand Up @@ -721,6 +741,23 @@ trait WorkflowComponent {
).as[Int].map(_.head)
}

private object UpdateWorkflowStatusAndCostRawSql extends RawSqlQuery {
val driver: JdbcProfile = WorkflowComponent.this.driver

private def update(newStatus: WorkflowStatus, cost: BigDecimal) =
sql"update WORKFLOW set status = ${newStatus.toString}, cost = ${cost}, status_last_changed = ${new Timestamp(System.currentTimeMillis())}, record_version = record_version + 1, rawls_hostname = ${hostname} "

def actionForWorkflowRecs(workflows: Seq[WorkflowRecord], newStatus: WorkflowStatus, cost: BigDecimal) = {
val where = sql"where ("
val workflowTuples = reduceSqlActionsWithDelim(workflows.map { case wf =>
sql"(id = ${wf.id} AND record_version = ${wf.recordVersion})"
},
sql" OR "
)
concatSqlActions(update(newStatus, cost), where, workflowTuples, sql")").as[Int]
}
}

private object UpdateWorkflowStatusAndExecutionIdRawSql extends RawSqlQuery {
val driver: JdbcProfile = WorkflowComponent.this.driver

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,22 @@ object SubmissionMonitorActor {
executionServiceCluster: ExecutionServiceCluster,
config: SubmissionMonitorConfig,
queryTimeout: Duration,
workbenchMetricBaseName: String
workbenchMetricBaseName: String,
costCapThreshold: Option[BigDecimal] = None
): Props =
Props(
new SubmissionMonitorActor(workspaceName,
submissionId,
datasource,
samDAO,
googleServicesDAO,
notificationDAO,
executionServiceCluster,
config,
queryTimeout,
workbenchMetricBaseName
new SubmissionMonitorActor(
workspaceName,
submissionId,
datasource,
samDAO,
googleServicesDAO,
notificationDAO,
executionServiceCluster,
config,
queryTimeout,
workbenchMetricBaseName,
costCapThreshold
)
)

Expand Down Expand Up @@ -122,7 +125,8 @@ class SubmissionMonitorActor(val workspaceName: WorkspaceName,
val executionServiceCluster: ExecutionServiceCluster,
val config: SubmissionMonitorConfig,
val queryTimeout: Duration,
override val workbenchMetricBaseName: String
override val workbenchMetricBaseName: String,
val costCapThreshold: Option[BigDecimal]
) extends Actor
with SubmissionMonitor
with LazyLogging {
Expand Down Expand Up @@ -184,6 +188,7 @@ trait SubmissionMonitor extends FutureSupport with LazyLogging with RawlsInstrum
val executionServiceCluster: ExecutionServiceCluster
val config: SubmissionMonitorConfig
val queryTimeout: Duration
val costCapThreshold: Option[BigDecimal]

// Cache these metric builders since they won't change for this SubmissionMonitor
protected lazy val workspaceMetricBuilder: ExpandedMetricBuilder =
Expand Down Expand Up @@ -253,8 +258,9 @@ trait SubmissionMonitor extends FutureSupport with LazyLogging with RawlsInstrum
Future.traverse(externalWorkflowIds) { workflowRec =>
// for each workflow query the exec service for status and if has Succeeded query again for outputs
toFutureTry(execServiceStatus(workflowRec, petUser) flatMap {
case Some(updatedWorkflowRec) => execServiceOutputs(updatedWorkflowRec, petUser)
case None => Future.successful(None)
case Some(updatedWorkflowRec) =>
execServiceOutputs(updatedWorkflowRec, petUser)
case None => Future.successful(None)
})
}

Expand Down Expand Up @@ -299,6 +305,12 @@ trait SubmissionMonitor extends FutureSupport with LazyLogging with RawlsInstrum
executionContext: ExecutionContext
): Future[Option[WorkflowRecord]] =
workflowRec.externalId match {
// fetch cost information for the workflow if submission has a cost cap threshold defined
case Some(externalId) if costCapThreshold.isDefined =>
executionServiceCluster.getCost(workflowRec, petUser).map { costBreakdown =>
Option(workflowRec.copy(status = costBreakdown.status, cost = costBreakdown.cost.some))
}
// fetch workflow status only if cost cap threshold is not defined
case Some(externalId) =>
executionServiceCluster.status(workflowRec, petUser).map { newStatus =>
if (newStatus.status != workflowRec.status) Option(workflowRec.copy(status = newStatus.status))
Expand All @@ -311,11 +323,8 @@ trait SubmissionMonitor extends FutureSupport with LazyLogging with RawlsInstrum
executionContext: ExecutionContext
): Future[Option[(WorkflowRecord, Option[ExecutionServiceOutputs])]] =
WorkflowStatuses.withName(workflowRec.status) match {
case status if WorkflowStatuses.terminalStatuses.contains(status) =>
if (status == WorkflowStatuses.Succeeded)
executionServiceCluster.outputs(workflowRec, petUser).map(outputs => Option((workflowRec, Option(outputs))))
else
Future.successful(Some((workflowRec, None)))
case status if status == WorkflowStatuses.Succeeded =>
executionServiceCluster.outputs(workflowRec, petUser).map(outputs => Option((workflowRec, Option(outputs))))
case _ => Future.successful(Some((workflowRec, None)))
}

Expand All @@ -329,6 +338,7 @@ trait SubmissionMonitor extends FutureSupport with LazyLogging with RawlsInstrum
* @param executionContext
* @return
*/

def handleStatusResponses(
response: ExecutionServiceStatusResponse
)(implicit executionContext: ExecutionContext): Future[StatusCheckComplete] =
Expand Down Expand Up @@ -444,9 +454,15 @@ trait SubmissionMonitor extends FutureSupport with LazyLogging with RawlsInstrum
numRowsUpdated <-
if (doRecordUpdate) {
for {
updateResult <- dataAccess.workflowQuery.updateStatus(currentRec,
WorkflowStatuses.withName(workflowRec.status)
)
updateResult <-
if (costCapThreshold.isDefined) {
dataAccess.workflowQuery.updateStatusAndCost(currentRec,
WorkflowStatuses.withName(workflowRec.status),
workflowRec.cost.getOrElse(BigDecimal(0))
)
} else {
dataAccess.workflowQuery.updateStatus(currentRec, WorkflowStatuses.withName(workflowRec.status))
}
_ = logger.info(
s"workflow ${externalId(currentRec)} status change ${currentRec.status} -> ${workflowRec.status} in submission ${submissionId}"
)
Expand Down Expand Up @@ -552,19 +568,32 @@ trait SubmissionMonitor extends FutureSupport with LazyLogging with RawlsInstrum

/**
* When there are no workflows with a running or queued status, mark the submission as done or aborted as appropriate.
* If there are still non-terminal workflows and a cost cap threshold is defined, check the current cost of all workflows in the submission and abort the submission if threshold has been exceeded.
*
* @param dataAccess
* @param executionContext
* @return true if the submission is done/aborted
*/
def updateSubmissionStatus(
dataAccess: DataAccess
)(implicit executionContext: ExecutionContext): ReadWriteAction[Boolean] =
dataAccess.workflowQuery.listWorkflowRecsForSubmissionAndStatuses(
submissionId,
(WorkflowStatuses.queuedStatuses ++ WorkflowStatuses.runningStatuses): _*
) flatMap { workflowRecs =>
if (workflowRecs.isEmpty) {
)(implicit executionContext: ExecutionContext): ReadWriteAction[Boolean] = {
val workflowRecsAction = if (costCapThreshold.isDefined) {
dataAccess.workflowQuery.listWorkflowRecsForSubmission(submissionId)
} else {
dataAccess.workflowQuery.listWorkflowRecsForSubmissionAndStatuses(
submissionId,
(WorkflowStatuses.queuedStatuses ++ WorkflowStatuses.runningStatuses): _*
)
}

workflowRecsAction.flatMap { workflowRecs =>
val nonTerminalWorkflows =
if (costCapThreshold.isDefined)
workflowRecs
.filterNot(wf => WorkflowStatuses.terminalStatuses.contains(WorkflowStatuses.withName(wf.status)))
else workflowRecs

if (nonTerminalWorkflows.isEmpty) {
dataAccess.submissionQuery.findById(submissionId).map(_.status).result.head.flatMap { status =>
val finalStatus = SubmissionStatuses.withName(status) match {
case SubmissionStatuses.Aborting => SubmissionStatuses.Aborted
Expand All @@ -577,10 +606,16 @@ trait SubmissionMonitor extends FutureSupport with LazyLogging with RawlsInstrum
logger.debug(s"submission $submissionId terminating to status $newStatus")
dataAccess.submissionQuery.updateStatus(submissionId, newStatus)
} map (_ => true)
} else if (costCapThreshold.isDefined && costCapThreshold.get <= workflowRecs.flatMap(_.cost).sum) {
logger.info(
s"Submission $submissionId exceeded its cost cap and will be aborted. [costCap=${costCapThreshold.get},currentSubmissionCost=${workflowRecs.flatMap(_.cost).sum}]"
)
dataAccess.submissionQuery.updateStatus(submissionId, SubmissionStatuses.Aborting).map(_ => false)
} else {
DBIO.successful(false)
}
}
}

def handleOutputs(workflowsWithOutputs: Seq[(WorkflowRecord, ExecutionServiceOutputs)],
dataAccess: DataAccess,
Expand Down
Loading

0 comments on commit 231e0cf

Please sign in to comment.