Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

truss predict bug fixes / unit tests #723

Merged
merged 22 commits into from
Nov 16, 2023

Conversation

helenlyang
Copy link
Contributor

@helenlyang helenlyang commented Nov 8, 2023

This PR addresses several issues with truss predict:

  • Unifies logic for getting the deployment version given a model_name or model_id
  • Supports using the --published flag with --model (i.e. model ID)
    • The above two points result in different and potentially unexpected behavior for users. Will leave this for a future PR--it should be bundled with changes to the API calls on the Baseten UI and more transparency in Truss CLI on what version is being used in predict
  • Calling truss predict --published returns production version instead of the first non-development version
  • Adds unit test coverage for helper functions in core and for BasetenRemote

Testing

Unit tests

  • poetry run pytest truss/tests/remote/baseten/test_core.py
  • poetry run pytest truss/tests/remote/baseten/test_remote.py

Truss CLI

Tested manually on a model with four published deployments and one development deployment:
image

For the production deployment, I set the predict function of my Model to return "production model" as its output. For the development deployment, I return "development model". The other deployments simply return the model_input.

Model version ID:

@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ # production version
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ poetry run truss predict --model-deployment qrj6d03 --data {}
? 🎮 Which remote do you want to connect to? prod
{
  "message": "production model"
}
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ # development version
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ poetry run truss predict --model-deployment qjd502q --data {}
? 🎮 Which remote do you want to connect to? prod
{
  "message": "development model"
}
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ # published, non-production version
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ poetry run truss predict --model-deployment qvvd0eq --data {}
? 🎮 Which remote do you want to connect to? prod
{}

Model name:

@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ # without --published; should use dev version
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ poetry run truss predict --data {}
? 🎮 Which remote do you want to connect to? prod
{
  "message": "development model"
}
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ # with --published; should use production version
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ poetry run truss predict --published --data {}
? 🎮 Which remote do you want to connect to? prod
{
  "message": "production model"
}

Model ID:

@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ # should use production version
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ poetry run truss predict --model 03ydnk43 --data {}
? 🎮 Which remote do you want to connect to? prod
{
  "message": "production model"
}

Testing on a second model with no production deployment:

@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ # if no production version exists, use the development version
@helenlyang ➜ /workspaces/truss/my-test-model (helenyang/bt-8803-fix-truss-predict) $ poetry run truss predict --model yqvvy0jq --data {}
? 🎮 Which remote do you want to connect to? prod
{
  "message": "development model"
}

* adds API for returning versions given model_id
* separate functions for getting model and version IDs and finding matching version
* service URL only uses model_versions endpoint
* check is_primary to find production deployment
Copy link

linear bot commented Nov 8, 2023

BT-8803 Fix `truss predict`

Problem to Fix

  1. Model name logic is different to the model ID logic
  2. Not respecting the —published flag if you pass a model-id
  3. Errors are very hard to read (separate PR)
  4. Logic is not unit tested

Correct truss predict logic:

  • If not published, there is only one development deployment, so use that
  • If it is published, use the production deployment (primary_version is set)
  • If there is no production deployment, use the latest published deployment

Repro steps:

  • I did truss push to build my truss. The build failed
  • I made a change, and did truss push again. The build succeeded, the model successfully deployed
  • I did truss predict , and got “Model is unhealthy, it is not ready to make predictions”

@@ -157,23 +157,22 @@ def models(self):
def get_model(self, model_name):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to get_model should be safe--the current callsites are https://github.com/basetenlabs/truss/blob/8901a0121dda36b822fbabe62f8a172ba72c6c6e/truss/remote/baseten/core.py#L50C1-L61 and are updated in this PR

@@ -184,18 +183,19 @@ def get_model_by_id(self, model_id: str):
model(id: "{model_id}") {{
name
id
primary_version{{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_model_by_id changes should also be safe, since this was the only callsite:

model = self._api.get_model_by_id(model_identifier.value)
model_id = model["model"]["id"]
model_version_id = model["model"]["primary_version"]["id"]

for version in versions:
if version["is_draft"] is True:
if version["is_primary"] and not version["is_draft"]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's possible for is_primary to be true for development models, but thought it'd be safer to check is_draft in case the is_primary definition ever changes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right -- it's not possible for is_primary to be true for development models. I think it's a little cleaner to leave this as an is_primary check just to have a cleaner definition of what "production" means (ie: is_primary = True), which is easier to understand. But I dont' feel strongly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@helenlyang helenlyang changed the title Helenyang/bt 8803 fix truss predict Fix inconsistencies in truss predict Nov 13, 2023
@helenlyang helenlyang marked this pull request as ready for review November 13, 2023 23:16
@helenlyang helenlyang requested a review from squidarth November 13, 2023 23:17
Copy link
Collaborator

@squidarth squidarth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very complex code, so awesome work working through it! Left some code comments, but also in the PR description, it would be awesome if you could drop a few CLI invocations that are worth testing again

for version in versions:
if version["is_draft"] is True:
return version
raise ValueError("No development version found")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something to consider here -- what if instead of using exceptions to represent the cases where there's no version, we returned None (so this function return type becomes Optional[dict])? And then the callers can decide how they want to handle that case (vs. having to catch exceptions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's nicer--changed these functions except for get_dev_version_info, which has some callsites in truss watch / patch logic. I'll leave updating that function and its callsites for a future PR to avoid increasing the scope of this one

truss/cli/cli.py Outdated
raise click.UsageError(
"Cannot use --published with --model or --model-version."
)
if published and model_version_id:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we're using the /models endpoint if the person passes --model, we're ignoring the published flag again. So I think we should keep this error message the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG, reverted this change

for version in versions:
if version["is_draft"] is True:
if version["is_primary"] and not version["is_draft"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right -- it's not possible for is_primary to be true for development models. I think it's a little cleaner to leave this as an is_primary check just to have a cleaner definition of what "production" means (ie: is_primary = True), which is easier to understand. But I dont' feel strongly

# Return the production deployment version.
try:
return get_prod_version_info_from_versions(model_versions)
except ValueError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per comment in the other file, let's change get_prod_version_info_from_versions to return None in the case where there's no prod version, so we don't ahve to use exceptions for control flow

truss/remote/baseten/remote.py Outdated Show resolved Hide resolved
_TEST_REMOTE_URL = "http://test_remote.com"


def test_get_service_by_version_id():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also have test for a no model w/ version exists?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, this made me realize that we should probably raise a UsageError rather than propagating the Baseten ApiError if the model version doesn't exist. Added a try / except there and a unit test (let me know if there's a better way to mock errors)

Copy link
Contributor Author

@helenlyang helenlyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! Added CLI commands to the description.

Following up on our conversation about model_id vs model_name inconsistency--I updated the PR description with our decision to stick with the current behavior. However, I do think we should consider unifying these code paths in the future.

I think it might make model version resolution less of a black box if we logged which model version is being called in truss push, e.g. f"Calling {'production' if published else 'development'} version {model_version_id} of model {model_id}". That could make changing truss predict by model_id to match model_name behavior less surprising to users--but could also be useful regardless

for version in versions:
if version["is_draft"] is True:
return version
raise ValueError("No development version found")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's nicer--changed these functions except for get_dev_version_info, which has some callsites in truss watch / patch logic. I'll leave updating that function and its callsites for a future PR to avoid increasing the scope of this one

for version in versions:
if version["is_draft"] is True:
if version["is_primary"] and not version["is_draft"]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

_TEST_REMOTE_URL = "http://test_remote.com"


def test_get_service_by_version_id():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, this made me realize that we should probably raise a UsageError rather than propagating the Baseten ApiError if the model version doesn't exist. Added a try / except there and a unit test (let me know if there's a better way to mock errors)

truss/cli/cli.py Outdated
raise click.UsageError(
"Cannot use --published with --model or --model-version."
)
if published and model_version_id:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG, reverted this change

Comment on lines 131 to 136
# TODO(helen): make this consistent with getting the service via
# model_name and respect --published in service_url_path.
model_version = BasetenRemote._get_matching_version(
model_versions, published
)
model_version_id = model_version["id"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is now weirdly in between the original implementation and my attempt to unify the model_id and model_name code paths. It might be more readable to revert this back to the original model_id code path and leave these changes for if / when we unify model_id and model_name paths

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think for now it might better to revert back for now. It's one less call at least.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG, done

Copy link
Collaborator

@squidarth squidarth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM! Awesome work here. Just left a couple more comments, feel free to resolve & ship

Comment on lines 131 to 136
# TODO(helen): make this consistent with getting the service via
# model_name and respect --published in service_url_path.
model_version = BasetenRemote._get_matching_version(
model_versions, published
)
model_version_id = model_version["id"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think for now it might better to revert back for now. It's one less call at least.

return (query_result["id"], query_result["versions"])


def get_dev_version_info_from_versions(versions: List[dict]) -> Optional[dict]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this private (ie: def _get_dev_version_from_versions()...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left this as public so we can get the dev version without requiring another GraphQL query (api.get_model is called inside get_dev_version_info). I added docstrings to try to make this more clear, let me know if that makes sense

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, sounds good

return (query_result["id"], query_result["versions"])


def get_model_versions_info_by_id(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of these functions have info which I think doesn't add a lot, I think this could just be def get_model_versions_by_id

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

versions = model["model"]["versions"]
dev_version = get_dev_version_info_from_versions(versions)
if not dev_version:
# TODO(helen): return dev_version in all cases rather than raising an error
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try to do this todo in a follow-up should be fairly straightforward

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG!

@helenlyang helenlyang changed the title Fix inconsistencies in truss predict truss predict bug fixes / unit tests Nov 16, 2023
@helenlyang helenlyang merged commit b5cccc5 into main Nov 16, 2023
@helenlyang helenlyang deleted the helenyang/bt-8803-fix-truss-predict branch November 16, 2023 18:16
helenlyang added a commit that referenced this pull request Nov 29, 2023
Follow-up to #723. This updates `core.get_dev_version` to return None if a development version doesn’t exist and 
updates callsites to raise errors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants