-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
truss predict
bug fixes / unit tests
#723
Conversation
* adds API for returning versions given model_id * separate functions for getting model and version IDs and finding matching version * service URL only uses model_versions endpoint * check is_primary to find production deployment
BT-8803 Fix `truss predict`
Problem to Fix
Correct truss predict logic:
Repro steps:
|
@@ -157,23 +157,22 @@ def models(self): | |||
def get_model(self, model_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to get_model
should be safe--the current callsites are https://github.com/basetenlabs/truss/blob/8901a0121dda36b822fbabe62f8a172ba72c6c6e/truss/remote/baseten/core.py#L50C1-L61 and are updated in this PR
truss/remote/baseten/api.py
Outdated
@@ -184,18 +183,19 @@ def get_model_by_id(self, model_id: str): | |||
model(id: "{model_id}") {{ | |||
name | |||
id | |||
primary_version{{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_model_by_id
changes should also be safe, since this was the only callsite:
truss/truss/remote/baseten/remote.py
Lines 104 to 106 in 8901a01
model = self._api.get_model_by_id(model_identifier.value) | |
model_id = model["model"]["id"] | |
model_version_id = model["model"]["primary_version"]["id"] |
truss/remote/baseten/core.py
Outdated
for version in versions: | ||
if version["is_draft"] is True: | ||
if version["is_primary"] and not version["is_draft"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's possible for is_primary
to be true for development models, but thought it'd be safer to check is_draft
in case the is_primary
definition ever changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right -- it's not possible for is_primary
to be true for development models. I think it's a little cleaner to leave this as an is_primary
check just to have a cleaner definition of what "production" means (ie: is_primary = True), which is easier to understand. But I dont' feel strongly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
truss predict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is very complex code, so awesome work working through it! Left some code comments, but also in the PR description, it would be awesome if you could drop a few CLI invocations that are worth testing again
truss/remote/baseten/core.py
Outdated
for version in versions: | ||
if version["is_draft"] is True: | ||
return version | ||
raise ValueError("No development version found") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something to consider here -- what if instead of using exceptions to represent the cases where there's no version, we returned None (so this function return type becomes Optional[dict]
)? And then the callers can decide how they want to handle that case (vs. having to catch exceptions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's nicer--changed these functions except for get_dev_version_info
, which has some callsites in truss watch / patch logic. I'll leave updating that function and its callsites for a future PR to avoid increasing the scope of this one
truss/cli/cli.py
Outdated
raise click.UsageError( | ||
"Cannot use --published with --model or --model-version." | ||
) | ||
if published and model_version_id: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now that we're using the /models endpoint if the person passes --model, we're ignoring the published flag again. So I think we should keep this error message the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG, reverted this change
truss/remote/baseten/core.py
Outdated
for version in versions: | ||
if version["is_draft"] is True: | ||
if version["is_primary"] and not version["is_draft"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right -- it's not possible for is_primary
to be true for development models. I think it's a little cleaner to leave this as an is_primary
check just to have a cleaner definition of what "production" means (ie: is_primary = True), which is easier to understand. But I dont' feel strongly
truss/remote/baseten/remote.py
Outdated
# Return the production deployment version. | ||
try: | ||
return get_prod_version_info_from_versions(model_versions) | ||
except ValueError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
per comment in the other file, let's change get_prod_version_info_from_versions
to return None
in the case where there's no prod version, so we don't ahve to use exceptions for control flow
_TEST_REMOTE_URL = "http://test_remote.com" | ||
|
||
|
||
def test_get_service_by_version_id(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also have test for a no model w/ version exists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, this made me realize that we should probably raise a UsageError rather than propagating the Baseten ApiError if the model version doesn't exist. Added a try / except there and a unit test (let me know if there's a better way to mock errors)
This reverts commit 2d7bc49.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! Added CLI commands to the description.
Following up on our conversation about model_id
vs model_name
inconsistency--I updated the PR description with our decision to stick with the current behavior. However, I do think we should consider unifying these code paths in the future.
I think it might make model version resolution less of a black box if we logged which model version is being called in truss push
, e.g. f"Calling {'production' if published else 'development'} version {model_version_id} of model {model_id}"
. That could make changing truss predict
by model_id
to match model_name
behavior less surprising to users--but could also be useful regardless
truss/remote/baseten/core.py
Outdated
for version in versions: | ||
if version["is_draft"] is True: | ||
return version | ||
raise ValueError("No development version found") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's nicer--changed these functions except for get_dev_version_info
, which has some callsites in truss watch / patch logic. I'll leave updating that function and its callsites for a future PR to avoid increasing the scope of this one
truss/remote/baseten/core.py
Outdated
for version in versions: | ||
if version["is_draft"] is True: | ||
if version["is_primary"] and not version["is_draft"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
_TEST_REMOTE_URL = "http://test_remote.com" | ||
|
||
|
||
def test_get_service_by_version_id(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, this made me realize that we should probably raise a UsageError rather than propagating the Baseten ApiError if the model version doesn't exist. Added a try / except there and a unit test (let me know if there's a better way to mock errors)
truss/cli/cli.py
Outdated
raise click.UsageError( | ||
"Cannot use --published with --model or --model-version." | ||
) | ||
if published and model_version_id: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG, reverted this change
truss/remote/baseten/remote.py
Outdated
# TODO(helen): make this consistent with getting the service via | ||
# model_name and respect --published in service_url_path. | ||
model_version = BasetenRemote._get_matching_version( | ||
model_versions, published | ||
) | ||
model_version_id = model_version["id"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation is now weirdly in between the original implementation and my attempt to unify the model_id
and model_name
code paths. It might be more readable to revert this back to the original model_id
code path and leave these changes for if / when we unify model_id
and model_name
paths
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think for now it might better to revert back for now. It's one less call at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG, done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, LGTM! Awesome work here. Just left a couple more comments, feel free to resolve & ship
truss/remote/baseten/remote.py
Outdated
# TODO(helen): make this consistent with getting the service via | ||
# model_name and respect --published in service_url_path. | ||
model_version = BasetenRemote._get_matching_version( | ||
model_versions, published | ||
) | ||
model_version_id = model_version["id"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think for now it might better to revert back for now. It's one less call at least.
truss/remote/baseten/core.py
Outdated
return (query_result["id"], query_result["versions"]) | ||
|
||
|
||
def get_dev_version_info_from_versions(versions: List[dict]) -> Optional[dict]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this private (ie: def _get_dev_version_from_versions()...
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left this as public so we can get the dev version without requiring another GraphQL query (api.get_model
is called inside get_dev_version_info
). I added docstrings to try to make this more clear, let me know if that makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, sounds good
truss/remote/baseten/core.py
Outdated
return (query_result["id"], query_result["versions"]) | ||
|
||
|
||
def get_model_versions_info_by_id( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of these functions have info
which I think doesn't add a lot, I think this could just be def get_model_versions_by_id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
versions = model["model"]["versions"] | ||
dev_version = get_dev_version_info_from_versions(versions) | ||
if not dev_version: | ||
# TODO(helen): return dev_version in all cases rather than raising an error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's try to do this todo in a follow-up should be fairly straightforward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG!
truss predict
truss predict
bug fixes / unit tests
Follow-up to #723. This updates `core.get_dev_version` to return None if a development version doesn’t exist and updates callsites to raise errors.
This PR addresses several issues with
truss predict
:Unifies logic for getting the deployment version given amodel_name
ormodel_id
Supports using the--published
flag with--model
(i.e. model ID)predict
truss predict --published
returns production version instead of the first non-development versioncore
and forBasetenRemote
Testing
Unit tests
poetry run pytest truss/tests/remote/baseten/test_core.py
poetry run pytest truss/tests/remote/baseten/test_remote.py
Truss CLI
Tested manually on a model with four published deployments and one development deployment:
![image](https://private-user-images.githubusercontent.com/14193683/282966177-89880616-791a-4b38-a8f7-4bd5f984fc74.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzg4MTQyMDEsIm5iZiI6MTczODgxMzkwMSwicGF0aCI6Ii8xNDE5MzY4My8yODI5NjYxNzctODk4ODA2MTYtNzkxYS00YjM4LWE4ZjctNGJkNWY5ODRmYzc0LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDYlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA2VDAzNTE0MVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWRiYmJhNDk3ODI5NzIxYjU4ZjIyOWNhOTVlYjgxNzg5N2ZhNzk4N2JiNmRiZmMzNzNkYTQ1YTFlMTAxZGQ0ZjUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.UDnHWEG77i4bMlQGGOrqICDeRK4QCk-CDr5t8mALfcE)
For the production deployment, I set the
predict
function of myModel
to return"production model"
as its output. For the development deployment, I return"development model"
. The other deployments simply return themodel_input
.Model version ID:
Model name:
Model ID:
Testing on a second model with no production deployment: