diff --git a/src/main/java/org/junit/support/testng/engine/DiscoveryListener.java b/src/main/java/org/junit/support/testng/engine/DiscoveryListener.java index adf1815..a0e34db 100644 --- a/src/main/java/org/junit/support/testng/engine/DiscoveryListener.java +++ b/src/main/java/org/junit/support/testng/engine/DiscoveryListener.java @@ -12,8 +12,12 @@ import java.util.HashSet; import java.util.Set; +import java.util.function.Predicate; +import org.junit.platform.engine.EngineDiscoveryRequest; +import org.junit.platform.engine.Filter; import org.junit.platform.engine.TestDescriptor; +import org.junit.platform.engine.discovery.ClassNameFilter; import org.testng.ITestClass; import org.testng.ITestResult; @@ -21,8 +25,10 @@ class DiscoveryListener extends DefaultListener { private final TestClassRegistry testClassRegistry = new TestClassRegistry(); private final TestNGEngineDescriptor engineDescriptor; + private final Predicate classNameFilter; - public DiscoveryListener(TestNGEngineDescriptor engineDescriptor) { + public DiscoveryListener(EngineDiscoveryRequest request, TestNGEngineDescriptor engineDescriptor) { + this.classNameFilter = Filter.composeFilters(request.getFiltersByType(ClassNameFilter.class)).toPredicate(); this.engineDescriptor = engineDescriptor; } @@ -34,8 +40,15 @@ public void finalizeDiscovery() { @Override public void onBeforeClass(ITestClass testClass) { - testClassRegistry.start(testClass.getRealClass(), - () -> engineDescriptor.findClassDescriptor(testClass.getRealClass())); + testClassRegistry.start(testClass.getRealClass(), realClass -> { + ClassDescriptor classDescriptor = engineDescriptor.findClassDescriptor(realClass); + if (classDescriptor == null && classNameFilter.test(realClass.getName())) { + classDescriptor = engineDescriptor.getTestDescriptorFactory().createClassDescriptor(engineDescriptor, + realClass); + engineDescriptor.addChild(classDescriptor); + } + return classDescriptor; + }); } @Override diff --git a/src/main/java/org/junit/support/testng/engine/ExecutionListener.java b/src/main/java/org/junit/support/testng/engine/ExecutionListener.java index 57da6b7..5b63c15 100644 --- a/src/main/java/org/junit/support/testng/engine/ExecutionListener.java +++ b/src/main/java/org/junit/support/testng/engine/ExecutionListener.java @@ -55,7 +55,7 @@ class ExecutionListener extends DefaultListener { public void onBeforeClass(ITestClass testClass) { ClassDescriptor classDescriptor = requireNonNull(engineDescriptor.findClassDescriptor(testClass.getRealClass()), "Missing class descriptor"); - testClassRegistry.start(testClass.getRealClass(), () -> { + testClassRegistry.start(testClass.getRealClass(), __ -> { delegate.executionStarted(classDescriptor); return classDescriptor; }); diff --git a/src/main/java/org/junit/support/testng/engine/TestClassRegistry.java b/src/main/java/org/junit/support/testng/engine/TestClassRegistry.java index 06a3689..95ecdde 100644 --- a/src/main/java/org/junit/support/testng/engine/TestClassRegistry.java +++ b/src/main/java/org/junit/support/testng/engine/TestClassRegistry.java @@ -16,20 +16,25 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; -import java.util.function.Supplier; +import java.util.function.Function; class TestClassRegistry { private final Set classDescriptors = ConcurrentHashMap.newKeySet(); private final Map, Entry> testClasses = new ConcurrentHashMap<>(); - void start(Class testClass, Supplier onFirst) { + void start(Class testClass, Function, ClassDescriptor> onFirst) { Entry entry = testClasses.computeIfAbsent(testClass, __ -> { - ClassDescriptor classDescriptor = onFirst.get(); - classDescriptors.add(classDescriptor); - return new Entry(classDescriptor); + ClassDescriptor classDescriptor = onFirst.apply(testClass); + if (classDescriptor != null) { + classDescriptors.add(classDescriptor); + return new Entry(classDescriptor); + } + return null; }); - entry.inProgress.incrementAndGet(); + if (entry != null) { + entry.inProgress.incrementAndGet(); + } } Optional get(Class testClass) { diff --git a/src/main/java/org/junit/support/testng/engine/TestNGTestEngine.java b/src/main/java/org/junit/support/testng/engine/TestNGTestEngine.java index 3cd62c5..be3a7cb 100644 --- a/src/main/java/org/junit/support/testng/engine/TestNGTestEngine.java +++ b/src/main/java/org/junit/support/testng/engine/TestNGTestEngine.java @@ -84,7 +84,7 @@ public TestDescriptor discover(EngineDiscoveryRequest request, UniqueId uniqueId List methodNames = engineDescriptor.getQualifiedMethodNames(); ConfigurationParameters configurationParameters = request.getConfigurationParameters(); - DiscoveryListener listener = new DiscoveryListener(engineDescriptor); + DiscoveryListener listener = new DiscoveryListener(request, engineDescriptor); if (testClasses.length > 0) { withTemporarySystemProperty(TESTNG_MODE_DRYRUN, "true", diff --git a/src/test/java/org/junit/support/testng/engine/DiscoveryIntegrationTests.java b/src/test/java/org/junit/support/testng/engine/DiscoveryIntegrationTests.java index 7f228c5..dde0c24 100644 --- a/src/test/java/org/junit/support/testng/engine/DiscoveryIntegrationTests.java +++ b/src/test/java/org/junit/support/testng/engine/DiscoveryIntegrationTests.java @@ -32,6 +32,7 @@ import example.basics.InheritedClassLevelOnlyAnnotationTestCase; import example.basics.InheritingSubClassTestCase; import example.basics.JUnitTestCase; +import example.basics.NestedTestClass; import example.basics.SimpleTestCase; import example.basics.SuccessPercentageTestCase; import example.basics.TwoMethodsTestCase; @@ -287,6 +288,32 @@ void doesNotThrowExceptionWhenNonExecutableTypeOfClassIsSelected(Class testCl assertThat(rootDescriptor.getChildren()).isEmpty(); } + @Test + void discoversNestedTestClasses() { + var selectedTestClass = NestedTestClass.class; + var request = request().selectors(selectClass(selectedTestClass)).build(); + + var rootDescriptor = testEngine.discover(request, engineId); + + assertThat(rootDescriptor.getUniqueId()).isEqualTo(engineId); + assertThat(rootDescriptor.getChildren()).hasSize(2); + + Map classDescriptors = rootDescriptor.getChildren().stream() // + .collect(toMap(TestDescriptor::getDisplayName, identity())); + + TestDescriptor classDescriptor = classDescriptors.get("A"); + assertThat(classDescriptor.getLegacyReportingName()).isEqualTo(NestedTestClass.A.class.getName()); + assertThat(classDescriptor.getType()).isEqualTo(CONTAINER); + assertThat(classDescriptor.getSource()).contains(ClassSource.from(NestedTestClass.A.class)); + assertThat(classDescriptor.getChildren()).hasSize(1); + + classDescriptor = classDescriptors.get("B"); + assertThat(classDescriptor.getLegacyReportingName()).isEqualTo(NestedTestClass.B.class.getName()); + assertThat(classDescriptor.getType()).isEqualTo(CONTAINER); + assertThat(classDescriptor.getSource()).contains(ClassSource.from(NestedTestClass.B.class)); + assertThat(classDescriptor.getChildren()).hasSize(1); + } + interface InterfaceTestCase { } diff --git a/src/test/java/org/junit/support/testng/engine/ReportingIntegrationTests.java b/src/test/java/org/junit/support/testng/engine/ReportingIntegrationTests.java index 22cda48..bcf2987 100644 --- a/src/test/java/org/junit/support/testng/engine/ReportingIntegrationTests.java +++ b/src/test/java/org/junit/support/testng/engine/ReportingIntegrationTests.java @@ -13,6 +13,7 @@ import static org.junit.platform.commons.util.StringUtils.isBlank; import static org.junit.platform.engine.FilterResult.excluded; import static org.junit.platform.engine.FilterResult.includedIf; +import static org.junit.platform.engine.discovery.ClassNameFilter.excludeClassNamePatterns; import static org.junit.platform.engine.discovery.DiscoverySelectors.selectClass; import static org.junit.platform.engine.discovery.DiscoverySelectors.selectMethod; import static org.junit.platform.engine.discovery.DiscoverySelectors.selectUniqueId; @@ -35,6 +36,7 @@ import example.basics.CustomAttributeTestCase; import example.basics.ExpectedExceptionsTestCase; import example.basics.InheritingSubClassTestCase; +import example.basics.NestedTestClass; import example.basics.ParallelExecutionTestCase; import example.basics.RetriedTestCase; import example.basics.SimpleTestCase; @@ -46,6 +48,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.junit.platform.engine.Filter; import org.junit.platform.engine.UniqueId; import org.junit.platform.engine.support.descriptor.MethodSource; import org.junit.platform.launcher.PostDiscoveryFilter; @@ -302,4 +305,17 @@ void reportsParallelInvocations() { event(testClass(testClass), finishedSuccessfully())); } + @Test + void onlyExecutesNestedTestClassesThatMatchClassNameFilter() { + var selectedTestClass = NestedTestClass.class; + + var results = testNGEngine() // + .selectors(selectClass(selectedTestClass)) // + .filters((Filter) excludeClassNamePatterns(".*A$")) // + .execute(); + + results.containerEvents().assertStatistics(stats -> stats.started(2).finished(2)); + results.testEvents().assertStatistics(stats -> stats.started(1).finished(1)); + } + } diff --git a/src/testFixtures/java/example/basics/NestedTestClass.java b/src/testFixtures/java/example/basics/NestedTestClass.java new file mode 100644 index 0000000..5e2d6ec --- /dev/null +++ b/src/testFixtures/java/example/basics/NestedTestClass.java @@ -0,0 +1,28 @@ +/* + * Copyright 2021 the original author or authors. + * + * All rights reserved. This program and the accompanying materials are + * made available under the terms of the Eclipse Public License v2.0 which + * accompanies this distribution and is available at + * + * https://www.eclipse.org/legal/epl-v20.html + */ + +package example.basics; + +import org.testng.annotations.Test; + +public class NestedTestClass { + + public static class A { + @Test + public void test() { + } + } + + public static class B { + @Test + public void test() { + } + } +}