Skip to content

Commit

Permalink
Interceptors set through @InInterceptors and similar annotations are …
Browse files Browse the repository at this point in the history
…not looked up the CDI container, fix #1367
  • Loading branch information
ppalaga committed May 6, 2024
1 parent 3deaedf commit a66b9d8
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,28 @@ public static <T> T getInstance(String beanRef, boolean namedBeansSupported) {

final Class<T> classObj = (Class<T>) loadClass(beanRef);
Objects.requireNonNull(classObj, "Could not load class " + beanRef);
return getInstance(classObj);
}

/**
* @param <T> a type to which the returned bean can be casted
* @param beanClass the type to look up in the CDI container or create via reflection
* @return an instance of a Bean
*/
public static <T> T getInstance(Class<? extends T> beanClass) {
try {
return CDI.current().select(classObj).get();
return CDI.current().select(beanClass).get();
} catch (UnsatisfiedResolutionException e) {
// silent fail
}
try {
return classObj.getConstructor().newInstance();
return beanClass.getConstructor().newInstance();
} catch (NoSuchMethodException e) {
throw new RuntimeException("Could not instantiate " + beanRef
throw new RuntimeException("Could not instantiate " + beanClass.getName()
+ " using the default constructor. Make sure that the constructor exists and that the class is static in case it is an inner class.",
e);
} catch (ReflectiveOperationException | RuntimeException e) {
throw new RuntimeException("Could not instantiate " + beanRef + " using the default constructor.", e);
throw new RuntimeException("Could not instantiate " + beanClass.getName() + " using the default constructor.", e);
}
}

Expand All @@ -60,6 +69,24 @@ public static <T> T getInstance(String beanRef, String beanKind, String sei, Str
}
}

public static <T> T getInstance(Class<? extends T> beanClass, String beanKind, String sei, String clientOrEndpoint) {
try {
return getInstance(beanClass);
} catch (AmbiguousResolutionException e) {
/*
* There are multiple beans of this type
* and we do not know which one to use
*/
throw new IllegalStateException("Unable to add a " + beanKind + " to CXF " + clientOrEndpoint + " " + sei + ":"
+ " there are multiple instances of " + beanClass.getName() + " available in the CDI container."
+ " Either make sure there is only one instance available in the container"
+ " or create a unique subtype of " + beanClass.getName() + " and set that one on " + sei
+ " or add @jakarta.inject.Named(\"myName\") to some of the beans and refer to that bean by #myName on "
+ sei,
e);
}
}

