diff --git a/hadoop-hdfs-project/hadoop-hdfs/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/CustomizedCallbackHandler.java b/hadoop-hdfs-project/hadoop-hdfs/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/CustomizedCallbackHandler.java index eff093490bcd1..a15282bd6307f 100644 --- a/hadoop-hdfs-project/hadoop-hdfs/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/CustomizedCallbackHandler.java +++ b/hadoop-hdfs-project/hadoop-hdfs/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/CustomizedCallbackHandler.java @@ -20,13 +20,15 @@ import javax.security.auth.callback.Callback; import javax.security.auth.callback.UnsupportedCallbackException; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.util.List; /** For handling customized {@link Callback}. */ public interface CustomizedCallbackHandler { class DefaultHandler implements CustomizedCallbackHandler{ @Override - public void handleCallback(List callbacks, String username, char[] password) + public void handleCallbacks(List callbacks, String username, char[] password) throws UnsupportedCallbackException { if (!callbacks.isEmpty()) { throw new UnsupportedCallbackException(callbacks.get(0)); @@ -34,6 +36,25 @@ public void handleCallback(List callbacks, String username, char[] pas } } - void handleCallback(List callbacks, String name, char[] password) + static CustomizedCallbackHandler delegate(Object delegated) { + final String methodName = "handleCallbacks"; + final Class clazz = delegated.getClass(); + final Method method; + try { + method = clazz.getMethod(methodName, List.class, String.class, char[].class); + } catch (NoSuchMethodException e) { + throw new IllegalStateException("Failed to get method " + methodName + " from " + clazz, e); + } + + return (callbacks, name, password) -> { + try { + method.invoke(delegated, callbacks, name, password); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new IOException("Failed to invoke " + method, e); + } + }; + } + + void handleCallbacks(List callbacks, String name, char[] password) throws UnsupportedCallbackException, IOException; } diff --git a/hadoop-hdfs-project/hadoop-hdfs/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslDataTransferServer.java b/hadoop-hdfs-project/hadoop-hdfs/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslDataTransferServer.java index ae79800b3ed37..d71544fc77dc6 100644 --- a/hadoop-hdfs-project/hadoop-hdfs/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslDataTransferServer.java +++ b/hadoop-hdfs-project/hadoop-hdfs/src/main/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/SaslDataTransferServer.java @@ -225,14 +225,20 @@ static final class SaslServerCallbackHandler SaslServerCallbackHandler(Configuration conf, PasswordFunction passwordFunction) { this.passwordFunction = passwordFunction; - final Class clazz = conf.getClass( + final Class clazz = conf.getClass( HdfsClientConfigKeys.DFS_DATA_TRANSFER_SASL_CUSTOMIZEDCALLBACKHANDLER_CLASS_KEY, - CustomizedCallbackHandler.DefaultHandler.class, CustomizedCallbackHandler.class); + CustomizedCallbackHandler.DefaultHandler.class); + final Object callbackHandler; try { - this.customizedCallbackHandler = clazz.newInstance(); + callbackHandler = clazz.newInstance(); } catch (Exception e) { throw new IllegalStateException("Failed to create a new instance of " + clazz, e); } + if (callbackHandler instanceof CustomizedCallbackHandler) { + customizedCallbackHandler = (CustomizedCallbackHandler) callbackHandler; + } else { + customizedCallbackHandler = CustomizedCallbackHandler.delegate(callbackHandler); + } } @Override @@ -271,7 +277,7 @@ public void handle(Callback[] callbacks) throws IOException, if (unknownCallbacks != null) { final String name = nc != null ? nc.getDefaultName() : null; final char[] password = name != null ? passwordFunction.apply(name) : null; - customizedCallbackHandler.handleCallback(unknownCallbacks, name, password); + customizedCallbackHandler.handleCallbacks(unknownCallbacks, name, password); } } } diff --git a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/TestCustomizedCallbackHandler.java b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/TestCustomizedCallbackHandler.java index 88d1d66bc40ff..37de661720839 100644 --- a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/TestCustomizedCallbackHandler.java +++ b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/protocol/datatransfer/sasl/TestCustomizedCallbackHandler.java @@ -20,6 +20,7 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hdfs.client.HdfsClientConfigKeys; import org.apache.hadoop.hdfs.protocol.datatransfer.sasl.SaslDataTransferServer.SaslServerCallbackHandler; +import org.apache.hadoop.test.LambdaTestUtils; import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; @@ -27,18 +28,37 @@ import javax.security.auth.callback.Callback; import javax.security.auth.callback.UnsupportedCallbackException; -import java.util.Arrays; +import java.io.IOException; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +/** For testing {@link CustomizedCallbackHandler}. */ public class TestCustomizedCallbackHandler { - public static final Logger LOG = LoggerFactory.getLogger(TestCustomizedCallbackHandler.class); + static final Logger LOG = LoggerFactory.getLogger(TestCustomizedCallbackHandler.class); + + static final AtomicReference> LAST_CALLBACKS = new AtomicReference<>(); + + static void runHandleCallbacks(Object caller, List callbacks, String name) { + LOG.info("{}: handling {} for {}", caller.getClass().getSimpleName(), callbacks, name); + LAST_CALLBACKS.set(callbacks); + } + + /** Assert if the callbacks in {@link #LAST_CALLBACKS} are the same as the expected callbacks. */ + static void assertCallbacks(Callback[] expected) { + final List computed = LAST_CALLBACKS.getAndSet(null); + Assert.assertNotNull(computed); + Assert.assertEquals(expected.length, computed.size()); + for (int i = 0; i < expected.length; i++) { + Assert.assertSame(expected[i], computed.get(i)); + } + } static class MyCallback implements Callback { } static class MyCallbackHandler implements CustomizedCallbackHandler { @Override - public void handleCallback(List callbacks, String name, char[] password) { - LOG.info("{}: handling {} for {}", getClass().getSimpleName(), callbacks, name); + public void handleCallbacks(List callbacks, String name, char[] password) { + runHandleCallbacks(this, callbacks, name); } } @@ -48,16 +68,52 @@ public void testCustomizedCallbackHandler() throws Exception { final Callback[] callbacks = {new MyCallback()}; // without setting conf, expect UnsupportedCallbackException - try { - new SaslServerCallbackHandler(conf, String::toCharArray).handle(callbacks); - Assert.fail("Expected UnsupportedCallbackException for " + Arrays.asList(callbacks)); - } catch (UnsupportedCallbackException e) { - LOG.info("The failure is expected", e); - } + LambdaTestUtils.intercept(UnsupportedCallbackException.class, () -> runTest(conf, callbacks)); // set conf and expect success conf.setClass(HdfsClientConfigKeys.DFS_DATA_TRANSFER_SASL_CUSTOMIZEDCALLBACKHANDLER_CLASS_KEY, MyCallbackHandler.class, CustomizedCallbackHandler.class); new SaslServerCallbackHandler(conf, String::toCharArray).handle(callbacks); + assertCallbacks(callbacks); + } + + static class MyCallbackMethod { + public void handleCallbacks(List callbacks, String name, char[] password) + throws UnsupportedCallbackException { + runHandleCallbacks(this, callbacks, name); + } + } + + static class MyExceptionMethod { + public void handleCallbacks(List callbacks, String name, char[] password) + throws UnsupportedCallbackException { + runHandleCallbacks(this, callbacks, name); + throw new UnsupportedCallbackException(callbacks.get(0)); + } + } + + @Test + public void testCustomizedCallbackMethod() throws Exception { + final Configuration conf = new Configuration(); + final Callback[] callbacks = {new MyCallback()}; + + // without setting conf, expect UnsupportedCallbackException + LambdaTestUtils.intercept(UnsupportedCallbackException.class, () -> runTest(conf, callbacks)); + + // set conf and expect success + conf.setClass(HdfsClientConfigKeys.DFS_DATA_TRANSFER_SASL_CUSTOMIZEDCALLBACKHANDLER_CLASS_KEY, + MyCallbackMethod.class, Object.class); + new SaslServerCallbackHandler(conf, String::toCharArray).handle(callbacks); + assertCallbacks(callbacks); + + // set conf and expect exception + conf.setClass(HdfsClientConfigKeys.DFS_DATA_TRANSFER_SASL_CUSTOMIZEDCALLBACKHANDLER_CLASS_KEY, + MyExceptionMethod.class, Object.class); + LambdaTestUtils.intercept(IOException.class, () -> runTest(conf, callbacks)); + } + + static void runTest(Configuration conf, Callback... callbacks) + throws IOException, UnsupportedCallbackException { + new SaslServerCallbackHandler(conf, String::toCharArray).handle(callbacks); } }