diff --git a/src/main/java/org/java_websocket/drafts/Draft_6455.java b/src/main/java/org/java_websocket/drafts/Draft_6455.java index 68feb615..eb487997 100644 --- a/src/main/java/org/java_websocket/drafts/Draft_6455.java +++ b/src/main/java/org/java_websocket/drafts/Draft_6455.java @@ -115,13 +115,23 @@ public class Draft_6455 extends Draft { /** * Attribute for the used extension in this draft */ - private IExtension extension = new DefaultExtension(); + private IExtension negotiatedExtension = new DefaultExtension(); + + /** + * Attribute for the default extension + */ + private IExtension defaultExtension = new DefaultExtension(); /** * Attribute for all available extension in this draft */ private List knownExtensions; + /** + * Current active extension used to decode messages + */ + private IExtension currentDecodingExtension; + /** * Attribute for the used protocol in this draft */ @@ -241,10 +251,11 @@ public Draft_6455(List inputExtensions, List inputProtoco knownExtensions.addAll(inputExtensions); //We always add the DefaultExtension to implement the normal RFC 6455 specification if (!hasDefault) { - knownExtensions.add(this.knownExtensions.size(), extension); + knownExtensions.add(this.knownExtensions.size(), negotiatedExtension); } knownProtocols.addAll(inputProtocols); maxFrameSize = inputMaxFrameSize; + currentDecodingExtension = null; } @Override @@ -259,9 +270,9 @@ public HandshakeState acceptHandshakeAsServer(ClientHandshake handshakedata) String requestedExtension = handshakedata.getFieldValue(SEC_WEB_SOCKET_EXTENSIONS); for (IExtension knownExtension : knownExtensions) { if (knownExtension.acceptProvidedExtensionAsServer(requestedExtension)) { - extension = knownExtension; + negotiatedExtension = knownExtension; extensionState = HandshakeState.MATCHED; - log.trace("acceptHandshakeAsServer - Matching extension found: {}", extension); + log.trace("acceptHandshakeAsServer - Matching extension found: {}", negotiatedExtension); break; } } @@ -316,9 +327,9 @@ public HandshakeState acceptHandshakeAsClient(ClientHandshake request, ServerHan String requestedExtension = response.getFieldValue(SEC_WEB_SOCKET_EXTENSIONS); for (IExtension knownExtension : knownExtensions) { if (knownExtension.acceptProvidedExtensionAsClient(requestedExtension)) { - extension = knownExtension; + negotiatedExtension = knownExtension; extensionState = HandshakeState.MATCHED; - log.trace("acceptHandshakeAsClient - Matching extension found: {}", extension); + log.trace("acceptHandshakeAsClient - Matching extension found: {}", negotiatedExtension); break; } } @@ -337,7 +348,7 @@ public HandshakeState acceptHandshakeAsClient(ClientHandshake request, ServerHan * @return the extension which is used or null, if handshake is not yet done */ public IExtension getExtension() { - return extension; + return negotiatedExtension; } /** @@ -562,8 +573,20 @@ private Framedata translateSingleFrame(ByteBuffer buffer) frame.setRSV3(rsv3); payload.flip(); frame.setPayload(payload); - getExtension().isFrameValid(frame); - getExtension().decodeFrame(frame); + if (frame.getOpcode() != Opcode.CONTINUOUS) { + // Prioritize the negotiated extension + if (frame.isRSV1() || frame.isRSV2() || frame.isRSV3()) { + currentDecodingExtension = getExtension(); + } else { + // No encoded message, so we can use the default one + currentDecodingExtension = defaultExtension; + } + } + if (currentDecodingExtension == null) { + currentDecodingExtension = defaultExtension; + } + currentDecodingExtension.isFrameValid(frame); + currentDecodingExtension.decodeFrame(frame); if (log.isTraceEnabled()) { log.trace("afterDecoding({}): {}", frame.getPayloadData().remaining(), (frame.getPayloadData().remaining() > 1000 ? "too big to display" @@ -780,10 +803,10 @@ public List createFrames(String text, boolean mask) { @Override public void reset() { incompleteframe = null; - if (extension != null) { - extension.reset(); + if (negotiatedExtension != null) { + negotiatedExtension.reset(); } - extension = new DefaultExtension(); + negotiatedExtension = new DefaultExtension(); protocol = null; } @@ -1116,7 +1139,7 @@ public boolean equals(Object o) { if (maxFrameSize != that.getMaxFrameSize()) { return false; } - if (extension != null ? !extension.equals(that.getExtension()) : that.getExtension() != null) { + if (negotiatedExtension != null ? !negotiatedExtension.equals(that.getExtension()) : that.getExtension() != null) { return false; } return protocol != null ? protocol.equals(that.getProtocol()) : that.getProtocol() == null; @@ -1124,7 +1147,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - int result = extension != null ? extension.hashCode() : 0; + int result = negotiatedExtension != null ? negotiatedExtension.hashCode() : 0; result = 31 * result + (protocol != null ? protocol.hashCode() : 0); result = 31 * result + (maxFrameSize ^ (maxFrameSize >>> 32)); return result; diff --git a/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java b/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java index 824ceb68..f24f9b6c 100644 --- a/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java +++ b/src/main/java/org/java_websocket/extensions/permessage_deflate/PerMessageDeflateExtension.java @@ -142,15 +142,15 @@ public void decodeFrame(Framedata inputFrame) throws InvalidDataException { return; } + if (!inputFrame.isRSV1() && inputFrame.getOpcode() != Opcode.CONTINUOUS) { + return; + } + // RSV1 bit must be set only for the first frame. if (inputFrame.getOpcode() == Opcode.CONTINUOUS && inputFrame.isRSV1()) { throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, "RSV1 bit can only be set for the first frame."); } - // If rsv1 is not set, we dont have a compressed message - if (!inputFrame.isRSV1()) { - return; - } // Decompressed output buffer. ByteArrayOutputStream output = new ByteArrayOutputStream(); @@ -181,11 +181,6 @@ We can check the getRemaining() method to see whether the data we supplied has b throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, e.getMessage()); } - // RSV1 bit must be cleared after decoding, so that other extensions don't throw an exception. - if (inputFrame.isRSV1()) { - ((DataFrame) inputFrame).setRSV1(false); - } - // Set frames payload to the new decompressed data. ((FramedataImpl1) inputFrame) .setPayload(ByteBuffer.wrap(output.toByteArray(), 0, output.size()));