Skip to content

Commit

Permalink
ArC: improve validation of interceptor method signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
Ladicek committed Feb 15, 2024
1 parent 754f4e4 commit 2b9ae4d
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jakarta.enterprise.inject.UnsatisfiedResolutionException;
import jakarta.enterprise.inject.spi.DefinitionException;
import jakarta.enterprise.inject.spi.DeploymentException;
import jakarta.enterprise.inject.spi.InterceptionType;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
Expand Down Expand Up @@ -710,8 +711,17 @@ static void addImplicitQualifiers(Set<AnnotationInstance> qualifiers) {
}

static List<MethodInfo> getCallbacks(ClassInfo beanClass, DotName annotation, IndexView index) {
InterceptionType interceptionType = null;
if (DotNames.POST_CONSTRUCT.equals(annotation)) {
interceptionType = InterceptionType.POST_CONSTRUCT;
} else if (DotNames.PRE_DESTROY.equals(annotation)) {
interceptionType = InterceptionType.PRE_DESTROY;
} else {
throw new IllegalArgumentException("Unexpected callback annotation: " + annotation);
}

List<MethodInfo> callbacks = new ArrayList<>();
collectCallbacks(beanClass, callbacks, annotation, index, new HashSet<>());
collectCallbacks(beanClass, callbacks, annotation, index, new HashSet<>(), interceptionType);
Collections.reverse(callbacks);
return callbacks;
}
Expand All @@ -729,7 +739,8 @@ static List<MethodInfo> getAroundInvokes(ClassInfo beanClass, BeanDeployment dep
continue;
}
if (store.hasAnnotation(method, DotNames.AROUND_INVOKE)) {
InterceptorInfo.addInterceptorMethod(allMethods, methods, method);
InterceptorInfo.addInterceptorMethod(allMethods, methods, method, InterceptionType.AROUND_INVOKE,
InterceptorPlacement.TARGET_CLASS);
if (++aroundInvokesFound > 1) {
throw new DefinitionException(
"Multiple @AroundInvoke interceptor methods declared on class: " + aClass);
Expand Down Expand Up @@ -1042,24 +1053,18 @@ private static void fetchType(Type type, BeanDeployment beanDeployment) {
}

private static void collectCallbacks(ClassInfo clazz, List<MethodInfo> callbacks, DotName annotation, IndexView index,
Set<String> knownMethods) {
Set<String> knownMethods, InterceptionType interceptionType) {
for (MethodInfo method : clazz.methods()) {
if (method.hasAnnotation(annotation) && !knownMethods.contains(method.name())) {
if (method.returnType().kind() == Kind.VOID && method.parameterTypes().isEmpty()) {
callbacks.add(method);
} else {
// invalid signature - build a meaningful message.
throw new DefinitionException("Invalid signature for the method `" + method + "` from class `"
+ method.declaringClass() + "`. Methods annotated with `" + annotation + "` must return" +
" `void` and cannot have parameters.");
}
InterceptorInfo.validateSignature(method, interceptionType, InterceptorPlacement.TARGET_CLASS);
callbacks.add(method);
}
knownMethods.add(method.name());
}
if (clazz.superName() != null) {
ClassInfo superClass = getClassByName(index, clazz.superName());
if (superClass != null) {
collectCallbacks(superClass, callbacks, annotation, index, knownMethods);
collectCallbacks(superClass, callbacks, annotation, index, knownMethods, interceptionType);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -106,28 +107,32 @@ public class InterceptorInfo extends BeanInfo implements Comparable<InterceptorI
+ aClass);
}
if (store.hasAnnotation(method, DotNames.AROUND_INVOKE)) {
addInterceptorMethod(allMethods, aroundInvokes, method);
addInterceptorMethod(allMethods, aroundInvokes, method, InterceptionType.AROUND_INVOKE,
InterceptorPlacement.INTERCEPTOR_CLASS);
if (++aroundInvokesFound > 1) {
throw new DefinitionException(
"Multiple @AroundInvoke interceptor methods declared on class: " + aClass);
}
}
if (store.hasAnnotation(method, DotNames.AROUND_CONSTRUCT)) {
addInterceptorMethod(allMethods, aroundConstructs, method);
addInterceptorMethod(allMethods, aroundConstructs, method, InterceptionType.AROUND_CONSTRUCT,
InterceptorPlacement.INTERCEPTOR_CLASS);
if (++aroundConstructsFound > 1) {
throw new DefinitionException(
"Multiple @AroundConstruct interceptor methods declared on class: " + aClass);
}
}
if (store.hasAnnotation(method, DotNames.POST_CONSTRUCT)) {
addInterceptorMethod(allMethods, postConstructs, method);
addInterceptorMethod(allMethods, postConstructs, method, InterceptionType.POST_CONSTRUCT,
InterceptorPlacement.INTERCEPTOR_CLASS);
if (++postConstructsFound > 1) {
throw new DefinitionException(
"Multiple @PostConstruct interceptor methods declared on class: " + aClass);
}
}
if (store.hasAnnotation(method, DotNames.PRE_DESTROY)) {
addInterceptorMethod(allMethods, preDestroys, method);
addInterceptorMethod(allMethods, preDestroys, method, InterceptionType.PRE_DESTROY,
InterceptorPlacement.INTERCEPTOR_CLASS);
if (++preDestroysFound > 1) {
throw new DefinitionException(
"Multiple @PreDestroy interceptor methods declared on class: " + aClass);
Expand Down Expand Up @@ -297,8 +302,9 @@ public int compareTo(InterceptorInfo other) {
return getTarget().toString().compareTo(other.getTarget().toString());
}

static void addInterceptorMethod(List<MethodInfo> allMethods, List<MethodInfo> interceptorMethods, MethodInfo method) {
validateSignature(method);
static void addInterceptorMethod(List<MethodInfo> allMethods, List<MethodInfo> interceptorMethods, MethodInfo method,
InterceptionType interceptionType, InterceptorPlacement interceptorPlacement) {
validateSignature(method, interceptionType, interceptorPlacement);
if (!isInterceptorMethodOverriden(allMethods, method)) {
interceptorMethods.add(method);
}
Expand All @@ -319,19 +325,105 @@ static boolean hasInterceptorMethodParameter(MethodInfo method) {
|| method.parameterType(0).name().equals(DotNames.ARC_INVOCATION_CONTEXT));
}

private static MethodInfo validateSignature(MethodInfo method) {
if (!hasInterceptorMethodParameter(method)) {
throw new IllegalStateException(
"An interceptor method must accept exactly one parameter of type jakarta.interceptor.InvocationContext: "
+ method + " declared on " + method.declaringClass());
private enum InterceptorMethodError {
MUST_HAVE_PARAMETER,
MUST_NOT_HAVE_PARAMETER,
WRONG_RETURN_TYPE,
}

static void validateSignature(MethodInfo method, InterceptionType interceptionType,
InterceptorPlacement interceptorPlacement) {
boolean isLifecycleCallback = interceptionType == InterceptionType.AROUND_CONSTRUCT
|| interceptionType == InterceptionType.POST_CONSTRUCT
|| interceptionType == InterceptionType.PRE_DESTROY;

boolean mustHaveParameter = !isLifecycleCallback || interceptorPlacement == InterceptorPlacement.INTERCEPTOR_CLASS;
boolean mustNotHaveParameter = isLifecycleCallback && interceptorPlacement == InterceptorPlacement.TARGET_CLASS;
boolean mayReturnVoid = isLifecycleCallback;
boolean mayReturnObject = !isLifecycleCallback || interceptorPlacement == InterceptorPlacement.INTERCEPTOR_CLASS;

Set<InterceptorMethodError> errors = EnumSet.noneOf(InterceptorMethodError.class);
if (mustHaveParameter && !hasInterceptorMethodParameter(method)) {
errors.add(InterceptorMethodError.MUST_HAVE_PARAMETER);
}
if (mustNotHaveParameter && method.parametersCount() > 0) {
errors.add(InterceptorMethodError.MUST_NOT_HAVE_PARAMETER);
}

boolean wrongReturnType = true;
if (mayReturnVoid && method.returnType().kind().equals(Kind.VOID)) {
wrongReturnType = false;
}
if (mayReturnObject && method.returnType().name().equals(DotNames.OBJECT)) {
wrongReturnType = false;
}
if (wrongReturnType) {
errors.add(InterceptorMethodError.WRONG_RETURN_TYPE);
}
if (!method.returnType().kind().equals(Type.Kind.VOID) &&
!method.returnType().name().equals(DotNames.OBJECT)) {
throw new IllegalStateException(
"The return type of an interceptor method must be java.lang.Object or void: "
+ method + " declared on " + method.declaringClass());

if (!errors.isEmpty()) {
StringBuilder msg = new StringBuilder();
switch (interceptionType) {
case AROUND_CONSTRUCT:
msg.append("@AroundConstruct");
break;
case AROUND_INVOKE:
msg.append("@AroundInvoke");
break;
case POST_CONSTRUCT:
msg.append("@PostConstruct");
break;
case PRE_DESTROY:
msg.append("@PreDestroy");
break;
default:
throw new IllegalArgumentException("Unknown interception type: " + interceptionType);
}
if (isLifecycleCallback) {
msg.append(" lifecycle callback method");
} else {
msg.append(" interceptor method");
}
msg.append(" declared in ");
switch (interceptorPlacement) {
case INTERCEPTOR_CLASS:
msg.append("an interceptor class");
break;
case TARGET_CLASS:
msg.append("a target class");
break;
default:
throw new IllegalArgumentException("Unknown interceptor placement: " + interceptorPlacement);
}
msg.append(" must ");

if (errors.contains(InterceptorMethodError.MUST_HAVE_PARAMETER)) {
msg.append("have exactly one parameter of type jakarta.interceptor.InvocationContext");
} else if (errors.contains(InterceptorMethodError.MUST_NOT_HAVE_PARAMETER)) {
msg.append("have zero parameters");
}

if (errors.contains(InterceptorMethodError.WRONG_RETURN_TYPE)) {
if (errors.contains(InterceptorMethodError.MUST_HAVE_PARAMETER)
|| errors.contains(InterceptorMethodError.MUST_NOT_HAVE_PARAMETER)) {
msg.append(" and must ");
}
msg.append("have a return type of ");
if (mayReturnVoid) {
msg.append("void");
}
if (mayReturnVoid && mayReturnObject) {
msg.append(" or ");
}
if (mayReturnObject) {
msg.append("java.lang.Object");
}
}

msg.append(": ").append(method).append(" declared in ").append(method.declaringClass().name());

throw new DefinitionException(msg.toString());
}
return method;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package io.quarkus.arc.processor;

enum InterceptorPlacement {
INTERCEPTOR_CLASS,
TARGET_CLASS,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package io.quarkus.arc.test.interceptors.illegal;

import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.lang.annotation.Retention;
import java.lang.annotation.Target;

import jakarta.annotation.Priority;
import jakarta.enterprise.inject.spi.DefinitionException;
import jakarta.interceptor.AroundInvoke;
import jakarta.interceptor.Interceptor;
import jakarta.interceptor.InterceptorBinding;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.test.ArcTestContainer;

public class InterceptorWithoutParameterTest {
@RegisterExtension
public ArcTestContainer container = ArcTestContainer.builder()
.beanClasses(MyInterceptor.class, MyInterceptorBinding.class)
.shouldFail()
.build();

@Test
public void trigger() {
Throwable error = container.getFailure();
assertNotNull(error);
assertInstanceOf(DefinitionException.class, error);
assertTrue(error.getMessage().contains(
"@AroundInvoke interceptor method declared in an interceptor class must have exactly one parameter"));
assertTrue(error.getMessage().contains("intercept()"));
assertTrue(error.getMessage().contains("InterceptorWithoutParameterTest$MyInterceptor"));
}

@Target({ TYPE, METHOD, FIELD, PARAMETER })
@Retention(RUNTIME)
@InterceptorBinding
@interface MyInterceptorBinding {
}

@MyInterceptorBinding
@Interceptor
@Priority(1)
static class MyInterceptor {
@AroundInvoke
Object intercept() throws Exception {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkus.arc.test.validation;

import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand All @@ -23,7 +24,11 @@ public class InvalidPostConstructTest {
public void testFailure() {
Throwable error = container.getFailure();
assertNotNull(error);
assertTrue(error instanceof DefinitionException);
assertInstanceOf(DefinitionException.class, error);
assertTrue(error.getMessage().contains(
"@PostConstruct lifecycle callback method declared in a target class must have a return type of void"));
assertTrue(error.getMessage().contains("invalid()"));
assertTrue(error.getMessage().contains("InvalidPostConstructTest$InvalidBean"));
}

@ApplicationScoped
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package io.quarkus.arc.test.validation;

import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import jakarta.annotation.PostConstruct;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.spi.DefinitionException;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

Expand All @@ -23,10 +23,11 @@ public class InvalidPostConstructWithParametersTest {
public void testFailure() {
Throwable error = container.getFailure();
assertNotNull(error);
assertTrue(error instanceof DefinitionException);
Assertions.assertTrue(error.getMessage().contains("invalid(java.lang.String ignored)"));
Assertions.assertTrue(error.getMessage().contains("$InvalidBean"));
Assertions.assertTrue(error.getMessage().contains("PostConstruct"));
assertInstanceOf(DefinitionException.class, error);
assertTrue(error.getMessage().contains(
"@PostConstruct lifecycle callback method declared in a target class must have zero parameters"));
assertTrue(error.getMessage().contains("invalid(java.lang.String ignored)"));
assertTrue(error.getMessage().contains("InvalidPostConstructWithParametersTest$InvalidBean"));
}

@ApplicationScoped
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package io.quarkus.arc.test.validation;

import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import jakarta.annotation.PreDestroy;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.spi.DefinitionException;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

Expand All @@ -24,10 +24,11 @@ public class InvalidPreDestroyTest {
public void testFailure() {
Throwable error = container.getFailure();
assertNotNull(error);
assertTrue(error instanceof DefinitionException);
Assertions.assertTrue(error.getMessage().contains("invalid()"));
Assertions.assertTrue(error.getMessage().contains("$InvalidBean"));
Assertions.assertTrue(error.getMessage().contains("PreDestroy"));
assertInstanceOf(DefinitionException.class, error);
assertTrue(error.getMessage().contains(
"@PreDestroy lifecycle callback method declared in a target class must have a return type of void"));
assertTrue(error.getMessage().contains("invalid()"));
assertTrue(error.getMessage().contains("InvalidPreDestroyTest$InvalidBean"));
}

@ApplicationScoped
Expand Down

0 comments on commit 2b9ae4d

Please sign in to comment.