diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractAutowireCapableBeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractAutowireCapableBeanFactory.java index 12004a5d0e..867330f82b 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractAutowireCapableBeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/AbstractAutowireCapableBeanFactory.java @@ -657,10 +657,10 @@ public abstract class AbstractAutowireCapableBeanFactory extends AbstractBeanFac // If all factory methods have the same return type, return that type. // Can't clearly figure out exact method due to type converting / autowiring! + Class commonType = null; boolean cache = false; int minNrOfArgs = mbd.getConstructorArgumentValues().getArgumentCount(); Method[] candidates = ReflectionUtils.getUniqueDeclaredMethods(factoryClass); - Set> returnTypes = new HashSet>(1); for (Method factoryMethod : candidates) { if (Modifier.isStatic(factoryMethod.getModifiers()) == isStatic && factoryMethod.getName().equals(mbd.getFactoryMethodName()) && @@ -694,7 +694,7 @@ public abstract class AbstractAutowireCapableBeanFactory extends AbstractBeanFac factoryMethod, args, getBeanClassLoader()); if (returnType != null) { cache = true; - returnTypes.add(returnType); + commonType = ClassUtils.determineCommonAncestor(returnType, commonType); } } catch (Throwable ex) { @@ -704,18 +704,17 @@ public abstract class AbstractAutowireCapableBeanFactory extends AbstractBeanFac } } else { - returnTypes.add(factoryMethod.getReturnType()); + commonType = ClassUtils.determineCommonAncestor(factoryMethod.getReturnType(), commonType); } } } - if (returnTypes.size() == 1) { + if (commonType != null) { // Clear return type found: all factory methods return same type. - Class result = returnTypes.iterator().next(); if (cache) { - mbd.resolvedFactoryMethodReturnType = result; + mbd.resolvedFactoryMethodReturnType = commonType; } - return result; + return commonType; } else { // Ambiguous return types found: return null to indicate "not determinable". diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/xml/FactoryMethods.java b/spring-beans/src/test/java/org/springframework/beans/factory/xml/FactoryMethods.java index eb7113448b..4a9703eaa7 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/xml/FactoryMethods.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/xml/FactoryMethods.java @@ -54,11 +54,11 @@ public class FactoryMethods { return new FactoryMethods(tb, name, num); } - static FactoryMethods newInstance(TestBean tb, int num, Integer something) { + static ExtendedFactoryMethods newInstance(TestBean tb, int num, Integer something) { if (something != null) { throw new IllegalStateException("Should never be called with non-null value"); } - return new FactoryMethods(tb, null, num); + return new ExtendedFactoryMethods(tb, null, num); } @SuppressWarnings("unused") @@ -119,4 +119,12 @@ public class FactoryMethods { this.name = name; } + + public static class ExtendedFactoryMethods extends FactoryMethods { + + ExtendedFactoryMethods(TestBean tb, String name, int num) { + super(tb, name, num); + } + } + } diff --git a/spring-core/src/main/java/org/springframework/util/ClassUtils.java b/spring-core/src/main/java/org/springframework/util/ClassUtils.java index 1877f2aa48..710b5bbf8f 100644 --- a/spring-core/src/main/java/org/springframework/util/ClassUtils.java +++ b/spring-core/src/main/java/org/springframework/util/ClassUtils.java @@ -1143,6 +1143,39 @@ public abstract class ClassUtils { return Proxy.getProxyClass(classLoader, interfaces); } + /** + * Determine the common ancestor of the given classes, if any. + * @param clazz1 the class to introspect + * @param clazz2 the other class to introspect + * @return the common ancestor (i.e. common superclass, one interface + * extending the other), or {@code null} if none found. If any of the + * given classes is {@code null}, the other class will be returned. + * @since 3.2.6 + */ + public static Class determineCommonAncestor(Class clazz1, Class clazz2) { + if (clazz1 == null) { + return clazz2; + } + if (clazz2 == null) { + return clazz1; + } + if (clazz1.isAssignableFrom(clazz2)) { + return clazz1; + } + if (clazz2.isAssignableFrom(clazz1)) { + return clazz2; + } + Class ancestor = clazz1; + do { + ancestor = ancestor.getSuperclass(); + if (ancestor == null || Object.class.equals(ancestor)) { + return null; + } + } + while (!ancestor.isAssignableFrom(clazz2)); + return ancestor; + } + /** * Check whether the given class is visible in the given ClassLoader. * @param clazz the class to check (typically an interface) diff --git a/spring-core/src/test/java/org/springframework/util/ClassUtilsTests.java b/spring-core/src/test/java/org/springframework/util/ClassUtilsTests.java index ce187d7c71..f29b8b6198 100644 --- a/spring-core/src/test/java/org/springframework/util/ClassUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/util/ClassUtilsTests.java @@ -20,41 +20,48 @@ import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Proxy; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; - -import junit.framework.TestCase; +import java.util.Set; import org.springframework.tests.sample.objects.DerivedTestObject; import org.springframework.tests.sample.objects.ITestInterface; import org.springframework.tests.sample.objects.ITestObject; import org.springframework.tests.sample.objects.TestObject; +import static org.junit.Assert.*; + +import org.junit.Before; +import org.junit.Test; + /** * @author Colin Sampaleanu * @author Juergen Hoeller * @author Rob Harrop * @author Rick Evans */ -public class ClassUtilsTests extends TestCase { +public class ClassUtilsTests { private ClassLoader classLoader = getClass().getClassLoader(); - @Override + @Before public void setUp() { InnerClass.noArgCalled = false; InnerClass.argCalled = false; InnerClass.overloadedCalled = false; } + @Test public void testIsPresent() throws Exception { assertTrue(ClassUtils.isPresent("java.lang.String", classLoader)); assertFalse(ClassUtils.isPresent("java.lang.MySpecialString", classLoader)); } + @Test public void testForName() throws ClassNotFoundException { assertEquals(String.class, ClassUtils.forName("java.lang.String", classLoader)); assertEquals(String[].class, ClassUtils.forName("java.lang.String[]", classLoader)); @@ -69,6 +76,7 @@ public class ClassUtilsTests extends TestCase { assertEquals(short[][][].class, ClassUtils.forName("[[[S", classLoader)); } + @Test public void testForNameWithPrimitiveClasses() throws ClassNotFoundException { assertEquals(boolean.class, ClassUtils.forName("boolean", classLoader)); assertEquals(byte.class, ClassUtils.forName("byte", classLoader)); @@ -81,6 +89,7 @@ public class ClassUtilsTests extends TestCase { assertEquals(void.class, ClassUtils.forName("void", classLoader)); } + @Test public void testForNameWithPrimitiveArrays() throws ClassNotFoundException { assertEquals(boolean[].class, ClassUtils.forName("boolean[]", classLoader)); assertEquals(byte[].class, ClassUtils.forName("byte[]", classLoader)); @@ -92,6 +101,7 @@ public class ClassUtilsTests extends TestCase { assertEquals(double[].class, ClassUtils.forName("double[]", classLoader)); } + @Test public void testForNameWithPrimitiveArraysInternalName() throws ClassNotFoundException { assertEquals(boolean[].class, ClassUtils.forName(boolean[].class.getName(), classLoader)); assertEquals(byte[].class, ClassUtils.forName(byte[].class.getName(), classLoader)); @@ -103,76 +113,91 @@ public class ClassUtilsTests extends TestCase { assertEquals(double[].class, ClassUtils.forName(double[].class.getName(), classLoader)); } + @Test public void testGetShortName() { String className = ClassUtils.getShortName(getClass()); assertEquals("Class name did not match", "ClassUtilsTests", className); } + @Test public void testGetShortNameForObjectArrayClass() { String className = ClassUtils.getShortName(Object[].class); assertEquals("Class name did not match", "Object[]", className); } + @Test public void testGetShortNameForMultiDimensionalObjectArrayClass() { String className = ClassUtils.getShortName(Object[][].class); assertEquals("Class name did not match", "Object[][]", className); } + @Test public void testGetShortNameForPrimitiveArrayClass() { String className = ClassUtils.getShortName(byte[].class); assertEquals("Class name did not match", "byte[]", className); } + @Test public void testGetShortNameForMultiDimensionalPrimitiveArrayClass() { String className = ClassUtils.getShortName(byte[][][].class); assertEquals("Class name did not match", "byte[][][]", className); } + @Test public void testGetShortNameForInnerClass() { String className = ClassUtils.getShortName(InnerClass.class); assertEquals("Class name did not match", "ClassUtilsTests.InnerClass", className); } + @Test public void testGetShortNameAsProperty() { String shortName = ClassUtils.getShortNameAsProperty(this.getClass()); assertEquals("Class name did not match", "classUtilsTests", shortName); } + @Test public void testGetClassFileName() { assertEquals("String.class", ClassUtils.getClassFileName(String.class)); assertEquals("ClassUtilsTests.class", ClassUtils.getClassFileName(getClass())); } + @Test public void testGetPackageName() { assertEquals("java.lang", ClassUtils.getPackageName(String.class)); assertEquals(getClass().getPackage().getName(), ClassUtils.getPackageName(getClass())); } + @Test public void testGetQualifiedName() { String className = ClassUtils.getQualifiedName(getClass()); assertEquals("Class name did not match", "org.springframework.util.ClassUtilsTests", className); } + @Test public void testGetQualifiedNameForObjectArrayClass() { String className = ClassUtils.getQualifiedName(Object[].class); assertEquals("Class name did not match", "java.lang.Object[]", className); } + @Test public void testGetQualifiedNameForMultiDimensionalObjectArrayClass() { String className = ClassUtils.getQualifiedName(Object[][].class); assertEquals("Class name did not match", "java.lang.Object[][]", className); } + @Test public void testGetQualifiedNameForPrimitiveArrayClass() { String className = ClassUtils.getQualifiedName(byte[].class); assertEquals("Class name did not match", "byte[]", className); } + @Test public void testGetQualifiedNameForMultiDimensionalPrimitiveArrayClass() { String className = ClassUtils.getQualifiedName(byte[][].class); assertEquals("Class name did not match", "byte[][]", className); } + @Test public void testHasMethod() throws Exception { assertTrue(ClassUtils.hasMethod(Collection.class, "size")); assertTrue(ClassUtils.hasMethod(Collection.class, "remove", Object.class)); @@ -180,6 +205,7 @@ public class ClassUtilsTests extends TestCase { assertFalse(ClassUtils.hasMethod(Collection.class, "someOtherMethod")); } + @Test public void testGetMethodIfAvailable() throws Exception { Method method = ClassUtils.getMethodIfAvailable(Collection.class, "size"); assertNotNull(method); @@ -193,6 +219,7 @@ public class ClassUtilsTests extends TestCase { assertNull(ClassUtils.getMethodIfAvailable(Collection.class, "someOtherMethod")); } + @Test public void testGetMethodCountForName() { assertEquals("Verifying number of overloaded 'print' methods for OverloadedMethodsClass.", 2, ClassUtils.getMethodCountForName(OverloadedMethodsClass.class, "print")); @@ -200,6 +227,7 @@ public class ClassUtilsTests extends TestCase { ClassUtils.getMethodCountForName(SubOverloadedMethodsClass.class, "print")); } + @Test public void testCountOverloadedMethods() { assertFalse(ClassUtils.hasAtLeastOneMethodWithName(TestObject.class, "foobar")); // no args @@ -208,6 +236,7 @@ public class ClassUtilsTests extends TestCase { assertTrue(ClassUtils.hasAtLeastOneMethodWithName(TestObject.class, "setAge")); } + @Test public void testNoArgsStaticMethod() throws IllegalAccessException, InvocationTargetException { Method method = ClassUtils.getStaticMethod(InnerClass.class, "staticMethod", (Class[]) null); method.invoke(null, (Object[]) null); @@ -215,6 +244,7 @@ public class ClassUtilsTests extends TestCase { InnerClass.noArgCalled); } + @Test public void testArgsStaticMethod() throws IllegalAccessException, InvocationTargetException { Method method = ClassUtils.getStaticMethod(InnerClass.class, "argStaticMethod", new Class[] {String.class}); @@ -222,6 +252,7 @@ public class ClassUtilsTests extends TestCase { assertTrue("argument method was not invoked.", InnerClass.argCalled); } + @Test public void testOverloadedStaticMethod() throws IllegalAccessException, InvocationTargetException { Method method = ClassUtils.getStaticMethod(InnerClass.class, "staticMethod", new Class[] {String.class}); @@ -230,6 +261,7 @@ public class ClassUtilsTests extends TestCase { InnerClass.overloadedCalled); } + @Test public void testIsAssignable() { assertTrue(ClassUtils.isAssignable(Object.class, Object.class)); assertTrue(ClassUtils.isAssignable(String.class, String.class)); @@ -245,11 +277,13 @@ public class ClassUtilsTests extends TestCase { assertFalse(ClassUtils.isAssignable(double.class, Integer.class)); } + @Test public void testClassPackageAsResourcePath() { String result = ClassUtils.classPackageAsResourcePath(Proxy.class); assertTrue(result.equals("java/lang/reflect")); } + @Test public void testAddResourcePathToPackagePath() { String result = "java/lang/reflect/xyzabc.xml"; assertEquals(result, ClassUtils.addResourcePathToPackagePath(Proxy.class, "xyzabc.xml")); @@ -259,6 +293,7 @@ public class ClassUtilsTests extends TestCase { ClassUtils.addResourcePathToPackagePath(Proxy.class, "a/b/c/d.xml")); } + @Test public void testGetAllInterfaces() { DerivedTestObject testBean = new DerivedTestObject(); List ifcs = Arrays.asList(ClassUtils.getAllInterfaces(testBean)); @@ -268,6 +303,7 @@ public class ClassUtilsTests extends TestCase { assertTrue("Contains IOther", ifcs.contains(ITestInterface.class)); } + @Test public void testClassNamesToString() { List ifcs = new LinkedList(); ifcs.add(Serializable.class); @@ -288,6 +324,36 @@ public class ClassUtilsTests extends TestCase { assertEquals("[]", ClassUtils.classNamesToString(Collections.EMPTY_LIST)); } + @Test + public void testDetermineCommonAncestor() { + assertEquals(Number.class, ClassUtils.determineCommonAncestor(Integer.class, Number.class)); + assertEquals(Number.class, ClassUtils.determineCommonAncestor(Number.class, Integer.class)); + assertEquals(Number.class, ClassUtils.determineCommonAncestor(Number.class, null)); + assertEquals(Integer.class, ClassUtils.determineCommonAncestor(null, Integer.class)); + assertEquals(Integer.class, ClassUtils.determineCommonAncestor(Integer.class, Integer.class)); + + assertEquals(Number.class, ClassUtils.determineCommonAncestor(Integer.class, Float.class)); + assertEquals(Number.class, ClassUtils.determineCommonAncestor(Float.class, Integer.class)); + assertNull(ClassUtils.determineCommonAncestor(Integer.class, String.class)); + assertNull(ClassUtils.determineCommonAncestor(String.class, Integer.class)); + + assertEquals(Collection.class, ClassUtils.determineCommonAncestor(List.class, Collection.class)); + assertEquals(Collection.class, ClassUtils.determineCommonAncestor(Collection.class, List.class)); + assertEquals(Collection.class, ClassUtils.determineCommonAncestor(Collection.class, null)); + assertEquals(List.class, ClassUtils.determineCommonAncestor(null, List.class)); + assertEquals(List.class, ClassUtils.determineCommonAncestor(List.class, List.class)); + + assertNull(ClassUtils.determineCommonAncestor(List.class, Set.class)); + assertNull(ClassUtils.determineCommonAncestor(Set.class, List.class)); + assertNull(ClassUtils.determineCommonAncestor(List.class, Runnable.class)); + assertNull(ClassUtils.determineCommonAncestor(Runnable.class, List.class)); + + assertEquals(List.class, ClassUtils.determineCommonAncestor(List.class, ArrayList.class)); + assertEquals(List.class, ClassUtils.determineCommonAncestor(ArrayList.class, List.class)); + assertNull(ClassUtils.determineCommonAncestor(List.class, String.class)); + assertNull(ClassUtils.determineCommonAncestor(String.class, List.class)); + } + public static class InnerClass {