From ff8bdf1eee80ff9f1e9d8e7d38bc12b2bea7d222 Mon Sep 17 00:00:00 2001 From: Simone Bordet Date: Tue, 2 Dec 2014 11:12:08 +0100 Subject: [PATCH] Backport 478e4284efac50cf5a811aac3728b210fc929a5c --- alpn-boot/pom.xml | 2 +- .../sun/security/ssl/ClientHandshaker.java | 57 +- alpn-tests/pom.xml | 2 +- .../mortbay/jetty/alpn/AbstractALPNTest.java | 362 +++++++ .../mortbay/jetty/alpn/SSLEngineALPNTest.java | 910 +++--------------- .../mortbay/jetty/alpn/SSLSocketALPNTest.java | 237 ++--- pom.xml | 2 +- 7 files changed, 642 insertions(+), 930 deletions(-) create mode 100644 alpn-tests/src/test/java/org/mortbay/jetty/alpn/AbstractALPNTest.java diff --git a/alpn-boot/pom.xml b/alpn-boot/pom.xml index dedf466..0104d2c 100644 --- a/alpn-boot/pom.xml +++ b/alpn-boot/pom.xml @@ -3,7 +3,7 @@ org.mortbay.jetty.alpn alpn-project - 7.1.0.v20141016 + 7.0.0.1-dev 4.0.0 diff --git a/alpn-boot/src/main/java/sun/security/ssl/ClientHandshaker.java b/alpn-boot/src/main/java/sun/security/ssl/ClientHandshaker.java index adbfab3..163823f 100644 --- a/alpn-boot/src/main/java/sun/security/ssl/ClientHandshaker.java +++ b/alpn-boot/src/main/java/sun/security/ssl/ClientHandshaker.java @@ -567,6 +567,12 @@ public Subject run() throws Exception { } setHandshakeSessionSE(session); + + // ALPN_CHANGES_BEGIN + if (isInitialHandshake) + alpnSelected(mesg); + // ALPN_CHANGES_END + return; } @@ -596,42 +602,47 @@ public Subject run() throws Exception { // ALPN_CHANGES_BEGIN if (isInitialHandshake) + alpnSelected(mesg); + // ALPN_CHANGES_END + } + + // ALPN_CHANGES_BEGIN + private void alpnSelected(ServerHello mesg) throws IOException + { + ALPN.ClientProvider provider = (ALPN.ClientProvider)(conn != null ? ALPN.get(conn) : ALPN.get(engine)); + Object ssl = conn != null ? conn : engine; + if (provider != null) { - ALPN.ClientProvider provider = (ALPN.ClientProvider)(conn != null ? ALPN.get(conn) : ALPN.get(engine)); - Object ssl = conn != null ? conn : engine; - if (provider != null) + ALPNExtension extension = (ALPNExtension)mesg.extensions.get(ExtensionType.EXT_ALPN); + if (extension != null) { - ALPNExtension extension = (ALPNExtension)mesg.extensions.get(ExtensionType.EXT_ALPN); - if (extension != null) + List protocols = extension.getProtocols(); + try { - List protocols = extension.getProtocols(); - try - { - String protocol = protocols == null || protocols.isEmpty() ? null : protocols.get(0); - if (ALPN.debug) - System.err.println("[C] ALPN protocol '" + protocol + "' selected by server for " + ssl); - provider.selected(protocol); - } - catch (Throwable x) - { - fatalSE(Alerts.alert_no_application_protocol, "Could not negotiate application protocol", x); - } + String protocol = protocols == null || protocols.isEmpty() ? null : protocols.get(0); + if (ALPN.debug) + System.err.println("[C] ALPN protocol '" + protocol + "' selected by server for " + ssl); + provider.selected(protocol); } - else + catch (Throwable x) { - if (ALPN.debug) - System.err.println("[C] ALPN not supported by server for " + ssl); - provider.unsupported(); + fatalSE(Alerts.alert_no_application_protocol, "Could not negotiate application protocol", x); } } else { if (ALPN.debug) - System.err.println("[C] ALPN client provider not present for " + ssl); + System.err.println("[C] ALPN not supported by server for " + ssl); + provider.unsupported(); } } - // ALPN_CHANGES_END + else + { + if (ALPN.debug) + System.err.println("[C] ALPN client provider not present for " + ssl); + } } + // ALPN_CHANGES_END /* * Server's own key was either a signing-only key, or was too diff --git a/alpn-tests/pom.xml b/alpn-tests/pom.xml index 4834848..d0c20b2 100644 --- a/alpn-tests/pom.xml +++ b/alpn-tests/pom.xml @@ -3,7 +3,7 @@ org.mortbay.jetty.alpn alpn-project - 7.1.0.v20141016 + 7.0.0.1-dev 4.0.0 diff --git a/alpn-tests/src/test/java/org/mortbay/jetty/alpn/AbstractALPNTest.java b/alpn-tests/src/test/java/org/mortbay/jetty/alpn/AbstractALPNTest.java new file mode 100644 index 0000000..6287562 --- /dev/null +++ b/alpn-tests/src/test/java/org/mortbay/jetty/alpn/AbstractALPNTest.java @@ -0,0 +1,362 @@ +// +// ======================================================================== +// Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd. +// ------------------------------------------------------------------------ +// All rights reserved. This program and the accompanying materials +// are made available under the terms of the Eclipse Public License v1.0 +// and Apache License v2.0 which accompanies this distribution. +// +// The Eclipse Public License is available at +// http://www.eclipse.org/legal/epl-v10.html +// +// The Apache License v2.0 is available at +// http://www.opensource.org/licenses/apache2.0.php +// +// You may elect to redistribute this code under either of these licenses. +// ======================================================================== +// + +package org.mortbay.jetty.alpn; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSession; + +import org.eclipse.jetty.alpn.ALPN; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public abstract class AbstractALPNTest +{ + protected abstract SSLResult performTLSHandshake(SSLResult handshake, ALPN.ClientProvider clientProvider, ALPN.ServerProvider serverProvider) throws Exception; + + protected abstract void performTLSClose(SSLResult sslResult) throws Exception; + + protected abstract void performDataExchange(SSLResult sslResult) throws Exception; + + protected abstract void performTLSRenegotiation(SSLResult sslResult, boolean client) throws Exception; + + protected abstract SSLSession getSSLSession(SSLResult sslResult, boolean client) throws Exception; + + @Before + public void prepare() throws Exception + { + Assert.assertNull("ALPN classes must be in the bootclasspath.", ALPN.class.getClassLoader()); + ALPN.debug = true; + } + + @Test + public void testALPNSuccessful() throws Exception + { + final String protocolName = "test"; + final CountDownLatch latch = new CountDownLatch(3); + ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() + { + @Override + public List protocols() + { + latch.countDown(); + return Arrays.asList(protocolName); + } + + @Override + public void unsupported() + { + Assert.fail(); + } + + @Override + public void selected(String protocol) + { + Assert.assertEquals(protocolName, protocol); + latch.countDown(); + } + }; + ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() + { + @Override + public void unsupported() + { + Assert.fail(); + } + + @Override + public String select(List protocols) + { + Assert.assertEquals(1, protocols.size()); + String protocol = protocols.get(0); + Assert.assertEquals(protocolName, protocol); + latch.countDown(); + return protocol; + } + }; + SSLResult sslResult = performTLSHandshake(null, clientProvider, serverProvider); + Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); + + // Verify that we can exchange data without errors. + performDataExchange(sslResult); + + performTLSClose(sslResult); + } + + @Test + public void testServerDoesNotSendALPN() throws Exception + { + final String protocolName = "test"; + final CountDownLatch latch = new CountDownLatch(3); + ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() + { + @Override + public List protocols() + { + latch.countDown(); + return Arrays.asList(protocolName); + } + + @Override + public void unsupported() + { + latch.countDown(); + } + + @Override + public void selected(String protocol) + { + Assert.fail(); + } + }; + ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() + { + @Override + public void unsupported() + { + Assert.fail(); + } + + @Override + public String select(List protocols) + { + Assert.assertEquals(1, protocols.size()); + String protocol = protocols.get(0); + Assert.assertEquals(protocolName, protocol); + latch.countDown(); + // By returning null, the server won't send the ALPN extension. + return null; + } + }; + SSLResult sslResult = performTLSHandshake(null, clientProvider, serverProvider); + Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); + + // Verify that we can exchange data without errors. + performDataExchange(sslResult); + + performTLSClose(sslResult); + } + + @Test + public void testServerThrowsException() throws Exception + { + final String protocolName = "test"; + final CountDownLatch latch = new CountDownLatch(3); + ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() + { + @Override + public List protocols() + { + latch.countDown(); + return Arrays.asList(protocolName); + } + + @Override + public void unsupported() + { + latch.countDown(); + } + + @Override + public void selected(String protocol) + { + Assert.fail(); + } + }; + ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() + { + @Override + public void unsupported() + { + Assert.fail(); + } + + @Override + public String select(List protocols) + { + // By throwing, the server will close the connection. + throw new IllegalStateException("explicitly_thrown_by_test"); + } + }; + + try + { + performTLSHandshake(null, clientProvider, serverProvider); + Assert.fail(); + } + catch (Exception x) + { + // Expected. + } + } + + @Test + public void testClientTLSRenegotiation() throws Exception + { + testTLSRenegotiation(true); + } + + @Test + public void testServerTLSRenegotiation() throws Exception + { + testTLSRenegotiation(false); + } + + private void testTLSRenegotiation(boolean client) throws Exception + { + final String protocolName = "test"; + final AtomicReference latch = new AtomicReference<>(new CountDownLatch(3)); + ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() + { + @Override + public List protocols() + { + latch.get().countDown(); + return Arrays.asList(protocolName); + } + + @Override + public void unsupported() + { + latch.get().countDown(); + Assert.fail(); + } + + @Override + public void selected(String protocol) + { + Assert.assertEquals(protocolName, protocol); + latch.get().countDown(); + } + }; + ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() + { + @Override + public void unsupported() + { + latch.get().countDown(); + Assert.fail(); + } + + @Override + public String select(List protocols) + { + Assert.assertEquals(1, protocols.size()); + String protocol = protocols.get(0); + Assert.assertEquals(protocolName, protocol); + latch.get().countDown(); + return protocol; + } + }; + SSLResult sslResult = performTLSHandshake(null, clientProvider, serverProvider); + Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); + + // Verify that we can exchange data without errors. + performDataExchange(sslResult); + + latch.set(new CountDownLatch(1)); + performTLSRenegotiation(sslResult, client); + + // The data exchange may trigger the completion of the TLS renegotiation. + performDataExchange(sslResult); + + // ALPN must not trigger. + Assert.assertFalse(latch.get().await(1, TimeUnit.SECONDS)); + + performTLSClose(sslResult); + } + + @Test + public void testTLSSessionResumption() throws Exception + { + final String protocolName = "test"; + final AtomicReference latch = new AtomicReference<>(); + ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() + { + @Override + public List protocols() + { + latch.get().countDown(); + return Arrays.asList(protocolName); + } + + @Override + public void unsupported() + { + Assert.fail(); + } + + @Override + public void selected(String protocol) + { + Assert.assertEquals(protocolName, protocol); + latch.get().countDown(); + } + }; + ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() + { + @Override + public void unsupported() + { + Assert.fail(); + } + + @Override + public String select(List protocols) + { + Assert.assertEquals(1, protocols.size()); + String protocol = protocols.get(0); + Assert.assertEquals(protocolName, protocol); + latch.get().countDown(); + return protocol; + } + }; + + // First TLS handshake. + latch.set(new CountDownLatch(3)); + SSLResult sslResult = performTLSHandshake(null, clientProvider, serverProvider); + Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); + SSLSession clientSession1 = getSSLSession(sslResult, true); + SSLSession serverSession1 = getSSLSession(sslResult, false); + // Must close the first session before starting the second one. + performTLSClose(sslResult); + + // Second TLS handshake. + latch.set(new CountDownLatch(3)); + sslResult = performTLSHandshake(sslResult, clientProvider, serverProvider); + Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); + + Assert.assertSame(clientSession1, getSSLSession(sslResult, true)); + Assert.assertSame(serverSession1, getSSLSession(sslResult, false)); + + performTLSClose(sslResult); + } + + public static class SSLResult + { + public SSLContext context; + public S client; + public S server; + } +} diff --git a/alpn-tests/src/test/java/org/mortbay/jetty/alpn/SSLEngineALPNTest.java b/alpn-tests/src/test/java/org/mortbay/jetty/alpn/SSLEngineALPNTest.java index 672b04e..e832004 100644 --- a/alpn-tests/src/test/java/org/mortbay/jetty/alpn/SSLEngineALPNTest.java +++ b/alpn-tests/src/test/java/org/mortbay/jetty/alpn/SSLEngineALPNTest.java @@ -18,821 +18,229 @@ package org.mortbay.jetty.alpn; -import java.net.InetSocketAddress; import java.nio.ByteBuffer; -import java.nio.channels.ServerSocketChannel; -import java.nio.channels.SocketChannel; -import java.nio.charset.Charset; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; +import java.util.Random; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLSession; import org.eclipse.jetty.alpn.ALPN; import org.junit.Assert; -import org.junit.Test; -public class SSLEngineALPNTest +public class SSLEngineALPNTest extends AbstractALPNTest { - @Test - public void testNegotiationSuccessful() throws Exception + @Override + protected SSLResult performTLSHandshake(SSLResult handshake, ALPN.ClientProvider clientProvider, ALPN.ServerProvider serverProvider) throws Exception { - ALPN.debug = true; - final String protocolName = "test"; - final AtomicReference latch = new AtomicReference<>(new CountDownLatch(3)); - ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() - { - @Override - public List protocols() - { - latch.get().countDown(); - return Arrays.asList(protocolName); - } - - @Override - public void unsupported() - { - Assert.fail(); - } - - @Override - public void selected(String protocol) - { - Assert.assertEquals(protocolName, protocol); - latch.get().countDown(); - } - }; - ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() - { - @Override - public void unsupported() - { - Assert.fail(); - } - - @Override - public String select(List protocols) - { - Assert.assertEquals(1, protocols.size()); - String protocol = protocols.get(0); - Assert.assertEquals(protocolName, protocol); - latch.get().countDown(); - return protocol; - } - }; - testNegotiationSuccessful(clientProvider, serverProvider, latch); - } + SSLContext sslContext = handshake == null ? SSLSupport.newSSLContext() : handshake.context; + SSLResult sslResult = new SSLResult<>(); + sslResult.context = sslContext; + // Use host and port to allow for session resumption. + int randomPort = new Random().nextInt(20000) + 10000; + int clientPort = handshake == null ? randomPort : handshake.client.getPeerPort(); + SSLEngine clientSSLEngine = sslContext.createSSLEngine("localhost", clientPort); + sslResult.client = clientSSLEngine; + clientSSLEngine.setUseClientMode(true); + int serverPort = handshake == null ? randomPort + 1 : handshake.server.getPeerPort(); + SSLEngine serverSSLEngine = sslContext.createSSLEngine("localhost", serverPort); + sslResult.server = serverSSLEngine; + serverSSLEngine.setUseClientMode(false); + + ByteBuffer encrypted = ByteBuffer.allocate(clientSSLEngine.getSession().getPacketBufferSize()); + ByteBuffer decrypted = ByteBuffer.allocate(clientSSLEngine.getSession().getApplicationBufferSize()); + + ALPN.put(clientSSLEngine, clientProvider); + clientSSLEngine.beginHandshake(); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, clientSSLEngine.getHandshakeStatus()); + + ALPN.put(serverSSLEngine, serverProvider); + serverSSLEngine.beginHandshake(); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, serverSSLEngine.getHandshakeStatus()); - @Test - public void testServerDoesNotNegotiate() throws Exception - { - ALPN.debug = true; - final String protocolName = "test"; - final AtomicReference latch = new AtomicReference<>(new CountDownLatch(3)); - ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() - { - @Override - public List protocols() - { - latch.get().countDown(); - return Arrays.asList(protocolName); - } - - @Override - public void unsupported() - { - latch.get().countDown(); - } - - @Override - public void selected(String protocol) - { - Assert.fail(); - } - }; - ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() - { - @Override - public void unsupported() - { - Assert.fail(); - } - - @Override - public String select(List protocols) - { - Assert.assertEquals(1, protocols.size()); - String protocol = protocols.get(0); - Assert.assertEquals(protocolName, protocol); - latch.get().countDown(); - // By returning null, the server won't send the ALPN extension. - return null; - } - }; - testNegotiationSuccessful(clientProvider, serverProvider, latch); - } - - @Test - public void testServerFailsNegotiation() throws Exception - { - ALPN.debug = true; - final SSLContext context = SSLSupport.newSSLContext(); - final int readTimeout = 5000; - final String protocolName = "test"; - final ServerSocketChannel server = ServerSocketChannel.open(); - final AtomicReference latch = new AtomicReference<>(new CountDownLatch(1)); - server.bind(new InetSocketAddress("localhost", 0)); - System.err.println("Server listening on " + server.getLocalAddress()); - new Thread() - { - @Override - public void run() - { - SSLEngine sslEngine = context.createSSLEngine(); - try - { - sslEngine.setUseClientMode(false); - ALPN.put(sslEngine, new ALPN.ServerProvider() - { - @Override - public void unsupported() - { - Assert.fail(); - } - - @Override - public String select(List protocols) - { - throw new IllegalStateException("Server says nothing is good enough"); - } - }); - ByteBuffer encrypted = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); - ByteBuffer decrypted = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); - - SocketChannel socket = server.accept(); - socket.socket().setSoTimeout(readTimeout); - - sslEngine.beginHandshake(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, sslEngine.getHandshakeStatus()); - - // Read ClientHello - socket.read(encrypted); - encrypted.flip(); - SSLEngineResult result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_TASK, result.getHandshakeStatus()); - sslEngine.getDelegatedTask().run(); - - // Generate and write ServerHello, it will throw an exception. - encrypted.clear(); - sslEngine.wrap(decrypted, encrypted); - Assert.fail(); - } - catch (Exception x) - { - Throwable cause = x.getCause(); - if (cause instanceof IllegalStateException) - latch.get().countDown(); - else - x.printStackTrace(); - } - finally - { - ALPN.remove(sslEngine); - } - } - }.start(); + // Generate and write ClientHello + wrap(clientSSLEngine, decrypted, encrypted); - SSLEngine sslEngine = context.createSSLEngine(); - sslEngine.setUseClientMode(true); - ALPN.put(sslEngine, new ALPN.ClientProvider() - { - @Override - public List protocols() - { - return Arrays.asList(protocolName); - } + // Read the ClientHello + unwrap(serverSSLEngine, encrypted, decrypted); - @Override - public void unsupported() - { - Assert.fail(); - } + // Generate and write ServerHello (and other messages) + wrap(serverSSLEngine, decrypted, encrypted); - @Override - public void selected(String protocol) - { - Assert.fail(); - } - }); - ByteBuffer encrypted = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); - ByteBuffer decrypted = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); + // Read the ServerHello (and other messages) + unwrap(clientSSLEngine, encrypted, decrypted); - SocketChannel client = SocketChannel.open(server.getLocalAddress()); - client.socket().setSoTimeout(readTimeout); + // Generate and write ClientKeyExchange, ChangeCipherSpec and Finished + wrap(clientSSLEngine, decrypted, encrypted); - sslEngine.beginHandshake(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); + // Read ClientKeyExchange, ChangeCipherSpec and Finished + unwrap(serverSSLEngine, encrypted, decrypted); - // Generate and write ClientHello - SSLEngineResult result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); + // Generate and write ChangeCipherSpec and Finished + wrap(serverSSLEngine, decrypted, encrypted); - Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); + // Read ChangeCipherSpec and Finished + unwrap(clientSSLEngine, encrypted, decrypted); - // The server was supposed to send the TLS close alert, - // but it does not (at least in this JDK version). + Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, clientSSLEngine.getHandshakeStatus()); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, serverSSLEngine.getHandshakeStatus()); - // Close - ALPN.remove(sslEngine); - client.close(); - server.close(); + return sslResult; } - @Test - public void testClientFailsNegotiation() throws Exception + @Override + protected void performTLSClose(SSLResult sslResult) throws Exception { - ALPN.debug = true; - final SSLContext context = SSLSupport.newSSLContext(); - final int readTimeout = 5000; - final String protocolName = "test"; - final ServerSocketChannel server = ServerSocketChannel.open(); - server.bind(new InetSocketAddress("localhost", 0)); - System.err.println("Server listening on " + server.getLocalAddress()); - new Thread() - { - @Override - public void run() - { - try - { - final SSLEngine sslEngine = context.createSSLEngine(); - sslEngine.setUseClientMode(false); - ALPN.put(sslEngine, new ALPN.ServerProvider() - { - @Override - public void unsupported() - { - Assert.fail(); - } - - @Override - public String select(List protocols) - { - return protocolName + "NotInList"; - } - }); - ByteBuffer encrypted = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); - ByteBuffer decrypted = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); - - SocketChannel socket = server.accept(); - socket.socket().setSoTimeout(readTimeout); - - sslEngine.beginHandshake(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, sslEngine.getHandshakeStatus()); - - // Read ClientHello - socket.read(encrypted); - encrypted.flip(); - SSLEngineResult result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_TASK, result.getHandshakeStatus()); - sslEngine.getDelegatedTask().run(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); - - // Generate and write ServerHello - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - - // Close - ALPN.remove(sslEngine); - socket.close(); - } - catch (Exception x) - { - x.printStackTrace(); - } - } - }.start(); + SSLEngine clientSSLEngine = sslResult.client; + SSLEngine serverSSLEngine = sslResult.server; - final SSLEngine sslEngine = context.createSSLEngine(); - sslEngine.setUseClientMode(true); - ALPN.put(sslEngine, new ALPN.ClientProvider() - { - @Override - public void unsupported() - { - Assert.fail(); - } + ByteBuffer encrypted = ByteBuffer.allocate(clientSSLEngine.getSession().getPacketBufferSize()); + ByteBuffer decrypted = ByteBuffer.allocate(clientSSLEngine.getSession().getApplicationBufferSize()); - @Override - public List protocols() - { - return Arrays.asList(protocolName); - } + clientSSLEngine.closeOutbound(); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, clientSSLEngine.getHandshakeStatus()); - @Override - public void selected(String protocol) - { - Assert.assertNotEquals(protocolName, protocol); - throw new IllegalStateException("Client says I didn't ask for that protocol"); - } - }); - ByteBuffer encrypted = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); - ByteBuffer decrypted = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); + wrap(clientSSLEngine, decrypted, encrypted); - SocketChannel client = SocketChannel.open(server.getLocalAddress()); - client.socket().setSoTimeout(readTimeout); + unwrap(serverSSLEngine, encrypted, decrypted); - sslEngine.beginHandshake(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); + wrap(serverSSLEngine, decrypted, encrypted); - // Generate and write ClientHello - SSLEngineResult result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); + unwrap(clientSSLEngine, encrypted, decrypted); - // Read Server Hello - while (sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) - { - encrypted.clear(); - client.read(encrypted); - encrypted.flip(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) - sslEngine.getDelegatedTask().run(); - } - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); - - try - { - // Generate and write ClientKeyExchange - encrypted.clear(); - sslEngine.wrap(decrypted, encrypted); - Assert.fail(); - } - catch (RuntimeException x) - { - Throwable cause = x.getCause(); - if (!(cause instanceof IllegalStateException)) - throw x; - } - finally - { - // Close - ALPN.remove(sslEngine); - client.close(); - server.close(); - } + Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, clientSSLEngine.getHandshakeStatus()); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, serverSSLEngine.getHandshakeStatus()); } - private void testNegotiationSuccessful(ALPN.ClientProvider clientProvider, final ALPN.ServerProvider serverProvider, AtomicReference latch) throws Exception + @Override + protected void performDataExchange(SSLResult sslResult) throws Exception { - final SSLContext context = SSLSupport.newSSLContext(); - - final int readTimeout = 5000; - final String data = "data"; - final ServerSocketChannel server = ServerSocketChannel.open(); - server.bind(new InetSocketAddress("localhost", 0)); - System.err.println("Server listening on " + server.getLocalAddress()); - new Thread() - { - @Override - public void run() - { - try - { - SSLEngine sslEngine = context.createSSLEngine(); - sslEngine.setUseClientMode(false); - ALPN.put(sslEngine, serverProvider); - ByteBuffer encrypted = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); - ByteBuffer decrypted = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); - - SocketChannel socket = server.accept(); - socket.socket().setSoTimeout(readTimeout); - - sslEngine.beginHandshake(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, sslEngine.getHandshakeStatus()); - - // Read ClientHello - socket.read(encrypted); - encrypted.flip(); - SSLEngineResult result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_TASK, result.getHandshakeStatus()); - sslEngine.getDelegatedTask().run(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); - - // Generate and write ServerHello - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - - // Read up to Finished - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_TASK, result.getHandshakeStatus()); - sslEngine.getDelegatedTask().run(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, sslEngine.getHandshakeStatus()); - if (!encrypted.hasRemaining()) - { - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - } - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - if (!encrypted.hasRemaining()) - { - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - } - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - - if (SSLEngineResult.HandshakeStatus.NEED_UNWRAP == result.getHandshakeStatus()) - { - if (!encrypted.hasRemaining()) - { - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - } - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - } - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); - - // Generate and write ChangeCipherSpec - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - // Generate and write Finished - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.FINISHED, result.getHandshakeStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, sslEngine.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - - // Read data - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); - - // Echo the data back - encrypted.clear(); - decrypted.flip(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - - // Read re-handshake - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_TASK, result.getHandshakeStatus()); - sslEngine.getDelegatedTask().run(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); - - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_TASK, result.getHandshakeStatus()); - sslEngine.getDelegatedTask().run(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, sslEngine.getHandshakeStatus()); - if (!encrypted.hasRemaining()) - { - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - } - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, sslEngine.getHandshakeStatus()); - if (!encrypted.hasRemaining()) - { - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - } - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); - - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.FINISHED, result.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - - // Read more data - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); - - // Echo the data back - encrypted.clear(); - decrypted.flip(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - - // TODO - // Re-handshake -// sslEngine.beginHandshake(); - - // Close - ALPN.remove(sslEngine); - encrypted.clear(); - socket.read(encrypted); - encrypted.flip(); - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.CLOSED, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); - encrypted.clear(); - Assert.assertEquals(-1, socket.read(encrypted)); - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.CLOSED, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); - encrypted.flip(); - socket.write(encrypted); - socket.close(); - } - catch (Exception x) - { - x.printStackTrace(); - } - } - }.start(); - - SSLEngine sslEngine = context.createSSLEngine(); - sslEngine.setUseClientMode(true); - ALPN.put(sslEngine, clientProvider); - ByteBuffer encrypted = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); - ByteBuffer decrypted = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); + SSLEngine clientSSLEngine = sslResult.client; + SSLEngine serverSSLEngine = sslResult.server; - SocketChannel client = SocketChannel.open(server.getLocalAddress()); - client.socket().setSoTimeout(readTimeout); + ByteBuffer encrypted = ByteBuffer.allocate(clientSSLEngine.getSession().getPacketBufferSize()); + ByteBuffer decrypted = ByteBuffer.allocate(clientSSLEngine.getSession().getApplicationBufferSize()); - sslEngine.beginHandshake(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); - - // Generate and write ClientHello - SSLEngineResult result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, clientSSLEngine.getHandshakeStatus()); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, serverSSLEngine.getHandshakeStatus()); - // Read Server Hello - while (sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) - { - encrypted.clear(); - client.read(encrypted); - encrypted.flip(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) - sslEngine.getDelegatedTask().run(); - } - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); + byte[] data = new byte[1024]; + new Random().nextBytes(data); - // Generate and write ClientKeyExchange + // Write the data. encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); + decrypted.clear(); + decrypted.put(data).flip(); + SSLEngineResult result = clientSSLEngine.wrap(decrypted, encrypted); Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); + + // Read the data. encrypted.flip(); - client.write(encrypted); - // Generate and write ChangeCipherSpec - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); + decrypted.clear(); + result = serverSSLEngine.unwrap(encrypted, decrypted); Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); - // Generate and write Finished + + // Write the data back. encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); + decrypted.flip(); + result = serverSSLEngine.wrap(decrypted, encrypted); Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - encrypted.flip(); - client.write(encrypted); - - if (SSLEngineResult.HandshakeStatus.NEED_WRAP == result.getHandshakeStatus()) - { - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - encrypted.flip(); - client.write(encrypted); - } - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - // Read ChangeCipherSpec - encrypted.clear(); - client.read(encrypted); + // Read the echo. encrypted.flip(); decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); + result = clientSSLEngine.unwrap(encrypted, decrypted); Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - // Read Finished - if (!encrypted.hasRemaining()) + } + + @Override + protected void performTLSRenegotiation(SSLResult sslResult, boolean client) throws Exception + { + SSLEngine clientSSLEngine = sslResult.client; + SSLEngine serverSSLEngine = sslResult.server; + + ByteBuffer encrypted = ByteBuffer.allocate(clientSSLEngine.getSession().getPacketBufferSize()); + ByteBuffer decrypted = ByteBuffer.allocate(clientSSLEngine.getSession().getApplicationBufferSize()); + + if (client) { - encrypted.clear(); - client.read(encrypted); - encrypted.flip(); + clientSSLEngine.beginHandshake(); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, clientSSLEngine.getHandshakeStatus()); } - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.FINISHED, result.getHandshakeStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, sslEngine.getHandshakeStatus()); + else + { + serverSSLEngine.beginHandshake(); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, serverSSLEngine.getHandshakeStatus()); - Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); + wrap(serverSSLEngine, decrypted, encrypted); - // Now try to write real data to see if it works - encrypted.clear(); - result = sslEngine.wrap(ByteBuffer.wrap(data.getBytes("UTF-8")), encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); + unwrap(clientSSLEngine, encrypted, decrypted); + } - // Read echo - encrypted.clear(); - client.read(encrypted); - encrypted.flip(); - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); + wrap(clientSSLEngine, decrypted, encrypted); - decrypted.flip(); - Assert.assertEquals(data, Charset.forName("UTF-8").decode(decrypted).toString()); + unwrap(serverSSLEngine, encrypted, decrypted); - // Perform a re-handshake, and verify that ALPN does not trigger - latch.set(new CountDownLatch(4)); - sslEngine.beginHandshake(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); + wrap(serverSSLEngine, decrypted, encrypted); - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); + unwrap(clientSSLEngine, encrypted, decrypted); - encrypted.clear(); - client.read(encrypted); - encrypted.flip(); - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_TASK, result.getHandshakeStatus()); - sslEngine.getDelegatedTask().run(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); + wrap(clientSSLEngine, decrypted, encrypted); - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); + unwrap(serverSSLEngine, encrypted, decrypted); + + Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, clientSSLEngine.getHandshakeStatus()); + Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, serverSSLEngine.getHandshakeStatus()); + } + + @Override + protected SSLSession getSSLSession(SSLResult sslResult, boolean client) throws Exception + { + SSLEngine sslEngine = client ? sslResult.client : sslResult.server; + return sslEngine.getSession(); + } + private void wrap(SSLEngine sslEngine, ByteBuffer decrypted, ByteBuffer encrypted) throws Exception + { encrypted.clear(); - client.read(encrypted); - encrypted.flip(); - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - if (!encrypted.hasRemaining()) + ByteBuffer tmp = ByteBuffer.allocate(encrypted.capacity()); + while (true) { encrypted.clear(); - client.read(encrypted); + SSLEngineResult result = sslEngine.wrap(decrypted, encrypted); + SSLEngineResult.Status status = result.getStatus(); + if (status != SSLEngineResult.Status.OK && status != SSLEngineResult.Status.CLOSED) + throw new AssertionError(status.toString()); encrypted.flip(); + tmp.put(encrypted); + if (result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_WRAP) + { + tmp.flip(); + encrypted.clear(); + encrypted.put(tmp).flip(); + return; + } } - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.FINISHED, result.getHandshakeStatus()); - - // Re-handshake completed, check ALPN was not invoked - Assert.assertEquals(4, latch.get().getCount()); - - // Write more data - encrypted.clear(); - result = sslEngine.wrap(ByteBuffer.wrap(data.getBytes("UTF-8")), encrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); - - // Read echo - encrypted.clear(); - client.read(encrypted); - encrypted.flip(); - decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.OK, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); - - decrypted.flip(); - Assert.assertEquals(data, Charset.forName("UTF-8").decode(decrypted).toString()); + } - // Close - ALPN.remove(sslEngine); - sslEngine.closeOutbound(); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_WRAP, sslEngine.getHandshakeStatus()); - encrypted.clear(); - result = sslEngine.wrap(decrypted, encrypted); - Assert.assertSame(SSLEngineResult.Status.CLOSED, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.getHandshakeStatus()); - encrypted.flip(); - client.write(encrypted); - client.shutdownOutput(); - encrypted.clear(); - client.read(encrypted); - encrypted.flip(); + private void unwrap(SSLEngine sslEngine, ByteBuffer encrypted, ByteBuffer decrypted) throws Exception + { decrypted.clear(); - result = sslEngine.unwrap(encrypted, decrypted); - Assert.assertSame(SSLEngineResult.Status.CLOSED, result.getStatus()); - Assert.assertSame(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); - client.close(); - - server.close(); + while (true) + { + decrypted.clear(); + SSLEngineResult result = sslEngine.unwrap(encrypted, decrypted); + SSLEngineResult.Status status = result.getStatus(); + if (status != SSLEngineResult.Status.OK && status != SSLEngineResult.Status.CLOSED) + throw new AssertionError(status.toString()); + SSLEngineResult.HandshakeStatus handshakeStatus = result.getHandshakeStatus(); + if (handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) + { + sslEngine.getDelegatedTask().run(); + handshakeStatus = sslEngine.getHandshakeStatus(); + } + if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) + return; + } } } diff --git a/alpn-tests/src/test/java/org/mortbay/jetty/alpn/SSLSocketALPNTest.java b/alpn-tests/src/test/java/org/mortbay/jetty/alpn/SSLSocketALPNTest.java index 4f2e6d6..975936c 100644 --- a/alpn-tests/src/test/java/org/mortbay/jetty/alpn/SSLSocketALPNTest.java +++ b/alpn-tests/src/test/java/org/mortbay/jetty/alpn/SSLSocketALPNTest.java @@ -18,40 +18,45 @@ package org.mortbay.jetty.alpn; -import java.io.InputStream; -import java.io.OutputStream; import java.net.InetSocketAddress; -import java.util.Arrays; -import java.util.List; +import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import javax.net.ssl.HandshakeCompletedEvent; -import javax.net.ssl.HandshakeCompletedListener; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; import org.eclipse.jetty.alpn.ALPN; +import org.junit.After; import org.junit.Assert; -import org.junit.Test; -public class SSLSocketALPNTest +public class SSLSocketALPNTest extends AbstractALPNTest { - @Test - public void testNegotiationSuccessful() throws Exception + private SSLServerSocket acceptor; + + @After + public void dispose() throws Exception { - ALPN.debug = true; + if (acceptor != null) + acceptor.close(); + } - SSLContext context = SSLSupport.newSSLContext(); + @Override + protected SSLResult performTLSHandshake(SSLResult handshake, ALPN.ClientProvider clientProvider, final ALPN.ServerProvider serverProvider) throws Exception + { + SSLContext sslContext = handshake == null ? SSLSupport.newSSLContext() : handshake.context; + + final CountDownLatch latch = new CountDownLatch(2); + final SSLResult sslResult = new SSLResult<>(); + sslResult.context = sslContext; + final int readTimeout = 5000; + if (handshake == null) + { + acceptor = (SSLServerSocket)sslContext.getServerSocketFactory().createServerSocket(); + acceptor.bind(new InetSocketAddress("localhost", 0)); + } - final int readTimeout = 50000; - final String data = "data"; - final String protocolName = "test"; - final AtomicReference latch = new AtomicReference<>(new CountDownLatch(3)); - final SSLServerSocket server = (SSLServerSocket)context.getServerSocketFactory().createServerSocket(); - server.bind(new InetSocketAddress("localhost", 0)); - final CountDownLatch handshakeLatch = new CountDownLatch(2); new Thread() { @Override @@ -59,71 +64,15 @@ public void run() { try { - SSLSocket socket = (SSLSocket)server.accept(); - socket.setUseClientMode(false); - socket.setSoTimeout(readTimeout); - ALPN.put(socket, new ALPN.ServerProvider() - { - @Override - public void unsupported() - { - } - - @Override - public String select(List protocols) - { - Assert.assertEquals(1, protocols.size()); - String protocol = protocols.get(0); - Assert.assertEquals(protocolName, protocol); - latch.get().countDown(); - return protocol; - } - }); - socket.addHandshakeCompletedListener(new HandshakeCompletedListener() - { - @Override - public void handshakeCompleted(HandshakeCompletedEvent event) - { - handshakeLatch.countDown(); - } - }); - socket.startHandshake(); - - InputStream serverInput = socket.getInputStream(); - for (int i = 0; i < data.length(); ++i) - { - int read = serverInput.read(); - Assert.assertEquals(data.charAt(i), read); - } - - OutputStream serverOutput = socket.getOutputStream(); - serverOutput.write(data.getBytes("UTF-8")); - serverOutput.flush(); - - for (int i = 0; i < data.length(); ++i) - { - int read = serverInput.read(); - Assert.assertEquals(data.charAt(i), read); - } - - serverOutput.write(data.getBytes("UTF-8")); - serverOutput.flush(); - - // Re-handshake - socket.startHandshake(); - - for (int i = 0; i < data.length(); ++i) - { - int read = serverInput.read(); - Assert.assertEquals(data.charAt(i), read); - } - - serverOutput.write(data.getBytes("UTF-8")); - serverOutput.flush(); - - Assert.assertEquals(4, latch.get().getCount()); - - socket.close(); + SSLSocket serverSSLSocket = (SSLSocket)acceptor.accept(); + System.err.println(serverSSLSocket); + sslResult.server = serverSSLSocket; + + serverSSLSocket.setUseClientMode(false); + serverSSLSocket.setSoTimeout(readTimeout); + ALPN.put(serverSSLSocket, serverProvider); + serverSSLSocket.startHandshake(); + latch.countDown(); } catch (Exception x) { @@ -132,84 +81,66 @@ public void handshakeCompleted(HandshakeCompletedEvent event) } }.start(); - SSLSocket client = (SSLSocket)context.getSocketFactory().createSocket("localhost", server.getLocalPort()); - client.setUseClientMode(true); - client.setSoTimeout(readTimeout); - ALPN.put(client, new ALPN.ClientProvider() - { - @Override - public List protocols() - { - latch.get().countDown(); - return Arrays.asList(protocolName); - } + SSLSocket clientSSLSocket = (SSLSocket)sslContext.getSocketFactory().createSocket("localhost", acceptor.getLocalPort()); + sslResult.client = clientSSLSocket; + latch.countDown(); - @Override - public void unsupported() - { - } + clientSSLSocket.setUseClientMode(true); + clientSSLSocket.setSoTimeout(readTimeout); + ALPN.put(clientSSLSocket, clientProvider); + clientSSLSocket.startHandshake(); - @Override - public void selected(String protocol) - { - Assert.assertEquals(protocolName, protocol); - latch.get().countDown(); - } - }); - - client.addHandshakeCompletedListener(new HandshakeCompletedListener() - { - @Override - public void handshakeCompleted(HandshakeCompletedEvent event) - { - handshakeLatch.countDown(); - } - }); - client.startHandshake(); - - Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); - Assert.assertTrue(handshakeLatch.await(5, TimeUnit.SECONDS)); - - // Check whether we can write real data to the connection - OutputStream clientOutput = client.getOutputStream(); - clientOutput.write(data.getBytes("UTF-8")); - clientOutput.flush(); + Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); + return sslResult; + } - InputStream clientInput = client.getInputStream(); - for (int i = 0; i < data.length(); ++i) - { - int read = clientInput.read(); - Assert.assertEquals(data.charAt(i), read); - } + @Override + protected void performTLSClose(SSLResult sslResult) throws Exception + { + sslResult.client.close(); + sslResult.server.close(); + } - // Re-handshake - latch.set(new CountDownLatch(4)); - client.startHandshake(); - Assert.assertEquals(4, latch.get().getCount()); + @Override + protected void performDataExchange(SSLResult sslResult) throws Exception + { + SSLSocket clientSSLSocket = sslResult.client; + SSLSocket serverSSLSocket = sslResult.server; - clientOutput.write(data.getBytes("UTF-8")); - clientOutput.flush(); + byte[] data = new byte[1024]; + new Random().nextBytes(data); - for (int i = 0; i < data.length(); ++i) - { - int read = clientInput.read(); - Assert.assertEquals(data.charAt(i), read); - } + // Write the data. + clientSSLSocket.getOutputStream().write(data); - clientOutput.write(data.getBytes("UTF-8")); - clientOutput.flush(); + // Read the data. + byte[] buffer = new byte[data.length]; + int read = 0; + while (read < data.length) + read += serverSSLSocket.getInputStream().read(buffer); - for (int i = 0; i < data.length(); ++i) - { - int read = clientInput.read(); - Assert.assertEquals(data.charAt(i), read); - } + // Write the data back. + serverSSLSocket.getOutputStream().write(buffer, 0, read); - int read = clientInput.read(); - Assert.assertEquals(-1, read); + // Read the echo. + read = 0; + while (read < data.length) + read += clientSSLSocket.getInputStream().read(buffer); + } - client.close(); + @Override + protected void performTLSRenegotiation(SSLResult sslResult, boolean client) throws Exception + { + if (client) + sslResult.client.startHandshake(); + else + sslResult.server.startHandshake(); + } - server.close(); + @Override + protected SSLSession getSSLSession(SSLResult sslResult, boolean client) throws Exception + { + SSLSocket sslSocket = client ? sslResult.client : sslResult.server; + return sslSocket.getSession(); } } diff --git a/pom.xml b/pom.xml index 1393b8c..67770cc 100644 --- a/pom.xml +++ b/pom.xml @@ -9,7 +9,7 @@ 4.0.0 org.mortbay.jetty.alpn alpn-project - 7.1.0.v20141016 + 7.0.0.1-dev pom Jetty :: ALPN :: Project