From 3f07de27915ac5eb00c520acb5d145ba1180dad5 Mon Sep 17 00:00:00 2001 From: snower Date: Fri, 13 Apr 2018 14:09:06 +0800 Subject: [PATCH] fix ssl --- tormysql/connections.py | 2 +- tormysql/platform/asyncio.py | 54 ++++++++++- tormysql/platform/tornado.py | 171 ++++++++++++++++++++++++++++++++++- 3 files changed, 217 insertions(+), 10 deletions(-) diff --git a/tormysql/connections.py b/tormysql/connections.py index b767de4..3ab4602 100644 --- a/tormysql/connections.py +++ b/tormysql/connections.py @@ -221,7 +221,7 @@ def finish(future): else: child_gr.switch(future.result()) - future = self._sock.start_tls(False, self.ctx, server_hostname=self.host) + future = self._sock.start_tls(False, self.ctx, server_hostname=self.host, connect_timeout=self.connect_timeout) future.add_done_callback(finish) self._rfile = self._sock = main.switch() diff --git a/tormysql/platform/asyncio.py b/tormysql/platform/asyncio.py index 1409a55..83b2a5a 100644 --- a/tormysql/platform/asyncio.py +++ b/tormysql/platform/asyncio.py @@ -26,6 +26,7 @@ def __init__(self, address, bind_address): self._transport = None self._close_callback = None self._connect_future = None + self._connect_ssl_future = None self._read_future = None self._read_bytes = 0 self._closed = False @@ -47,6 +48,13 @@ def on_closed(self, exc_info = False): self._connect_future.set_exception(StreamClosedError(None)) self._connect_future = None + if self._connect_ssl_future: + if exc_info: + self._connect_ssl_future.set_exception(exc_info[1] if isinstance(exc_info, tuple) else exc_info) + else: + self._connect_ssl_future.set_exception(StreamClosedError(None)) + self._connect_ssl_future = None + if self._read_future: if exc_info: self._read_future.set_exception(exc_info[1] if isinstance(exc_info, tuple) else exc_info) @@ -82,12 +90,12 @@ def connect(self, address, connect_timeout = 0, server_hostname = None): self._loop = current_ioloop() future = self._connect_future = Future(loop=self._loop) if connect_timeout: - def timeout(): + def on_timeout(): self._loop_connect_timeout = None if self._connect_future: self.close((None, IOError("Connect timeout"), None)) - self._loop_connect_timeout = self._loop.call_later(connect_timeout, connect_timeout) + self._loop_connect_timeout = self._loop.call_later(connect_timeout, on_timeout) def connected(connect_future): if self._loop_connect_timeout: @@ -96,10 +104,10 @@ def connected(connect_future): if connect_future._exception is not None: self.on_closed(connect_future.exception()) + self._connect_future = None else: self._connect_future = None future.set_result(connect_future.result()) - self._connect_future = None connect_future = ensure_future(self._connect(address, server_hostname)) connect_future.add_done_callback(connected) @@ -107,7 +115,7 @@ def connected(connect_future): def connection_made(self, transport): self._transport = transport - if self._connect_future is None: + if self._connect_future is None and self._connect_ssl_future is None: transport.close() else: self._transport.set_write_buffer_limits(1024 * 1024 * 1024) @@ -151,4 +159,40 @@ def write(self, data): if self._closed: raise StreamClosedError(IOError('Already Closed')) - self._transport.write(data) \ No newline at end of file + self._transport.write(data) + + def start_tls(self, server_side, ssl_options=None, server_hostname=None, connect_timeout=None): + if not self._transport or self._read_future: + raise ValueError("IOStream is not idle; cannot convert to SSL") + + self._connect_ssl_future = connect_ssl_future = Future(loop=self._loop) + waiter = Future(loop=self._loop) + + def on_connected(future): + if self._loop_connect_timeout: + self._loop_connect_timeout.cancel() + self._loop_connect_timeout = None + + if connect_ssl_future._exception is not None: + self.on_closed(future.exception()) + self._connect_ssl_future = None + else: + self._connect_ssl_future = None + connect_ssl_future.set_result(self) + waiter.add_done_callback(on_connected) + + if connect_timeout: + def on_timeout(): + self._loop_connect_timeout = None + if not waiter.done(): + self.close((None, IOError("Connect timeout"), None)) + + self._loop_connect_timeout = self._loop.call_later(connect_timeout, on_timeout) + + self._transport.pause_reading() + sock, self._transport._sock = self._transport._sock, None + self._transport = self._loop._make_ssl_transport( + sock, self, ssl_options, waiter, + server_side=False, server_hostname=server_hostname) + + return connect_ssl_future \ No newline at end of file diff --git a/tormysql/platform/tornado.py b/tormysql/platform/tornado.py index 2e9d383..178b3ee 100644 --- a/tormysql/platform/tornado.py +++ b/tormysql/platform/tornado.py @@ -7,7 +7,7 @@ import sys import socket import errno -from tornado.iostream import IOStream as BaseIOStream, StreamClosedError, _ERRNO_WOULDBLOCK +from tornado.iostream import IOStream as BaseIOStream, SSLIOStream as BaseSSLIOStream, StreamClosedError, _ERRNO_WOULDBLOCK, ssl, ssl_wrap_socket, _client_ssl_defaults from tornado.concurrent import Future from tornado.gen import coroutine from tornado.ioloop import IOLoop @@ -24,8 +24,9 @@ def current_ioloop(): class IOStream(BaseIOStream): - def __init__(self, address, bind_address, *args, **kwargs): - socket = self.init_socket(address, bind_address) + def __init__(self, address, bind_address, socket = None, *args, **kwargs): + if socket is None: + socket = self.init_socket(address, bind_address) super(IOStream, self).__init__(socket, *args, **kwargs) @@ -192,4 +193,166 @@ def write(self, data): if self._write_buffer_size: if not self._state & self.io_loop.WRITE: self._state = self._state | self.io_loop.WRITE - self.io_loop.update_handler(self.fileno(), self._state) \ No newline at end of file + self.io_loop.update_handler(self.fileno(), self._state) + + def start_tls(self, server_side, ssl_options=None, server_hostname=None, connect_timeout = None): + if (self._read_callback or self._read_future or + self._write_callback or self._write_futures or + self._connect_callback or self._connect_future or + self._pending_callbacks or self._closed or + self._read_buffer or self._write_buffer): + raise ValueError("IOStream is not idle; cannot convert to SSL") + + if ssl_options is None: + ssl_options = _client_ssl_defaults + + socket = self.socket + self.io_loop.remove_handler(socket) + self.socket = None + socket = ssl_wrap_socket(socket, ssl_options, + server_hostname=server_hostname, + server_side=server_side, + do_handshake_on_connect=False) + orig_close_callback = self._close_callback + self._close_callback = None + + future = Future() + ssl_stream = SSLIOStream(socket, ssl_options=ssl_options) + + # Wrap the original close callback so we can fail our Future as well. + # If we had an "unwrap" counterpart to this method we would need + # to restore the original callback after our Future resolves + # so that repeated wrap/unwrap calls don't build up layers. + + def close_callback(): + if not future.done(): + # Note that unlike most Futures returned by IOStream, + # this one passes the underlying error through directly + # instead of wrapping everything in a StreamClosedError + # with a real_error attribute. This is because once the + # connection is established it's more helpful to raise + # the SSLError directly than to hide it behind a + # StreamClosedError (and the client is expecting SSL + # issues rather than network issues since this method is + # named start_tls). + future.set_exception(ssl_stream.error or StreamClosedError()) + if orig_close_callback is not None: + orig_close_callback() + + if connect_timeout: + def timeout(): + ssl_stream._loop_connect_timeout = None + if not future.done(): + ssl_stream.close((None, IOError("Connect timeout"), None)) + + ssl_stream._loop_connect_timeout = self.io_loop.call_later(connect_timeout, timeout) + + ssl_stream.set_close_callback(close_callback) + ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream) + ssl_stream.max_buffer_size = self.max_buffer_size + ssl_stream.read_chunk_size = self.read_chunk_size + return future + +class SSLIOStream(IOStream, BaseSSLIOStream): + def __init__(self, socket, *args, **kwargs): + self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults) + IOStream.__init__(self, None, None, socket, *args, **kwargs) + + self._ssl_accepting = True + self._handshake_reading = False + self._handshake_writing = False + self._ssl_connect_callback = None + self._loop_connect_timeout = None + self._server_hostname = None + + # If the socket is already connected, attempt to start the handshake. + try: + self.socket.getpeername() + except socket.error: + pass + else: + # Indirectly start the handshake, which will run on the next + # IOLoop iteration and then the real IO state will be set in + # _handle_events. + self._add_io_state(self.io_loop.WRITE) + + def _handle_read(self): + if self._ssl_accepting: + self._do_ssl_handshake() + return + + chunk = True + + while True: + try: + chunk = self.socket.recv(self.read_chunk_size) + if not chunk: + break + if self._read_buffer_size: + self._read_buffer += chunk + else: + self._read_buffer = bytearray(chunk) + self._read_buffer_size += len(chunk) + except ssl.SSLError as e: + if e.args[0] == ssl.SSL_ERROR_WANT_READ: + break + + self.close(exc_info=True) + return + except (socket.error, IOError, OSError) as e: + en = e.errno if hasattr(e, 'errno') else e.args[0] + if en in _ERRNO_WOULDBLOCK: + break + + if en == errno.EINTR: + continue + + self.close(exc_info=True) + return + + if self._read_future is not None and self._read_buffer_size >= self._read_bytes: + future, self._read_future = self._read_future, None + self._read_buffer, data = bytearray(), self._read_buffer + self._read_buffer_size = 0 + self._read_bytes = 0 + future.set_result(data) + + if not chunk: + self.close() + return + + def _handle_write(self): + if self._ssl_accepting: + self._do_ssl_handshake() + return + + try: + num_bytes = self.socket.send(memoryview(self._write_buffer)[ + self._write_buffer_pos: self._write_buffer_pos + self._write_buffer_size]) + self._write_buffer_pos += num_bytes + self._write_buffer_size -= num_bytes + except ssl.SSLError as e: + if e.args[0] != ssl.SSL_ERROR_WANT_WRITE: + self.close(exc_info=True) + return + except (socket.error, IOError, OSError) as e: + en = e.errno if hasattr(e, 'errno') else e.args[0] + if en not in _ERRNO_WOULDBLOCK: + self.close(exc_info=True) + return + + if not self._write_buffer_size: + if self._write_buffer_pos > 0: + self._write_buffer = bytearray() + self._write_buffer_pos = 0 + + if self._state & self.io_loop.WRITE: + self._state = self._state & ~self.io_loop.WRITE + self.io_loop.update_handler(self.fileno(), self._state) + + def _run_ssl_connect_callback(self): + if self._state & self.io_loop.WRITE: + self._state = self._state & ~self.io_loop.WRITE + self.io_loop.update_handler(self.fileno(), self._state) + + BaseSSLIOStream._run_ssl_connect_callback(self) \ No newline at end of file