diff --git a/nats/js/client.py b/nats/js/client.py index eda06798..6cea6039 100644 --- a/nats/js/client.py +++ b/nats/js/client.py @@ -25,7 +25,7 @@ from nats.aio.msg import Msg from nats.aio.subscription import Subscription from nats.js import api -from nats.js.errors import BadBucketError, BucketNotFoundError, InvalidBucketNameError, NotFoundError +from nats.js.errors import BadBucketError, BucketNotFoundError, InvalidBucketNameError, NotFoundError, FetchTimeoutError from nats.js.kv import KeyValue from nats.js.manager import JetStreamManager from nats.js.object_store import ( @@ -547,6 +547,13 @@ def _is_temporary_error(cls, status: Optional[str]) -> bool: else: return False + @classmethod + def _is_heartbeat(cls, status: Optional[str]) -> bool: + if status == api.StatusCode.CONTROL_MESSAGE: + return True + else: + return False + @classmethod def _time_until(cls, timeout: Optional[float], start_time: float) -> Optional[float]: @@ -620,9 +627,7 @@ async def activity_check(self): self._active = False if not active: if self._ordered: - await self.reset_ordered_consumer( - self._sseq + 1 - ) + await self.reset_ordered_consumer(self._sseq + 1) except asyncio.CancelledError: break @@ -882,14 +887,18 @@ async def consumer_info(self) -> api.ConsumerInfo: ) return info - async def fetch(self, - batch: int = 1, - timeout: Optional[float] = 5) -> List[Msg]: + async def fetch( + self, + batch: int = 1, + timeout: Optional[float] = 5, + heartbeat: Optional[float] = None + ) -> List[Msg]: """ fetch makes a request to JetStream to be delivered a set of messages. :param batch: Number of messages to fetch from server. :param timeout: Max duration of the fetch request before it expires. + :param heartbeat: Idle Heartbeat interval in seconds for the fetch request. :: @@ -925,15 +934,16 @@ async def main(): timeout * 1_000_000_000 ) - 100_000 if timeout else None if batch == 1: - msg = await self._fetch_one(expires, timeout) + msg = await self._fetch_one(expires, timeout, heartbeat) return [msg] - msgs = await self._fetch_n(batch, expires, timeout) + msgs = await self._fetch_n(batch, expires, timeout, heartbeat) return msgs async def _fetch_one( self, expires: Optional[int], timeout: Optional[float], + heartbeat: Optional[float] = None ) -> Msg: queue = self._sub._pending_queue @@ -957,6 +967,10 @@ async def _fetch_one( next_req['batch'] = 1 if expires: next_req['expires'] = int(expires) + if heartbeat: + next_req['idle_heartbeat'] = int( + heartbeat * 1_000_000_000 + ) # to nanoseconds await self._nc.publish( self._nms, @@ -965,6 +979,7 @@ async def _fetch_one( ) start_time = time.monotonic() + got_any_response = False while True: try: deadline = JetStreamContext._time_until( @@ -976,6 +991,10 @@ async def _fetch_one( # Should have received at least a processable message at this point, status = JetStreamContext.is_status_msg(msg) if status: + if JetStreamContext._is_heartbeat(status): + got_any_response = True + continue + # In case of a temporary error, treat it as a timeout to retry. if JetStreamContext._is_temporary_error(status): raise nats.errors.TimeoutError @@ -993,6 +1012,8 @@ async def _fetch_one( # due to a reconnect while the fetch request, # the JS API not responding on time, or maybe # there were no messages yet. + if got_any_response: + raise FetchTimeoutError raise async def _fetch_n( @@ -1000,10 +1021,12 @@ async def _fetch_n( batch: int, expires: Optional[int], timeout: Optional[float], + heartbeat: Optional[float] = None ) -> List[Msg]: msgs = [] queue = self._sub._pending_queue start_time = time.monotonic() + got_any_response = False needed = batch # Fetch as many as needed from the internal pending queue. @@ -1029,6 +1052,10 @@ async def _fetch_n( next_req['batch'] = needed if expires: next_req['expires'] = expires + if heartbeat: + next_req['idle_heartbeat'] = int( + heartbeat * 1_000_000_000 + ) # to nanoseconds next_req['no_wait'] = True await self._nc.publish( self._nms, @@ -1040,12 +1067,20 @@ async def _fetch_n( try: msg = await self._sub.next_msg(timeout) except asyncio.TimeoutError: + # Return any message that was already available in the internal queue. if msgs: return msgs raise + got_any_response = False + status = JetStreamContext.is_status_msg(msg) - if JetStreamContext._is_processable_msg(status, msg): + if JetStreamContext._is_heartbeat(status): + # Mark that we got any response from the server so this is not + # a possible i/o timeout error or due to a disconnection. + got_any_response = True + pass + elif JetStreamContext._is_processable_msg(status, msg): # First processable message received, do not raise error from now. msgs.append(msg) needed -= 1 @@ -1061,6 +1096,10 @@ async def _fetch_n( # No more messages after this so fallthrough # after receiving the rest. break + elif JetStreamContext._is_heartbeat(status): + # Skip heartbeats. + got_any_response = True + continue elif JetStreamContext._is_processable_msg(status, msg): needed -= 1 msgs.append(msg) @@ -1079,6 +1118,11 @@ async def _fetch_n( next_req['batch'] = needed if expires: next_req['expires'] = expires + if heartbeat: + next_req['idle_heartbeat'] = int( + heartbeat * 1_000_000_000 + ) # to nanoseconds + await self._nc.publish( self._nms, json.dumps(next_req).encode(), @@ -1099,7 +1143,12 @@ async def _fetch_n( if len(msgs) == 0: # Not a single processable message has been received so far, # if this timed out then let the error be raised. - msg = await self._sub.next_msg(timeout=deadline) + try: + msg = await self._sub.next_msg(timeout=deadline) + except asyncio.TimeoutError: + if got_any_response: + raise FetchTimeoutError + raise else: try: msg = await self._sub.next_msg(timeout=deadline) @@ -1109,6 +1158,10 @@ async def _fetch_n( if msg: status = JetStreamContext.is_status_msg(msg) + if JetStreamContext._is_heartbeat(status): + got_any_response = True + continue + if not status: needed -= 1 msgs.append(msg) @@ -1132,6 +1185,9 @@ async def _fetch_n( msg = await self._sub.next_msg(timeout=deadline) status = JetStreamContext.is_status_msg(msg) + if JetStreamContext._is_heartbeat(status): + got_any_response = True + continue if JetStreamContext._is_processable_msg(status, msg): needed -= 1 msgs.append(msg) @@ -1140,6 +1196,9 @@ async def _fetch_n( # at least one message has already arrived. pass + if len(msgs) == 0 and got_any_response: + raise FetchTimeoutError + return msgs ###################### diff --git a/nats/js/errors.py b/nats/js/errors.py index 69bfaa7f..faad26a1 100644 --- a/nats/js/errors.py +++ b/nats/js/errors.py @@ -1,4 +1,4 @@ -# Copyright 2016-2022 The NATS Authors +# Copyright 2016-2024 The NATS Authors # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -133,6 +133,15 @@ def __str__(self) -> str: return "nats: no response from stream" +class FetchTimeoutError(nats.errors.TimeoutError): + """ + Raised if the consumer timed out waiting for messages. + """ + + def __str__(self) -> str: + return "nats: fetch timeout" + + class ConsumerSequenceMismatchError(Error): """ Async error raised by the client with idle_heartbeat mode enabled diff --git a/tests/test_js.py b/tests/test_js.py index e6608222..0d4d893c 100644 --- a/tests/test_js.py +++ b/tests/test_js.py @@ -842,7 +842,7 @@ async def test_consumer_with_multiple_filters(self): ok = await msgs[0].ack_sync() assert ok - @async_debug_test + @async_long_test async def test_add_consumer_with_backoff(self): nc = NATS() await nc.connect() @@ -901,6 +901,102 @@ async def cb(msg): assert info.config.backoff == [1, 2] await nc.close() + @async_long_test + async def test_fetch_heartbeats(self): + nc = NATS() + await nc.connect() + + js = nc.jetstream() + + await js.add_stream(name="events", subjects=["events.>"]) + await js.add_consumer( + "events", + durable_name="a", + max_deliver=2, + max_waiting=5, + ack_wait=30, + max_ack_pending=5, + filter_subject="events.>", + ) + sub = await js.pull_subscribe_bind("a", stream="events") + + with pytest.raises(nats.js.errors.FetchTimeoutError): + await sub.fetch(1, timeout=1, heartbeat=0.1) + + with pytest.raises(asyncio.TimeoutError): + await sub.fetch(1, timeout=1, heartbeat=0.1) + + with pytest.raises(nats.errors.TimeoutError): + await sub.fetch(1, timeout=1, heartbeat=0.1) + + for i in range(0, 15): + await js.publish("events.%d" % i, b'i:%d' % i) + + # Fetch(n) + msgs = await sub.fetch(5, timeout=5, heartbeat=0.1) + assert len(msgs) == 5 + for msg in msgs: + await msg.ack_sync() + info = await js.consumer_info("events", "a") + assert info.num_pending == 10 + + # Fetch(1) + msgs = await sub.fetch(1, timeout=1, heartbeat=0.1) + assert len(msgs) == 1 + for msg in msgs: + await msg.ack_sync() + + # Receive some messages. + msgs = await sub.fetch(20, timeout=2, heartbeat=0.1) + for msg in msgs: + await msg.ack_sync() + msgs = await sub.fetch(4, timeout=2, heartbeat=0.1) + for msg in msgs: + await msg.ack_sync() + + # Check that messages were removed from being pending. + info = await js.consumer_info("events", "a") + assert info.num_pending == 0 + + # Ask for more messages but there aren't any. + with pytest.raises(nats.js.errors.FetchTimeoutError): + await sub.fetch(4, timeout=1, heartbeat=0.1) + + with pytest.raises(asyncio.TimeoutError): + msgs = await sub.fetch(4, timeout=1, heartbeat=0.1) + + with pytest.raises(nats.errors.TimeoutError): + msgs = await sub.fetch(4, timeout=1, heartbeat=0.1) + + with pytest.raises(nats.js.errors.APIError) as err: + await sub.fetch(1, timeout=1, heartbeat=0.5) + assert err.value.description == 'Bad Request - heartbeat value too large' + + # Example of catching fetch timeout instead first. + got_fetch_timeout = False + got_io_timeout = False + try: + await sub.fetch(1, timeout=1, heartbeat=0.2) + except nats.js.errors.FetchTimeoutError: + got_fetch_timeout = True + except nats.errors.TimeoutError: + got_io_timeout = True + assert got_fetch_timeout == True + assert got_io_timeout == False + + got_fetch_timeout = False + got_io_timeout = False + try: + await sub.fetch(1, timeout=1, heartbeat=0.2) + except nats.errors.TimeoutError: + got_io_timeout = True + except nats.js.errors.FetchTimeoutError: + got_fetch_timeout = True + assert got_fetch_timeout == False + assert got_io_timeout == True + + await nc.close() + class JSMTest(SingleJetStreamServerTestCase):