diff --git a/rocketmq-spring-boot/src/main/java/org/apache/rocketmq/spring/support/DefaultRocketMQListenerContainer.java b/rocketmq-spring-boot/src/main/java/org/apache/rocketmq/spring/support/DefaultRocketMQListenerContainer.java index fb7762e9..9460dee9 100644 --- a/rocketmq-spring-boot/src/main/java/org/apache/rocketmq/spring/support/DefaultRocketMQListenerContainer.java +++ b/rocketmq-spring-boot/src/main/java/org/apache/rocketmq/spring/support/DefaultRocketMQListenerContainer.java @@ -20,9 +20,11 @@ import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; import java.nio.charset.Charset; import java.util.List; import java.util.Objects; + import org.apache.rocketmq.client.AccessChannel; import org.apache.rocketmq.client.consumer.DefaultMQPushConsumer; import org.apache.rocketmq.client.consumer.MessageSelector; @@ -58,6 +60,7 @@ import org.springframework.context.ApplicationContextAware; import org.springframework.context.SmartLifecycle; import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.converter.MessageConversionException; @@ -66,6 +69,7 @@ import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.MimeTypeUtils; +import org.springframework.util.ReflectionUtils; @SuppressWarnings("WeakerAccess") public class DefaultRocketMQListenerContainer implements InitializingBean, @@ -538,6 +542,8 @@ private Object doConvertMessage(MessageExt messageExt) { if (messageType instanceof Class) { //if the messageType has not Generic Parameter return this.getMessageConverter().fromMessage(MessageBuilder.withPayload(str).build(), (Class) messageType); + } else if (messageType instanceof TypeVariable) { + return this.getMessageConverter().fromMessage(MessageBuilder.withPayload(str).build(), Object.class); } else { //if the messageType has Generic Parameter, then use SmartMessageConverter#fromMessage with third parameter "conversionHint". //we have validate the MessageConverter is SmartMessageConverter in this#getMethodParameter. @@ -553,61 +559,37 @@ private Object doConvertMessage(MessageExt messageExt) { private MethodParameter getMethodParameter() { Class targetClass; + Class consumerInterface; if (rocketMQListener != null) { targetClass = AopProxyUtils.ultimateTargetClass(rocketMQListener); + consumerInterface = RocketMQListener.class; } else { targetClass = AopProxyUtils.ultimateTargetClass(rocketMQReplyListener); + consumerInterface = RocketMQReplyListener.class; } - Type messageType = this.getMessageType(); - Class clazz = null; - if (messageType instanceof ParameterizedType && messageConverter instanceof SmartMessageConverter) { - clazz = (Class) ((ParameterizedType) messageType).getRawType(); - } else if (messageType instanceof Class) { - clazz = (Class) messageType; - } else { - throw new RuntimeException("parameterType:" + messageType + " of onMessage method is not supported"); - } - try { - final Method method = targetClass.getMethod("onMessage", clazz); - return new MethodParameter(method, 0); - } catch (NoSuchMethodException e) { - e.printStackTrace(); - throw new RuntimeException("parameterType:" + messageType + " of onMessage method is not supported"); - } + ResolvableType resolvableType = ResolvableType.forClass(targetClass).as(consumerInterface); + Class methodParameterType = resolvableType.getGeneric().resolve(Object.class); + Method onMessage = ReflectionUtils.findMethod(targetClass, "onMessage", methodParameterType); + return MethodParameter.forExecutable(onMessage, 0); } + private Type getMessageType() { Class targetClass; + Class consumerInterface; if (rocketMQListener != null) { targetClass = AopProxyUtils.ultimateTargetClass(rocketMQListener); + consumerInterface = RocketMQListener.class; } else { targetClass = AopProxyUtils.ultimateTargetClass(rocketMQReplyListener); + consumerInterface = RocketMQReplyListener.class; } - Type matchedGenericInterface = null; - while (Objects.nonNull(targetClass)) { - Type[] interfaces = targetClass.getGenericInterfaces(); - if (Objects.nonNull(interfaces)) { - for (Type type : interfaces) { - if (type instanceof ParameterizedType && - (Objects.equals(((ParameterizedType) type).getRawType(), RocketMQListener.class) || Objects.equals(((ParameterizedType) type).getRawType(), RocketMQReplyListener.class))) { - matchedGenericInterface = type; - break; - } - } - } - targetClass = targetClass.getSuperclass(); - } - if (Objects.isNull(matchedGenericInterface)) { - return Object.class; - } - - Type[] actualTypeArguments = ((ParameterizedType) matchedGenericInterface).getActualTypeArguments(); - if (Objects.nonNull(actualTypeArguments) && actualTypeArguments.length > 0) { - return actualTypeArguments[0]; - } - return Object.class; + ResolvableType resolvableType = ResolvableType.forClass(targetClass).as(consumerInterface); + Type messageType = resolvableType.getGeneric().getType(); + return messageType; } + private void initRocketMQPushConsumer() throws MQClientException { if (rocketMQListener == null && rocketMQReplyListener == null) { throw new IllegalArgumentException("Property 'rocketMQListener' or 'rocketMQReplyListener' is required"); diff --git a/rocketmq-spring-boot/src/test/java/org/apache/rocketmq/spring/support/DefaultRocketMQListenerContainerTest.java b/rocketmq-spring-boot/src/test/java/org/apache/rocketmq/spring/support/DefaultRocketMQListenerContainerTest.java index de15fcdf..2b07d3e0 100644 --- a/rocketmq-spring-boot/src/test/java/org/apache/rocketmq/spring/support/DefaultRocketMQListenerContainerTest.java +++ b/rocketmq-spring-boot/src/test/java/org/apache/rocketmq/spring/support/DefaultRocketMQListenerContainerTest.java @@ -48,6 +48,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -65,7 +66,7 @@ public void onMessage(String message) { } }); Class result = (Class) getMessageType.invoke(listenerContainer); - assertThat(result.getName().equals(String.class.getName())); + assertEquals(result, String.class); //support message listenerContainer.setRocketMQListener(new RocketMQListener() { @@ -74,7 +75,7 @@ public void onMessage(Message message) { } }); result = (Class) getMessageType.invoke(listenerContainer); - assertThat(result.getName().equals(Message.class.getName())); + assertEquals(result, Message.class); listenerContainer.setRocketMQListener(new RocketMQListener() { @Override @@ -82,8 +83,9 @@ public void onMessage(MessageExt message) { } }); result = (Class) getMessageType.invoke(listenerContainer); - assertThat(result.getName().equals(MessageExt.class.getName())); + assertEquals(result, MessageExt.class); + listenerContainer.setRocketMQListener(null); listenerContainer.setRocketMQReplyListener(new RocketMQReplyListener() { @Override @@ -92,7 +94,7 @@ public String onMessage(MessageExt message) { } }); result = (Class) getMessageType.invoke(listenerContainer); - assertThat(result.getName().equals(MessageExt.class.getName())); + assertEquals(result, MessageExt.class); listenerContainer.setRocketMQReplyListener(new RocketMQReplyListener() { @Override @@ -101,7 +103,7 @@ public String onMessage(String message) { } }); result = (Class) getMessageType.invoke(listenerContainer); - assertThat(result.getName().equals(String.class.getName())); + assertEquals(result, String.class); } @Test