diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaKeysetScrollQueryCreator.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaKeysetScrollQueryCreator.java index 47d093deff..bf97edc7d2 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaKeysetScrollQueryCreator.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaKeysetScrollQueryCreator.java @@ -79,7 +79,7 @@ protected JpqlQueryBuilder.AbstractJpqlQuery createQuery(@Nullable JpqlQueryBuil JpqlQueryBuilder.Predicate keysetPredicate = keysetSpec.createJpqlPredicate(getFrom(), getEntity(), value -> { syntheticBindings.add(provider.nextSynthetic(value, scrollPosition)); - return JpqlQueryBuilder.expression(render(counter.incrementAndGet())); + return placeholder(counter.incrementAndGet()); }); JpqlQueryBuilder.Predicate predicateToUse = getPredicate(predicate, keysetPredicate); diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParameters.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParameters.java index 0430bb4283..d40895ab78 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParameters.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaParameters.java @@ -63,7 +63,7 @@ protected JpaParameters(ParametersSource parametersSource, super(parametersSource, parameterFactory); } - private JpaParameters(List parameters) { + JpaParameters(List parameters) { super(parameters); } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java index edac9f12d2..3e2a07fc78 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryCreator.java @@ -15,23 +15,30 @@ */ package org.springframework.data.jpa.repository.query; -import static org.springframework.data.repository.query.parser.Part.Type.*; +import static org.springframework.data.repository.query.parser.Part.Type.IS_NOT_EMPTY; +import static org.springframework.data.repository.query.parser.Part.Type.NOT_CONTAINING; +import static org.springframework.data.repository.query.parser.Part.Type.NOT_LIKE; +import static org.springframework.data.repository.query.parser.Part.Type.SIMPLE_PROPERTY; import jakarta.persistence.EntityManager; import jakarta.persistence.criteria.CriteriaQuery; import jakarta.persistence.criteria.Expression; -import jakarta.persistence.criteria.From; import jakarta.persistence.criteria.Predicate; +import jakarta.persistence.metamodel.Attribute; +import jakarta.persistence.metamodel.Bindable; import jakarta.persistence.metamodel.EntityType; +import jakarta.persistence.metamodel.Metamodel; import jakarta.persistence.metamodel.SingularAttribute; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; +import java.util.stream.Collectors; import org.springframework.data.domain.Sort; import org.springframework.data.jpa.domain.JpaSort; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.ParameterPlaceholder; import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.PathAndOrigin; import org.springframework.data.jpa.repository.query.ParameterBinding.PartTreeParameterBinding; import org.springframework.data.jpa.repository.support.JpqlQueryTemplates; @@ -56,6 +63,7 @@ * @author Moritz Becker * @author Andrey Kovalev * @author Greg Turnquist + * @author Christoph Strobl * @author Jinmyeong Kim */ class JpaQueryCreator extends AbstractQueryCreator implements JpqlQueryCreator { @@ -66,8 +74,8 @@ class JpaQueryCreator extends AbstractQueryCreator entityType; - private final From from; private final JpqlQueryBuilder.Entity entity; + private final Metamodel metamodel; /** * Create a new {@link JpaQueryCreator}. @@ -88,12 +96,12 @@ public JpaQueryCreator(PartTree tree, ReturnedType type, ParameterMetadataProvid this.templates = templates; this.escape = provider.getEscape(); this.entityType = em.getMetamodel().entity(type.getDomainType()); - this.from = em.getCriteriaBuilder().createQuery().from(type.getDomainType()); this.entity = JpqlQueryBuilder.entity(returnedType.getDomainType()); + this.metamodel = em.getMetamodel(); } - From getFrom() { - return from; + Bindable getFrom() { + return entityType; } JpqlQueryBuilder.Entity getEntity() { @@ -175,7 +183,7 @@ protected JpqlQueryBuilder.Select buildQuery(Sort sort) { QueryUtils.checkSortExpression(order); try { - expression = JpqlQueryBuilder.expression(JpqlUtils.toExpressionRecursively(entity, from, + expression = JpqlQueryBuilder.expression(JpqlUtils.toExpressionRecursively(metamodel, entity, entityType, PropertyPath.from(order.getProperty(), entityType.getJavaType()))); } catch (PropertyReferenceException e) { @@ -210,12 +218,19 @@ private JpqlQueryBuilder.Select doSelect(Sort sort) { if (returnedType.needsCustomConstruction()) { - Collection requiredSelection = getRequiredSelection(sort, returnedType); + Collection requiredSelection = null; + if (returnedType.getReturnedType().getPackageName().startsWith("java.util") + || returnedType.getReturnedType().getPackageName().startsWith("jakarta.persistence")) { + requiredSelection = metamodel.managedType(returnedType.getDomainType()).getAttributes().stream() + .map(Attribute::getName).collect(Collectors.toList()); + } else { + requiredSelection = getRequiredSelection(sort, returnedType); + } List paths = new ArrayList<>(requiredSelection.size()); for (String selection : requiredSelection) { - paths.add( - JpqlUtils.toExpressionRecursively(entity, from, PropertyPath.from(selection, from.getJavaType()), true)); + paths.add(JpqlUtils.toExpressionRecursively(metamodel, entity, entityType, + PropertyPath.from(selection, returnedType.getDomainType()), true)); } if (useTupleQuery()) { @@ -231,14 +246,14 @@ private JpqlQueryBuilder.Select doSelect(Sort sort) { if (entityType.hasSingleIdAttribute()) { SingularAttribute id = entityType.getId(entityType.getIdType().getJavaType()); - return selectStep.select( - JpqlUtils.toExpressionRecursively(entity, from, PropertyPath.from(id.getName(), from.getJavaType()), true)); + return selectStep.select(JpqlUtils.toExpressionRecursively(metamodel, entity, entityType, + PropertyPath.from(id.getName(), returnedType.getDomainType()), true)); } else { List paths = entityType.getIdClassAttributes().stream()// - .map(it -> JpqlUtils.toExpressionRecursively(entity, from, - PropertyPath.from(it.getName(), from.getJavaType()), true)) + .map(it -> JpqlUtils.toExpressionRecursively(metamodel, entity, entityType, + PropertyPath.from(it.getName(), returnedType.getDomainType()), true)) .toList(); return selectStep.select(paths); } @@ -255,12 +270,12 @@ Collection getRequiredSelection(Sort sort, ReturnedType returnedType) { return returnedType.getInputProperties(); } - String render(ParameterBinding binding) { - return render(binding.getRequiredPosition()); + JpqlQueryBuilder.Expression placeholder(ParameterBinding binding) { + return placeholder(binding.getRequiredPosition()); } - String render(int position) { - return "?" + position; + JpqlQueryBuilder.Expression placeholder(int position) { + return JpqlQueryBuilder.parameter(ParameterPlaceholder.indexed(position)); } /** @@ -305,7 +320,7 @@ public JpqlQueryBuilder.Predicate build() { PropertyPath property = part.getProperty(); Type type = part.getType(); - PathAndOrigin pas = JpqlUtils.toExpressionRecursively(entity, from, property); + PathAndOrigin pas = JpqlUtils.toExpressionRecursively(metamodel, entity, entityType, property); JpqlQueryBuilder.WhereStep where = JpqlQueryBuilder.where(pas); JpqlQueryBuilder.WhereStep whereIgnoreCase = JpqlQueryBuilder.where(potentiallyIgnoreCase(pas)); @@ -313,25 +328,25 @@ public JpqlQueryBuilder.Predicate build() { case BETWEEN: PartTreeParameterBinding first = provider.next(part); ParameterBinding second = provider.next(part); - return where.between(render(first), render(second)); + return where.between(placeholder(first), placeholder(second)); case AFTER: case GREATER_THAN: - return where.gt(render(provider.next(part))); + return where.gt(placeholder(provider.next(part))); case GREATER_THAN_EQUAL: - return where.gte(render(provider.next(part))); + return where.gte(placeholder(provider.next(part))); case BEFORE: case LESS_THAN: - return where.lt(render(provider.next(part))); + return where.lt(placeholder(provider.next(part))); case LESS_THAN_EQUAL: - return where.lte(render(provider.next(part))); + return where.lte(placeholder(provider.next(part))); case IS_NULL: return where.isNull(); case IS_NOT_NULL: return where.isNotNull(); case NOT_IN: - return whereIgnoreCase.notIn(render(provider.next(part, Collection.class))); + return whereIgnoreCase.notIn(placeholder(provider.next(part, Collection.class))); case IN: - return whereIgnoreCase.in(render(provider.next(part, Collection.class))); + return whereIgnoreCase.in(placeholder(provider.next(part, Collection.class))); case STARTING_WITH: case ENDING_WITH: case CONTAINING: @@ -340,8 +355,8 @@ public JpqlQueryBuilder.Predicate build() { if (property.getLeafProperty().isCollection()) { where = JpqlQueryBuilder.where(entity, property); - return type.equals(NOT_CONTAINING) ? where.notMemberOf(render(provider.next(part))) - : where.memberOf(render(provider.next(part))); + return type.equals(NOT_CONTAINING) ? where.notMemberOf(placeholder(provider.next(part))) + : where.memberOf(placeholder(provider.next(part))); } case LIKE: @@ -349,7 +364,7 @@ public JpqlQueryBuilder.Predicate build() { PartTreeParameterBinding parameter = provider.next(part, String.class); JpqlQueryBuilder.Expression parameterExpression = potentiallyIgnoreCase(part.getProperty(), - JpqlQueryBuilder.parameter(render(parameter))); + placeholder(parameter)); // Predicate like = builder.like(propertyExpression, parameterExpression, escape.getEscapeCharacter()); String escapeChar = Character.toString(escape.getEscapeCharacter()); return @@ -362,23 +377,16 @@ public JpqlQueryBuilder.Predicate build() { case FALSE: return where.isFalse(); case SIMPLE_PROPERTY: - PartTreeParameterBinding simple = provider.next(part); - - if (simple.isIsNullParameter()) { - return where.isNull(); - } - - return whereIgnoreCase.eq(potentiallyIgnoreCase(property, JpqlQueryBuilder.expression(render(simple)))); case NEGATING_SIMPLE_PROPERTY: - PartTreeParameterBinding negating = provider.next(part); + PartTreeParameterBinding simple = provider.next(part); - if (negating.isIsNullParameter()) { - return where.isNotNull(); + if (simple.isIsNullParameter()) { + return type.equals(SIMPLE_PROPERTY) ? where.isNull() : where.isNotNull(); } - return whereIgnoreCase - .neq(potentiallyIgnoreCase(property, JpqlQueryBuilder.expression(render(negating)))); + JpqlQueryBuilder.Expression expression = potentiallyIgnoreCase(property, placeholder(metadata)); + return type.equals(SIMPLE_PROPERTY) ? whereIgnoreCase.eq(expression) : whereIgnoreCase.neq(expression); case IS_EMPTY: case IS_NOT_EMPTY: @@ -412,8 +420,8 @@ private JpqlQueryBuilder.Expression potentiallyIgnoreCase(JpqlQueryBuilder.O * @param path must not be {@literal null}. * @return */ - private JpqlQueryBuilder.Expression potentiallyIgnoreCase(PathAndOrigin pas) { - return potentiallyIgnoreCase(pas.path(), JpqlQueryBuilder.expression(pas)); + private JpqlQueryBuilder.Expression potentiallyIgnoreCase(PathAndOrigin path) { + return potentiallyIgnoreCase(path.path(), JpqlQueryBuilder.expression(path)); } /** diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilder.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilder.java index 42c8ee95d7..cb53998c3f 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilder.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilder.java @@ -15,7 +15,8 @@ */ package org.springframework.data.jpa.repository.query; -import static org.springframework.data.jpa.repository.query.QueryTokens.*; +import static org.springframework.data.jpa.repository.query.QueryTokens.TOKEN_ASC; +import static org.springframework.data.jpa.repository.query.QueryTokens.TOKEN_DESC; import java.util.ArrayList; import java.util.Arrays; @@ -32,7 +33,9 @@ import org.springframework.data.util.Predicates; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; /** * A Domain-Specific Language to build JPQL queries using Java code. @@ -189,7 +192,7 @@ public static Expression expression(PathAndOrigin pas) { } /** - * Create a simple expression from a string. + * Create a simple expression from a string as is. * * @param expression * @return @@ -201,11 +204,19 @@ public static Expression expression(String expression) { return new LiteralExpression(expression); } + public static Expression stringLiteral(String literal) { + return new StringLiteralExpression(literal); + } + public static Expression parameter(String parameter) { Assert.hasText(parameter, "Parameter must not be empty or null"); - return new ParameterExpression(parameter); + return new ParameterExpression(new ParameterPlaceholder(parameter)); + } + + public static Expression parameter(ParameterPlaceholder placeholder) { + return new ParameterExpression(placeholder); } public static Expression orderBy(Expression sortExpression, Sort.Order order) { @@ -279,12 +290,12 @@ public Predicate isNotNull() { @Override public Predicate isTrue() { - return new LhsPredicate(rhs, "IS TRUE"); + return new LhsPredicate(rhs, "= TRUE"); } @Override public Predicate isFalse() { - return new LhsPredicate(rhs, "IS FALSE"); + return new LhsPredicate(rhs, "= FALSE"); } @Override @@ -309,7 +320,7 @@ public Predicate notIn(Expression value) { @Override public Predicate inMultivalued(Expression value) { - return new MemberOfPredicate(rhs, "IN", value); + return new MemberOfPredicate(rhs, "IN", value); // TODO: that does not line up in my head - ahahah } @Override @@ -466,6 +477,42 @@ public String toString() { } } + static PathAndOrigin path(Origin origin, String path) { + + if(origin instanceof Entity entity) { + + try { + PropertyPath from = PropertyPath.from(path, ClassUtils.forName(entity.entity, Entity.class.getClassLoader())); + return new PathAndOrigin(from, entity, false); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + if(origin instanceof Join join) { + + Origin parent = join.source; + List segments = new ArrayList<>(); + segments.add(join.path); + while(!(parent instanceof Entity)) { + if(parent instanceof Join pj) { + parent = pj.source; + segments.add(pj.path); + } else { + parent = null; + } + } + + if(parent instanceof Entity entity) { + Collections.reverse(segments); + segments.add(path); + PathAndOrigin path1 = path(parent, StringUtils.collectionToDelimitedString(segments, ".")); + return new PathAndOrigin(path1.path().getLeafProperty(), origin, false); + } + } + throw new IllegalStateException(" oh no "); + + } + /** * Entity selection. * @@ -513,7 +560,9 @@ record ConstructorExpression(String resultType, Multiselect multiselect) impleme @Override public String render(RenderContext context) { - return "new %s(%s)".formatted(resultType, multiselect.render(context)); + + + return "new %s(%s)".formatted(resultType, multiselect.render(new ConstructorContext(context))); } @Override @@ -542,7 +591,9 @@ public String render(RenderContext context) { } builder.append(PathExpression.render(path, context)); - builder.append(" ").append(path.path().getSegment()); + if(!context.isConstructorContext()) { + builder.append(" ").append(path.path().getSegment()); + } } return builder.toString(); @@ -583,7 +634,7 @@ default Predicate or(Predicate other) { * @param other * @return a composed predicate combining this and {@code other} using the AND operator. */ - default Predicate and(Predicate other) { + default Predicate and(Predicate other) { // don't like the structuring of this and the nest() thing return new AndPredicate(this, other); } @@ -799,6 +850,22 @@ public String prefixWithAlias(Origin source, String fragment) { String alias = getAlias(source); return ObjectUtils.isEmpty(source) ? fragment : alias + "." + fragment; } + + public boolean isConstructorContext() { + return false; + } + } + + static class ConstructorContext extends RenderContext { + + ConstructorContext(RenderContext rootContext) { + super(rootContext.aliases); + } + + @Override + public boolean isConstructorContext() { + return true; + } } /** @@ -807,7 +874,7 @@ public String prefixWithAlias(Origin source, String fragment) { */ public interface Origin { - String getName(); + String getName(); // TODO: mainly used along records - shoule we call this just name()? } /** @@ -1051,11 +1118,28 @@ public String toString() { } } - record ParameterExpression(String parameter) implements Expression { + record StringLiteralExpression(String literal) implements Expression { @Override public String render(RenderContext context) { - return parameter; + return "'%s'".formatted(literal.replaceAll("'", "''")); + } + + public String raw() { + return literal; + } + + @Override + public String toString() { + return render(RenderContext.EMPTY); + } + } + + record ParameterExpression(ParameterPlaceholder parameter) implements Expression { + + @Override + public String render(RenderContext context) { + return parameter.placeholder; } @Override @@ -1158,6 +1242,8 @@ record InPredicate(Expression path, String operator, Expression predicate) imple @Override public String render(RenderContext context) { + + //TODO: should we rather wrap it with nested or check if its a nested predicate before we call render return "%s %s (%s)".formatted(path.render(context), operator, predicate.render(context)); } @@ -1216,4 +1302,21 @@ public String toString() { public record PathAndOrigin(PropertyPath path, Origin origin, boolean onTheJoin) { } + + public record ParameterPlaceholder(String placeholder) { + + public ParameterPlaceholder { + Assert.hasText(placeholder, "Placeholder must not be null nor empty"); + } + + public static ParameterPlaceholder indexed(int index) { + return new ParameterPlaceholder("?%s".formatted(index)); + } + + public static ParameterPlaceholder named(String name) { + + Assert.hasText(name, "Placeholder name must not be empty"); + return new ParameterPlaceholder(":%s".formatted(name)); + } + } } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlUtils.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlUtils.java index 50da5558bb..d3b32380cd 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlUtils.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JpqlUtils.java @@ -15,27 +15,64 @@ */ package org.springframework.data.jpa.repository.query; +import static jakarta.persistence.metamodel.Attribute.PersistentAttributeType.ELEMENT_COLLECTION; +import static jakarta.persistence.metamodel.Attribute.PersistentAttributeType.MANY_TO_MANY; +import static jakarta.persistence.metamodel.Attribute.PersistentAttributeType.MANY_TO_ONE; +import static jakarta.persistence.metamodel.Attribute.PersistentAttributeType.ONE_TO_MANY; +import static jakarta.persistence.metamodel.Attribute.PersistentAttributeType.ONE_TO_ONE; + +import jakarta.persistence.ManyToOne; +import jakarta.persistence.OneToOne; import jakarta.persistence.criteria.From; import jakarta.persistence.criteria.Join; import jakarta.persistence.criteria.JoinType; +import jakarta.persistence.metamodel.Attribute; +import jakarta.persistence.metamodel.Attribute.PersistentAttributeType; +import jakarta.persistence.metamodel.Bindable; +import jakarta.persistence.metamodel.ManagedType; +import jakarta.persistence.metamodel.Metamodel; +import jakarta.persistence.metamodel.PluralAttribute; +import jakarta.persistence.metamodel.SingularAttribute; +import java.lang.annotation.Annotation; +import java.lang.reflect.AnnotatedElement; +import java.lang.reflect.Member; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; +import org.springframework.core.annotation.AnnotationUtils; import org.springframework.data.mapping.PropertyPath; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; /** * @author Mark Paluch */ class JpqlUtils { - static JpqlQueryBuilder.PathAndOrigin toExpressionRecursively(JpqlQueryBuilder.Origin source, From from, - PropertyPath property) { - return toExpressionRecursively(source, from, property, false); + private static final Map> ASSOCIATION_TYPES; + + static { + Map> persistentAttributeTypes = new HashMap<>(); + persistentAttributeTypes.put(ONE_TO_ONE, OneToOne.class); + persistentAttributeTypes.put(ONE_TO_MANY, null); + persistentAttributeTypes.put(MANY_TO_ONE, ManyToOne.class); + persistentAttributeTypes.put(MANY_TO_MANY, null); + persistentAttributeTypes.put(ELEMENT_COLLECTION, null); + + ASSOCIATION_TYPES = Collections.unmodifiableMap(persistentAttributeTypes); } - static JpqlQueryBuilder.PathAndOrigin toExpressionRecursively(JpqlQueryBuilder.Origin source, From from, - PropertyPath property, boolean isForSelection) { - return toExpressionRecursively(source, from, property, isForSelection, false); + static JpqlQueryBuilder.PathAndOrigin toExpressionRecursively(Metamodel metamodel, JpqlQueryBuilder.Origin source, + Bindable from, PropertyPath property) { + return toExpressionRecursively(metamodel, source, from, property, false); + } + + static JpqlQueryBuilder.PathAndOrigin toExpressionRecursively(Metamodel metamodel, JpqlQueryBuilder.Origin source, + Bindable from, PropertyPath property, boolean isForSelection) { + return toExpressionRecursively(metamodel, source, from, property, isForSelection, false); } /** @@ -45,18 +82,18 @@ static JpqlQueryBuilder.PathAndOrigin toExpressionRecursively(JpqlQueryBuilder.O * @param property the property path * @param isForSelection is the property navigated for the selection or ordering part of the query? * @param hasRequiredOuterJoin has a parent already required an outer join? - * @param the type of the expression * @return the expression */ @SuppressWarnings("unchecked") - static JpqlQueryBuilder.PathAndOrigin toExpressionRecursively(JpqlQueryBuilder.Origin source, From from, - PropertyPath property, boolean isForSelection, boolean hasRequiredOuterJoin) { + static JpqlQueryBuilder.PathAndOrigin toExpressionRecursively(Metamodel metamodel, JpqlQueryBuilder.Origin source, + Bindable from, PropertyPath property, boolean isForSelection, boolean hasRequiredOuterJoin) { String segment = property.getSegment(); boolean isLeafProperty = !property.hasNext(); - boolean requiresOuterJoin = QueryUtils.requiresOuterJoin(from, property, isForSelection, hasRequiredOuterJoin); + boolean requiresOuterJoin = requiresOuterJoin(metamodel, source, from, property, isForSelection, + hasRequiredOuterJoin); // if it does not require an outer join and is a leaf, simply get the segment if (!requiresOuterJoin && isLeafProperty) { @@ -66,9 +103,10 @@ static JpqlQueryBuilder.PathAndOrigin toExpressionRecursively(JpqlQueryBuilder.O // get or create the join JpqlQueryBuilder.Join joinSource = requiresOuterJoin ? JpqlQueryBuilder.leftJoin(source, segment) : JpqlQueryBuilder.innerJoin(source, segment); - JoinType joinType = requiresOuterJoin ? JoinType.LEFT : JoinType.INNER; - Join join = QueryUtils.getOrCreateJoin(from, segment, joinType); +// JoinType joinType = requiresOuterJoin ? JoinType.LEFT : JoinType.INNER; +// Join join = QueryUtils.getOrCreateJoin(from, segment, joinType); +// // if it's a leaf, return the join if (isLeafProperty) { return new JpqlQueryBuilder.PathAndOrigin(property, joinSource, true); @@ -76,7 +114,110 @@ static JpqlQueryBuilder.PathAndOrigin toExpressionRecursively(JpqlQueryBuilder.O PropertyPath nextProperty = Objects.requireNonNull(property.next(), "An element of the property path is null"); +// ManagedType managedType = ; + Bindable managedTypeForModel = (Bindable) getManagedTypeForModel(from); +// Attribute joinAttribute = getModelForPath(metamodel, property, getManagedTypeForModel(from), null); // recurse with the next property - return toExpressionRecursively(joinSource, join, nextProperty, isForSelection, requiresOuterJoin); + return toExpressionRecursively(metamodel, joinSource, managedTypeForModel, nextProperty, isForSelection, requiresOuterJoin); + } + + /** + * Checks if this attribute requires an outer join. This is the case e.g. if it hadn't already been fetched with an + * inner join and if it's an optional association, and if previous paths has already required outer joins. It also + * ensures outer joins are used even when Hibernate defaults to inner joins (HHH-12712 and HHH-12999) + * + * @param metamodel + * @param source + * @param bindable + * @param propertyPath + * @param isForSelection + * @param hasRequiredOuterJoin + * @return + */ + static boolean requiresOuterJoin(Metamodel metamodel, JpqlQueryBuilder.Origin source, Bindable bindable, + PropertyPath propertyPath, boolean isForSelection, boolean hasRequiredOuterJoin) { + + ManagedType managedType = getManagedTypeForModel(bindable); + Attribute attribute = getModelForPath(metamodel, propertyPath, managedType, bindable); + + boolean isPluralAttribute = bindable instanceof PluralAttribute; + if (attribute == null) { + return isPluralAttribute; + } + + if (!ASSOCIATION_TYPES.containsKey(attribute.getPersistentAttributeType())) { + return false; + } + + boolean isCollection = attribute.isCollection(); + + // if this path is an optional one to one attribute navigated from the not owning side we also need an + // explicit outer join to avoid https://hibernate.atlassian.net/browse/HHH-12712 + // and https://github.com/eclipse-ee4j/jpa-api/issues/170 + boolean isInverseOptionalOneToOne = PersistentAttributeType.ONE_TO_ONE == attribute.getPersistentAttributeType() + && StringUtils.hasText(getAnnotationProperty(attribute, "mappedBy", "")); + + boolean isLeafProperty = !propertyPath.hasNext(); + if (isLeafProperty && !isForSelection && !isCollection && !isInverseOptionalOneToOne && !hasRequiredOuterJoin) { + return false; + } + + return hasRequiredOuterJoin || getAnnotationProperty(attribute, "optional", true); + } + + @Nullable + private static T getAnnotationProperty(Attribute attribute, String propertyName, T defaultValue) { + + Class associationAnnotation = ASSOCIATION_TYPES.get(attribute.getPersistentAttributeType()); + + if (associationAnnotation == null) { + return defaultValue; + } + + Member member = attribute.getJavaMember(); + + if (!(member instanceof AnnotatedElement annotatedMember)) { + return defaultValue; + } + + Annotation annotation = AnnotationUtils.getAnnotation(annotatedMember, associationAnnotation); + return annotation == null ? defaultValue : (T) AnnotationUtils.getValue(annotation, propertyName); + } + + @Nullable + private static ManagedType getManagedTypeForModel(Bindable model) { + + if (model instanceof ManagedType managedType) { + return managedType; + } + + if (!(model instanceof SingularAttribute singularAttribute)) { + return null; + } + + return singularAttribute.getType() instanceof ManagedType managedType ? managedType : null; + } + + @Nullable + private static Attribute getModelForPath(Metamodel metamodel, PropertyPath path, + @Nullable ManagedType managedType, Bindable fallback) { + + String segment = path.getSegment(); + if (managedType != null) { + try { + return managedType.getAttribute(segment); + } catch (IllegalArgumentException ex) { + // ManagedType may be erased for some vendor if the attribute is declared as generic + } + } + + Class fallbackType = fallback.getBindableJavaType(); + try { + return metamodel.managedType(fallbackType).getAttribute(segment); + } catch (IllegalArgumentException e) { + + } + + return null; } } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/KeysetScrollSpecification.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/KeysetScrollSpecification.java index ede9516b05..df4106f35e 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/KeysetScrollSpecification.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/KeysetScrollSpecification.java @@ -24,6 +24,8 @@ import java.util.List; +import jakarta.persistence.metamodel.Bindable; +import jakarta.persistence.metamodel.Metamodel; import org.springframework.data.domain.KeysetScrollPosition; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Order; @@ -77,11 +79,11 @@ public Predicate createPredicate(Root root, CriteriaBuilder criteriaBuilder) } @Nullable - public JpqlQueryBuilder.Predicate createJpqlPredicate(From from, JpqlQueryBuilder.Entity entity, + public JpqlQueryBuilder.Predicate createJpqlPredicate(Bindable from, JpqlQueryBuilder.Entity entity, ParameterFactory factory) { KeysetScrollDelegate delegate = KeysetScrollDelegate.of(position.getDirection()); - return delegate.createPredicate(position, sort, new JpqlStrategy(from, entity, factory)); + return delegate.createPredicate(position, sort, new JpqlStrategy(null, from, entity, factory)); } @SuppressWarnings("rawtypes") @@ -128,22 +130,24 @@ public Predicate or(List intermediate) { private static class JpqlStrategy implements QueryStrategy { - private final From from; + private final Bindable from; private final JpqlQueryBuilder.Entity entity; private final ParameterFactory factory; + private final Metamodel metamodel; - public JpqlStrategy(From from, JpqlQueryBuilder.Entity entity, ParameterFactory factory) { + public JpqlStrategy(Metamodel metamodel, Bindable from, JpqlQueryBuilder.Entity entity, ParameterFactory factory) { this.from = from; this.entity = entity; this.factory = factory; + this.metamodel = metamodel; } @Override public JpqlQueryBuilder.Expression createExpression(String property) { - PropertyPath path = PropertyPath.from(property, from.getJavaType()); - return JpqlQueryBuilder.expression(JpqlUtils.toExpressionRecursively(entity, from, path)); + PropertyPath path = PropertyPath.from(property, from.getBindableJavaType()); + return JpqlQueryBuilder.expression(JpqlUtils.toExpressionRecursively(metamodel, entity, from, path)); } @Override diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java index c03384cb48..e6eedeede5 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/ParameterBinding.java @@ -42,6 +42,7 @@ * * @author Thomas Darimont * @author Mark Paluch + * @author Christoph Strobl */ class ParameterBinding { @@ -213,7 +214,10 @@ public PartTreeParameterBinding(BindingIdentifier identifier, ParameterOrigin or this.templates = templates; this.escape = escape; - this.type = value == null && Type.SIMPLE_PROPERTY.equals(part.getType()) ? Type.IS_NULL : part.getType(); + this.type = value == null + && (Type.SIMPLE_PROPERTY.equals(part.getType()) || Type.NEGATING_SIMPLE_PROPERTY.equals(part.getType())) + ? Type.IS_NULL + : part.getType(); this.ignoreCase = Part.IgnoreCaseType.ALWAYS.equals(part.shouldIgnoreCase()); this.noWildcards = part.getProperty().getLeafProperty().isCollection(); } diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java index 2aa982f174..725a97a7d9 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeJpaQuery.java @@ -62,7 +62,7 @@ public class PartTreeJpaQuery extends AbstractJpaQuery { private final PartTree tree; private final JpaParameters parameters; - private final QueryPreparer query; + private final QueryPreparer queryPreparer; private final QueryPreparer countQuery; private final EntityManager em; private final EscapeCharacter escape; @@ -102,7 +102,7 @@ public class PartTreeJpaQuery extends AbstractJpaQuery { this.tree = new PartTree(method.getName(), domainClass); validate(tree, parameters, method.toString()); this.countQuery = new CountQueryPreparer(); - this.query = tree.isCountProjection() ? countQuery : new QueryPreparer(); + this.queryPreparer = tree.isCountProjection() ? countQuery : new QueryPreparer(); } catch (Exception o_O) { throw new IllegalArgumentException( @@ -112,7 +112,7 @@ public class PartTreeJpaQuery extends AbstractJpaQuery { @Override public Query doCreateQuery(JpaParametersParameterAccessor accessor) { - return query.createQuery(accessor); + return queryPreparer.createQuery(accessor); } @Override @@ -210,12 +210,7 @@ private static boolean expectsCollection(Type type) { */ private class QueryPreparer { - private final Map cache = new LinkedHashMap() { - @Override - protected boolean removeEldestEntry(Map.Entry eldest) { - return size() > 256; - } - }; + private final PartTreeQueryCache cache = new PartTreeQueryCache(); /** * Creates a new {@link Query} for the given parameter values. @@ -279,7 +274,7 @@ private Query restrictMaxResultsIfNecessary(Query query, @Nullable ScrollPositio protected JpqlQueryCreator createCreator(Sort sort, JpaParametersParameterAccessor accessor) { synchronized (cache) { - JpqlQueryCreator jpqlQueryCreator = cache.get(sort); + JpqlQueryCreator jpqlQueryCreator = cache.get(sort, accessor); // this caching thingy is broken due to IS NULL rendering for simple properties if (jpqlQueryCreator != null) { return jpqlQueryCreator; } @@ -304,7 +299,7 @@ protected JpqlQueryCreator createCreator(Sort sort, JpaParametersParameterAccess } synchronized (cache) { - cache.put(sort, creator); + cache.put(sort, accessor, creator); } return creator; diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeQueryCache.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeQueryCache.java new file mode 100644 index 0000000000..71f952c2c8 --- /dev/null +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/PartTreeQueryCache.java @@ -0,0 +1,100 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.repository.query; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +import org.springframework.data.domain.Sort; +import org.springframework.lang.Nullable; +import org.springframework.util.ObjectUtils; + +/** + * @author Christoph Strobl + */ +class PartTreeQueryCache { + + private final Map cache = new LinkedHashMap() { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > 256; + } + }; + + @Nullable + JpqlQueryCreator get(Sort sort, JpaParametersParameterAccessor accessor) { + return cache.get(CacheKey.of(sort, accessor)); + } + + @Nullable + JpqlQueryCreator put(Sort sort, JpaParametersParameterAccessor accessor, JpqlQueryCreator creator) { + return cache.put(CacheKey.of(sort, accessor), creator); + } + + static class CacheKey { + + private final Sort sort; + private final Map params; + + public CacheKey(Sort sort, Map params) { + this.sort = sort; + this.params = params; + } + + static CacheKey of(Sort sort, JpaParametersParameterAccessor accessor) { + + Object[] values = accessor.getValues(); + if (ObjectUtils.isEmpty(values)) { + return new CacheKey(sort, Map.of()); + } + + return new CacheKey(sort, toNullableMap(values)); + } + + static Map toNullableMap(Object[] args) { + + Map paramMap = new HashMap<>(args.length); + for (int i = 0; i < args.length; i++) { + paramMap.put(i, args[i] != null ? Nulled.NO : Nulled.YES); + } + return paramMap; + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CacheKey cacheKey = (CacheKey) o; + return sort.equals(cacheKey.sort) && params.equals(cacheKey.params); + } + + @Override + public int hashCode() { + return Objects.hash(sort, params); + } + } + + enum Nulled { + YES, NO + } + +} diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java index 11010a7aca..4b05507bc5 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/QueryUtils.java @@ -914,7 +914,7 @@ private static T getAnnotationProperty(Attribute attribute, String pro * @param attribute the attribute name to check. * @return true if the attribute has already been inner joined */ - private static boolean isAlreadyInnerJoined(From from, String attribute) { + static boolean isAlreadyInnerJoined(From from, String attribute) { for (Fetch fetch : from.getFetches()) { diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/UserRepositoryFinderTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/UserRepositoryFinderTests.java index fdf3330ca5..e1dada1234 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/UserRepositoryFinderTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/UserRepositoryFinderTests.java @@ -405,4 +405,5 @@ void findByNegatingSimplePropertyUsingMixedNullNonNullArgument() { result = userRepository.findUserByLastname(carter.getLastname()); assertThat(result).containsExactly(carter); } + } diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryCreatorTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryCreatorTests.java new file mode 100644 index 0000000000..dc2866fa8b --- /dev/null +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryCreatorTests.java @@ -0,0 +1,1039 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.repository.query; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import jakarta.persistence.ElementCollection; +import jakarta.persistence.EntityManager; +import jakarta.persistence.Id; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.OneToMany; +import jakarta.persistence.Tuple; +import jakarta.persistence.metamodel.Metamodel; + +import java.util.Date; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.FieldSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.Sort; +import org.springframework.data.jpa.repository.support.JpqlQueryTemplates; +import org.springframework.data.jpa.util.TestMetaModel; +import org.springframework.data.projection.ProjectionFactory; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.query.ParameterAccessor; +import org.springframework.data.repository.query.ReturnedType; +import org.springframework.data.repository.query.parser.PartTree; +import org.springframework.data.util.Lazy; +import org.springframework.lang.Nullable; + +/** + * @author Christoph Strobl + */ +class JpaQueryCreatorTests { + + private static final TestMetaModel ORDER = TestMetaModel.hibernateModel(Order.class, LineItem.class, Product.class); + private static final TestMetaModel PERSON = TestMetaModel.hibernateModel(Person.class); + + static List ignoreCaseTemplates = List.of(JpqlQueryTemplates.LOWER, JpqlQueryTemplates.UPPER); + + @Test + void simpleProperty() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCountry") // + .withParameters("AT") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.country = ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void simpleNullProperty() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCountry") // + .withParameterTypes(String.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.country IS NULL", Order.class.getName()) // + .validateQuery(); + } + + @Test + void negatingSimpleProperty() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCountryNot") // + .withParameters("US") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.country != ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void negatingSimpleNullProperty() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCountryIsNot") // + .withParameterTypes(String.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.country IS NOT NULL", Order.class.getName()) // + .validateQuery(); + } + + @Test + void simpleAnd() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCountryAndDate") // + .withParameters("GB", new Date()) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.country = ?1 AND o.date = ?2", Order.class.getName()) // + .validateQuery(); + } + + @Test + void simpleOr() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCountryOrDate") // + .withParameters("BE", new Date()) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.country = ?1 OR o.date = ?2", Order.class.getName()) // + .validateQuery(); + } + + @Test + void simpleAndOr() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCountryAndDateOrCompleted") // + .withParameters("IT", new Date(), Boolean.FALSE) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.country = ?1 AND o.date = ?2 OR o.completed = ?3", + Order.class.getName()) // + .validateQuery(); + } + + @Test + void distinct() { + + queryCreator(ORDER) // + .forTree(Order.class, "findDistinctOrderByCountry") // + .withParameters("AU") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT DISTINCT o FROM %s o WHERE o.country = ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void count() { + + queryCreator(ORDER) // + .forTree(Order.class, "countOrderByCountry") // + .returing(Long.class) // + .withParameters("AU") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT COUNT(o) FROM %s o WHERE o.country = ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void countWithJoins() { + + queryCreator(ORDER) // + .forTree(Order.class, "countOrderByLineItemsQuantityGreaterThan") // + .returing(Long.class) // + .withParameterTypes(Integer.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT COUNT(o) FROM %s o LEFT JOIN o.lineItems l WHERE l.quantity > ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void countDistinct() { + + queryCreator(ORDER) // + .forTree(Order.class, "countDistinctOrderByCountry") // + .returing(Long.class) // + .withParameters("AU") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT COUNT(DISTINCT o) FROM %s o WHERE o.country = ?1", Order.class.getName()) // + .validateQuery(); + } + + @ParameterizedTest + @FieldSource("ignoreCaseTemplates") + void simplePropertyIgnoreCase(JpqlQueryTemplates ingnoreCaseTemplate) { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCountryIgnoreCase") // + .ingnoreCaseAs(ingnoreCaseTemplate) // + .withParameters("BB") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE %s(o.country) = %s(?1)", Order.class.getName(), + ingnoreCaseTemplate.getIgnoreCaseOperator(), ingnoreCaseTemplate.getIgnoreCaseOperator()) // + .validateQuery(); + } + + @ParameterizedTest + @FieldSource("ignoreCaseTemplates") + void simplePropertyAllIgnoreCase(JpqlQueryTemplates ingnoreCaseTemplate) { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameAndProductTypeAllIgnoreCase") // + .ingnoreCaseAs(ingnoreCaseTemplate) // + .withParameters("spring", "data") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE %s(p.name) = %s(?1) AND %s(p.productType) = %s(?2)", + Product.class.getName(), ingnoreCaseTemplate.getIgnoreCaseOperator(), + ingnoreCaseTemplate.getIgnoreCaseOperator(), ingnoreCaseTemplate.getIgnoreCaseOperator(), + ingnoreCaseTemplate.getIgnoreCaseOperator()) // + .validateQuery(); + } + + @ParameterizedTest + @FieldSource("ignoreCaseTemplates") + void simplePropertyMixedCase(JpqlQueryTemplates ingnoreCaseTemplate) { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameAndProductTypeIgnoreCase") // + .ingnoreCaseAs(ingnoreCaseTemplate) // + .withParameters("spring", "data") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.name = ?1 AND %s(p.productType) = %s(?2)", Product.class.getName(), + ingnoreCaseTemplate.getIgnoreCaseOperator(), ingnoreCaseTemplate.getIgnoreCaseOperator(), + ingnoreCaseTemplate.getIgnoreCaseOperator()) // + .validateQuery(); + } + + @Test + void lessThan() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByDateLessThan") // + .withParameterTypes(Date.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.date < ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void lessThanEqual() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByDateLessThanEqual") // + .withParameterTypes(Date.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.date <= ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void greaterThan() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByDateGreaterThan") // + .withParameterTypes(Date.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.date > ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void before() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByDateBefore") // + .withParameterTypes(Date.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.date < ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void after() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByDateAfter") // + .withParameterTypes(Date.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.date > ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void between() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByDateBetween") // + .withParameterTypes(Date.class, Date.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.date BETWEEN ?1 AND ?2", Order.class.getName()) // + .validateQuery(); + } + + @Test + void isNull() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByDateIsNull") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.date IS NULL", Order.class.getName()) // + .validateQuery(); + } + + @Test + void isNotNull() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByDateIsNotNull") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.date IS NOT NULL", Order.class.getName()) // + .validateQuery(); + } + + @ParameterizedTest + @ValueSource(strings = { "", "spring", "%spring", "spring%", "%spring%" }) + void like(String parameterValue) { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameLike") // + .withParameters(parameterValue) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.name LIKE ?1 ESCAPE '\\'", Product.class.getName()) // + .expectPlaceholderValue("?1", parameterValue) // + .validateQuery(); + } + + @Test + void containingString() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameContaining") // + .withParameters("spring") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.name LIKE ?1 ESCAPE '\\'", Product.class.getName()) // + .expectPlaceholderValue("?1", "%spring%") // + .validateQuery(); + } + + @Test + void notContainingString() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameNotContaining") // + .withParameters("spring") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.name NOT LIKE ?1 ESCAPE '\\'", Product.class.getName()) // + .expectPlaceholderValue("?1", "%spring%") // + .validateQuery(); + } + + @Test + void in() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameIn") // + .withParameters(List.of("spring", "data")) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.name IN (?1)", Product.class.getName()) // + .expectPlaceholderValue("?1", List.of("spring", "data")) // + .validateQuery(); + } + + @Test + void notIn() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameNotIn") // + .withParameters(List.of("spring", "data")) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.name NOT IN (?1)", Product.class.getName()) // + .expectPlaceholderValue("?1", List.of("spring", "data")) // + .validateQuery(); + } + + @Test + void containingSingleEntryElementCollection() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByCategoriesContaining") // + .withParameterTypes(String.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE ?1 MEMBER OF p.categories", Product.class.getName()) // + .validateQuery(); + } + + @Test + void notContainingSingleEntryElementCollection() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByCategoriesNotContaining") // + .withParameterTypes(String.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE ?1 NOT MEMBER OF p.categories", Product.class.getName()) // + .validateQuery(); + } + + @ParameterizedTest + @FieldSource("ignoreCaseTemplates") + void likeWithIgnoreCase(JpqlQueryTemplates ingnoreCaseTemplate) { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameLikeIgnoreCase") // + .ingnoreCaseAs(ingnoreCaseTemplate) // + .withParameters("%spring%") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE %s(p.name) LIKE %s(?1) ESCAPE '\\'", Product.class.getName(), + ingnoreCaseTemplate.getIgnoreCaseOperator(), ingnoreCaseTemplate.getIgnoreCaseOperator()) // + .expectPlaceholderValue("?1", "%spring%") // + .validateQuery(); + } + + @ParameterizedTest + @ValueSource(strings = { "", "spring", "%spring", "spring%", "%spring%" }) + void notLike(String parameterValue) { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameNotLike") // + .withParameters(parameterValue) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.name NOT LIKE ?1 ESCAPE '\\'", Product.class.getName()) // + .expectPlaceholderValue("?1", parameterValue) // + .validateQuery(); + } + + @ParameterizedTest + @FieldSource("ignoreCaseTemplates") + void notLikeWithIgnoreCase(JpqlQueryTemplates ingnoreCaseTemplate) { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameNotLikeIgnoreCase") // + .ingnoreCaseAs(ingnoreCaseTemplate) // + .withParameters("%spring%") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE %s(p.name) NOT LIKE %s(?1) ESCAPE '\\'", Product.class.getName(), + ingnoreCaseTemplate.getIgnoreCaseOperator(), ingnoreCaseTemplate.getIgnoreCaseOperator()) // + .expectPlaceholderValue("?1", "%spring%") // + .validateQuery(); + } + + @Test + void startingWith() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameStartingWith") // + .withParameters("spring") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.name LIKE ?1 ESCAPE '\\'", Product.class.getName()) // + .expectPlaceholderValue("?1", "spring%") // + .validateQuery(); + } + + @ParameterizedTest + @FieldSource("ignoreCaseTemplates") + void startingWithIgnoreCase(JpqlQueryTemplates ingnoreCaseTemplate) { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameStartingWithIgnoreCase") // + .ingnoreCaseAs(ingnoreCaseTemplate) // + .withParameters("spring") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE %s(p.name) LIKE %s(?1) ESCAPE '\\'", Product.class.getName(), + ingnoreCaseTemplate.getIgnoreCaseOperator(), ingnoreCaseTemplate.getIgnoreCaseOperator()) // + .expectPlaceholderValue("?1", "spring%") // + .validateQuery(); + } + + @Test + void endingWith() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameEndingWith") // + .withParameters("spring") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.name LIKE ?1 ESCAPE '\\'", Product.class.getName()) // + .expectPlaceholderValue("?1", "%spring") // + .validateQuery(); + } + + @ParameterizedTest + @FieldSource("ignoreCaseTemplates") + void endingWithIgnoreCase(JpqlQueryTemplates ingnoreCaseTemplate) { + + queryCreator(ORDER) // + .forTree(Product.class, "findProductByNameEndingWithIgnoreCase") // + .ingnoreCaseAs(ingnoreCaseTemplate) // + .withParameters("spring") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE %s(p.name) LIKE %s(?1) ESCAPE '\\'", Product.class.getName(), + ingnoreCaseTemplate.getIgnoreCaseOperator(), ingnoreCaseTemplate.getIgnoreCaseOperator()) // + .expectPlaceholderValue("?1", "%spring") // + .validateQuery(); + } + + @Test + void greaterThanEqual() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByDateGreaterThanEqual") // + .withParameterTypes(Date.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.date >= ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void isTrue() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCompletedIsTrue") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.completed = TRUE", Order.class.getName()) // + .validateQuery(); + } + + @Test + void isFalse() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCompletedIsFalse") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.completed = FALSE", Order.class.getName()) // + .validateQuery(); + } + + @Test + void empty() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByLineItemsEmpty") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.lineItems IS EMPTY", Order.class.getName()) // + .validateQuery(); + } + + @Test + void notEmpty() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByLineItemsNotEmpty") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.lineItems IS NOT EMPTY", Order.class.getName()) // + .validateQuery(); + } + + @Test + void sortBySingle() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByCountryOrderByDate") // + .withParameters("CA") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o WHERE o.country = ?1 ORDER BY o.date asc", Order.class.getName()) // + .validateQuery(); + } + + @Test + void sortByMulti() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByOrderByCountryAscDateDesc") // + .withParameters() // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o ORDER BY o.country asc, o.date desc", Order.class.getName()) // + .validateQuery(); + } + + @Disabled("should we support this?") + @ParameterizedTest + @FieldSource("ignoreCaseTemplates") + void sortBySingleIngoreCase(JpqlQueryTemplates ingoreCase) { + + String jpql = queryCreator(ORDER) // + .forTree(Order.class, "findOrderByOrderByCountryAscAllIgnoreCase") // + .render(); + + assertThat(jpql).isEqualTo("SELECT o FROM %s o ORDER BY %s(o.date) asc", Order.class.getName(), + ingoreCase.getIgnoreCaseOperator()); + } + + @Test + void matchSimpleJoin() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByLineItemsQuantityGreaterThan") // + .withParameterTypes(Integer.class) // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o LEFT JOIN o.lineItems l WHERE l.quantity > ?1", Order.class.getName()) // + .validateQuery(); + } + + @Test + void matchSimpleNestedJoin() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByLineItemsProductNameIs") // + .withParameters("spring") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT o FROM %s o LEFT JOIN o.lineItems l INNER JOIN l.product p WHERE p.name = ?1", + Order.class.getName()) // + .validateQuery(); + } + + @Test + void matchMultiOnNestedJoin() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByLineItemsQuantityGreaterThanAndLineItemsProductNameIs") // + .withParameters(10, "spring") // + .as(QueryCreatorTester::create) // + .expectJpql( + "SELECT o FROM %s o LEFT JOIN o.lineItems l INNER JOIN l.product p WHERE l.quantity > ?1 AND p.name = ?2", + Order.class.getName()) // + .validateQuery(); + } + + @Test + void matchSameEntityMultipleTimes() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByLineItemsProductNameIsAndLineItemsProductNameIsNot") // + .withParameters("spring", "sukrauq") // + .as(QueryCreatorTester::create) // + .expectJpql( + "SELECT o FROM %s o LEFT JOIN o.lineItems l INNER JOIN l.product p WHERE p.name = ?1 AND p.name != ?2", + Order.class.getName()) // + .validateQuery(); + } + + @Test + void matchSameEntityMultipleTimesViaDifferentProperties() { + + queryCreator(ORDER) // + .forTree(Order.class, "findOrderByLineItemsProductNameIsAndLineItemsProduct2NameIs") // + .withParameters(10, "spring") // + .as(QueryCreatorTester::create) // + .expectJpql( + "SELECT o FROM %s o LEFT JOIN o.lineItems l INNER JOIN l.product p INNER JOIN l.product2 join_0 WHERE p.name = ?1 AND join_0.name = ?2", + Order.class.getName()) // + .validateQuery(); + } + + @Test + void dtoProjection() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProjectionByNameIs") // + .returing(DtoProductProjection.class) // + .withParameters("spring") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT new %s(p.name, p.productType) FROM %s p WHERE p.name = ?1", + DtoProductProjection.class.getName(), Product.class.getName()) // + .validateQuery(); + } + + @Test + void interfaceProjection() { + + queryCreator(ORDER) // + .forTree(Product.class, "findProjectionByNameIs") // + .returing(InterfaceProductProjection.class) // + .withParameters("spring") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p.name name, p.productType productType FROM %s p WHERE p.name = ?1", + Product.class.getName()) // + .validateQuery(); + } + + @ParameterizedTest + @ValueSource(classes = { Tuple.class, Map.class }) + void tupleProjection(Class resultType) { + + queryCreator(PERSON) // + .forTree(Person.class, "findProjectionByFirstnameIs") // + .returing(resultType) // + .withParameters("chris") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p.id id, p.firstname firstname, p.lastname lastname FROM %s p WHERE p.firstname = ?1", + Person.class.getName()) // + .validateQuery(); + } + + @ParameterizedTest + @ValueSource(classes = { Long.class, List.class, Person.class }) + void delete(Class resultType) { + + queryCreator(PERSON) // + .forTree(Person.class, "deletePersonByFirstname") // + .returing(resultType) // + .withParameters("chris") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p FROM %s p WHERE p.firstname = ?1", Person.class.getName()) // + .validateQuery(); + } + + @Test + void exists() { + + queryCreator(PERSON) // + .forTree(Person.class, "existsPersonByFirstname") // + .returing(Long.class).withParameters("chris") // + .as(QueryCreatorTester::create) // + .expectJpql("SELECT p.id id FROM %s p WHERE p.firstname = ?1", Person.class.getName()) // + .validateQuery(); + } + + QueryCreatorBuilder queryCreator(Metamodel metamodel) { + return new DefaultCreatorBuilder(metamodel); + } + + JpaQueryCreator queryCreator(PartTree tree, ReturnedType returnedType, Metamodel metamodel, Object... arguments) { + return queryCreator(tree, returnedType, metamodel, JpqlQueryTemplates.UPPER, arguments); + } + + JpaQueryCreator queryCreator(PartTree tree, ReturnedType returnedType, Metamodel metamodel, + JpqlQueryTemplates templates, Object... arguments) { + + ParameterMetadataProvider parameterMetadataProvider = new ParameterMetadataProvider( + StubJpaParameterParameterAccessor.accessor(arguments), EscapeCharacter.DEFAULT, templates); + return queryCreator(tree, returnedType, metamodel, templates, parameterMetadataProvider); + } + + JpaQueryCreator queryCreator(PartTree tree, ReturnedType returnedType, Metamodel metamodel, + JpqlQueryTemplates templates, Class... argumentTypes) { + + ParameterMetadataProvider parameterMetadataProvider = new ParameterMetadataProvider( + StubJpaParameterParameterAccessor.accessor(argumentTypes), EscapeCharacter.DEFAULT, templates); + return queryCreator(tree, returnedType, metamodel, templates, parameterMetadataProvider); + } + + JpaQueryCreator queryCreator(PartTree tree, ReturnedType returnedType, Metamodel metamodel, + JpqlQueryTemplates templates, JpaParametersParameterAccessor parameterAccessor) { + + EntityManager entityManager = mock(EntityManager.class); + when(entityManager.getMetamodel()).thenReturn(metamodel); + + ParameterMetadataProvider parameterMetadataProvider = new ParameterMetadataProvider(parameterAccessor, + EscapeCharacter.DEFAULT, templates); + return new JpaQueryCreator(tree, returnedType, parameterMetadataProvider, templates, entityManager); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private JpaParametersParameterAccessor accessor(Class... argumentTypes) { + + return StubJpaParameterParameterAccessor.accessor(argumentTypes); + } + + @jakarta.persistence.Entity + static class Order { + + @Id Long id; + Date date; + String country; + Boolean completed; + + @OneToMany List lineItems; + } + + @jakarta.persistence.Entity + static class LineItem { + + @Id Long id; + + @ManyToOne Product product; + @ManyToOne Product product2; + int quantity; + + } + + @jakarta.persistence.Entity + static class Person { + @Id Long id; + String firstname; + String lastname; + } + + @jakarta.persistence.Entity + static class Product { + + @Id Long id; + + String name; + String productType; + + @ElementCollection List categories; + } + + static class DtoProductProjection { + + String name; + String productType; + + DtoProductProjection(String name, String productType) { + this.name = name; + this.productType = productType; + } + } + + interface InterfaceProductProjection { + String getName(); + + String getProductType(); + } + + static class QueryCreatorTester { + + QueryCreatorBuilder builder; + Lazy jpql; + + private QueryCreatorTester(QueryCreatorBuilder builder) { + this.builder = builder; + this.jpql = Lazy.of(builder::render); + } + + static QueryCreatorTester create(QueryCreatorBuilder builder) { + return new QueryCreatorTester(builder); + } + + QueryCreatorTester expectJpql(String jpql, Object... args) { + + assertThat(this.jpql.get()).isEqualTo(jpql, args); + return this; + } + + QueryCreatorTester expectPlaceholderValue(String placeholder, Object value) { + return expectBindingAt(builder.bindingIndexFor(placeholder), value); + } + + QueryCreatorTester expectBindingAt(int position, Object value) { + + Object current = builder.bindableParameters().getBindableValue(position - 1); + assertThat(current).isEqualTo(value); + return this; + } + + QueryCreatorTester validateQuery() { + + if (builder instanceof DefaultCreatorBuilder dcb && dcb.metamodel instanceof TestMetaModel tmm) { + return validateQuery(tmm.entityManager()); + } + + throw new IllegalStateException("No EntityManager found, plase provide one via [verify(EntityManager)]"); + } + + QueryCreatorTester validateQuery(EntityManager entityManager) { + + if (builder instanceof DefaultCreatorBuilder dcb) { + entityManager.createQuery(this.jpql.get(), dcb.returnedType.getReturnedType()); + } else { + entityManager.createQuery(this.jpql.get()); + } + return this; + } + + } + + interface QueryCreatorBuilder { + + QueryCreatorBuilder returing(ReturnedType returnedType); + + QueryCreatorBuilder forTree(Class root, String querySource); + + QueryCreatorBuilder withParameters(Object... arguments); + + QueryCreatorBuilder withParameterTypes(Class... argumentTypes); + + QueryCreatorBuilder ingnoreCaseAs(JpqlQueryTemplates queryTemplate); + + default T as(Function transformer) { + return transformer.apply(this); + } + + default String render() { + return render(null); + } + + ParameterAccessor bindableParameters(); + + int bindingIndexFor(String placeholder); + + String render(@Nullable Sort sort); + + QueryCreatorBuilder returing(Class type); + } + + private class DefaultCreatorBuilder implements QueryCreatorBuilder { + + private static final ProjectionFactory PROJECTION_FACTORY = new SpelAwareProxyProjectionFactory(); + + private final Metamodel metamodel; + private ReturnedType returnedType; + private PartTree partTree; + private Object[] arguments; + private Class[] argumentTypes; + private JpqlQueryTemplates queryTemplates; + private Lazy queryCreator = Lazy.of(this::initJpaQueryCreator); + private Lazy parameterAccessor = Lazy.of(this::initParameterAccessor); + + public DefaultCreatorBuilder(Metamodel metamodel) { + this.metamodel = metamodel; + arguments = new Object[0]; + queryTemplates = JpqlQueryTemplates.UPPER; + } + + @Override + public QueryCreatorBuilder returing(ReturnedType returnedType) { + this.returnedType = returnedType; + return this; + } + + @Override + public QueryCreatorBuilder returing(Class type) { + + if (this.returnedType != null) { + return returing(ReturnedType.of(type, returnedType.getDomainType(), PROJECTION_FACTORY)); + } + + return returing(ReturnedType.of(type, type, PROJECTION_FACTORY)); + } + + @Override + public QueryCreatorBuilder forTree(Class root, String querySource) { + + this.partTree = new PartTree(querySource, root); + if (returnedType == null) { + returnedType = ReturnedType.of(root, root, PROJECTION_FACTORY); + } + return this; + } + + @Override + public QueryCreatorBuilder withParameters(Object... arguments) { + this.arguments = arguments; + return this; + } + + @Override + public QueryCreatorBuilder withParameterTypes(Class... argumentTypes) { + this.argumentTypes = argumentTypes; + return this; + } + + @Override + public QueryCreatorBuilder ingnoreCaseAs(JpqlQueryTemplates queryTemplate) { + this.queryTemplates = queryTemplate; + return this; + } + + @Override + public String render(@Nullable Sort sort) { + return queryCreator.get().createQuery(sort != null ? sort : Sort.unsorted()); + } + + @Override + public int bindingIndexFor(String placeholder) { + + return queryCreator.get().getBindings().stream().filter(binding -> { + + if (binding.getIdentifier().hasPosition() && placeholder.startsWith("?")) { + return binding.getPosition() == Integer.parseInt(placeholder.substring(1)); + } + + if (!binding.getIdentifier().hasName()) { + return false; + } + + return binding.getIdentifier().getName().equals(placeholder); + }).findFirst().map(ParameterBinding::getPosition).orElse(-1); + } + + @Override + public ParameterAccessor bindableParameters() { + + return new ParameterAccessor() { + @Nullable + @Override + public ScrollPosition getScrollPosition() { + return null; + } + + @Override + public Pageable getPageable() { + return null; + } + + @Override + public Sort getSort() { + return null; + } + + @Nullable + @Override + public Class findDynamicProjection() { + return null; + } + + @Nullable + @Override + public Object getBindableValue(int index) { + + ParameterBinding parameterBinding = queryCreator.get().getBindings().get(index); + return parameterBinding.prepare(parameterAccessor.get().getBindableValue(index)); + } + + @Override + public boolean hasBindableNullValue() { + return false; + } + + @Override + public Iterator iterator() { + return null; + } + }; + + } + + JpaParametersParameterAccessor initParameterAccessor() { + + if (arguments.length > 0 || argumentTypes == null) { + return StubJpaParameterParameterAccessor.accessor(arguments); + } + return StubJpaParameterParameterAccessor.accessor(argumentTypes); + } + + JpaQueryCreator initJpaQueryCreator() { + + if (arguments.length > 0 || argumentTypes == null) { + return queryCreator(partTree, returnedType, metamodel, queryTemplates, parameterAccessor.get()); + } + return queryCreator(partTree, returnedType, metamodel, queryTemplates, parameterAccessor.get()); + } + } +} diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilderUnitTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilderUnitTests.java new file mode 100644 index 0000000000..04fb7079de --- /dev/null +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JpqlQueryBuilderUnitTests.java @@ -0,0 +1,265 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.repository.query; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import jakarta.persistence.Id; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.OneToMany; + +import java.util.Date; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.springframework.data.domain.Sort; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.AbstractJpqlQuery; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.Entity; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.Expression; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.Join; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.OrderExpression; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.Origin; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.ParameterPlaceholder; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.PathAndOrigin; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.Predicate; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.RenderContext; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.SelectStep; +import org.springframework.data.jpa.repository.query.JpqlQueryBuilder.WhereStep; + +/** + * @author Christoph Strobl + */ +class JpqlQueryBuilderUnitTests { + + @Test + void placeholdersRenderCorrectly() { + + assertThat(JpqlQueryBuilder.parameter(ParameterPlaceholder.indexed(1)).render(RenderContext.EMPTY)).isEqualTo("?1"); + assertThat(JpqlQueryBuilder.parameter(ParameterPlaceholder.named("arg1")).render(RenderContext.EMPTY)) + .isEqualTo(":arg1"); + assertThat(JpqlQueryBuilder.parameter("?1").render(RenderContext.EMPTY)).isEqualTo("?1"); + } + + @Test + void placeholdersErrorOnInvaludInput() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> JpqlQueryBuilder.parameter((String) null)); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JpqlQueryBuilder.parameter("")); + } + + @Test + void stringLiteralRendersAsQuotedString() { + + assertThat(JpqlQueryBuilder.stringLiteral("literal").render(RenderContext.EMPTY)).isEqualTo("'literal'"); + + /* JPA Spec - 4.6.1 Literals: + > A string literal that includes a single quote is represented by two single quotes--for example: 'literal''s'. */ + assertThat(JpqlQueryBuilder.stringLiteral("literal's").render(RenderContext.EMPTY)).isEqualTo("'literal''s'"); + } + + @Test + void entity() { + + Entity entity = JpqlQueryBuilder.entity(Order.class); + assertThat(entity.alias()).isEqualTo("o"); + assertThat(entity.entity()).isEqualTo(Order.class.getName()); + assertThat(entity.getName()).isEqualTo(Order.class.getSimpleName()); // TODO: this really confusing + assertThat(entity.simpleName()).isEqualTo(Order.class.getSimpleName()); + } + + @Test + void literalExpressionRendersAsIs() { + Expression expression = JpqlQueryBuilder.expression("CONCAT(person.lastName, ‘, ’, person.firstName))"); + assertThat(expression.render(RenderContext.EMPTY)).isEqualTo("CONCAT(person.lastName, ‘, ’, person.firstName))"); + } + + @Test + void xxx() { + + Entity entity = JpqlQueryBuilder.entity(Order.class); + PathAndOrigin orderDate = JpqlQueryBuilder.path(entity, "date"); + + String fragment = JpqlQueryBuilder.where(orderDate).eq("{d '2024-11-05'}").render(ctx(entity)); + + assertThat(fragment).isEqualTo("o.date = {d '2024-11-05'}"); + + // JpqlQueryBuilder.where(PathAndOrigin) + } + + @Test + void predicateRendering() { + + + Entity entity = JpqlQueryBuilder.entity(Order.class); + WhereStep where = JpqlQueryBuilder.where(JpqlQueryBuilder.path(entity, "country")); + + assertThat(where.between("'AT'", "'DE'").render(ctx(entity))).isEqualTo("o.country BETWEEN 'AT' AND 'DE'"); + assertThat(where.eq("'AT'").render(ctx(entity))).isEqualTo("o.country = 'AT'"); + assertThat(where.eq(JpqlQueryBuilder.stringLiteral("AT")).render(ctx(entity))).isEqualTo("o.country = 'AT'"); + assertThat(where.gt("'AT'").render(ctx(entity))).isEqualTo("o.country > 'AT'"); + assertThat(where.gte("'AT'").render(ctx(entity))).isEqualTo("o.country >= 'AT'"); + // TODO: that is really really bad + // lange namen + assertThat(where.in("'AT', 'DE'").render(ctx(entity))).isEqualTo("o.country IN ('AT', 'DE')"); + + // 1 in age - cleanup what is not used - remove everything eles + // assertThat(where.inMultivalued("'AT', 'DE'").render(ctx(entity))).isEqualTo("o.country IN ('AT', 'DE')"); // + assertThat(where.isEmpty().render(ctx(entity))).isEqualTo("o.country IS EMPTY"); + assertThat(where.isNotEmpty().render(ctx(entity))).isEqualTo("o.country IS NOT EMPTY"); + assertThat(where.isTrue().render(ctx(entity))).isEqualTo("o.country = TRUE"); + assertThat(where.isFalse().render(ctx(entity))).isEqualTo("o.country = FALSE"); + assertThat(where.isNull().render(ctx(entity))).isEqualTo("o.country IS NULL"); + assertThat(where.isNotNull().render(ctx(entity))).isEqualTo("o.country IS NOT NULL"); + assertThat(where.like("'\\_%'", "" + EscapeCharacter.DEFAULT.getEscapeCharacter()).render(ctx(entity))) + .isEqualTo("o.country LIKE '\\_%' ESCAPE '\\'"); + assertThat(where.notLike("'\\_%'", "" + EscapeCharacter.DEFAULT.getEscapeCharacter()).render(ctx(entity))) + .isEqualTo("o.country NOT LIKE '\\_%' ESCAPE '\\'"); + assertThat(where.lt("'AT'").render(ctx(entity))).isEqualTo("o.country < 'AT'"); + assertThat(where.lte("'AT'").render(ctx(entity))).isEqualTo("o.country <= 'AT'"); + assertThat(where.memberOf("'AT'").render(ctx(entity))).isEqualTo("'AT' MEMBER OF o.country"); + // TODO: can we have this where.value(foo).memberOf(pathAndOrigin); + assertThat(where.notMemberOf("'AT'").render(ctx(entity))).isEqualTo("'AT' NOT MEMBER OF o.country"); + assertThat(where.neq("'AT'").render(ctx(entity))).isEqualTo("o.country != 'AT'"); + } + + @Test + void selectRendering() { + + // make sure things are immutable + SelectStep select = JpqlQueryBuilder.selectFrom(Order.class); // the select step is mutable - not sure i like it + // hm, I somehow exepect this to render only the selection part + assertThat(select.count().render()).startsWith("SELECT COUNT(o)"); + assertThat(select.distinct().entity().render()).startsWith("SELECT DISTINCT o "); + assertThat(select.distinct().count().render()).startsWith("SELECT COUNT(DISTINCT o) "); + assertThat(JpqlQueryBuilder.selectFrom(Order.class).select(JpqlQueryBuilder.path(JpqlQueryBuilder.entity(Order.class), "country")).render()) + .startsWith("SELECT o.country "); + } + +// @Test +// void sorting() { +// +// JpqlQueryBuilder.orderBy(new OrderExpression() , Sort.Order.asc("country")); +// +// Entity entity = JpqlQueryBuilder.entity(Order.class); +// +// AbstractJpqlQuery query = JpqlQueryBuilder.selectFrom(Order.class) +// .entity() +// .orderBy() +// .where(context -> "1 = 1"); +// +// } + + @Test + void joins() { + + Entity entity = JpqlQueryBuilder.entity(LineItem.class); + Join li_pr = JpqlQueryBuilder.innerJoin(entity, "product"); + Join li_pr2 = JpqlQueryBuilder.innerJoin(entity, "product2"); + + PathAndOrigin productName = JpqlQueryBuilder.path(li_pr, "name"); + PathAndOrigin personName = JpqlQueryBuilder.path(li_pr2, "name"); + + String fragment = JpqlQueryBuilder.where(productName).eq(JpqlQueryBuilder.stringLiteral("ex30")) + .and(JpqlQueryBuilder.where(personName).eq(JpqlQueryBuilder.stringLiteral("ex40"))).render(ctx(entity)); + + assertThat(fragment).isEqualTo("p.name = 'ex30' AND join_0.name = 'ex40'"); + } + + @Test + void x2() { + + Entity entity = JpqlQueryBuilder.entity(LineItem.class); + Join li_pr = JpqlQueryBuilder.innerJoin(entity, "product"); + Join li_pe = JpqlQueryBuilder.innerJoin(entity, "person"); + + PathAndOrigin productName = JpqlQueryBuilder.path(li_pr, "name"); + PathAndOrigin personName = JpqlQueryBuilder.path(li_pe, "name"); + + String fragment = JpqlQueryBuilder.where(productName).eq(JpqlQueryBuilder.stringLiteral("ex30")) + .and(JpqlQueryBuilder.where(personName).eq(JpqlQueryBuilder.stringLiteral("cstrobl"))).render(ctx(entity)); + + assertThat(fragment).isEqualTo("p.name = 'ex30' AND join_0.name = 'cstrobl'"); + } + + @Test + void x3() { + + Entity entity = JpqlQueryBuilder.entity(LineItem.class); + Join li_pr = JpqlQueryBuilder.innerJoin(entity, "product"); + Join li_pe = JpqlQueryBuilder.innerJoin(entity, "person"); + + PathAndOrigin productName = JpqlQueryBuilder.path(li_pr, "name"); + PathAndOrigin personName = JpqlQueryBuilder.path(li_pe, "name"); + + // JpqlQueryBuilder.and("x = y", "a = b"); -> x = y AND a = b + + // JpqlQueryBuilder.nested(JpqlQueryBuilder.and("x = y", "a = b")) (x = y AND a = b) + + String fragment = JpqlQueryBuilder.where(productName).eq(JpqlQueryBuilder.stringLiteral("ex30")) + .and(JpqlQueryBuilder.where(personName).eq(JpqlQueryBuilder.stringLiteral("cstrobl"))).render(ctx(entity)); + + assertThat(fragment).isEqualTo("p.name = 'ex30' AND join_0.name = 'cstrobl'"); + } + + static RenderContext ctx(Entity... entities) { + Map aliases = new LinkedHashMap<>(entities.length); + for (Entity entity : entities) { + aliases.put(entity, entity.alias()); + } + + return new RenderContext(aliases); + } + + @jakarta.persistence.Entity + static class Order { + + @Id Long id; + Date date; + String country; + + @OneToMany List lineItems; + } + + @jakarta.persistence.Entity + static class LineItem { + + @Id Long id; + + @ManyToOne Product product; + @ManyToOne Product product2; + @ManyToOne Product person; + + } + + @jakarta.persistence.Entity + static class Person { + @Id Long id; + String name; + } + + @jakarta.persistence.Entity + static class Product { + + @Id Long id; + + String name; + String productType; + + } +} diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/PartTreeJpaQueryIntegrationTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/PartTreeJpaQueryIntegrationTests.java index fe8f239a60..552006172e 100644 --- a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/PartTreeJpaQueryIntegrationTests.java +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/PartTreeJpaQueryIntegrationTests.java @@ -112,7 +112,7 @@ void recreatesQueryIfNullValueIsGiven(String criteria) throws Exception { Query query = jpaQuery.createQuery(getAccessor(queryMethod, new Object[] { "Matthews", PageRequest.of(0, 1) })); assertThat(HibernateUtils.getHibernateQuery(query.unwrap(HIBERNATE_NATIVE_QUERY))) - .contains("firstname %s :".formatted(criteria.endsWith("Not") ? "<>" : "=")); + .contains("firstname %s ?".formatted(criteria.endsWith("Not") ? "!=" : "=")); query = jpaQuery.createQuery(getAccessor(queryMethod, new Object[] { null, PageRequest.of(0, 1) })); diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/PartTreeQueryCacheUnitTests.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/PartTreeQueryCacheUnitTests.java new file mode 100644 index 0000000000..aa3911473f --- /dev/null +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/PartTreeQueryCacheUnitTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.repository.query; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.function.Supplier; +import java.util.stream.Stream; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.FieldSource; +import org.mockito.Mockito; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Sort.Direction; + +/** + * @author Christoph Strobl + */ +public class PartTreeQueryCacheUnitTests { + + PartTreeQueryCache cache; + + static Supplier> cacheInput = () -> Stream.of( + Arguments.arguments(Sort.unsorted(), StubJpaParameterParameterAccessor.accessor()), // + Arguments.arguments(Sort.by(Direction.ASC, "one"), StubJpaParameterParameterAccessor.accessor()), // + Arguments.arguments(Sort.by(Direction.DESC, "one"), StubJpaParameterParameterAccessor.accessor()), // + Arguments.arguments(Sort.unsorted(), + StubJpaParameterParameterAccessor.accessorFor(String.class).withValues("value")), // + Arguments.arguments(Sort.unsorted(), + StubJpaParameterParameterAccessor.accessorFor(String.class).withValues(new Object[] { null })), // + Arguments.arguments(Sort.by(Direction.ASC, "one"), + StubJpaParameterParameterAccessor.accessorFor(String.class).withValues("value")), // + Arguments.arguments(Sort.by(Direction.ASC, "one"), + StubJpaParameterParameterAccessor.accessorFor(String.class).withValues(new Object[] { null }))); + + @BeforeEach + void beforeEach() { + cache = new PartTreeQueryCache(); + } + + @ParameterizedTest + @FieldSource("cacheInput") + void getReturnsNullForEmptyCache(Sort sort, JpaParametersParameterAccessor accessor) { + assertThat(cache.get(sort, accessor)).isNull(); + } + + @ParameterizedTest + @FieldSource("cacheInput") + void getReturnsCachedInstance(Sort sort, JpaParametersParameterAccessor accessor) { + + JpaQueryCreator queryCreator = Mockito.mock(JpaQueryCreator.class); + + assertThat(cache.put(sort, accessor, queryCreator)).isNull(); + assertThat(cache.get(sort, accessor)).isSameAs(queryCreator); + } + + @ParameterizedTest + @FieldSource("cacheInput") + void cacheGetWithSort(Sort sort, JpaParametersParameterAccessor accessor) { + + JpaQueryCreator queryCreator = Mockito.mock(JpaQueryCreator.class); + assertThat(cache.put(Sort.by("not-in-cache"), accessor, queryCreator)).isNull(); + + assertThat(cache.get(sort, accessor)).isNull(); + } + + @ParameterizedTest + @FieldSource("cacheInput") + void cacheGetWithccessor(Sort sort, JpaParametersParameterAccessor accessor) { + + JpaQueryCreator queryCreator = Mockito.mock(JpaQueryCreator.class); + assertThat(cache.put(sort, StubJpaParameterParameterAccessor.accessor("spring", "data"), queryCreator)).isNull(); + + assertThat(cache.get(sort, accessor)).isNull(); + } + + @Test + void cachesOnNullableNotArgumentType() { + + JpaQueryCreator queryCreator = Mockito.mock(JpaQueryCreator.class); + Sort sort = Sort.unsorted(); + assertThat(cache.put(sort, StubJpaParameterParameterAccessor.accessor("spring", "data"), queryCreator)).isNull(); + + assertThat(cache.get(sort, + StubJpaParameterParameterAccessor.accessor(new Class[] { String.class, String.class }, "spring", null))) + .isNull(); + + assertThat(cache.get(sort, + StubJpaParameterParameterAccessor.accessor(new Class[] { String.class, String.class }, null, "data"))).isNull(); + + assertThat(cache.get(sort, + StubJpaParameterParameterAccessor.accessor(new Class[] { String.class, String.class }, "data", "spring"))) + .isSameAs(queryCreator); + + assertThat(cache.get(Sort.by("not-in-cache"), + StubJpaParameterParameterAccessor.accessor(new Class[] { String.class, String.class }, "data", "spring"))) + .isNull(); + } + +} diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/StubJpaParameterParameterAccessor.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/StubJpaParameterParameterAccessor.java new file mode 100644 index 0000000000..c5794c9644 --- /dev/null +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/StubJpaParameterParameterAccessor.java @@ -0,0 +1,93 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.repository.query; + +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.mockito.Mockito; +import org.springframework.core.MethodParameter; +import org.springframework.data.jpa.repository.query.JpaParameters.JpaParameter; +import org.springframework.data.util.TypeInformation; + +/** + * @author Christoph Strobl + */ +public class StubJpaParameterParameterAccessor extends JpaParametersParameterAccessor { + + private StubJpaParameterParameterAccessor(JpaParameters parameters, Object[] values) { + super(parameters, values); + } + + static JpaParametersParameterAccessor accessor(Object... values) { + + Class[] parameterTypes = Arrays.stream(values).map(it -> it != null ? it.getClass() : Object.class) + .toArray(Class[]::new); + return accessor(parameterTypes, values); + } + + static JpaParametersParameterAccessor accessor(Class... parameterTypes) { + return accessor(parameterTypes, new Object[parameterTypes.length]); + } + + static AccessorBuilder accessorFor(Class... parameterTypes) { + return arguments -> accessor(parameterTypes, arguments); + + } + + interface AccessorBuilder { + JpaParametersParameterAccessor withValues(Object... arguments); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + static JpaParametersParameterAccessor accessor(Class[] parameterTypes, Object... parameters) { + + List parametersList = new ArrayList<>(parameterTypes.length); + List valueList = new ArrayList<>(parameterTypes.length); + + for (int i = 0; i < parameterTypes.length; i++) { + + if (i < parameters.length) { + valueList.add(parameters[i]); + } + + Class parameterType = parameterTypes[i]; + MethodParameter mock = Mockito.mock(MethodParameter.class); + when(mock.getParameterType()).thenReturn((Class) parameterType); + JpaParameter parameter = new JpaParameter(mock, TypeInformation.of(parameterType)); + parametersList.add(parameter); + } + + return new StubJpaParameterParameterAccessor(new JpaParameters(parametersList), valueList.toArray()); + } + + @Override + public String toString() { + List parameters = new ArrayList<>(getParameters().getNumberOfParameters()); + + for (int i = 0; i < getParameters().getNumberOfParameters(); i++) { + Object value = getValue(i); + if (value == null) { + value = "null"; + } + parameters.add("%s: %s (%s)".formatted(i, value, getParameters().getParameter(i).getType().getSimpleName())); + } + return "%s".formatted(parameters); + } +} diff --git a/spring-data-jpa/src/test/java/org/springframework/data/jpa/util/TestMetaModel.java b/spring-data-jpa/src/test/java/org/springframework/data/jpa/util/TestMetaModel.java new file mode 100644 index 0000000000..a755ba222b --- /dev/null +++ b/spring-data-jpa/src/test/java/org/springframework/data/jpa/util/TestMetaModel.java @@ -0,0 +1,119 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jpa.util; + +import jakarta.persistence.EntityManager; +import jakarta.persistence.EntityManagerFactory; +import jakarta.persistence.metamodel.EmbeddableType; +import jakarta.persistence.metamodel.EntityType; +import jakarta.persistence.metamodel.ManagedType; +import jakarta.persistence.metamodel.Metamodel; +import jakarta.persistence.spi.ClassTransformer; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.hibernate.jpa.HibernatePersistenceProvider; +import org.hibernate.jpa.boot.internal.EntityManagerFactoryBuilderImpl; +import org.hibernate.jpa.boot.internal.PersistenceUnitInfoDescriptor; +import org.springframework.data.util.Lazy; +import org.springframework.instrument.classloading.SimpleThrowawayClassLoader; +import org.springframework.orm.jpa.persistenceunit.MutablePersistenceUnitInfo; + +/** + * @author Christoph Strobl + */ +public class TestMetaModel implements Metamodel { + + private final String persistenceUnit; + private final Set> managedTypes; + private final Lazy entityManagerFactory = Lazy.of(this::init); + private final Lazy metamodel = Lazy.of(() -> entityManagerFactory.get().getMetamodel()); + private Lazy enityManager = Lazy.of(() -> entityManagerFactory.get().createEntityManager()); + + TestMetaModel(Set> managedTypes) { + this("dynamic-tests", managedTypes); + } + + TestMetaModel(String persistenceUnit, Set> managedTypes) { + this.persistenceUnit = persistenceUnit; + this.managedTypes = managedTypes; + } + + public static TestMetaModel hibernateModel(Class... types) { + return new TestMetaModel(Set.of(types)); + } + + public static TestMetaModel hibernateModel(String persistenceUnit, Class... types) { + return new TestMetaModel(persistenceUnit, Set.of(types)); + } + + public EntityType entity(Class cls) { + return metamodel.get().entity(cls); + } + + public ManagedType managedType(Class cls) { + return metamodel.get().managedType(cls); + } + + public EmbeddableType embeddable(Class cls) { + return metamodel.get().embeddable(cls); + } + + public Set> getManagedTypes() { + return metamodel.get().getManagedTypes(); + } + + public Set> getEntities() { + return metamodel.get().getEntities(); + } + + public Set> getEmbeddables() { + return metamodel.get().getEmbeddables(); + } + + public EntityManager entityManager() { + return enityManager.get(); + } + + EntityManagerFactory init() { + + MutablePersistenceUnitInfo persistenceUnitInfo = new MutablePersistenceUnitInfo() { + @Override + public ClassLoader getNewTempClassLoader() { + return new SimpleThrowawayClassLoader(this.getClass().getClassLoader()); + } + + @Override + public void addTransformer(ClassTransformer classTransformer) { + // just ingnore it + } + }; + + persistenceUnitInfo.setPersistenceUnitName(persistenceUnit); + this.managedTypes.stream().map(Class::getName).forEach(persistenceUnitInfo::addManagedClassName); + + persistenceUnitInfo.setPersistenceProviderClassName(HibernatePersistenceProvider.class.getName()); + + return new EntityManagerFactoryBuilderImpl(new PersistenceUnitInfoDescriptor(persistenceUnitInfo) { + @Override + public List getManagedClassNames() { + return persistenceUnitInfo.getManagedClassNames(); + } + }, Map.of("hibernate.dialect", "org.hibernate.dialect.H2Dialect")).build(); + } +} diff --git a/spring-data-jpa/src/test/resources/META-INF/persistence.xml b/spring-data-jpa/src/test/resources/META-INF/persistence.xml index 1c3be472e0..4f904373c3 100644 --- a/spring-data-jpa/src/test/resources/META-INF/persistence.xml +++ b/spring-data-jpa/src/test/resources/META-INF/persistence.xml @@ -102,6 +102,14 @@ + + org.hibernate.jpa.HibernatePersistenceProvider + true + + + + +