885 lines
29 KiB
Python
885 lines
29 KiB
Python
|
import asyncio
|
||
|
import contextlib
|
||
|
import functools
|
||
|
import socket
|
||
|
import traceback
|
||
|
import typing
|
||
|
import unittest
|
||
|
|
||
|
from tornado.concurrent import Future
|
||
|
from tornado import gen
|
||
|
from tornado.httpclient import HTTPError, HTTPRequest
|
||
|
from tornado.locks import Event
|
||
|
from tornado.log import gen_log, app_log
|
||
|
from tornado.netutil import Resolver
|
||
|
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||
|
from tornado.template import DictLoader
|
||
|
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
|
||
|
from tornado.web import Application, RequestHandler
|
||
|
|
||
|
try:
|
||
|
import tornado.websocket # noqa: F401
|
||
|
from tornado.util import _websocket_mask_python
|
||
|
except ImportError:
|
||
|
# The unittest module presents misleading errors on ImportError
|
||
|
# (it acts as if websocket_test could not be found, hiding the underlying
|
||
|
# error). If we get an ImportError here (which could happen due to
|
||
|
# TORNADO_EXTENSION=1), print some extra information before failing.
|
||
|
traceback.print_exc()
|
||
|
raise
|
||
|
|
||
|
from tornado.websocket import (
|
||
|
WebSocketHandler,
|
||
|
websocket_connect,
|
||
|
WebSocketError,
|
||
|
WebSocketClosedError,
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
from tornado import speedups
|
||
|
except ImportError:
|
||
|
speedups = None # type: ignore
|
||
|
|
||
|
|
||
|
class TestWebSocketHandler(WebSocketHandler):
|
||
|
"""Base class for testing handlers that exposes the on_close event.
|
||
|
|
||
|
This allows for tests to see the close code and reason on the
|
||
|
server side.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def initialize(self, close_future=None, compression_options=None):
|
||
|
self.close_future = close_future
|
||
|
self.compression_options = compression_options
|
||
|
|
||
|
def get_compression_options(self):
|
||
|
return self.compression_options
|
||
|
|
||
|
def on_close(self):
|
||
|
if self.close_future is not None:
|
||
|
self.close_future.set_result((self.close_code, self.close_reason))
|
||
|
|
||
|
|
||
|
class EchoHandler(TestWebSocketHandler):
|
||
|
@gen.coroutine
|
||
|
def on_message(self, message):
|
||
|
try:
|
||
|
yield self.write_message(message, isinstance(message, bytes))
|
||
|
except asyncio.CancelledError:
|
||
|
pass
|
||
|
except WebSocketClosedError:
|
||
|
pass
|
||
|
|
||
|
|
||
|
class ErrorInOnMessageHandler(TestWebSocketHandler):
|
||
|
def on_message(self, message):
|
||
|
1 / 0
|
||
|
|
||
|
|
||
|
class HeaderHandler(TestWebSocketHandler):
|
||
|
def open(self):
|
||
|
methods_to_test = [
|
||
|
functools.partial(self.write, "This should not work"),
|
||
|
functools.partial(self.redirect, "http://localhost/elsewhere"),
|
||
|
functools.partial(self.set_header, "X-Test", ""),
|
||
|
functools.partial(self.set_cookie, "Chocolate", "Chip"),
|
||
|
functools.partial(self.set_status, 503),
|
||
|
self.flush,
|
||
|
self.finish,
|
||
|
]
|
||
|
for method in methods_to_test:
|
||
|
try:
|
||
|
# In a websocket context, many RequestHandler methods
|
||
|
# raise RuntimeErrors.
|
||
|
method() # type: ignore
|
||
|
raise Exception("did not get expected exception")
|
||
|
except RuntimeError:
|
||
|
pass
|
||
|
self.write_message(self.request.headers.get("X-Test", ""))
|
||
|
|
||
|
|
||
|
class HeaderEchoHandler(TestWebSocketHandler):
|
||
|
def set_default_headers(self):
|
||
|
self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
|
||
|
|
||
|
def prepare(self):
|
||
|
for k, v in self.request.headers.get_all():
|
||
|
if k.lower().startswith("x-test"):
|
||
|
self.set_header(k, v)
|
||
|
|
||
|
|
||
|
class NonWebSocketHandler(RequestHandler):
|
||
|
def get(self):
|
||
|
self.write("ok")
|
||
|
|
||
|
|
||
|
class RedirectHandler(RequestHandler):
|
||
|
def get(self):
|
||
|
self.redirect("/echo")
|
||
|
|
||
|
|
||
|
class CloseReasonHandler(TestWebSocketHandler):
|
||
|
def open(self):
|
||
|
self.on_close_called = False
|
||
|
self.close(1001, "goodbye")
|
||
|
|
||
|
|
||
|
class AsyncPrepareHandler(TestWebSocketHandler):
|
||
|
@gen.coroutine
|
||
|
def prepare(self):
|
||
|
yield gen.moment
|
||
|
|
||
|
def on_message(self, message):
|
||
|
self.write_message(message)
|
||
|
|
||
|
|
||
|
class PathArgsHandler(TestWebSocketHandler):
|
||
|
def open(self, arg):
|
||
|
self.write_message(arg)
|
||
|
|
||
|
|
||
|
class CoroutineOnMessageHandler(TestWebSocketHandler):
|
||
|
def initialize(self, **kwargs):
|
||
|
super().initialize(**kwargs)
|
||
|
self.sleeping = 0
|
||
|
|
||
|
@gen.coroutine
|
||
|
def on_message(self, message):
|
||
|
if self.sleeping > 0:
|
||
|
self.write_message("another coroutine is already sleeping")
|
||
|
self.sleeping += 1
|
||
|
yield gen.sleep(0.01)
|
||
|
self.sleeping -= 1
|
||
|
self.write_message(message)
|
||
|
|
||
|
|
||
|
class RenderMessageHandler(TestWebSocketHandler):
|
||
|
def on_message(self, message):
|
||
|
self.write_message(self.render_string("message.html", message=message))
|
||
|
|
||
|
|
||
|
class SubprotocolHandler(TestWebSocketHandler):
|
||
|
def initialize(self, **kwargs):
|
||
|
super().initialize(**kwargs)
|
||
|
self.select_subprotocol_called = False
|
||
|
|
||
|
def select_subprotocol(self, subprotocols):
|
||
|
if self.select_subprotocol_called:
|
||
|
raise Exception("select_subprotocol called twice")
|
||
|
self.select_subprotocol_called = True
|
||
|
if "goodproto" in subprotocols:
|
||
|
return "goodproto"
|
||
|
return None
|
||
|
|
||
|
def open(self):
|
||
|
if not self.select_subprotocol_called:
|
||
|
raise Exception("select_subprotocol not called")
|
||
|
self.write_message("subprotocol=%s" % self.selected_subprotocol)
|
||
|
|
||
|
|
||
|
class OpenCoroutineHandler(TestWebSocketHandler):
|
||
|
def initialize(self, test, **kwargs):
|
||
|
super().initialize(**kwargs)
|
||
|
self.test = test
|
||
|
self.open_finished = False
|
||
|
|
||
|
@gen.coroutine
|
||
|
def open(self):
|
||
|
yield self.test.message_sent.wait()
|
||
|
yield gen.sleep(0.010)
|
||
|
self.open_finished = True
|
||
|
|
||
|
def on_message(self, message):
|
||
|
if not self.open_finished:
|
||
|
raise Exception("on_message called before open finished")
|
||
|
self.write_message("ok")
|
||
|
|
||
|
|
||
|
class ErrorInOpenHandler(TestWebSocketHandler):
|
||
|
def open(self):
|
||
|
raise Exception("boom")
|
||
|
|
||
|
|
||
|
class ErrorInAsyncOpenHandler(TestWebSocketHandler):
|
||
|
async def open(self):
|
||
|
await asyncio.sleep(0)
|
||
|
raise Exception("boom")
|
||
|
|
||
|
|
||
|
class NoDelayHandler(TestWebSocketHandler):
|
||
|
def open(self):
|
||
|
self.set_nodelay(True)
|
||
|
self.write_message("hello")
|
||
|
|
||
|
|
||
|
class WebSocketBaseTestCase(AsyncHTTPTestCase):
|
||
|
def setUp(self):
|
||
|
super().setUp()
|
||
|
self.conns_to_close = []
|
||
|
|
||
|
def tearDown(self):
|
||
|
for conn in self.conns_to_close:
|
||
|
conn.close()
|
||
|
super().tearDown()
|
||
|
|
||
|
@gen.coroutine
|
||
|
def ws_connect(self, path, **kwargs):
|
||
|
ws = yield websocket_connect(
|
||
|
"ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
|
||
|
)
|
||
|
self.conns_to_close.append(ws)
|
||
|
raise gen.Return(ws)
|
||
|
|
||
|
|
||
|
class WebSocketTest(WebSocketBaseTestCase):
|
||
|
def get_app(self):
|
||
|
self.close_future = Future() # type: Future[None]
|
||
|
return Application(
|
||
|
[
|
||
|
("/echo", EchoHandler, dict(close_future=self.close_future)),
|
||
|
("/non_ws", NonWebSocketHandler),
|
||
|
("/redirect", RedirectHandler),
|
||
|
("/header", HeaderHandler, dict(close_future=self.close_future)),
|
||
|
(
|
||
|
"/header_echo",
|
||
|
HeaderEchoHandler,
|
||
|
dict(close_future=self.close_future),
|
||
|
),
|
||
|
(
|
||
|
"/close_reason",
|
||
|
CloseReasonHandler,
|
||
|
dict(close_future=self.close_future),
|
||
|
),
|
||
|
(
|
||
|
"/error_in_on_message",
|
||
|
ErrorInOnMessageHandler,
|
||
|
dict(close_future=self.close_future),
|
||
|
),
|
||
|
(
|
||
|
"/async_prepare",
|
||
|
AsyncPrepareHandler,
|
||
|
dict(close_future=self.close_future),
|
||
|
),
|
||
|
(
|
||
|
"/path_args/(.*)",
|
||
|
PathArgsHandler,
|
||
|
dict(close_future=self.close_future),
|
||
|
),
|
||
|
(
|
||
|
"/coroutine",
|
||
|
CoroutineOnMessageHandler,
|
||
|
dict(close_future=self.close_future),
|
||
|
),
|
||
|
("/render", RenderMessageHandler, dict(close_future=self.close_future)),
|
||
|
(
|
||
|
"/subprotocol",
|
||
|
SubprotocolHandler,
|
||
|
dict(close_future=self.close_future),
|
||
|
),
|
||
|
(
|
||
|
"/open_coroutine",
|
||
|
OpenCoroutineHandler,
|
||
|
dict(close_future=self.close_future, test=self),
|
||
|
),
|
||
|
("/error_in_open", ErrorInOpenHandler),
|
||
|
("/error_in_async_open", ErrorInAsyncOpenHandler),
|
||
|
("/nodelay", NoDelayHandler),
|
||
|
],
|
||
|
template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
|
||
|
)
|
||
|
|
||
|
def get_http_client(self):
|
||
|
# These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
|
||
|
return SimpleAsyncHTTPClient()
|
||
|
|
||
|
def tearDown(self):
|
||
|
super().tearDown()
|
||
|
RequestHandler._template_loaders.clear()
|
||
|
|
||
|
def test_http_request(self):
|
||
|
# WS server, HTTP client.
|
||
|
response = self.fetch("/echo")
|
||
|
self.assertEqual(response.code, 400)
|
||
|
|
||
|
def test_missing_websocket_key(self):
|
||
|
response = self.fetch(
|
||
|
"/echo",
|
||
|
headers={
|
||
|
"Connection": "Upgrade",
|
||
|
"Upgrade": "WebSocket",
|
||
|
"Sec-WebSocket-Version": "13",
|
||
|
},
|
||
|
)
|
||
|
self.assertEqual(response.code, 400)
|
||
|
|
||
|
def test_bad_websocket_version(self):
|
||
|
response = self.fetch(
|
||
|
"/echo",
|
||
|
headers={
|
||
|
"Connection": "Upgrade",
|
||
|
"Upgrade": "WebSocket",
|
||
|
"Sec-WebSocket-Version": "12",
|
||
|
},
|
||
|
)
|
||
|
self.assertEqual(response.code, 426)
|
||
|
|
||
|
@gen_test
|
||
|
def test_websocket_gen(self):
|
||
|
ws = yield self.ws_connect("/echo")
|
||
|
yield ws.write_message("hello")
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, "hello")
|
||
|
|
||
|
def test_websocket_callbacks(self):
|
||
|
websocket_connect(
|
||
|
"ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop
|
||
|
)
|
||
|
ws = self.wait().result()
|
||
|
ws.write_message("hello")
|
||
|
ws.read_message(self.stop)
|
||
|
response = self.wait().result()
|
||
|
self.assertEqual(response, "hello")
|
||
|
self.close_future.add_done_callback(lambda f: self.stop())
|
||
|
ws.close()
|
||
|
self.wait()
|
||
|
|
||
|
@gen_test
|
||
|
def test_binary_message(self):
|
||
|
ws = yield self.ws_connect("/echo")
|
||
|
ws.write_message(b"hello \xe9", binary=True)
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, b"hello \xe9")
|
||
|
|
||
|
@gen_test
|
||
|
def test_unicode_message(self):
|
||
|
ws = yield self.ws_connect("/echo")
|
||
|
ws.write_message("hello \u00e9")
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, "hello \u00e9")
|
||
|
|
||
|
@gen_test
|
||
|
def test_error_in_closed_client_write_message(self):
|
||
|
ws = yield self.ws_connect("/echo")
|
||
|
ws.close()
|
||
|
with self.assertRaises(WebSocketClosedError):
|
||
|
ws.write_message("hello \u00e9")
|
||
|
|
||
|
@gen_test
|
||
|
def test_render_message(self):
|
||
|
ws = yield self.ws_connect("/render")
|
||
|
ws.write_message("hello")
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, "<b>hello</b>")
|
||
|
|
||
|
@gen_test
|
||
|
def test_error_in_on_message(self):
|
||
|
ws = yield self.ws_connect("/error_in_on_message")
|
||
|
ws.write_message("hello")
|
||
|
with ExpectLog(app_log, "Uncaught exception"):
|
||
|
response = yield ws.read_message()
|
||
|
self.assertIs(response, None)
|
||
|
|
||
|
@gen_test
|
||
|
def test_websocket_http_fail(self):
|
||
|
with self.assertRaises(HTTPError) as cm:
|
||
|
yield self.ws_connect("/notfound")
|
||
|
self.assertEqual(cm.exception.code, 404)
|
||
|
|
||
|
@gen_test
|
||
|
def test_websocket_http_success(self):
|
||
|
with self.assertRaises(WebSocketError):
|
||
|
yield self.ws_connect("/non_ws")
|
||
|
|
||
|
@gen_test
|
||
|
def test_websocket_http_redirect(self):
|
||
|
with self.assertRaises(HTTPError):
|
||
|
yield self.ws_connect("/redirect")
|
||
|
|
||
|
@gen_test
|
||
|
def test_websocket_network_fail(self):
|
||
|
sock, port = bind_unused_port()
|
||
|
sock.close()
|
||
|
with self.assertRaises(IOError):
|
||
|
with ExpectLog(gen_log, ".*", required=False):
|
||
|
yield websocket_connect(
|
||
|
"ws://127.0.0.1:%d/" % port, connect_timeout=3600
|
||
|
)
|
||
|
|
||
|
@gen_test
|
||
|
def test_websocket_close_buffered_data(self):
|
||
|
with contextlib.closing(
|
||
|
(yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port()))
|
||
|
) as ws:
|
||
|
ws.write_message("hello")
|
||
|
ws.write_message("world")
|
||
|
# Close the underlying stream.
|
||
|
ws.stream.close()
|
||
|
|
||
|
@gen_test
|
||
|
def test_websocket_headers(self):
|
||
|
# Ensure that arbitrary headers can be passed through websocket_connect.
|
||
|
with contextlib.closing(
|
||
|
(
|
||
|
yield websocket_connect(
|
||
|
HTTPRequest(
|
||
|
"ws://127.0.0.1:%d/header" % self.get_http_port(),
|
||
|
headers={"X-Test": "hello"},
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
) as ws:
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, "hello")
|
||
|
|
||
|
@gen_test
|
||
|
def test_websocket_header_echo(self):
|
||
|
# Ensure that headers can be returned in the response.
|
||
|
# Specifically, that arbitrary headers passed through websocket_connect
|
||
|
# can be returned.
|
||
|
with contextlib.closing(
|
||
|
(
|
||
|
yield websocket_connect(
|
||
|
HTTPRequest(
|
||
|
"ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
|
||
|
headers={"X-Test-Hello": "hello"},
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
) as ws:
|
||
|
self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
|
||
|
self.assertEqual(
|
||
|
ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
|
||
|
)
|
||
|
|
||
|
@gen_test
|
||
|
def test_server_close_reason(self):
|
||
|
ws = yield self.ws_connect("/close_reason")
|
||
|
msg = yield ws.read_message()
|
||
|
# A message of None means the other side closed the connection.
|
||
|
self.assertIs(msg, None)
|
||
|
self.assertEqual(ws.close_code, 1001)
|
||
|
self.assertEqual(ws.close_reason, "goodbye")
|
||
|
# The on_close callback is called no matter which side closed.
|
||
|
code, reason = yield self.close_future
|
||
|
# The client echoed the close code it received to the server,
|
||
|
# so the server's close code (returned via close_future) is
|
||
|
# the same.
|
||
|
self.assertEqual(code, 1001)
|
||
|
|
||
|
@gen_test
|
||
|
def test_client_close_reason(self):
|
||
|
ws = yield self.ws_connect("/echo")
|
||
|
ws.close(1001, "goodbye")
|
||
|
code, reason = yield self.close_future
|
||
|
self.assertEqual(code, 1001)
|
||
|
self.assertEqual(reason, "goodbye")
|
||
|
|
||
|
@gen_test
|
||
|
def test_write_after_close(self):
|
||
|
ws = yield self.ws_connect("/close_reason")
|
||
|
msg = yield ws.read_message()
|
||
|
self.assertIs(msg, None)
|
||
|
with self.assertRaises(WebSocketClosedError):
|
||
|
ws.write_message("hello")
|
||
|
|
||
|
@gen_test
|
||
|
def test_async_prepare(self):
|
||
|
# Previously, an async prepare method triggered a bug that would
|
||
|
# result in a timeout on test shutdown (and a memory leak).
|
||
|
ws = yield self.ws_connect("/async_prepare")
|
||
|
ws.write_message("hello")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "hello")
|
||
|
|
||
|
@gen_test
|
||
|
def test_path_args(self):
|
||
|
ws = yield self.ws_connect("/path_args/hello")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "hello")
|
||
|
|
||
|
@gen_test
|
||
|
def test_coroutine(self):
|
||
|
ws = yield self.ws_connect("/coroutine")
|
||
|
# Send both messages immediately, coroutine must process one at a time.
|
||
|
yield ws.write_message("hello1")
|
||
|
yield ws.write_message("hello2")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "hello1")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "hello2")
|
||
|
|
||
|
@gen_test
|
||
|
def test_check_origin_valid_no_path(self):
|
||
|
port = self.get_http_port()
|
||
|
|
||
|
url = "ws://127.0.0.1:%d/echo" % port
|
||
|
headers = {"Origin": "http://127.0.0.1:%d" % port}
|
||
|
|
||
|
with contextlib.closing(
|
||
|
(yield websocket_connect(HTTPRequest(url, headers=headers)))
|
||
|
) as ws:
|
||
|
ws.write_message("hello")
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, "hello")
|
||
|
|
||
|
@gen_test
|
||
|
def test_check_origin_valid_with_path(self):
|
||
|
port = self.get_http_port()
|
||
|
|
||
|
url = "ws://127.0.0.1:%d/echo" % port
|
||
|
headers = {"Origin": "http://127.0.0.1:%d/something" % port}
|
||
|
|
||
|
with contextlib.closing(
|
||
|
(yield websocket_connect(HTTPRequest(url, headers=headers)))
|
||
|
) as ws:
|
||
|
ws.write_message("hello")
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, "hello")
|
||
|
|
||
|
@gen_test
|
||
|
def test_check_origin_invalid_partial_url(self):
|
||
|
port = self.get_http_port()
|
||
|
|
||
|
url = "ws://127.0.0.1:%d/echo" % port
|
||
|
headers = {"Origin": "127.0.0.1:%d" % port}
|
||
|
|
||
|
with self.assertRaises(HTTPError) as cm:
|
||
|
yield websocket_connect(HTTPRequest(url, headers=headers))
|
||
|
self.assertEqual(cm.exception.code, 403)
|
||
|
|
||
|
@gen_test
|
||
|
def test_check_origin_invalid(self):
|
||
|
port = self.get_http_port()
|
||
|
|
||
|
url = "ws://127.0.0.1:%d/echo" % port
|
||
|
# Host is 127.0.0.1, which should not be accessible from some other
|
||
|
# domain
|
||
|
headers = {"Origin": "http://somewhereelse.com"}
|
||
|
|
||
|
with self.assertRaises(HTTPError) as cm:
|
||
|
yield websocket_connect(HTTPRequest(url, headers=headers))
|
||
|
|
||
|
self.assertEqual(cm.exception.code, 403)
|
||
|
|
||
|
@gen_test
|
||
|
def test_check_origin_invalid_subdomains(self):
|
||
|
port = self.get_http_port()
|
||
|
|
||
|
# CaresResolver may return ipv6-only results for localhost, but our
|
||
|
# server is only running on ipv4. Test for this edge case and skip
|
||
|
# the test if it happens.
|
||
|
addrinfo = yield Resolver().resolve("localhost", port)
|
||
|
families = set(addr[0] for addr in addrinfo)
|
||
|
if socket.AF_INET not in families:
|
||
|
self.skipTest("localhost does not resolve to ipv4")
|
||
|
return
|
||
|
|
||
|
url = "ws://localhost:%d/echo" % port
|
||
|
# Subdomains should be disallowed by default. If we could pass a
|
||
|
# resolver to websocket_connect we could test sibling domains as well.
|
||
|
headers = {"Origin": "http://subtenant.localhost"}
|
||
|
|
||
|
with self.assertRaises(HTTPError) as cm:
|
||
|
yield websocket_connect(HTTPRequest(url, headers=headers))
|
||
|
|
||
|
self.assertEqual(cm.exception.code, 403)
|
||
|
|
||
|
@gen_test
|
||
|
def test_subprotocols(self):
|
||
|
ws = yield self.ws_connect(
|
||
|
"/subprotocol", subprotocols=["badproto", "goodproto"]
|
||
|
)
|
||
|
self.assertEqual(ws.selected_subprotocol, "goodproto")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "subprotocol=goodproto")
|
||
|
|
||
|
@gen_test
|
||
|
def test_subprotocols_not_offered(self):
|
||
|
ws = yield self.ws_connect("/subprotocol")
|
||
|
self.assertIs(ws.selected_subprotocol, None)
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "subprotocol=None")
|
||
|
|
||
|
@gen_test
|
||
|
def test_open_coroutine(self):
|
||
|
self.message_sent = Event()
|
||
|
ws = yield self.ws_connect("/open_coroutine")
|
||
|
yield ws.write_message("hello")
|
||
|
self.message_sent.set()
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "ok")
|
||
|
|
||
|
@gen_test
|
||
|
def test_error_in_open(self):
|
||
|
with ExpectLog(app_log, "Uncaught exception"):
|
||
|
ws = yield self.ws_connect("/error_in_open")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertIsNone(res)
|
||
|
|
||
|
@gen_test
|
||
|
def test_error_in_async_open(self):
|
||
|
with ExpectLog(app_log, "Uncaught exception"):
|
||
|
ws = yield self.ws_connect("/error_in_async_open")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertIsNone(res)
|
||
|
|
||
|
@gen_test
|
||
|
def test_nodelay(self):
|
||
|
ws = yield self.ws_connect("/nodelay")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "hello")
|
||
|
|
||
|
|
||
|
class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
|
||
|
def initialize(self, **kwargs):
|
||
|
super().initialize(**kwargs)
|
||
|
self.sleeping = 0
|
||
|
|
||
|
async def on_message(self, message):
|
||
|
if self.sleeping > 0:
|
||
|
self.write_message("another coroutine is already sleeping")
|
||
|
self.sleeping += 1
|
||
|
await gen.sleep(0.01)
|
||
|
self.sleeping -= 1
|
||
|
self.write_message(message)
|
||
|
|
||
|
|
||
|
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
|
||
|
def get_app(self):
|
||
|
return Application([("/native", NativeCoroutineOnMessageHandler)])
|
||
|
|
||
|
@gen_test
|
||
|
def test_native_coroutine(self):
|
||
|
ws = yield self.ws_connect("/native")
|
||
|
# Send both messages immediately, coroutine must process one at a time.
|
||
|
yield ws.write_message("hello1")
|
||
|
yield ws.write_message("hello2")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "hello1")
|
||
|
res = yield ws.read_message()
|
||
|
self.assertEqual(res, "hello2")
|
||
|
|
||
|
|
||
|
class CompressionTestMixin(object):
|
||
|
MESSAGE = "Hello world. Testing 123 123"
|
||
|
|
||
|
def get_app(self):
|
||
|
class LimitedHandler(TestWebSocketHandler):
|
||
|
@property
|
||
|
def max_message_size(self):
|
||
|
return 1024
|
||
|
|
||
|
def on_message(self, message):
|
||
|
self.write_message(str(len(message)))
|
||
|
|
||
|
return Application(
|
||
|
[
|
||
|
(
|
||
|
"/echo",
|
||
|
EchoHandler,
|
||
|
dict(compression_options=self.get_server_compression_options()),
|
||
|
),
|
||
|
(
|
||
|
"/limited",
|
||
|
LimitedHandler,
|
||
|
dict(compression_options=self.get_server_compression_options()),
|
||
|
),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
def get_server_compression_options(self):
|
||
|
return None
|
||
|
|
||
|
def get_client_compression_options(self):
|
||
|
return None
|
||
|
|
||
|
def verify_wire_bytes(self, bytes_in: int, bytes_out: int) -> None:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
@gen_test
|
||
|
def test_message_sizes(self: typing.Any):
|
||
|
ws = yield self.ws_connect(
|
||
|
"/echo", compression_options=self.get_client_compression_options()
|
||
|
)
|
||
|
# Send the same message three times so we can measure the
|
||
|
# effect of the context_takeover options.
|
||
|
for i in range(3):
|
||
|
ws.write_message(self.MESSAGE)
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, self.MESSAGE)
|
||
|
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
|
||
|
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
|
||
|
self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
|
||
|
|
||
|
@gen_test
|
||
|
def test_size_limit(self: typing.Any):
|
||
|
ws = yield self.ws_connect(
|
||
|
"/limited", compression_options=self.get_client_compression_options()
|
||
|
)
|
||
|
# Small messages pass through.
|
||
|
ws.write_message("a" * 128)
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, "128")
|
||
|
# This message is too big after decompression, but it compresses
|
||
|
# down to a size that will pass the initial checks.
|
||
|
ws.write_message("a" * 2048)
|
||
|
response = yield ws.read_message()
|
||
|
self.assertIsNone(response)
|
||
|
|
||
|
|
||
|
class UncompressedTestMixin(CompressionTestMixin):
|
||
|
"""Specialization of CompressionTestMixin when we expect no compression."""
|
||
|
|
||
|
def verify_wire_bytes(self: typing.Any, bytes_in, bytes_out):
|
||
|
# Bytes out includes the 4-byte mask key per message.
|
||
|
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
|
||
|
self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
|
||
|
|
||
|
|
||
|
class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
|
||
|
pass
|
||
|
|
||
|
|
||
|
# If only one side tries to compress, the extension is not negotiated.
|
||
|
class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
|
||
|
def get_server_compression_options(self):
|
||
|
return {}
|
||
|
|
||
|
|
||
|
class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
|
||
|
def get_client_compression_options(self):
|
||
|
return {}
|
||
|
|
||
|
|
||
|
class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
|
||
|
def get_server_compression_options(self):
|
||
|
return {}
|
||
|
|
||
|
def get_client_compression_options(self):
|
||
|
return {}
|
||
|
|
||
|
def verify_wire_bytes(self, bytes_in, bytes_out):
|
||
|
self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
|
||
|
self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
|
||
|
# Bytes out includes the 4 bytes mask key per message.
|
||
|
self.assertEqual(bytes_out, bytes_in + 12)
|
||
|
|
||
|
|
||
|
class MaskFunctionMixin(object):
|
||
|
# Subclasses should define self.mask(mask, data)
|
||
|
def mask(self, mask: bytes, data: bytes) -> bytes:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def test_mask(self: typing.Any):
|
||
|
self.assertEqual(self.mask(b"abcd", b""), b"")
|
||
|
self.assertEqual(self.mask(b"abcd", b"b"), b"\x03")
|
||
|
self.assertEqual(self.mask(b"abcd", b"54321"), b"TVPVP")
|
||
|
self.assertEqual(self.mask(b"ZXCV", b"98765432"), b"c`t`olpd")
|
||
|
# Include test cases with \x00 bytes (to ensure that the C
|
||
|
# extension isn't depending on null-terminated strings) and
|
||
|
# bytes with the high bit set (to smoke out signedness issues).
|
||
|
self.assertEqual(
|
||
|
self.mask(b"\x00\x01\x02\x03", b"\xff\xfb\xfd\xfc\xfe\xfa"),
|
||
|
b"\xff\xfa\xff\xff\xfe\xfb",
|
||
|
)
|
||
|
self.assertEqual(
|
||
|
self.mask(b"\xff\xfb\xfd\xfc", b"\x00\x01\x02\x03\x04\x05"),
|
||
|
b"\xff\xfa\xff\xff\xfb\xfe",
|
||
|
)
|
||
|
|
||
|
|
||
|
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
||
|
def mask(self, mask, data):
|
||
|
return _websocket_mask_python(mask, data)
|
||
|
|
||
|
|
||
|
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
|
||
|
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
||
|
def mask(self, mask, data):
|
||
|
return speedups.websocket_mask(mask, data)
|
||
|
|
||
|
|
||
|
class ServerPeriodicPingTest(WebSocketBaseTestCase):
|
||
|
def get_app(self):
|
||
|
class PingHandler(TestWebSocketHandler):
|
||
|
def on_pong(self, data):
|
||
|
self.write_message("got pong")
|
||
|
|
||
|
return Application([("/", PingHandler)], websocket_ping_interval=0.01)
|
||
|
|
||
|
@gen_test
|
||
|
def test_server_ping(self):
|
||
|
ws = yield self.ws_connect("/")
|
||
|
for i in range(3):
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, "got pong")
|
||
|
# TODO: test that the connection gets closed if ping responses stop.
|
||
|
|
||
|
|
||
|
class ClientPeriodicPingTest(WebSocketBaseTestCase):
|
||
|
def get_app(self):
|
||
|
class PingHandler(TestWebSocketHandler):
|
||
|
def on_ping(self, data):
|
||
|
self.write_message("got ping")
|
||
|
|
||
|
return Application([("/", PingHandler)])
|
||
|
|
||
|
@gen_test
|
||
|
def test_client_ping(self):
|
||
|
ws = yield self.ws_connect("/", ping_interval=0.01)
|
||
|
for i in range(3):
|
||
|
response = yield ws.read_message()
|
||
|
self.assertEqual(response, "got ping")
|
||
|
# TODO: test that the connection gets closed if ping responses stop.
|
||
|
ws.close()
|
||
|
|
||
|
|
||
|
class ManualPingTest(WebSocketBaseTestCase):
|
||
|
def get_app(self):
|
||
|
class PingHandler(TestWebSocketHandler):
|
||
|
def on_ping(self, data):
|
||
|
self.write_message(data, binary=isinstance(data, bytes))
|
||
|
|
||
|
return Application([("/", PingHandler)])
|
||
|
|
||
|
@gen_test
|
||
|
def test_manual_ping(self):
|
||
|
ws = yield self.ws_connect("/")
|
||
|
|
||
|
self.assertRaises(ValueError, ws.ping, "a" * 126)
|
||
|
|
||
|
ws.ping("hello")
|
||
|
resp = yield ws.read_message()
|
||
|
# on_ping always sees bytes.
|
||
|
self.assertEqual(resp, b"hello")
|
||
|
|
||
|
ws.ping(b"binary hello")
|
||
|
resp = yield ws.read_message()
|
||
|
self.assertEqual(resp, b"binary hello")
|
||
|
|
||
|
|
||
|
class MaxMessageSizeTest(WebSocketBaseTestCase):
|
||
|
def get_app(self):
|
||
|
return Application([("/", EchoHandler)], websocket_max_message_size=1024)
|
||
|
|
||
|
@gen_test
|
||
|
def test_large_message(self):
|
||
|
ws = yield self.ws_connect("/")
|
||
|
|
||
|
# Write a message that is allowed.
|
||
|
msg = "a" * 1024
|
||
|
ws.write_message(msg)
|
||
|
resp = yield ws.read_message()
|
||
|
self.assertEqual(resp, msg)
|
||
|
|
||
|
# Write a message that is too large.
|
||
|
ws.write_message(msg + "b")
|
||
|
resp = yield ws.read_message()
|
||
|
# A message of None means the other side closed the connection.
|
||
|
self.assertIs(resp, None)
|
||
|
self.assertEqual(ws.close_code, 1009)
|
||
|
self.assertEqual(ws.close_reason, "message too big")
|
||
|
# TODO: Needs tests of messages split over multiple
|
||
|
# continuation frames.
|