public static <T> void addBeans(List<String> beanRefs, String beanKind, String sei, String clientOrEndpoint,
List<T> destination) {
for (String beanRef : beanRefs) {
Expand All @@ -72,6 +99,19 @@ public static <T> void addBeans(List<String> beanRefs, String beanKind, String s
}
}

public static <T> void addBeansByType(List<Class<? extends T>> beanTypes, String beanKind, String sei,
String clientOrEndpoint,
List<T> destination) {
for (Class<? extends T> beanType : beanTypes) {
T item = getInstance(beanType, beanKind, sei, clientOrEndpoint);
if (item == null) {
throw new IllegalStateException("Could not lookup bean " + beanType.getName());
} else {
destination.add(item);
}
}
}

private static Class<?> loadClass(String className) {
try {
return Thread.currentThread().getContextClassLoader().loadClass(className);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package io.quarkiverse.cxf;

import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.cxf.endpoint.Endpoint;
import org.apache.cxf.feature.Feature;
import org.apache.cxf.feature.Features;
import org.apache.cxf.interceptor.AnnotationInterceptors;
import org.apache.cxf.interceptor.InFaultInterceptors;
import org.apache.cxf.interceptor.InInterceptors;
import org.apache.cxf.interceptor.Interceptor;
import org.apache.cxf.interceptor.OutFaultInterceptors;
import org.apache.cxf.interceptor.OutInterceptors;
import org.apache.cxf.jaxws.JaxWsServerFactoryBean;
import org.apache.cxf.jaxws.support.JaxWsServiceFactoryBean;
import org.apache.cxf.message.Message;

/**
* A JaxWsServerFactoryBean allowing to look up <code>@InInterceptors</code> in the CDI container.
*/
public class QuarkusJaxWsServerFactoryBean extends JaxWsServerFactoryBean {

private final String endpointString;

public QuarkusJaxWsServerFactoryBean(JaxWsServiceFactoryBean serviceFactory, String endpointString) {
super(serviceFactory);
this.endpointString = endpointString;
}

@Override
protected void initializeAnnotationInterceptors(Endpoint ep, Class<?>... cls) {
final Class<?> seiClass = ((JaxWsServiceFactoryBean) getServiceFactory())
.getJaxWsImplementorInfo().getSEIClass();
if (seiClass != null) {
boolean found = false;
for (Class<?> c : cls) {
if (c.equals(seiClass)) {
found = true;
}
}
if (!found) {
Class<?>[] cls2 = new Class<?>[cls.length + 1];
System.arraycopy(cls, 0, cls2, 0, cls.length);
cls2[cls.length] = seiClass;
cls = cls2;
}
}
final AnnotationInterceptors provider = new QuarkusAnnotationInterceptors(
((JaxWsServiceFactoryBean) getServiceFactory()).getJaxWsImplementorInfo().getImplementorClass().getName(),
endpointString,
cls);
initializeAnnotationInterceptors(provider, ep);
}

@Override
protected boolean initializeAnnotationInterceptors(AnnotationInterceptors provider, Endpoint ep) {
boolean hasAnnotation = false;
final List<Interceptor<? extends Message>> inFaultInterceptors = provider.getInFaultInterceptors();
if (inFaultInterceptors != null) {
ep.getInFaultInterceptors().addAll(inFaultInterceptors);
hasAnnotation = true;
}
final List<Interceptor<? extends Message>> inInterceptors = provider.getInInterceptors();
if (inInterceptors != null) {
ep.getInInterceptors().addAll(inInterceptors);
hasAnnotation = true;
}
final List<Interceptor<? extends Message>> outFaultInterceptors = provider.getOutFaultInterceptors();
if (outFaultInterceptors != null) {
ep.getOutFaultInterceptors().addAll(outFaultInterceptors);
hasAnnotation = true;
}
final List<Interceptor<? extends Message>> outInterceptors = provider.getOutInterceptors();
if (outInterceptors != null) {
ep.getOutInterceptors().addAll(outInterceptors);
hasAnnotation = true;
}
final List<Feature> features2 = provider.getFeatures();
if (features2 != null) {
getFeatures().addAll(features2);
hasAnnotation = true;
}

return hasAnnotation;
}

static class QuarkusAnnotationInterceptors extends AnnotationInterceptors {

private final Class<?>[] clazzes;
private final String implementorClass;
private final String endpointString;

public QuarkusAnnotationInterceptors(String implementorClass, String endpointString, Class<?>... clz) {
this.implementorClass = implementorClass;
this.endpointString = endpointString;
this.clazzes = clz;
}

private <T> List<T> getAnnotationObject(Class<? extends Annotation> annotationClazz, Class<T> type) {

for (Class<?> cls : clazzes) {
Annotation annotation = cls.getAnnotation(annotationClazz);
if (annotation != null) {
return initializeAnnotationObjects(annotation, type);
}
}
return null;
}

private <T> List<T> initializeAnnotationObjects(Annotation annotation,
Class<T> type) {
final List<T> result = new ArrayList<>();

CXFRuntimeUtils.addBeansByType(
Arrays.asList(getAnnotationObjectClasses(annotation, type)),
type.getName(),
implementorClass,
endpointString,
result);
CXFRuntimeUtils.addBeans(
Arrays.asList(getAnnotationObjectNames(annotation)),
type.getName(),
implementorClass,
endpointString,
result);

return result;
}

@SuppressWarnings("unchecked")
private <T> Class<? extends T>[] getAnnotationObjectClasses(Annotation ann, Class<T> type) { //NOPMD
if (ann instanceof InFaultInterceptors) {
return (Class<? extends T>[]) ((InFaultInterceptors) ann).classes();
} else if (ann instanceof InInterceptors) {
return (Class<? extends T>[]) ((InInterceptors) ann).classes();
} else if (ann instanceof OutFaultInterceptors) {
return (Class<? extends T>[]) ((OutFaultInterceptors) ann).classes();
} else if (ann instanceof OutInterceptors) {
return (Class<? extends T>[]) ((OutInterceptors) ann).classes();
} else if (ann instanceof Features) {
return (Class<? extends T>[]) ((Features) ann).classes();
}
throw new UnsupportedOperationException("Doesn't support the annotation: " + ann);
}

private String[] getAnnotationObjectNames(Annotation ann) {
if (ann instanceof InFaultInterceptors) {
return ((InFaultInterceptors) ann).interceptors();
} else if (ann instanceof InInterceptors) {
return ((InInterceptors) ann).interceptors();
} else if (ann instanceof OutFaultInterceptors) {
return ((OutFaultInterceptors) ann).interceptors();
} else if (ann instanceof OutInterceptors) {
return ((OutInterceptors) ann).interceptors();
} else if (ann instanceof Features) {
return ((Features) ann).features();
}

throw new UnsupportedOperationException("Doesn't support the annotation: " + ann);
}

private List<Interceptor<? extends Message>> getAnnotationInterceptorList(Class<? extends Annotation> t) {
@SuppressWarnings("rawtypes")
List<Interceptor> i = getAnnotationObject(t, Interceptor.class);
if (i == null) {
return null;
}
List<Interceptor<? extends Message>> m = new ArrayList<>();
for (Interceptor<?> i2 : i) {
m.add(i2);
}
return m;
}

@Override
public List<Interceptor<? extends Message>> getInFaultInterceptors() {
return getAnnotationInterceptorList(InFaultInterceptors.class);
}

@Override
public List<Interceptor<? extends Message>> getInInterceptors() {
return getAnnotationInterceptorList(InInterceptors.class);
}

@Override
public List<Interceptor<? extends Message>> getOutFaultInterceptors() {
return getAnnotationInterceptorList(OutFaultInterceptors.class);
}

@Override
public List<Interceptor<? extends Message>> getOutInterceptors() {
return getAnnotationInterceptorList(OutInterceptors.class);
}

@Override
public List<Feature> getFeatures() {
return getAnnotationObject(Features.class, Feature.class);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.quarkiverse.cxf.CXFServletInfos;
import io.quarkiverse.cxf.CxfConfig;
import io.quarkiverse.cxf.CxfFixedConfig;
import io.quarkiverse.cxf.QuarkusJaxWsServerFactoryBean;
import io.quarkiverse.cxf.QuarkusRuntimeJaxWsServiceFactoryBean;
import io.quarkiverse.cxf.auth.AuthFaultOutInterceptor;
import io.quarkiverse.cxf.logging.LoggingFactoryCustomizer;
Expand Down Expand Up @@ -103,8 +104,10 @@ public CxfHandler(CXFServletInfos cxfServletInfos, BeanContainer beanContainer,

// suboptimal because done it in loop but not a real issue...
for (CXFServletInfo servletInfo : cxfServletInfos.getInfos()) {
final String endpointString = "endpoint " + servletInfo.getPath();
QuarkusRuntimeJaxWsServiceFactoryBean jaxWsServiceFactoryBean = new QuarkusRuntimeJaxWsServiceFactoryBean();
JaxWsServerFactoryBean jaxWsServerFactoryBean = new JaxWsServerFactoryBean(jaxWsServiceFactoryBean);
JaxWsServerFactoryBean jaxWsServerFactoryBean = new QuarkusJaxWsServerFactoryBean(jaxWsServiceFactoryBean,
endpointString);
jaxWsServerFactoryBean.setDestinationFactory(destinationFactory);
jaxWsServerFactoryBean.setBus(bus);
jaxWsServerFactoryBean.setProperties(new LinkedHashMap<>());
Expand All @@ -125,7 +128,6 @@ public CxfHandler(CXFServletInfos cxfServletInfos, BeanContainer beanContainer,
if (servletInfo.getWsdlPath() != null) {
jaxWsServerFactoryBean.setWsdlLocation(servletInfo.getWsdlPath());
}
final String endpointString = "endpoint " + servletInfo.getPath();
CXFRuntimeUtils.addBeans(servletInfo.getFeatures(), "feature", endpointString, endpointType,
jaxWsServerFactoryBean.getFeatures());
CXFRuntimeUtils.addBeans(servletInfo.getHandlers(), "handler", endpointString, endpointType,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* JBoss, Home of Professional Open Source
* Copyright 2015, Red Hat, Inc. and/or its affiliates, and individual
* contributors by the @authors tag. See the copyright.txt in the
* distribution for a full listing of individual contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.quarkiverse.cxf.it.wss.server;

import jakarta.jws.WebService;

import org.apache.cxf.interceptor.InInterceptors;

/**
* The implementation of the {@link WssRounderService}
*/
@InInterceptors(classes = { org.apache.cxf.ws.security.wss4j.WSS4JInInterceptor.class })
@WebService(serviceName = "WssRounderService", portName = "WssRounderService", name = "WssRounderService", endpointInterface = "io.quarkiverse.cxf.it.wss.server.WssRounderService", targetNamespace = WssRounderService.TARGET_NS)
public class AnnotatedWssRounderServiceImpl implements WssRounderService {
@Override
public long round(double a) {
return Math.round(a);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ quarkus.cxf.endpoint."/rounder".implementor = io.quarkiverse.cxf.it.wss.server.W
quarkus.cxf.endpoint."/rounder".in-interceptors = org.apache.cxf.ws.security.wss4j.WSS4JInInterceptor
# end::quarkus-cxf-rt-ws-security.adoc[]

quarkus.cxf.endpoint."/annotated-rounder".implementor = io.quarkiverse.cxf.it.wss.server.AnnotatedWssRounderServiceImpl

quarkus.cxf.endpoint."/security-policy-hello".implementor = io.quarkiverse.cxf.it.wss.server.policy.WssSecurityPolicyHelloServiceImpl

quarkus.native.resources.includes = saml-keystore.jks,server/*
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import org.eclipse.microprofile.config.Config;
import org.eclipse.microprofile.config.ConfigProvider;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import io.quarkiverse.cxf.test.QuarkusCxfClientTestUtil;
import io.quarkus.test.common.QuarkusTestResource;
Expand All @@ -37,14 +39,18 @@ void anonymous() throws IOException {

}

@Test
void usernameToken() throws IOException {
@ParameterizedTest
@ValueSource(strings = {
"rounder",
"annotated-rounder" })
void usernameToken(String endpointRelPath) throws IOException {

final Config config = ConfigProvider.getConfig();
final String username = config.getValue("wss.username", String.class);
final String password = config.getValue("wss.password", String.class);

final WssRounderService client = QuarkusCxfClientTestUtil.getClient(WssRounderService.class, "/soap/rounder");
final WssRounderService client = QuarkusCxfClientTestUtil.getClient(WssRounderService.class,
"/soap/" + endpointRelPath);

final CallbackHandler passwordCallback = new CallbackHandler() {
@Override
Expand Down

0 comments on commit a66b9d8

Please sign in to comment.