Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better example code for models using chains framework #1347

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions truss-chains/truss_chains/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,20 +219,36 @@ def _example_chainlet_code() -> str:
# called on erroneous code branches (which will not be triggered if
# `example_chainlet` is free of errors).
try:
from truss_chains import example_chainlet
from truss_chains.reference_code import reference_chainlet
# If `example_chainlet` fails validation and `_example_chainlet_code` is
# called as a result of that, we have a circular import ("partially initialized
# module 'truss_chains.example_chainlet' ...").
except AttributeError:
logging.error("example_chainlet` is broken.", exc_info=True, stack_info=True)
logging.error("`reference_chainlet` is broken.", exc_info=True, stack_info=True)
return "<EXAMPLE CODE MISSING/BROKEN>"

example_name = example_chainlet.HelloWorld.name
source = pathlib.Path(example_chainlet.__file__).read_text()
example_name = reference_chainlet.HelloWorld.name
return _get_cls_source(reference_chainlet.__file__, example_name)


@functools.cache
def _example_model_code() -> str:
try:
from truss_chains.reference_code import reference_model
except AttributeError:
logging.error("`reference_model` is broken.", exc_info=True, stack_info=True)
return "<EXAMPLE CODE MISSING/BROKEN>"

example_name = reference_model.HelloWorld.name
return _get_cls_source(reference_model.__file__, example_name)


def _get_cls_source(src_path: str, target_class_name: str) -> str:
source = pathlib.Path(src_path).read_text()
tree = ast.parse(source)
class_code = ""
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == example_name:
if isinstance(node, ast.ClassDef) and node.name == target_class_name:
# Extract the source code of the class definition
lines = source.splitlines()
class_code = "\n".join(lines[node.lineno - 1 : node.end_lineno])
Expand Down Expand Up @@ -605,7 +621,7 @@ def make_context_error_msg():
f"`{definitions.DeploymentContext}` as the last argument.\n"
f"Got arguments: `{params}`.\n"
"Example of correct `__init__` with context:\n"
f"{_example_chainlet_code()}"
f"{self._example_code()}"
)

if not params:
Expand Down Expand Up @@ -637,7 +653,7 @@ def make_context_error_msg():
_collect_error(
f"Incorrect default value `{param.default}` for `context` argument. "
"Example of correct `__init__` with dependencies:\n"
f"{_example_chainlet_code()}",
f"{self._example_code()}",
_ErrorKind.TYPE_ERROR,
self._location,
)
Expand Down Expand Up @@ -686,7 +702,7 @@ def _validate_dependency_param(
f"The init argument name `{definitions.CONTEXT_ARG_NAME}` is reserved for "
"the optional context argument, which must be trailing if used. Example "
"of correct `__init__` with context:\n"
f"{_example_chainlet_code()}",
f"{self._example_code()}",
_ErrorKind.TYPE_ERROR,
self._location,
)
Expand All @@ -697,7 +713,7 @@ def _validate_dependency_param(
"dependency Chainlets with default values from `chains.depends`-directive. "
f"Got `{param}`.\n"
f"Example of correct `__init__` with dependencies:\n"
f"{_example_chainlet_code()}",
f"{self._example_code()}",
_ErrorKind.TYPE_ERROR,
self._location,
)
Expand Down Expand Up @@ -731,6 +747,12 @@ def _validate_dependency_param(
)
return param.default # The Marker.

@functools.cache
def _example_code(self) -> str:
if self._cls.entity_type == "Model":
return _example_model_code()
return _example_chainlet_code()


def _validate_remote_config(
cls: Type[definitions.ABCChainlet], location: _ErrorLocation
Expand Down
10 changes: 10 additions & 0 deletions truss-chains/truss_chains/reference_code/reference_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import truss_chains as chains


class HelloWorld(chains.ModelBase):
def __init__(self, context: chains.DeploymentContext = chains.depends_context()):
self._call_count = 0

def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count
4 changes: 2 additions & 2 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,12 +827,12 @@ def init_chain(directory: Optional[Path]) -> None:

def _load_example_chainlet_code() -> str:
try:
from truss_chains import example_chainlet
from truss_chains.reference_code import reference_chainlet
# if the example is faulty, a validation error would be raised
except Exception as e:
raise Exception("Failed to load starter code. Please notify support.") from e

source = Path(example_chainlet.__file__).read_text()
source = Path(reference_chainlet.__file__).read_text()
return source


Expand Down
Loading