Skip to content

Commit

Permalink
Bug bash improvements to Python DX, better error messages (#1346)
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen authored Jan 29, 2025
1 parent ddbbb3c commit 7812b1d
Show file tree
Hide file tree
Showing 10 changed files with 302 additions and 187 deletions.
7 changes: 7 additions & 0 deletions truss-chains/tests/import/model_without_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class ClassWithoutModelInheritance:
def __init__(self):
self._call_count = 0

async def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count
19 changes: 19 additions & 0 deletions truss-chains/tests/import/standalone_with_multiple_entrypoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import truss_chains as chains


class FirstModel(chains.ModelBase):
def __init__(self):
self._call_count = 0

async def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count


class SecondModel(chains.ModelBase):
def __init__(self):
self._call_count = 0

async def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count
24 changes: 17 additions & 7 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
def test_chain():
with ensure_kill_all():
chain_root = TEST_ROOT / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "ItestChain"
) as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
Expand Down Expand Up @@ -109,7 +111,9 @@ def test_chain():
@pytest.mark.asyncio
async def test_chain_local():
chain_root = TEST_ROOT / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "ItestChain"
) as entrypoint:
with public_api.run_local():
with pytest.raises(ValueError):
# First time `SplitTextFailOnce` raises an error and
Expand Down Expand Up @@ -143,7 +147,9 @@ def test_streaming_chain():
with ensure_kill_all():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "Consumer"
) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
Expand Down Expand Up @@ -179,7 +185,7 @@ def test_streaming_chain():
async def test_streaming_chain_local():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
with framework.ChainletImporter.import_target(chain_root, "Consumer") as entrypoint:
with public_api.run_local():
result = await entrypoint().run_remote(cause_error=False)
print(result)
Expand All @@ -201,7 +207,7 @@ def test_numpy_chain(mode):
target = "HostBinary"
with ensure_kill_all():
chain_root = TEST_ROOT / "numpy_and_binary" / "chain.py"
with framework.import_target(chain_root, target) as entrypoint:
with framework.ChainletImporter.import_target(chain_root, target) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
Expand All @@ -221,7 +227,9 @@ def test_numpy_chain(mode):
async def test_timeout():
with ensure_kill_all():
chain_root = TEST_ROOT / "timeout" / "timeout_chain.py"
with framework.import_target(chain_root, "TimeoutChain") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "TimeoutChain"
) as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
Expand Down Expand Up @@ -288,7 +296,9 @@ def test_traditional_truss():
def test_custom_health_checks_chain():
with ensure_kill_all():
chain_root = TEST_ROOT / "custom_health_checks" / "custom_health_checks.py"
with framework.import_target(chain_root, "CustomHealthChecks") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "CustomHealthChecks"
) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
Expand Down
18 changes: 18 additions & 0 deletions truss-chains/tests/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import contextlib
import logging
import pathlib
import re
from typing import AsyncIterator, Iterator, List

Expand All @@ -12,6 +13,7 @@

utils.setup_dev_logging(logging.DEBUG)

TEST_ROOT = pathlib.Path(__file__).parent.resolve()

# Assert that naive chainlet initialization is detected and prevented. #################

Expand Down Expand Up @@ -668,3 +670,19 @@ def is_healthy(self) -> str: # type: ignore[misc]

async def run_remote(self) -> str:
return ""


def test_import_model_requires_entrypoint():
model_src = TEST_ROOT / "import" / "model_without_inheritance.py"
match = r"No Model class in `.+` inherits from"
with pytest.raises(ValueError, match=match), _raise_errors():
with framework.ModelImporter.import_target(model_src):
pass


def test_import_model_requires_single_entrypoint():
model_src = TEST_ROOT / "import" / "standalone_with_multiple_entrypoints.py"
match = r"Multiple Model classes in `.+` inherit from"
with pytest.raises(ValueError, match=match), _raise_errors():
with framework.ModelImporter.import_target(model_src):
pass
3 changes: 1 addition & 2 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
RemoteErrorDetail,
RPCOptions,
)
from truss_chains.framework import ChainletBase, ModelBase
from truss_chains.public_api import (
ChainletBase,
ModelBase,
depends,
depends_context,
mark_entrypoint,
Expand Down
4 changes: 2 additions & 2 deletions truss-chains/truss_chains/deployment/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def gen_truss_model_from_source(
# TODO(nikhil): Improve detection of directory structure, since right now
# we assume a flat structure
root_dir = model_src.absolute().parent
with framework.import_target(model_src) as entrypoint_cls:
with framework.ModelImporter.import_target(model_src) as entrypoint_cls:
descriptor = framework.get_descriptor(entrypoint_cls)
return gen_truss_model(
model_root=root_dir,
Expand Down Expand Up @@ -771,7 +771,7 @@ def gen_truss_chainlet(
gen_root = pathlib.Path(tempfile.gettempdir())
chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root)
logging.info(
f"Code generation for Chainlet `{chainlet_descriptor.name}` "
f"Code generation for {chainlet_descriptor.chainlet_cls.entity_type} `{chainlet_descriptor.name}` "
f"in `{chainlet_dir}`."
)
_write_truss_config_yaml(
Expand Down
6 changes: 4 additions & 2 deletions truss-chains/truss_chains/deployment/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,9 @@ def __init__(
self._remote_provider = cast(
b10_remote.BasetenRemote, remote_factory.RemoteFactory.create(remote=remote)
)
with framework.import_target(source, entrypoint) as entrypoint_cls:
with framework.ChainletImporter.import_target(
source, entrypoint
) as entrypoint_cls:
self._deployed_chain_name = name or entrypoint_cls.__name__
self._chain_root = _get_chain_root(entrypoint_cls)
chainlet_names = set(
Expand Down Expand Up @@ -733,7 +735,7 @@ def _patch(self, executor: concurrent.futures.Executor) -> None:
# Handle import errors gracefully (e.g. if user saved file, but there
# are syntax errors, undefined symbols etc.).
try:
with framework.import_target(
with framework.ChainletImporter.import_target(
self._source, self._entrypoint
) as entrypoint_cls:
chainlet_descriptors = _get_ordered_dependencies([entrypoint_cls])
Expand Down
Loading

0 comments on commit 7812b1d

Please sign in to comment.