Skip to content

Commit

Permalink
fix ssl
Browse files Browse the repository at this point in the history
  • Loading branch information
snower committed Apr 13, 2018
1 parent 1f133bd commit 3f07de2
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tormysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def finish(future):
else:
child_gr.switch(future.result())

future = self._sock.start_tls(False, self.ctx, server_hostname=self.host)
future = self._sock.start_tls(False, self.ctx, server_hostname=self.host, connect_timeout=self.connect_timeout)
future.add_done_callback(finish)
self._rfile = self._sock = main.switch()

Expand Down
54 changes: 49 additions & 5 deletions tormysql/platform/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, address, bind_address):
self._transport = None
self._close_callback = None
self._connect_future = None
self._connect_ssl_future = None
self._read_future = None
self._read_bytes = 0
self._closed = False
Expand All @@ -47,6 +48,13 @@ def on_closed(self, exc_info = False):
self._connect_future.set_exception(StreamClosedError(None))
self._connect_future = None

if self._connect_ssl_future:
if exc_info:
self._connect_ssl_future.set_exception(exc_info[1] if isinstance(exc_info, tuple) else exc_info)
else:
self._connect_ssl_future.set_exception(StreamClosedError(None))
self._connect_ssl_future = None

if self._read_future:
if exc_info:
self._read_future.set_exception(exc_info[1] if isinstance(exc_info, tuple) else exc_info)
Expand Down Expand Up @@ -82,12 +90,12 @@ def connect(self, address, connect_timeout = 0, server_hostname = None):
self._loop = current_ioloop()
future = self._connect_future = Future(loop=self._loop)
if connect_timeout:
def timeout():
def on_timeout():
self._loop_connect_timeout = None
if self._connect_future:
self.close((None, IOError("Connect timeout"), None))

self._loop_connect_timeout = self._loop.call_later(connect_timeout, connect_timeout)
self._loop_connect_timeout = self._loop.call_later(connect_timeout, on_timeout)

def connected(connect_future):
if self._loop_connect_timeout:
Expand All @@ -96,18 +104,18 @@ def connected(connect_future):

if connect_future._exception is not None:
self.on_closed(connect_future.exception())
self._connect_future = None
else:
self._connect_future = None
future.set_result(connect_future.result())
self._connect_future = None

connect_future = ensure_future(self._connect(address, server_hostname))
connect_future.add_done_callback(connected)
return self._connect_future

def connection_made(self, transport):
self._transport = transport
if self._connect_future is None:
if self._connect_future is None and self._connect_ssl_future is None:
transport.close()
else:
self._transport.set_write_buffer_limits(1024 * 1024 * 1024)
Expand Down Expand Up @@ -151,4 +159,40 @@ def write(self, data):
if self._closed:
raise StreamClosedError(IOError('Already Closed'))

self._transport.write(data)
self._transport.write(data)

def start_tls(self, server_side, ssl_options=None, server_hostname=None, connect_timeout=None):
if not self._transport or self._read_future:
raise ValueError("IOStream is not idle; cannot convert to SSL")

self._connect_ssl_future = connect_ssl_future = Future(loop=self._loop)
waiter = Future(loop=self._loop)

def on_connected(future):
if self._loop_connect_timeout:
self._loop_connect_timeout.cancel()
self._loop_connect_timeout = None

if connect_ssl_future._exception is not None:
self.on_closed(future.exception())
self._connect_ssl_future = None
else:
self._connect_ssl_future = None
connect_ssl_future.set_result(self)
waiter.add_done_callback(on_connected)

if connect_timeout:
def on_timeout():
self._loop_connect_timeout = None
if not waiter.done():
self.close((None, IOError("Connect timeout"), None))

self._loop_connect_timeout = self._loop.call_later(connect_timeout, on_timeout)

self._transport.pause_reading()
sock, self._transport._sock = self._transport._sock, None
self._transport = self._loop._make_ssl_transport(
sock, self, ssl_options, waiter,
server_side=False, server_hostname=server_hostname)

return connect_ssl_future
171 changes: 167 additions & 4 deletions tormysql/platform/tornado.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import socket
import errno
from tornado.iostream import IOStream as BaseIOStream, StreamClosedError, _ERRNO_WOULDBLOCK
from tornado.iostream import IOStream as BaseIOStream, SSLIOStream as BaseSSLIOStream, StreamClosedError, _ERRNO_WOULDBLOCK, ssl, ssl_wrap_socket, _client_ssl_defaults
from tornado.concurrent import Future
from tornado.gen import coroutine
from tornado.ioloop import IOLoop
Expand All @@ -24,8 +24,9 @@ def current_ioloop():


class IOStream(BaseIOStream):
def __init__(self, address, bind_address, *args, **kwargs):
socket = self.init_socket(address, bind_address)
def __init__(self, address, bind_address, socket = None, *args, **kwargs):
if socket is None:
socket = self.init_socket(address, bind_address)

super(IOStream, self).__init__(socket, *args, **kwargs)

Expand Down Expand Up @@ -192,4 +193,166 @@ def write(self, data):
if self._write_buffer_size:
if not self._state & self.io_loop.WRITE:
self._state = self._state | self.io_loop.WRITE
self.io_loop.update_handler(self.fileno(), self._state)
self.io_loop.update_handler(self.fileno(), self._state)

def start_tls(self, server_side, ssl_options=None, server_hostname=None, connect_timeout = None):
if (self._read_callback or self._read_future or
self._write_callback or self._write_futures or
self._connect_callback or self._connect_future or
self._pending_callbacks or self._closed or
self._read_buffer or self._write_buffer):
raise ValueError("IOStream is not idle; cannot convert to SSL")

if ssl_options is None:
ssl_options = _client_ssl_defaults

socket = self.socket
self.io_loop.remove_handler(socket)
self.socket = None
socket = ssl_wrap_socket(socket, ssl_options,
server_hostname=server_hostname,
server_side=server_side,
do_handshake_on_connect=False)
orig_close_callback = self._close_callback
self._close_callback = None

future = Future()
ssl_stream = SSLIOStream(socket, ssl_options=ssl_options)

# Wrap the original close callback so we can fail our Future as well.
# If we had an "unwrap" counterpart to this method we would need
# to restore the original callback after our Future resolves
# so that repeated wrap/unwrap calls don't build up layers.

def close_callback():
if not future.done():
# Note that unlike most Futures returned by IOStream,
# this one passes the underlying error through directly
# instead of wrapping everything in a StreamClosedError
# with a real_error attribute. This is because once the
# connection is established it's more helpful to raise
# the SSLError directly than to hide it behind a
# StreamClosedError (and the client is expecting SSL
# issues rather than network issues since this method is
# named start_tls).
future.set_exception(ssl_stream.error or StreamClosedError())
if orig_close_callback is not None:
orig_close_callback()

if connect_timeout:
def timeout():
ssl_stream._loop_connect_timeout = None
if not future.done():
ssl_stream.close((None, IOError("Connect timeout"), None))

ssl_stream._loop_connect_timeout = self.io_loop.call_later(connect_timeout, timeout)

ssl_stream.set_close_callback(close_callback)
ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream)
ssl_stream.max_buffer_size = self.max_buffer_size
ssl_stream.read_chunk_size = self.read_chunk_size
return future

class SSLIOStream(IOStream, BaseSSLIOStream):
def __init__(self, socket, *args, **kwargs):
self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults)
IOStream.__init__(self, None, None, socket, *args, **kwargs)

self._ssl_accepting = True
self._handshake_reading = False
self._handshake_writing = False
self._ssl_connect_callback = None
self._loop_connect_timeout = None
self._server_hostname = None

# If the socket is already connected, attempt to start the handshake.
try:
self.socket.getpeername()
except socket.error:
pass
else:
# Indirectly start the handshake, which will run on the next
# IOLoop iteration and then the real IO state will be set in
# _handle_events.
self._add_io_state(self.io_loop.WRITE)

def _handle_read(self):
if self._ssl_accepting:
self._do_ssl_handshake()
return

chunk = True

while True:
try:
chunk = self.socket.recv(self.read_chunk_size)
if not chunk:
break
if self._read_buffer_size:
self._read_buffer += chunk
else:
self._read_buffer = bytearray(chunk)
self._read_buffer_size += len(chunk)
except ssl.SSLError as e:
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
break

self.close(exc_info=True)
return
except (socket.error, IOError, OSError) as e:
en = e.errno if hasattr(e, 'errno') else e.args[0]
if en in _ERRNO_WOULDBLOCK:
break

if en == errno.EINTR:
continue

self.close(exc_info=True)
return

if self._read_future is not None and self._read_buffer_size >= self._read_bytes:
future, self._read_future = self._read_future, None
self._read_buffer, data = bytearray(), self._read_buffer
self._read_buffer_size = 0
self._read_bytes = 0
future.set_result(data)

if not chunk:
self.close()
return

def _handle_write(self):
if self._ssl_accepting:
self._do_ssl_handshake()
return

try:
num_bytes = self.socket.send(memoryview(self._write_buffer)[
self._write_buffer_pos: self._write_buffer_pos + self._write_buffer_size])
self._write_buffer_pos += num_bytes
self._write_buffer_size -= num_bytes
except ssl.SSLError as e:
if e.args[0] != ssl.SSL_ERROR_WANT_WRITE:
self.close(exc_info=True)
return
except (socket.error, IOError, OSError) as e:
en = e.errno if hasattr(e, 'errno') else e.args[0]
if en not in _ERRNO_WOULDBLOCK:
self.close(exc_info=True)
return

if not self._write_buffer_size:
if self._write_buffer_pos > 0:
self._write_buffer = bytearray()
self._write_buffer_pos = 0

if self._state & self.io_loop.WRITE:
self._state = self._state & ~self.io_loop.WRITE
self.io_loop.update_handler(self.fileno(), self._state)

def _run_ssl_connect_callback(self):
if self._state & self.io_loop.WRITE:
self._state = self._state & ~self.io_loop.WRITE
self.io_loop.update_handler(self.fileno(), self._state)

BaseSSLIOStream._run_ssl_connect_callback(self)

0 comments on commit 3f07de2

Please sign in to comment.