Skip to content

Commit

Permalink
Configurable startup time (pytorch#3262)
Browse files Browse the repository at this point in the history
* configurable start up time, minimum working example

* remove startuptimeout from async worker for now before I confirm what model_load_timeout is

* doc updates

* remove extra spaces in model manager

* apply formatting

* remove worker command logging

* add tests for long startup timeout

* worker thread add logging response timeout if worker state isn't worker_started

* add startuptimeout to registerWorkflow function

* add startuptimeout to the correct word in spellchecker

* working example

* small refactor

* small refactor

* added default value for model status

* Update ts_scripts/spellcheck_conf/wordlist.txt

* Fix java unit tests

* Fix regression test

* add startup_timeout for test to cast it to int

---------

Co-authored-by: Matthias Reso <[email protected]>
  • Loading branch information
Isalia20 and mreso authored Aug 12, 2024
1 parent 30eb13d commit ef196c0
Show file tree
Hide file tree
Showing 32 changed files with 322 additions and 24 deletions.
2 changes: 2 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ A model's parameters are defined in [model source code](https://github.com/pytor
* `maxWorkers`: the maximum number of workers of a model
* `batchSize`: the batch size of a model
* `maxBatchDelay`: the maximum delay in msec of a batch of a model
* `startupTimeout`: the timeout in sec of a specific model's startup. This setting takes priority over `default_startup_timeout` which is a default timeout over all models
* `responseTimeout`: the timeout in sec of a specific model's response. This setting takes priority over `default_response_timeout` which is a default timeout over all models
* `defaultVersion`: the default version of a model
* `marName`: the mar file name of a model
Expand Down Expand Up @@ -295,6 +296,7 @@ Most of the following properties are designed for performance tuning. Adjusting
* `job_queue_size`: Number inference jobs that frontend will queue before backend can serve. Default: 100.
* `async_logging`: Enable asynchronous logging for higher throughput, log output may be delayed if this is enabled. Default: false.
* `default_response_timeout`: Timeout, in seconds, used for all models backend workers before they are deemed unresponsive and rebooted. Default: 120 seconds.
* `default_startup_timeout`: Specifies the maximum time, in seconds, allowed for model backend workers to initialize and become ready. If a worker fails to start within this timeframe, it is considered unresponsive and will be restarted. Default: 120 seconds.
* `unregister_model_timeout`: Timeout, in seconds, used when handling an unregister model request when cleaning a process before it is deemed unresponsive and an error response is sent. Default: 120 seconds.
* `decode_input_request`: Configuration to let backend workers to decode requests, when the content type is known.
If this is set to "true", backend workers do "Bytearray to JSON object" conversion when the content type is "application/json" and
Expand Down
3 changes: 2 additions & 1 deletion docs/large_model_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ To reduce model latency we recommend:
#### Tune [model config YAML file](https://github.com/pytorch/serve/blob/5ee02e4f050c9b349025d87405b246e970ee710b/model-archiver/README.md)

You can tune the model config YAML file to get better performance in the following ways:
* Update the [responseTimeout](https://github.com/pytorch/serve/blob/5ee02e4f050c9b349025d87405b246e970ee710b/docs/configuration.md?plain=1#L216) if high model loading or inference latency causes response timeout.
* Update the [responseTimeout](https://github.com/pytorch/serve/blob/5ee02e4f050c9b349025d87405b246e970ee710b/docs/configuration.md?plain=1#L216) if high model inference latency causes response timeout.
* Update the [startupTimeout](https://github.com/pytorch/serve/blob/5ee02e4f050c9b349025d87405b246e970ee710b/docs/configuration.md?plain=1#L216) if high model loading latency causes startup timeout.
* Tune the [torchrun parameters](https://github.com/pytorch/serve/blob/2f1f52f553e83703b5c380c2570a36708ee5cafa/model-archiver/README.md?plain=1#L179). The supported parameters are defined at [here](https://github.com/pytorch/serve/blob/2f1f52f553e83703b5c380c2570a36708ee5cafa/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java#L329). For example, by default, `OMP_NUMBER_THREADS` is 1. This can be modified in the YAML file.
```yaml
#frontend settings
Expand Down
1 change: 1 addition & 0 deletions docs/management_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ To use this API after TorchServe starts, model API control has to be enabled. Ad
* `initial_workers` - the number of initial workers to create. The default value is `0`. TorchServe will not run inference until there is at least one work assigned.
* `synchronous` - whether or not the creation of worker is synchronous. The default value is false. TorchServe will create new workers without waiting for acknowledgement that the previous worker is online.
* `response_timeout` - If the model's backend worker doesn't respond with inference response within this timeout period, the worker will be deemed unresponsive and rebooted. The units is seconds. The default value is 120 seconds.
* `startup_timeout` - If the model's backend worker doesn't load the model within this timeout period, the worker will be deemed unresponsive and rebooted. The units is seconds. The default value is 120 seconds.

```bash
curl -X POST "http://localhost:8081/models?url=https://torchserve.pytorch.org/mar_files/squeezenet1_1.mar"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public class ModelConfig {
private int maxBatchDelay;
/** the timeout in sec of a specific model's response. */
private int responseTimeout = 120; // unit: sec
/** the timeout in sec of a specific model's startup. */
private int startupTimeout = 120; // unit: sec
/**
* the device type where the model is loaded. It can be gpu, cpu. The model is loaded on CPU if
* deviceType: "cpu" is set on a GPU host.
Expand Down Expand Up @@ -122,6 +124,13 @@ public static ModelConfig build(Map<String, Object> yamlMap) {
logger.warn("Invalid responseTimeout: {}, should be integer", v);
}
break;
case "startupTimeout":
if (v instanceof Integer) {
modelConfig.setStartupTimeout((int) v);
} else {
logger.warn("Invalid startupTimeout: {}, should be integer", v);
}
break;
case "deviceType":
if (v instanceof String) {
modelConfig.setDeviceType((String) v);
Expand Down Expand Up @@ -319,6 +328,18 @@ public void setResponseTimeout(int responseTimeout) {
this.responseTimeout = responseTimeout;
}

public int getStartupTimeout() {
return startupTimeout;
}

public void setStartupTimeout(int startupTimeout) {
if (startupTimeout <= 0) {
logger.warn("Invalid startupTimeout:{}", startupTimeout);
return;
}
this.startupTimeout = startupTimeout;
}

public List<Integer> getDeviceIds() {
return deviceIds;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public void testValidYamlConfig() throws InvalidModelException, IOException {
Assert.assertEquals(modelConfig.getBatchSize(), 1);
Assert.assertEquals(modelConfig.getMaxBatchDelay(), 100);
Assert.assertEquals(modelConfig.getResponseTimeout(), 120);
Assert.assertEquals(modelConfig.getStartupTimeout(), 120);
Assert.assertEquals(modelConfig.getDeviceType(), ModelConfig.DeviceType.GPU);
Assert.assertEquals(modelConfig.getParallelLevel(), 4);
Assert.assertEquals(modelConfig.getParallelType(), ModelConfig.ParallelType.PP);
Expand All @@ -42,6 +43,7 @@ public void testInvalidYamlConfig() throws InvalidModelException, IOException {
Assert.assertEquals(modelConfig.getBatchSize(), 1);
Assert.assertEquals(modelConfig.getMaxBatchDelay(), 100);
Assert.assertEquals(modelConfig.getResponseTimeout(), 120);
Assert.assertEquals(modelConfig.getStartupTimeout(), 120);
Assert.assertNotEquals(modelConfig.getDeviceType(), ModelConfig.DeviceType.GPU);
Assert.assertEquals(modelConfig.getParallelLevel(), 0);
Assert.assertNotEquals(modelConfig.getParallelType(), ModelConfig.ParallelType.PPTP);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ private void initModelStore() throws InvalidSnapshotException, IOException {
-1 * RegisterModelRequest.DEFAULT_BATCH_SIZE,
-1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY,
configManager.getDefaultResponseTimeout(),
configManager.getDefaultStartupTimeout(),
defaultModelName,
false,
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class DescribeModelResponse {
private int batchSize;
private int maxBatchDelay;
private int responseTimeout;
private int startupTimeout;
private long maxRetryTimeoutInSec;
private long clientTimeoutInMills;
private String parallelType;
Expand Down Expand Up @@ -132,10 +133,18 @@ public int getResponseTimeout() {
return responseTimeout;
}

public int getStartupTimeout() {
return startupTimeout;
}

public void setResponseTimeout(int responseTimeout) {
this.responseTimeout = responseTimeout;
}

public void setStartupTimeout(int startupTimeout) {
this.startupTimeout = startupTimeout;
}

public long getMaxRetryTimeoutInSec() {
return maxRetryTimeoutInSec;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public class RegisterModelRequest {
@SerializedName("response_timeout")
private int responseTimeout;

@SerializedName("startup_timeout")
private int startupTimeout;

@SerializedName("url")
private String modelUrl;

Expand All @@ -56,6 +59,7 @@ public RegisterModelRequest(QueryStringDecoder decoder) {
ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel());
synchronous = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "synchronous", "true"));
responseTimeout = NettyUtils.getIntParameter(decoder, "response_timeout", -1);
startupTimeout = NettyUtils.getIntParameter(decoder, "startup_timeout", -1);
modelUrl = NettyUtils.getParameter(decoder, "url", null);
s3SseKms = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "s3_sse_kms", "false"));
}
Expand All @@ -74,6 +78,7 @@ public RegisterModelRequest(org.pytorch.serve.grpc.management.RegisterModelReque
ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel());
synchronous = request.getSynchronous();
responseTimeout = GRPCUtils.getRegisterParam(request.getResponseTimeout(), -1);
startupTimeout = GRPCUtils.getRegisterParam(request.getStartupTimeout(), -1);
modelUrl = GRPCUtils.getRegisterParam(request.getUrl(), null);
s3SseKms = request.getS3SseKms();
}
Expand All @@ -84,6 +89,7 @@ public RegisterModelRequest() {
synchronous = true;
initialWorkers = ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel();
responseTimeout = -1;
startupTimeout = -1;
s3SseKms = false;
}

Expand Down Expand Up @@ -119,6 +125,10 @@ public Integer getResponseTimeout() {
return responseTimeout;
}

public Integer getStartupTimeout() {
return startupTimeout;
}

public String getModelUrl() {
return modelUrl;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ private static Operation getRegisterOperation() {
"integer",
"2",
"Maximum time, in seconds, the TorchServe waits for a response from the model inference code, default: 120."));
operation.addParameter(
new QueryParameter(
"startup_timeout",
"integer",
"120",
"Maximum time, in seconds, the TorchServe waits for the model to startup/initialize, default: 120."));
operation.addParameter(
new QueryParameter(
"initial_workers",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,14 @@ public static StatusResponse registerModel(RegisterModelRequest registerModelReq
int maxBatchDelay = registerModelRequest.getMaxBatchDelay();
int initialWorkers = registerModelRequest.getInitialWorkers();
int responseTimeout = registerModelRequest.getResponseTimeout();
int startupTimeout = registerModelRequest.getStartupTimeout();
boolean s3SseKms = registerModelRequest.getS3SseKms();
if (responseTimeout == -1) {
responseTimeout = ConfigManager.getInstance().getDefaultResponseTimeout();
}
if (startupTimeout == -1) {
startupTimeout = ConfigManager.getInstance().getDefaultStartupTimeout();
}

Manifest.RuntimeType runtimeType = null;
if (runtime != null) {
Expand All @@ -144,6 +148,7 @@ public static StatusResponse registerModel(RegisterModelRequest registerModelReq
batchSize,
maxBatchDelay,
responseTimeout,
startupTimeout,
initialWorkers,
registerModelRequest.getSynchronous(),
false,
Expand All @@ -158,6 +163,7 @@ public static StatusResponse handleRegister(
int batchSize,
int maxBatchDelay,
int responseTimeout,
int startupTimeout,
int initialWorkers,
boolean isSync,
boolean isWorkflowModel,
Expand All @@ -177,6 +183,7 @@ public static StatusResponse handleRegister(
batchSize,
maxBatchDelay,
responseTimeout,
startupTimeout,
null,
false,
isWorkflowModel,
Expand Down Expand Up @@ -403,6 +410,7 @@ private static DescribeModelResponse createModelResponse(
resp.setModelVersion(manifest.getModel().getModelVersion());
resp.setRuntime(manifest.getRuntime().getValue());
resp.setResponseTimeout(model.getResponseTimeout());
resp.setStartupTimeout(model.getStartupTimeout());
resp.setMaxRetryTimeoutInSec(model.getMaxRetryTimeoutInMill() / 1000);
resp.setClientTimeoutInMills(model.getClientTimeoutInMills());
resp.setParallelType(model.getParallelType().getParallelType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public final class ConfigManager {
private static final String TS_BLACKLIST_ENV_VARS = "blacklist_env_vars";
private static final String TS_DEFAULT_WORKERS_PER_MODEL = "default_workers_per_model";
private static final String TS_DEFAULT_RESPONSE_TIMEOUT = "default_response_timeout";
private static final String TS_DEFAULT_STARTUP_TIMEOUT = "default_startup_timeout";
private static final String TS_UNREGISTER_MODEL_TIMEOUT = "unregister_model_timeout";
private static final String TS_NUMBER_OF_NETTY_THREADS = "number_of_netty_threads";
private static final String TS_NETTY_CLIENT_THREADS = "netty_client_threads";
Expand Down Expand Up @@ -879,6 +880,10 @@ public int getDefaultResponseTimeout() {
return Integer.parseInt(prop.getProperty(TS_DEFAULT_RESPONSE_TIMEOUT, "120"));
}

public int getDefaultStartupTimeout() {
return Integer.parseInt(prop.getProperty(TS_DEFAULT_STARTUP_TIMEOUT, "120"));
}

public int getUnregisterModelTimeout() {
return Integer.parseInt(prop.getProperty(TS_UNREGISTER_MODEL_TIMEOUT, "120"));
}
Expand Down
12 changes: 12 additions & 0 deletions frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class Model {
public static final String BATCH_SIZE = "batchSize";
public static final String MAX_BATCH_DELAY = "maxBatchDelay";
public static final String RESPONSE_TIMEOUT = "responseTimeout";
public static final String STARTUP_TIMEOUT = "startupTimeout";
public static final String PARALLEL_LEVEL = "parallelLevel";
public static final String DEFAULT_VERSION = "defaultVersion";
public static final String MAR_NAME = "marName";
Expand All @@ -57,6 +58,7 @@ public class Model {
private ReentrantLock lock;
private ReentrantLock jobGroupLock;
private int responseTimeout;
private int startupTimeout;
private long sequenceMaxIdleMSec;
private long sequenceTimeoutMSec;
private int maxNumSequence;
Expand Down Expand Up @@ -178,6 +180,7 @@ public JsonObject getModelState(boolean isDefaultVersion) {
modelInfo.addProperty(BATCH_SIZE, getBatchSize());
modelInfo.addProperty(MAX_BATCH_DELAY, getMaxBatchDelay());
modelInfo.addProperty(RESPONSE_TIMEOUT, getResponseTimeout());
modelInfo.addProperty(STARTUP_TIMEOUT, getStartupTimeout());
modelInfo.addProperty(RUNTIME_TYPE, getRuntimeType().getValue());
if (parallelLevel > 0) {
modelInfo.addProperty(PARALLEL_LEVEL, parallelLevel);
Expand All @@ -191,6 +194,7 @@ public void setModelState(JsonObject modelInfo) {
maxWorkers = modelInfo.get(MAX_WORKERS).getAsInt();
maxBatchDelay = modelInfo.get(MAX_BATCH_DELAY).getAsInt();
responseTimeout = modelInfo.get(RESPONSE_TIMEOUT).getAsInt();
startupTimeout = modelInfo.get(STARTUP_TIMEOUT).getAsInt();
batchSize = modelInfo.get(BATCH_SIZE).getAsInt();

JsonElement runtime = modelInfo.get(RUNTIME_TYPE);
Expand Down Expand Up @@ -537,10 +541,18 @@ public int getResponseTimeout() {
return ConfigManager.getInstance().isDebug() ? Integer.MAX_VALUE : responseTimeout;
}

public int getStartupTimeout() {
return ConfigManager.getInstance().isDebug() ? Integer.MAX_VALUE : startupTimeout;
}

public void setResponseTimeout(int responseTimeout) {
this.responseTimeout = responseTimeout;
}

public void setStartupTimeout(int startupTimeout) {
this.startupTimeout = startupTimeout;
}

public List<Integer> getDeviceIds() {
return this.deviceIds;
}
Expand Down
Loading

0 comments on commit ef196c0

Please sign in to comment.