diff --git a/TestModels/Constraints/Model/Constraints.smithy b/TestModels/Constraints/Model/Constraints.smithy index bf5b2539a7..fe41d4522c 100644 --- a/TestModels/Constraints/Model/Constraints.smithy +++ b/TestModels/Constraints/Model/Constraints.smithy @@ -115,6 +115,11 @@ list ListLessThanOrEqualToTen { member: String } +@length(min: 1, max: 10) +list ListWithConstraint { + member: MyString +} + @length(min: 1, max: 10) map MyMap { key: String, @@ -133,6 +138,12 @@ map MapLessThanOrEqualToTen { value: String, } +@length(min: 1, max: 10) +map MapWithConstraint { + key: MyString, + value: MyString, +} + // we don't do patterns yet // @pattern("^[A-Za-z]+$") // string Alphabetic @@ -160,10 +171,21 @@ integer LessThanTen // member: ComplexListElement // } -// structure ComplexListElement { -// value: String, -// blob: Blob, -// } +union UnionWithConstraint { + IntegerValue: OneToTen, + StringValue: MyString, +} + +structure ComplexStructure { + InnerString: MyString, + @required + InnerBlob: MyBlob, +} + +@length(min: 1) +list ComplexStructureList { + member: ComplexStructure, +} structure GetConstraintsInput { MyString: MyString, @@ -175,9 +197,11 @@ structure GetConstraintsInput { MyList: MyList, NonEmptyList: NonEmptyList, ListLessThanOrEqualToTen: ListLessThanOrEqualToTen, + ListWithConstraint: ListWithConstraint, MyMap: MyMap, NonEmptyMap: NonEmptyMap, MapLessThanOrEqualToTen: MapLessThanOrEqualToTen, + MapWithConstraint: MapWithConstraint, // Alphabetic: Alphabetic, OneToTen: OneToTen, myTenToTen: TenToTen, @@ -187,6 +211,8 @@ structure GetConstraintsInput { // MyComplexUniqueList: MyComplexUniqueList, MyUtf8Bytes: Utf8Bytes, MyListOfUtf8Bytes: ListOfUtf8Bytes, + UnionWithConstraint: UnionWithConstraint, + ComplexStructureList: ComplexStructureList, } structure GetConstraintsOutput { @@ -199,9 +225,11 @@ structure GetConstraintsOutput { MyList: MyList, NonEmptyList: NonEmptyList, ListLessThanOrEqualToTen: ListLessThanOrEqualToTen, + ListWithConstraint: ListWithConstraint, MyMap: MyMap, NonEmptyMap: NonEmptyMap, MapLessThanOrEqualToTen: MapLessThanOrEqualToTen, + MapWithConstraint: MapWithConstraint, // Alphabetic: Alphabetic, OneToTen: OneToTen, thatTenToTen: TenToTen, @@ -211,6 +239,8 @@ structure GetConstraintsOutput { // MyComplexUniqueList: MyComplexUniqueList, MyUtf8Bytes: Utf8Bytes, MyListOfUtf8Bytes: ListOfUtf8Bytes, + UnionWithConstraint: UnionWithConstraint, + ComplexStructureList: ComplexStructureList, } // See Comment in traits.smithy diff --git a/TestModels/Constraints/runtimes/rust/src/lib.rs b/TestModels/Constraints/runtimes/rust/src/lib.rs index 916e5b91ef..d628ae079b 100644 --- a/TestModels/Constraints/runtimes/rust/src/lib.rs +++ b/TestModels/Constraints/runtimes/rust/src/lib.rs @@ -17,6 +17,7 @@ pub mod operation; mod standard_library_conversions; mod standard_library_externs; pub mod types; +mod validation; pub mod wrapped; pub(crate) use crate::implementation_from_dafny::r#_Wrappers_Compile; pub(crate) use crate::implementation_from_dafny::simple; diff --git a/TestModels/Constraints/runtimes/rust/tests/simple_constraints_test.rs b/TestModels/Constraints/runtimes/rust/tests/simple_constraints_test.rs index feda16fdd0..354e179790 100644 --- a/TestModels/Constraints/runtimes/rust/tests/simple_constraints_test.rs +++ b/TestModels/Constraints/runtimes/rust/tests/simple_constraints_test.rs @@ -4,6 +4,8 @@ extern crate simple_constraints; mod simple_constraints_test { use simple_constraints::*; + use std::collections::HashMap; + fn client() -> Client { let config = SimpleConstraintsConfig::builder() .required_string("test string") @@ -82,4 +84,87 @@ mod simple_constraints_test { let message = result.err().expect("error").to_string(); assert!(message.contains("one_to_ten")); } + + #[tokio::test] + async fn test_good_list_with_constraint() { + let vec = vec!["1".to_string(), "123".to_string(), "1234567890".to_string()]; + let result = client().get_constraints() + .list_with_constraint(vec) + .send().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_bad_list_with_constraint() { + let vec = vec!["".to_string(), "this string is too long".to_string()]; + let result = client().get_constraints() + .list_with_constraint(vec) + .send().await; + let message = result.err().expect("error").to_string(); + assert!(message.contains("member")); + } + + #[tokio::test] + async fn test_good_map_with_constraint() { + let mut map = HashMap::new(); + map.insert("foo".to_string(), "bar".to_string()); + + let result = client().get_constraints() + .map_with_constraint(map) + .send().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_bad_map_with_constraint() { + let mut map = HashMap::new(); + map.insert("this key is too long".to_string(), "bar".to_string()); + + let result = client().get_constraints() + .map_with_constraint(map) + .send().await; + let message = result.err().expect("error").to_string(); + assert!(message.contains("key")); + + let mut map = HashMap::new(); + map.insert("foo".to_string(), "this value is too long".to_string()); + + let result = client().get_constraints() + .map_with_constraint(map) + .send().await; + let message = result.err().expect("error").to_string(); + assert!(message.contains("value")); + } + + #[tokio::test] + async fn test_good_union_with_constraint() { + let union_val = types::UnionWithConstraint::IntegerValue(1); + let result = client().get_constraints() + .union_with_constraint(union_val) + .send().await; + assert!(result.is_ok()); + + let union_val = types::UnionWithConstraint::StringValue("foo".to_string()); + let result = client().get_constraints() + .union_with_constraint(union_val) + .send().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_bad_union_with_constraint() { + let union_val = types::UnionWithConstraint::IntegerValue(100); + let result = client().get_constraints() + .union_with_constraint(union_val) + .send().await; + let message = result.err().expect("error").to_string(); + assert!(message.contains("integer_value")); + + let union_val = types::UnionWithConstraint::StringValue("this string is too long".to_string()); + let result = client().get_constraints() + .union_with_constraint(union_val) + .send().await; + let message = result.err().expect("error").to_string(); + assert!(message.contains("string_value")); + } } diff --git a/TestModels/Constraints/src/SimpleConstraintsImpl.dfy b/TestModels/Constraints/src/SimpleConstraintsImpl.dfy index d7d9209b51..2e8ece61d6 100644 --- a/TestModels/Constraints/src/SimpleConstraintsImpl.dfy +++ b/TestModels/Constraints/src/SimpleConstraintsImpl.dfy @@ -25,9 +25,11 @@ module SimpleConstraintsImpl refines AbstractSimpleConstraintsOperations { MyList := input.MyList, NonEmptyList := input.NonEmptyList, ListLessThanOrEqualToTen := input.ListLessThanOrEqualToTen, + ListWithConstraint := input.ListWithConstraint, MyMap := input.MyMap, NonEmptyMap := input.NonEmptyMap, MapLessThanOrEqualToTen := input.MapLessThanOrEqualToTen, + MapWithConstraint := input.MapWithConstraint, // Alphabetic := input.Alphabetic, OneToTen := input.OneToTen, GreaterThanOne := input.GreaterThanOne, @@ -35,9 +37,11 @@ module SimpleConstraintsImpl refines AbstractSimpleConstraintsOperations { // MyUniqueList := input.MyUniqueList, // MyComplexUniqueList := input.MyComplexUniqueList, MyUtf8Bytes := input.MyUtf8Bytes, - MyListOfUtf8Bytes := input.MyListOfUtf8Bytes + MyListOfUtf8Bytes := input.MyListOfUtf8Bytes, + UnionWithConstraint := input.UnionWithConstraint, + ComplexStructureList := input.ComplexStructureList ); return Success(res); } -} \ No newline at end of file +} diff --git a/TestModels/Constraints/test/Helpers.dfy b/TestModels/Constraints/test/Helpers.dfy index 4f431398df..25e5a73439 100644 --- a/TestModels/Constraints/test/Helpers.dfy +++ b/TestModels/Constraints/test/Helpers.dfy @@ -33,9 +33,11 @@ module Helpers { MyList := Some(["00", "11"]), NonEmptyList := Some(["00", "11"]), ListLessThanOrEqualToTen := Some(["00", "11"]), + ListWithConstraint := Some(["0", "123", "MaxTenChar"]), MyMap := Some(map["0" := "1", "2" := "3"]), NonEmptyMap := Some(map["0" := "1", "2" := "3"]), MapLessThanOrEqualToTen := Some(map["0" := "1", "2" := "3"]), + MapWithConstraint := Some(map["0" := "0123456789", "abcdefghij" := "z"]), // Alphabetic := Some("alphabetic"), OneToTen := Some(3), myTenToTen := Some(3), @@ -44,7 +46,12 @@ module Helpers { // MyUniqueList := Some(["one", "two"]), // MyComplexUniqueList := Some(myComplexUniqueList), MyUtf8Bytes := Some(PROVIDER_ID), - MyListOfUtf8Bytes := Some([PROVIDER_ID, PROVIDER_ID]) + MyListOfUtf8Bytes := Some([PROVIDER_ID, PROVIDER_ID]), + UnionWithConstraint := Some(IntegerValue(1)), + ComplexStructureList := Some([ + ComplexStructure(InnerString := Some("s1"), InnerBlob := [1, 1]), + ComplexStructure(InnerString := Some("s2"), InnerBlob := [2, 3, 4]) + ]) ) } diff --git a/TestModels/Constraints/test/WrappedSimpleConstraintsTest.dfy b/TestModels/Constraints/test/WrappedSimpleConstraintsTest.dfy index d5fc2b3277..2f8423df6a 100644 --- a/TestModels/Constraints/test/WrappedSimpleConstraintsTest.dfy +++ b/TestModels/Constraints/test/WrappedSimpleConstraintsTest.dfy @@ -25,12 +25,16 @@ module WrappedSimpleConstraintsTest { TestGetConstraintWithMyList(client); TestGetConstraintWithNonEmptyList(client); TestGetConstraintWithListLessThanOrEqualToTen(client); + TestGetConstraintWithListWithConstraint(client); TestGetConstraintWithMyMap(client); TestGetConstraintWithNonEmptyMap(client); TestGetConstraintWithMapLessThanOrEqualToTen(client); + TestGetConstraintWithMapWithConstraint(client); TestGetConstraintWithGreaterThanOne(client); TestGetConstraintWithUtf8Bytes(client); TestGetConstraintWithListOfUtf8Bytes(client); + TestGetConstraintWithUnionWithConstraint(client); + TestGetConstraintWithComplexStructureList(client); var allowBadUtf8BytesFromDafny := true; if (allowBadUtf8BytesFromDafny) { @@ -373,6 +377,20 @@ module WrappedSimpleConstraintsTest { expect ret.Failure?; } + // both list and member have @length(min: 1, max: 10) + method TestGetConstraintWithListWithConstraint(client: ISimpleConstraintsClient) + requires client.ValidState() + modifies client.Modifies + ensures client.ValidState() + { + var input := GetValidInput(); + input := input.(ListWithConstraint := Some(["1", "2", "3"])); + var ret := client.GetConstraints(input := input); + expect ret.Success?; + + // TODO: Add negative tests once all languages support it + } + // @length(min: 1, max: 10) method TestGetConstraintWithMyMap(client: ISimpleConstraintsClient) requires client.ValidState() @@ -445,6 +463,20 @@ module WrappedSimpleConstraintsTest { expect ret.Failure?; } + // both map and member have @length(min: 1, max: 10) + method TestGetConstraintWithMapWithConstraint(client: ISimpleConstraintsClient) + requires client.ValidState() + modifies client.Modifies + ensures client.ValidState() + { + var input := GetValidInput(); + input := input.(MapWithConstraint := Some(map["0" := "1234", "abcd" := "j"])); + var ret := client.GetConstraints(input := input); + expect ret.Success?; + + // TODO: Add negative tests once all languages support it + } + // @range(min: 1) method TestGetConstraintWithGreaterThanOne(client: ISimpleConstraintsClient) requires client.ValidState() @@ -589,4 +621,38 @@ module WrappedSimpleConstraintsTest { ret := client.GetConstraints(input := input); expect ret.Failure?; } + + method TestGetConstraintWithUnionWithConstraint(client: ISimpleConstraintsClient) + requires client.ValidState() + modifies client.Modifies + ensures client.ValidState() + { + var input := GetValidInput(); + input := input.(UnionWithConstraint := Some(IntegerValue(1))); + var ret := client.GetConstraints(input := input); + expect ret.Success?; + + input := GetValidInput(); + input := input.(UnionWithConstraint := Some(StringValue("foo"))); + ret := client.GetConstraints(input := input); + expect ret.Success?; + + // TODO: Add negative tests once all languages support it + } + + method TestGetConstraintWithComplexStructureList(client: ISimpleConstraintsClient) + requires client.ValidState() + modifies client.Modifies + ensures client.ValidState() + { + var input := GetValidInput(); + input := input.(ComplexStructureList := Some([ + ComplexStructure(InnerString := Some("a"), InnerBlob := [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + ComplexStructure(InnerString := Some("abcdefghij"), InnerBlob := [0]) + ])); + var ret := client.GetConstraints(input := input); + expect ret.Success?; + + // TODO: Add negative tests once all languages support it + } } diff --git a/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/smithyrust/generator/AbstractRustShimGenerator.java b/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/smithyrust/generator/AbstractRustShimGenerator.java index 84845e84ab..aab91aa029 100644 --- a/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/smithyrust/generator/AbstractRustShimGenerator.java +++ b/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/smithyrust/generator/AbstractRustShimGenerator.java @@ -48,6 +48,7 @@ import software.amazon.smithy.model.traits.ErrorTrait; import software.amazon.smithy.model.traits.RequiredTrait; import software.amazon.smithy.model.traits.StringTrait; +import software.amazon.smithy.model.traits.TraitDefinition; import software.amazon.smithy.model.traits.UnitTypeTrait; public abstract class AbstractRustShimGenerator { @@ -132,8 +133,8 @@ protected boolean shouldGenerateStructForStructure( StructureShape structureShape ) { return ( + !structureShape.hasTrait(TraitDefinition.class) && !structureShape.hasTrait(ErrorTrait.class) && - !structureShape.hasTrait(ShapeId.from("smithy.api#trait")) && !structureShape.hasTrait(ReferenceTrait.class) && ModelUtils.isInServiceNamespace(structureShape, service) ); diff --git a/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/smithyrust/generator/RustLibraryShimGenerator.java b/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/smithyrust/generator/RustLibraryShimGenerator.java index 5f544a6289..69a969889a 100644 --- a/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/smithyrust/generator/RustLibraryShimGenerator.java +++ b/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/smithyrust/generator/RustLibraryShimGenerator.java @@ -7,12 +7,16 @@ import java.math.BigDecimal; import java.nio.file.Path; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; import java.util.stream.Collectors; import java.util.stream.Stream; import software.amazon.polymorph.smithydafny.DafnyNameResolver; @@ -40,6 +44,7 @@ import software.amazon.smithy.model.traits.EnumTrait; import software.amazon.smithy.model.traits.LengthTrait; import software.amazon.smithy.model.traits.RangeTrait; +import software.amazon.smithy.model.traits.RequiredTrait; import software.amazon.smithy.model.traits.UnitTypeTrait; /** @@ -48,6 +53,7 @@ public class RustLibraryShimGenerator extends AbstractRustShimGenerator { private final boolean generateWrappedClient; + private final RustValidationGenerator validationGenerator; public RustLibraryShimGenerator( final MergedServicesGenerator mergedGenerator, @@ -57,6 +63,7 @@ public RustLibraryShimGenerator( ) { super(mergedGenerator, model, service); this.generateWrappedClient = generateWrappedClient; + this.validationGenerator = new RustValidationGenerator(); } @Override @@ -71,7 +78,6 @@ protected Set rustFiles() { result.add(typesModule()); result.add(typesConfigModule()); result.add(typesBuildersModule()); - result.addAll( streamStructuresToGenerateStructsFor() .map(this::standardStructureModule) @@ -99,6 +105,9 @@ protected Set rustFiles() { .toList() ); + // validation + result.add(validationModule()); + // errors result.add(errorModule()); result.add(sealedUnhandledErrorModule()); @@ -242,10 +251,8 @@ private RustFile clientModule() { service ); variables.put( - "inputValidations", - new InputValidationGenerator() - .generateValidations(model, configShape) - .collect(Collectors.joining("\n")) + "inputValidationFunctionName", + RustValidationGenerator.shapeValidationFunctionName(configShape) ); final String content = evalTemplateResource( @@ -772,10 +779,8 @@ private RustFile operationOuterModule( operationVariables(bindingShape, operationShape) ); variables.put( - "inputValidations", - new InputValidationGenerator(bindingShape, operationShape) - .generateValidations(model, inputShape) - .collect(Collectors.joining("\n")) + "inputValidationFunctionName", + RustValidationGenerator.shapeValidationFunctionName(inputShape) ); if (bindingShape.isServiceShape()) { if (inputShape.hasTrait(PositionalTrait.class)) { @@ -860,57 +865,304 @@ private RustFile operationOuterModule( return new RustFile(path, TokenTree.of(content)); } - class InputValidationGenerator - extends ConstrainTraitUtils.ValidationGenerator { + class RustValidationGenerator { private final Map commonVariables; + private final Set shapesToValidate; + + // map from synthetic input/output shapes to their operations + private final Map relatedOperationShapes; + + RustValidationGenerator() { + this.commonVariables = serviceVariables(); + this.shapesToValidate = new TreeSet<>(); + this.relatedOperationShapes = new HashMap<>(); + + // Start from the service config and operation input/outputs + final var rootShapes = new HashSet(); + rootShapes.add(ModelUtils.getConfigShape(model, service)); + streamAllBoundOperationShapes() + .map(BoundOperationShape::operationShape) + .forEach(operationShape -> { + final var inputShapeId = operationShape.getInputShape(); + final var outputShapeId = operationShape.getOutputShape(); + rootShapes.add(model.expectShape(inputShapeId, StructureShape.class)); + rootShapes.add( + model.expectShape(outputShapeId, StructureShape.class) + ); + + relatedOperationShapes.put(inputShapeId, operationShape); + relatedOperationShapes.put(outputShapeId, operationShape); + }); + + // Traverse to find relevant shapes + final var queue = new LinkedList(rootShapes); + while (!queue.isEmpty()) { + final var shape = queue.poll(); + if (shapesToValidate.contains(shape)) { + continue; + } + if ( + !(shape instanceof MemberShape || + shape instanceof StructureShape || + shape instanceof UnionShape || + shape instanceof ListShape || + shape instanceof MapShape) + ) { + continue; + } + + shapesToValidate.add(shape); + if (shape instanceof MemberShape memberShape) { + queue.add(model.expectShape(memberShape.getTarget())); + } else { + queue.addAll(shape.getAllMembers().values()); + } + } + } + + Set getShapesToValidate() { + return shapesToValidate; + } /** - * Generates validation expressions for operation input structures. + * Generates a validation function for the given aggregate or member shape. + *

+ * Validation of constraints (required, range, and length) + * occurs only in validation functions for member shapes. + * Validation functions for aggregate shapes + * only delegate to the validation functions for their members. + *

+ * Other modules should therefore only call validation functions for aggregate shapes. */ - InputValidationGenerator( - final Shape bindingShape, - final OperationShape operationShape - ) { - this.commonVariables = - MapUtils.merge( - serviceVariables(), - operationVariables(bindingShape, operationShape) + private String generateValidationFunction(final Shape shape) { + final var validationBlocks = new ArrayList(); + + final var isStructureMember = + shape instanceof MemberShape memberShape && + model.expectShape(memberShape.getContainer()).isStructureShape(); + + if (shape instanceof MemberShape memberShape) { + memberShape + .getTrait(RequiredTrait.class) + .ifPresent(_trait -> + validationBlocks.add(this.validateRequired(memberShape)) + ); + // For simplicity, avoid wrapping the rest of the validation in a conditional + if (isStructureMember) { + validationBlocks.add( + """ + if input.is_none() { + return ::std::result::Result::Ok(()); + } + let input = input.as_ref().unwrap(); + """ + ); + } + + memberShape + .getMemberTrait(model, RangeTrait.class) + .ifPresent(trait -> + validationBlocks.add(this.validateRange(memberShape, trait)) + ); + memberShape + .getMemberTrait(model, LengthTrait.class) + .ifPresent(trait -> + validationBlocks.add(this.validateLength(memberShape, trait)) + ); + + // validate target if necessary + final var targetShape = model.expectShape(memberShape.getTarget()); + if (this.shapesToValidate.contains(targetShape)) { + final var memberVariables = structureMemberVariables(memberShape); + memberVariables.put( + "targetValidationFunctionName", + shapeValidationFunctionName(targetShape) + ); + validationBlocks.add( + evalTemplate( + """ + $targetValidationFunctionName:L(input)?; + """, + memberVariables + ) + ); + } + } else if (shape instanceof StructureShape structureShape) { + for (final var memberShape : structureShape.getAllMembers().values()) { + final var memberVariables = structureMemberVariables(memberShape); + memberVariables.put( + "memberValidationFunctionName", + shapeValidationFunctionName(memberShape) + ); + validationBlocks.add( + evalTemplate( + "$memberValidationFunctionName:L(&input.$fieldName:L)?;", + memberVariables + ) + ); + } + } else if (shape instanceof UnionShape unionShape) { + final var unionVariables = unionVariables(unionShape); + for (final var memberShape : unionShape.getAllMembers().values()) { + final var memberVariables = MapUtils.merge( + unionVariables, + unionMemberVariables(memberShape) + ); + memberVariables.put( + "memberValidationFunctionName", + shapeValidationFunctionName(memberShape) + ); + validationBlocks.add( + evalTemplate( + """ + if let $qualifiedRustUnionName:L::$rustUnionMemberName:L(ref inner) = &input { + $memberValidationFunctionName:L(inner)?; + } + """, + memberVariables + ) + ); + } + } else if (shape instanceof ListShape listShape) { + final var memberShape = listShape.getMember(); + final Map memberVariables = Map.of( + "memberValidationFunctionName", + shapeValidationFunctionName(memberShape) + ); + validationBlocks.add( + evalTemplate( + """ + for inner in input.iter() { + $memberValidationFunctionName:L(inner)?; + } + """, + memberVariables + ) ); - this.commonVariables.put( - "inputStructureName", - commonVariables.get("pascalCaseOperationInputName") + } else if (shape instanceof MapShape mapShape) { + final var keyShape = mapShape.getKey(); + final var valueShape = mapShape.getValue(); + final Map memberVariables = Map.of( + "keyValidationFunctionName", + shapeValidationFunctionName(keyShape), + "valueValidationFunctionName", + shapeValidationFunctionName(valueShape) ); - } + validationBlocks.add( + evalTemplate( + """ + for (inner_key, inner_val) in input.iter() { + $keyValidationFunctionName:L(inner_key)?; + $valueValidationFunctionName:L(inner_val)?; + } + """, + memberVariables + ) + ); + } else { + throw new IllegalArgumentException( + "Unsupported shape: " + shape.getId() + ); + } - /** - * Generates validation expressions for this service's client config structure. - */ - InputValidationGenerator() { - this.commonVariables = serviceVariables(); - this.commonVariables.put( - "inputStructureName", - commonVariables.get("qualifiedRustConfigName") + final var variables = new HashMap(); + variables.put( + "shapeValidationFunctionName", + shapeValidationFunctionName(shape) + ); + variables.put("validationBlocks", String.join("\n", validationBlocks)); + + try { + if (shape instanceof MemberShape memberShape) { + final var targetShape = model.expectShape(memberShape.getTarget()); + final var targetType = rustTypeForShape(targetShape); + if (isStructureMember) { + variables.put( + "shapeType", + "::std::option::Option<%s>".formatted(targetType) + ); + } else { + variables.put("shapeType", targetType); + } + } else if (relatedOperationShapes.containsKey(shape.getId())) { + // TODO This is a messy way to handle synthetic operation input/outputs, + // and the assumption may not even hold for the DB ESDK. + // See also: + final var operation = relatedOperationShapes.get(shape.getId()); + final var operationEntities = operationBindingIndex.getBindingShapes( + operation + ); + if (operationEntities.size() != 1) { + throw new IllegalStateException( + "Expected exactly 1 entity for operation %s".formatted( + operation.getId() + ) + ); + } + final var operationVariables = operationVariables( + operationEntities.iterator().next(), + operation + ); + final String shapeType; + if (operation.getInputShape().equals(shape.getId())) { + shapeType = operationVariables.get("operationInputType"); + } else { + shapeType = operationVariables.get("operationOutputType"); + } + variables.put("shapeType", shapeType); + } else { + variables.put("shapeType", rustTypeForShape(shape)); + } + } catch (Exception e) { + throw new RuntimeException( + "Failed on shape %s".formatted(shape.getId()), + e ); + } + + return evalTemplate( + """ + pub(crate) fn $shapeValidationFunctionName:L(input: &$shapeType:L) + -> ::std::result::Result<(), ::aws_smithy_types::error::operation::BuildError> + { + $validationBlocks:L + Ok(()) + } + """, + variables + ); } - @Override - protected String validateRequired(final MemberShape memberShape) { + public static String shapeValidationFunctionName(final Shape shape) { + // the ID foo.bar_baz.quux#My_ShapeName$the_member + // becomes foo_Pbar__baz_Pquux_HMy__ShapeName_Dthe__member + final var escapedId = shape + .getId() + .toString() + .replace("_", "__") + .replace(".", "_P") + .replace("#", "_H") + .replace("$", "_D"); + + return "validate_" + escapedId; + } + + private String validateRequired(final MemberShape memberShape) { return evalTemplate( """ - if input.$fieldName:L.is_none() { + if input.is_none() { return ::std::result::Result::Err(::aws_smithy_types::error::operation::BuildError::missing_field( "$fieldName:L", - "$fieldName:L was not specified but it is required when building $inputStructureName:L", - )).map_err($qualifiedRustServiceErrorType:L::wrap_validation_err); + "$fieldName:L is required but was not specified", + )); } """, MapUtils.merge(commonVariables, structureMemberVariables(memberShape)) ); } - @Override - protected String validateRange( + private String validateRange( final MemberShape memberShape, final RangeTrait rangeTrait ) { @@ -925,43 +1177,45 @@ protected String validateRange( final var max = rangeTrait .getMax() .map(bound -> asLiteral(bound, targetShape)); - final var conditionTemplate = - "!(%s..%s).contains(&x)".formatted( + + variables.put( + "condition", + "!(%s..%s).contains(input)".formatted( min.orElse(""), max.map(val -> "=" + val).orElse("") - ); - final var rangeDescription = describeMinMax(min, max); + ) + ); + variables.put("rangeDescription", describeMinMax(min, max)); return evalTemplate( """ - if matches!(input.$fieldName:L, Some(x) if %s) { + if $condition:L { return ::std::result::Result::Err(::aws_smithy_types::error::operation::BuildError::invalid_field( "$fieldName:L", - "$fieldName:L failed to satisfy constraint: Member must be %s", - )).map_err($qualifiedRustServiceErrorType:L::wrap_validation_err); + "$fieldName:L failed to satisfy constraint: Member must be $rangeDescription:L", + )); } - """.formatted(conditionTemplate, rangeDescription), + """, variables ); } - @Override - protected String validateLength( + private String validateLength( final MemberShape memberShape, final LengthTrait lengthTrait ) { final var targetShape = model.expectShape(memberShape.getTarget()); final var len = switch (targetShape.getType()) { - case BLOB -> "x.as_ref().len()"; + case BLOB -> "input.as_ref().len()"; case STRING -> targetShape.hasTrait(DafnyUtf8BytesTrait.class) // scalar values - ? "x.chars().count()" + ? "input.chars().count()" // The Smithy spec says that this should count scalar values, // but for consistency with the existing Java and .NET implementations, // we instead count UTF-16 code points. // See . - : "x.chars().map(::std::primitive::char::len_utf16).fold(0usize, ::std::ops::Add::add)"; - default -> "x.len()"; + : "input.chars().map(::std::primitive::char::len_utf16).fold(0usize, ::std::ops::Add::add)"; + default -> "input.len()"; }; final var variables = MapUtils.merge( commonVariables, @@ -978,11 +1232,11 @@ protected String validateLength( final var rangeDescription = describeMinMax(min, max); return evalTemplate( """ - if matches!(input.$fieldName:L, Some(ref x) if %s) { + if %s { return ::std::result::Result::Err(::aws_smithy_types::error::operation::BuildError::invalid_field( "$fieldName:L", "$fieldName:L failed to satisfy constraint: Member must have length %s", - )).map_err($qualifiedRustServiceErrorType:L::wrap_validation_err); + )); } """.formatted(conditionTemplate, rangeDescription), variables @@ -1820,6 +2074,19 @@ private String wrappedClientOperationImpl( ); } + private RustFile validationModule() { + final var validationFunctions = validationGenerator + .getShapesToValidate() + .stream() + .map(validationGenerator::generateValidationFunction) + .collect(Collectors.joining("\n")); + + return new RustFile( + rootPathForShape(service).resolve("validation.rs"), + TokenTree.of(validationFunctions) + ); + } + private Path operationsModuleFilePath(final Shape bindingShape) { return rootPathForShape(bindingShape).resolve("operation"); } diff --git a/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/utils/ConstrainTraitUtils.java b/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/utils/ConstrainTraitUtils.java index 439fea679c..7d9d586933 100644 --- a/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/utils/ConstrainTraitUtils.java +++ b/codegen/smithy-dafny-codegen/src/main/java/software/amazon/polymorph/utils/ConstrainTraitUtils.java @@ -8,12 +8,10 @@ import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.Shape; -import software.amazon.smithy.model.shapes.StructureShape; import software.amazon.smithy.model.traits.LengthTrait; import software.amazon.smithy.model.traits.RangeTrait; import software.amazon.smithy.model.traits.RequiredTrait; import software.amazon.smithy.model.traits.Trait; -import software.amazon.smithy.utils.Pair; // TODO: Support idRef, pattern, uniqueItems public class ConstrainTraitUtils { @@ -36,75 +34,6 @@ public static boolean hasConstraintTrait(Shape shape) { ); } - /** - * Utility class to generate validation expressions for all members of a structure. - * - * @param type of validation expressions, typically {@link String} or {@link TokenTree} - */ - public abstract static class ValidationGenerator { - - protected abstract V validateRequired(MemberShape memberShape); - - protected abstract V validateRange( - MemberShape memberShape, - RangeTrait rangeTrait - ); - - protected abstract V validateLength( - MemberShape memberShape, - LengthTrait lengthTrait - ); - - /** - * Returns a stream of constraint traits that Polymorph-generated code should enforce - * on any code path that invokes a service or resource operation, - * from either the given shape or the targeted shape (if the given shape is a member shape). - */ - private Stream enforcedConstraints( - final Model model, - final Shape shape - ) { - return Stream - .of( - shape.getMemberTrait(model, RequiredTrait.class), - shape.getMemberTrait(model, RangeTrait.class), - shape.getMemberTrait(model, LengthTrait.class) - ) - .flatMap(Optional::stream); - } - - public Stream generateValidations( - final Model model, - final StructureShape structureShape - ) { - return structureShape - .getAllMembers() - .values() - .stream() - .flatMap(memberShape -> - enforcedConstraints(model, memberShape) - .map(trait -> Pair.of(memberShape, trait)) - ) - .map(memberTrait -> { - final MemberShape memberShape = memberTrait.left; - final Trait trait = memberTrait.right; - if (trait instanceof RequiredTrait) { - return validateRequired(memberShape); - } else if (trait instanceof RangeTrait rangeTrait) { - return validateRange(memberShape, rangeTrait); - } else if (trait instanceof LengthTrait lengthTrait) { - return validateLength(memberShape, lengthTrait); - } - throw new UnsupportedOperationException( - "Unsupported constraint trait %s on shape %s".formatted( - trait, - structureShape.getId() - ) - ); - }); - } - } - public static class RangeTraitUtils { /** Return the trait's min as an accurate string representation diff --git a/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/client.rs b/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/client.rs index 3c8ac32a00..2eaa150e45 100644 --- a/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/client.rs +++ b/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/client.rs @@ -12,7 +12,8 @@ impl Client { pub fn from_conf( input: $qualifiedRustConfigName:L, ) -> Result { - $inputValidations:L + crate::validation::$inputValidationFunctionName:L(&input) + .map_err($qualifiedRustServiceErrorType:L::wrap_validation_err)?; let inner = crate::$dafnyInternalModuleName:L::_default::$sdkId:L( &$rustConversionsModuleName:L::$snakeCaseConfigName:L::_$snakeCaseConfigName:L::to_dafny(input), diff --git a/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/operation/outer.rs b/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/operation/outer.rs index caf57050cb..6a6fa57acc 100644 --- a/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/operation/outer.rs +++ b/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/operation/outer.rs @@ -15,7 +15,8 @@ impl $pascalCaseOperationName:L { $operationOutputType:L, $qualifiedRustServiceErrorType:L, > { - $inputValidations:L + crate::validation::$inputValidationFunctionName:L(&input) + .map_err($qualifiedRustServiceErrorType:L::wrap_validation_err)?; $operationSendBody:L } } diff --git a/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/types.rs b/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/types.rs index 28580f12c3..8354ec1976 100644 --- a/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/types.rs +++ b/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/types.rs @@ -1,5 +1,6 @@ /// Types for the `$configName:L` pub mod $snakeCaseConfigName:L; +pub use $qualifiedRustConfigName:L; pub mod builders; diff --git a/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/validation.rs b/codegen/smithy-dafny-codegen/src/main/resources/templates/runtimes/rust/validation.rs new file mode 100644 index 0000000000..e69de29bb2