Skip to content

Commit

Permalink
Adding token auth and model api to workflow and https (pytorch#3234)
Browse files Browse the repository at this point in the history
* adding workflow2 test

* adding token

* adding token

* testing workflow

* fixing tests

* adding comments to test

* adding two failure test

* testing docker file

* adding https tests

* fixing docker

* adding https to model control

* updating
  • Loading branch information
udaij12 authored Jul 12, 2024
1 parent 1125bb1 commit f717497
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && \
apt-get upgrade -y && \
apt-get install software-properties-common -y && \
add-apt-repository -y ppa:deadsnakes/ppa && \
add-apt-repository ppa:deadsnakes/ppa -y && \
apt remove python-pip python3-pip && \
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
python$PYTHON_VERSION \
Expand Down
63 changes: 63 additions & 0 deletions test/pytest/test_model_control_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

ROOT_DIR = os.path.join(tempfile.gettempdir(), "workspace")
REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
config_file_https = os.path.join(REPO_ROOT, "../resources/config_https.properties")

expected_output = {
"code": 405,
Expand Down Expand Up @@ -44,6 +45,35 @@ def setup_torchserve_api_enabled():
test_utils.stop_torchserve()


@pytest.fixture(scope="module")
def setup_torchserve_https():
MODEL_STORE = os.path.join(ROOT_DIR, "model_store/")
PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path")

Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True)

test_utils.start_torchserve(
snapshot_file=config_file_https,
models="mnist=mnist.mar",
no_config_snapshots=True,
enable_model_api=False,
)

params = (
("model_name", "mnist"),
("url", "mnist.mar"),
("initial_workers", "1"),
("synchronous", "true"),
)
response = requests.post(
"https://localhost:8081/models", params=params, verify=False
)

yield "test"

test_utils.stop_torchserve()


# Test register a model after startup - Model control mode: default
def test_register_model_failing(setup_torchserve):
response = requests.get("http://localhost:8081/models/mnist")
Expand Down Expand Up @@ -158,3 +188,36 @@ def test_priority_env(monkeypatch):
test_utils.stop_torchserve()

assert response.status_code == 200, "model control check failed"


# Test register a model after startup - Model control mode: default
def test_register_model_failing_https(setup_torchserve_https):
response = requests.get("https://localhost:8081/models/mnist", verify=False)
assert response.status_code == 200, "management check failed"
params = (
("model_name", "resnet-18"),
("url", "resnet-18.mar"),
("initial_workers", "1"),
("synchronous", "true"),
)
response = requests.post(
"https://localhost:8081/models", params=params, verify=False
)

assert response.status_code == 405, "model control check failed"
assert response.json() == expected_output, "unexpected exception"
response = requests.get("https://localhost:8081/models/resnet-18", verify=False)
assert response.status_code == 404, "management check failed"


# Test deleting a model after startup - Model control mode: default
def test_delete_model_failing_https(setup_torchserve_https):
response = requests.get("https://localhost:8081/models/mnist", verify=False)
assert response.status_code == 200, "management check failed"

response = requests.delete("https://localhost:8081/models/mnist", verify=False)

assert response.status_code == 405, "model control check failed"
assert response.json() == expected_output, "unexpected exception"
response = requests.get("https://localhost:8081/models/mnist", verify=False)
assert response.status_code == 200, "management check failed"
126 changes: 125 additions & 1 deletion test/pytest/test_token_authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
CURR_DIR = os.path.dirname(os.path.abspath(__file__))
data_file_zero = os.path.join(CURR_DIR, "test_data/0.png")
config_file = os.path.join(CURR_DIR, "../resources/config_token.properties")
config_file_workflow = os.path.join(
CURR_DIR, "../resources/config_token_workflow.properties"
)
config_file_https = os.path.join(CURR_DIR, "../resources/config_https.properties")


# Parse json file and return key
Expand All @@ -36,7 +40,11 @@ def setup_torchserve():

Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True)

test_utils.start_torchserve(no_config_snapshots=True, disable_token=False)
test_utils.start_torchserve(
snapshot_file=config_file_workflow,
no_config_snapshots=True,
disable_token=False,
)

key = read_key_file("management")
header = {"Authorization": f"Bearer {key}"}
Expand Down Expand Up @@ -89,6 +97,39 @@ def setup_torchserve_expiration():
test_utils.stop_torchserve()


@pytest.fixture(scope="module")
def setup_torchserve_https():
MODEL_STORE = os.path.join(ROOT_DIR, "model_store/")
PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path")

Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True)

test_utils.start_torchserve(
snapshot_file=config_file_https,
no_config_snapshots=True,
disable_token=False,
)

key = read_key_file("management")
header = {"Authorization": f"Bearer {key}"}

params = (
("model_name", "mnist"),
("url", "mnist.mar"),
("initial_workers", "1"),
("synchronous", "true"),
)
response = requests.post(
"https://localhost:8081/models", params=params, headers=header, verify=False
)
file_content = Path(f"{CURR_DIR}/key_file.json").read_text()
print(file_content)

yield "test"

test_utils.stop_torchserve()


# Test describe model API with token enabled
def test_managament_api_with_token(setup_torchserve):
key = read_key_file("management")
Expand Down Expand Up @@ -182,6 +223,52 @@ def test_token_management_api(setup_torchserve):
assert response.status_code == 200, "Token check failed"


# Test to register workflow using managment token
def test_workflow(setup_torchserve):
key = read_key_file("management")
header = {"Authorization": f"Bearer {key}"}

response = requests.post(
url="http://localhost:8081/workflows?url=smtest.war&workflow_name=smtest",
headers=header,
)

assert response.status_code == 200, "Token check failed"


# Test workflow register without token so it fails
def test_workflow_fail(setup_torchserve):
response = requests.post(
url="http://localhost:8081/workflows?url=smtest.war&workflow_name=smtest"
)

assert response.status_code == 400, "Token check failed"


# Test workflow inference using inference token
def test_workflow_inference(setup_torchserve):
key = read_key_file("inference")
header = {"Authorization": f"Bearer {key}"}

response = requests.post(
url="http://localhost:8080/wfpredict/smtest",
files={"data": open(data_file_zero, "rb")},
headers=header,
)

assert response.status_code == 200, "Token check failed"


# Test workflow inference without token so it fails
def test_workflow_inference_fail(setup_torchserve):
response = requests.post(
url="http://localhost:8080/wfpredict/smtest",
files={"data": open(data_file_zero, "rb")},
)

assert response.status_code == 400, "Token check failed"


# Test expiration time and regenerating new management and inference keys
def test_token_expiration_time(setup_torchserve_expiration):
key = read_key_file("management")
Expand Down Expand Up @@ -292,3 +379,40 @@ def test_priority_env_cmd(monkeypatch):
test_utils.stop_torchserve()

assert response.status_code == 400, "Token check failed"


# Test https management api
def test_management_api_https(setup_torchserve_https):
key = read_key_file("management")
header = {"Authorization": f"Bearer {key}"}
response = requests.get(
"https://localhost:8081/models/mnist", headers=header, verify=False
)

assert response.status_code == 200, "Token check failed"


# Test https describe model API with incorrect token and no token
def test_managament_api_with_incorrect_token_https(setup_torchserve_https):
# Using random key
header = {"Authorization": "Bearer abcd1234"}
response = requests.get(
f"https://localhost:8081/models/mnist", headers=header, verify=False
)

assert response.status_code == 400, "Token check failed"


# Test https inference API with token enabled
def test_inference_api_with_token_https(setup_torchserve_https):
key = read_key_file("inference")
header = {"Authorization": f"Bearer {key}"}

response = requests.post(
url="https://localhost:8080/predictions/mnist",
files={"data": open(data_file_zero, "rb")},
headers=header,
verify=False,
)

assert response.status_code == 200, "Token check failed"
4 changes: 4 additions & 0 deletions test/resources/config_https.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
inference_address=https://127.0.0.1:8080
management_address=https://127.0.0.1:8081
private_key_file=../resources/key.pem
certificate_file=../resources/certs.pem
2 changes: 2 additions & 0 deletions test/resources/config_token_workflow.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
workflow_store=../resources/workflow
model_store=../resources/models
Binary file added test/resources/models/mnist.mar
Binary file not shown.
Binary file added test/resources/workflow/smtest.war
Binary file not shown.

0 comments on commit f717497

Please sign in to comment.