Skip to content

Commit

Permalink
Merge pull request #36 from nats-io/tls-reconnect-test
Browse files Browse the repository at this point in the history
Add basic TLS reconnect test
  • Loading branch information
Waldemar Quevedo authored May 9, 2017
2 parents a4f9ddf + dc7361f commit e04bf18
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 4 deletions.
81 changes: 79 additions & 2 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nats.aio.errors import ErrConnectionClosed, ErrNoServers, ErrTimeout, \
ErrBadSubject, NatsError
from tests.utils import async_test, start_gnatsd, NatsTestCase, \
SingleServerTestCase, MultiServerAuthTestCase, TLSServerTestCase
SingleServerTestCase, MultiServerAuthTestCase, TLSServerTestCase, MultiTLSServerAuthTestCase

class ClientUtilsTest(NatsTestCase):

Expand Down Expand Up @@ -823,7 +823,7 @@ def worker_handler(msg):
self.assertEqual(1, reconnected_count)
self.assertEqual(1, err_count)

class TLSTest(TLSServerTestCase):
class ClientTLSTest(TLSServerTestCase):

@async_test
def test_connect(self):
Expand Down Expand Up @@ -869,6 +869,83 @@ def subscription_handler(msg):
self.assertEqual(1, nc._subs[sid].received)
yield from nc.close()

class ClientTLSReconnectTest(MultiTLSServerAuthTestCase):

@async_test
def test_tls_reconnect(self):

nc = NATS()
disconnected_count = 0
reconnected_count = 0
closed_count = 0
err_count = 0

@asyncio.coroutine
def disconnected_cb():
nonlocal disconnected_count
disconnected_count += 1

@asyncio.coroutine
def reconnected_cb():
nonlocal reconnected_count
reconnected_count += 1

@asyncio.coroutine
def closed_cb():
nonlocal closed_count
closed_count += 1

@asyncio.coroutine
def err_cb(e):
nonlocal err_count
err_count += 1

counter = 0
@asyncio.coroutine
def worker_handler(msg):
nonlocal counter
counter += 1
if msg.reply != "":
yield from nc.publish(msg.reply, 'Reply:{}'.format(counter).encode())

options = {
'servers': [
"nats://foo:[email protected]:4223",
"nats://hoge:[email protected]:4224"
],
'io_loop': self.loop,
'disconnected_cb': disconnected_cb,
'closed_cb': closed_cb,
'reconnected_cb': reconnected_cb,
'error_cb': err_cb,
'dont_randomize': True,
'tls': self.ssl_ctx
}
yield from nc.connect(**options)
self.assertTrue(nc.is_connected)

yield from nc.subscribe("example", cb=worker_handler)
response = yield from nc.timed_request("example", b'Help!', timeout=1)
self.assertEqual(b'Reply:1', response.data)

# Trigger a reconnnect and should be fine
yield from self.loop.run_in_executor(None, self.server_pool[0].stop)
yield from asyncio.sleep(1, loop=self.loop)

yield from nc.subscribe("example", cb=worker_handler)
response = yield from nc.timed_request("example", b'Help!', timeout=1)
self.assertEqual(b'Reply:2', response.data)

yield from nc.close()
self.assertTrue(nc.is_closed)
self.assertFalse(nc.is_connected)
self.assertEqual(1, nc.stats['reconnects'])
self.assertEqual(1, closed_count)
self.assertEqual(2, disconnected_count)
self.assertEqual(1, reconnected_count)
self.assertEqual(1, err_count)

if __name__ == '__main__':
runner = unittest.TextTestRunner(stream=sys.stdout)
unittest.main(verbosity=2, exit=False, testRunner=runner)

3 changes: 2 additions & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
test_suite.addTest(unittest.makeSuite(ClientUtilsTest))
test_suite.addTest(unittest.makeSuite(ClientTest))
test_suite.addTest(unittest.makeSuite(ClientReconnectTest))
test_suite.addTest(unittest.makeSuite(TLSTest))
test_suite.addTest(unittest.makeSuite(ClientTLSTest))
test_suite.addTest(unittest.makeSuite(ClientTLSReconnectTest))

# Skip tests using async/await syntax unless on Python 3.5
if sys.version_info >= (3, 5):
Expand Down
26 changes: 25 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def tearDown(self):
gnatsd.stop()
self.loop.close()


class TLSServerTestCase(NatsTestCase):
def setUp(self):
super().setUp()
Expand All @@ -150,6 +149,31 @@ def tearDown(self):
self.gnatsd.stop()
self.loop.close()

class MultiTLSServerAuthTestCase(NatsTestCase):

def setUp(self):
super(MultiTLSServerAuthTestCase, self).setUp()
self.server_pool = []
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

server1 = Gnatsd(port=4223, user="foo", password="bar", http_port=8223, tls=True)
self.server_pool.append(server1)
server2 = Gnatsd(port=4224, user="hoge", password="fuga", http_port=8224, tls=True)
self.server_pool.append(server2)
for gnatsd in self.server_pool:
start_gnatsd(gnatsd)

self.ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2
self.ssl_ctx.load_verify_locations('tests/certs/ca.pem')
self.ssl_ctx.load_cert_chain(certfile='tests/certs/client-cert.pem',
keyfile='tests/certs/client-key.pem')

def tearDown(self):
for gnatsd in self.server_pool:
gnatsd.stop()
self.loop.close()

def start_gnatsd(gnatsd: Gnatsd):
gnatsd.start()
Expand Down

0 comments on commit e04bf18

Please sign in to comment.