diff --git a/docs/index.md b/docs/index.md index 1d858105..1a5cdcdf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1143,6 +1143,10 @@ expose a namespace that other code can access through dot-notation. PluginResult initialized with either a dict or an object that exposes the namespace through Python getattr(). +If your plugin generates some special kind of data value which should be serializable +as a primitive type (usually a string), subclass PluginResult and add a `simplify` +method to your PluginResult. That method should return a Python primitive value. + In the rare event that a plugin has a function which need its arguments to be passed to it unevaluated, for later (perhaps conditional) evaluation, you can use the `@snowfakery.lazy decorator`. Then you can evaluate the arguments with `self.context.evaluate()`. For example: @@ -1179,7 +1183,12 @@ This would output an `OBJ` row with values: {'id': 1, 'some_value': 'abc : abc', 'some_value_2': '1 : 2'}) ``` -## Using Snowfakery within CumulusC +Occasionally you might write a plugin which needs to evaluate its +parameters lazily but doesn't care about the internals of the values +because it just returns it to some parent context. In that case, +use `context.evaluate_raw` instead of `context.evaluate`. + +## Using Snowfakery within CumulusCI You can verify that a Snowfakery-compatible version of CumulusCI is installed like this: diff --git a/snowfakery/data_generator_runtime.py b/snowfakery/data_generator_runtime.py index 29daaae7..2f8f512c 100644 --- a/snowfakery/data_generator_runtime.py +++ b/snowfakery/data_generator_runtime.py @@ -1,10 +1,9 @@ from collections import defaultdict -from datetime import date, datetime +from datetime import date from contextlib import contextmanager from enum import Enum, auto -from typing import Optional, Dict, List, Sequence, Mapping, NamedTuple, Union -from numbers import Number +from typing import Optional, Dict, List, Sequence, Mapping, NamedTuple import jinja2 import yaml @@ -625,9 +624,3 @@ def output_batches( continuing = bool(continuation_data) interpreter.loop_over_templates_until_finished(runtimecontext, continuing) return interpreter.globals - - -Scalar = Union[str, Number, date, datetime, None] -FieldValue = Union[ - None, Scalar, ObjectRow, tuple, NicknameSlot, snowfakery.plugins.PluginResult -] diff --git a/snowfakery/data_generator_runtime_object_model.py b/snowfakery/data_generator_runtime_object_model.py index 90390eed..df104b44 100644 --- a/snowfakery/data_generator_runtime_object_model.py +++ b/snowfakery/data_generator_runtime_object_model.py @@ -3,8 +3,7 @@ evaluate_function, ObjectRow, RuntimeContext, - FieldValue, - Scalar, + NicknameSlot, ) from contextlib import contextmanager from typing import Union, Dict, Sequence, Optional, cast @@ -19,10 +18,12 @@ DataGenValueError, fix_exception, ) +from .plugins import Scalar, PluginResult # objects that represent the hierarchy of a data generator. # roughly similar to the YAML structure but with domain-specific objects Definition = Union["ObjectTemplate", "SimpleValue", "StructuredValue"] +FieldValue = Union[None, Scalar, ObjectRow, tuple, NicknameSlot, PluginResult] class FieldDefinition(ABC): diff --git a/snowfakery/output_streams.py b/snowfakery/output_streams.py index a9058a85..40de2b4b 100644 --- a/snowfakery/output_streams.py +++ b/snowfakery/output_streams.py @@ -70,6 +70,11 @@ def cleanup(self, field_name, field_value, sourcetable, row): return self.flatten(sourcetable, field_name, row, field_value) else: encoder = self.encoders.get(type(field_value)) + if not encoder and hasattr(field_value, "simplify"): + + def encoder(field_value): + return field_value.simplify() + if not encoder: raise TypeError( f"No encoder found for {type(field_value)} in {self.__class__.__name__} " diff --git a/snowfakery/plugins.py b/snowfakery/plugins.py index 09322d89..a654766e 100644 --- a/snowfakery/plugins.py +++ b/snowfakery/plugins.py @@ -1,5 +1,6 @@ -from typing import Any, Callable, Mapping +from typing import Any, Callable, Mapping, Union from importlib import import_module +from datetime import date, datetime import yaml from yaml.representer import Representer @@ -7,6 +8,11 @@ import snowfakery.data_gen_exceptions as exc +from numbers import Number + + +Scalar = Union[str, Number, date, datetime, None] + class SnowfakeryPlugin: """Base class for all plugins. @@ -74,9 +80,20 @@ def context_vars(self): self.plugin.__class__.__name__ ) - def evaluate(self, field_definition): + def evaluate_raw(self, field_definition): + """Evaluate the contents of a field definition""" return field_definition.render(self.interpreter.current_context) + def evaluate(self, field_definition): + """Evaluate the contents of a field definition and simplify to a primitive value.""" + rc = self.evaluate_raw(field_definition) + if isinstance(rc, Scalar.__args__): + return rc + elif hasattr(rc, "simplify"): + return rc.simplify() + else: + raise f"Cannot simplify {field_definition}. Perhaps should have used evaluate_raw?" + def lazy(func: Any) -> Callable: """A lazy function is one that expects its arguments to be unparsed""" @@ -112,7 +129,7 @@ class PluginResult: PluginResults can be initialized with a dict or dict-like object. - PluginResults are serialized to contniuation files as dicts.""" + PluginResults are serialized to continuation files as dicts.""" def __init__(self, result: Mapping): self.result = result @@ -123,6 +140,12 @@ def __getattr__(self, name): def __reduce__(self): return (self.__class__, (dict(self.result),)) + def __repr__(self): + return f"<{self.__class__} {repr(self.result)}>" + + def __str__(self): + return str(self.result) + # round-trip PluginResult objects through continuation YAML if needed. yaml.SafeDumper.add_representer(PluginResult, Representer.represent_object) diff --git a/snowfakery/template_funcs.py b/snowfakery/template_funcs.py index 4118e7f1..d864ea71 100644 --- a/snowfakery/template_funcs.py +++ b/snowfakery/template_funcs.py @@ -177,12 +177,12 @@ def random_choice(self, *choices): raise ValueError("No choices supplied!") if getattr(choices[0], "function_name", None) == "choice": - choices = [self.context.evaluate(choice) for choice in choices] + choices = [self.context.evaluate_raw(choice) for choice in choices] rc = weighted_choice(choices) else: rc = random.choice(choices) if hasattr(rc, "render"): - rc = self.context.evaluate(rc) + rc = self.context.evaluate_raw(rc) return rc @lazy @@ -218,7 +218,7 @@ def if_(self, *choices: FieldDefinition): if not choices: raise ValueError("No choices supplied!") - choices = [self.context.evaluate(choice) for choice in choices] + choices = [self.context.evaluate_raw(choice) for choice in choices] for when, choice in choices[:-1]: if when is None: raise SyntaxError( @@ -231,7 +231,7 @@ def if_(self, *choices: FieldDefinition): ) rc = next(true_choices, choices[-1][-1]) # default to last choice if hasattr(rc, "render"): - rc = self.context.evaluate(rc) + rc = self.context.evaluate_raw(rc) return rc setattr(Functions, "if", Functions.if_) diff --git a/tests/test_custom_plugins_and_providers.py b/tests/test_custom_plugins_and_providers.py index a1c77633..aa2c18db 100644 --- a/tests/test_custom_plugins_and_providers.py +++ b/tests/test_custom_plugins_and_providers.py @@ -1,13 +1,17 @@ from io import StringIO import math +import operator from snowfakery.data_generator import generate from snowfakery import SnowfakeryPlugin, lazy +from snowfakery.plugins import PluginResult from snowfakery.data_gen_exceptions import ( DataGenError, DataGenTypeError, DataGenImportError, ) +from snowfakery.output_streams import JSONOutputStream + from unittest import mock import pytest @@ -42,6 +46,38 @@ def return_bad_type(self, value): return int # function +class MyEvaluator(PluginResult): + def __init__(self, operator, *operands): + super().__init__({"operator": operator, "operands": operands}) + + def _eval(self): + op = getattr(operator, self.result["operator"]) + vals = self.result["operands"] + rc = op(*vals) + return self.result.setdefault("value", str(rc)) + + def __str__(self): + return str(self._eval()) + + def simplify(self): + return int(self._eval()) + + +class EvalPlugin(SnowfakeryPlugin): + class Functions: + @lazy + def add(self, val1, val2): + return MyEvaluator( + "add", self.context.evaluate(val1), self.context.evaluate(val2) + ) + + @lazy + def sub(self, val1, val2): + return MyEvaluator( + "sub", self.context.evaluate(val1), self.context.evaluate(val2) + ) + + class TestCustomFakerProvider: @mock.patch(write_row_path) def test_custom_faker_provider(self, write_row_mock): @@ -129,6 +165,25 @@ def test_math_deconstructed(self, write_row_mock): generate(StringIO(yaml), {}) assert row_values(write_row_mock, 0, "twelve") == 12 + @mock.patch(write_row_path) + def test_stringification(self, write_row): + yaml = """ + - plugin: tests.test_custom_plugins_and_providers.EvalPlugin + - object: OBJ + fields: + some_value: + - EvalPlugin.add: + - 1 + - EvalPlugin.sub: + - 5 + - 3 + """ + with StringIO() as s: + output_stream = JSONOutputStream(s) + generate(StringIO(yaml), {}, output_stream) + output_stream.close() + assert eval(s.getvalue())[0]["some_value"] == 3 + class PluginThatNeedsState(SnowfakeryPlugin): class Functions: