Skip to content

Commit

Permalink
Merge pull request #540 from nats-io/fetch-hb
Browse files Browse the repository at this point in the history
js: add fetch heartbeat option
  • Loading branch information
wallyqs authored Feb 27, 2024
2 parents 9e24c40 + 2bb6e8d commit 10d2cc6
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 13 deletions.
81 changes: 70 additions & 11 deletions nats/js/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from nats.aio.msg import Msg
from nats.aio.subscription import Subscription
from nats.js import api
from nats.js.errors import BadBucketError, BucketNotFoundError, InvalidBucketNameError, NotFoundError
from nats.js.errors import BadBucketError, BucketNotFoundError, InvalidBucketNameError, NotFoundError, FetchTimeoutError
from nats.js.kv import KeyValue
from nats.js.manager import JetStreamManager
from nats.js.object_store import (
Expand Down Expand Up @@ -547,6 +547,13 @@ def _is_temporary_error(cls, status: Optional[str]) -> bool:
else:
return False

@classmethod
def _is_heartbeat(cls, status: Optional[str]) -> bool:
if status == api.StatusCode.CONTROL_MESSAGE:
return True
else:
return False

@classmethod
def _time_until(cls, timeout: Optional[float],
start_time: float) -> Optional[float]:
Expand Down Expand Up @@ -620,9 +627,7 @@ async def activity_check(self):
self._active = False
if not active:
if self._ordered:
await self.reset_ordered_consumer(
self._sseq + 1
)
await self.reset_ordered_consumer(self._sseq + 1)
except asyncio.CancelledError:
break

Expand Down Expand Up @@ -882,14 +887,18 @@ async def consumer_info(self) -> api.ConsumerInfo:
)
return info

async def fetch(self,
batch: int = 1,
timeout: Optional[float] = 5) -> List[Msg]:
async def fetch(
self,
batch: int = 1,
timeout: Optional[float] = 5,
heartbeat: Optional[float] = None
) -> List[Msg]:
"""
fetch makes a request to JetStream to be delivered a set of messages.
:param batch: Number of messages to fetch from server.
:param timeout: Max duration of the fetch request before it expires.
:param heartbeat: Idle Heartbeat interval in seconds for the fetch request.
::
Expand Down Expand Up @@ -925,15 +934,16 @@ async def main():
timeout * 1_000_000_000
) - 100_000 if timeout else None
if batch == 1:
msg = await self._fetch_one(expires, timeout)
msg = await self._fetch_one(expires, timeout, heartbeat)
return [msg]
msgs = await self._fetch_n(batch, expires, timeout)
msgs = await self._fetch_n(batch, expires, timeout, heartbeat)
return msgs

async def _fetch_one(
self,
expires: Optional[int],
timeout: Optional[float],
heartbeat: Optional[float] = None
) -> Msg:
queue = self._sub._pending_queue

Expand All @@ -957,6 +967,10 @@ async def _fetch_one(
next_req['batch'] = 1
if expires:
next_req['expires'] = int(expires)
if heartbeat:
next_req['idle_heartbeat'] = int(
heartbeat * 1_000_000_000
) # to nanoseconds

await self._nc.publish(
self._nms,
Expand All @@ -965,6 +979,7 @@ async def _fetch_one(
)

start_time = time.monotonic()
got_any_response = False
while True:
try:
deadline = JetStreamContext._time_until(
Expand All @@ -976,6 +991,10 @@ async def _fetch_one(
# Should have received at least a processable message at this point,
status = JetStreamContext.is_status_msg(msg)
if status:
if JetStreamContext._is_heartbeat(status):
got_any_response = True
continue

# In case of a temporary error, treat it as a timeout to retry.
if JetStreamContext._is_temporary_error(status):
raise nats.errors.TimeoutError
Expand All @@ -993,17 +1012,21 @@ async def _fetch_one(
# due to a reconnect while the fetch request,
# the JS API not responding on time, or maybe
# there were no messages yet.
if got_any_response:
raise FetchTimeoutError
raise

async def _fetch_n(
self,
batch: int,
expires: Optional[int],
timeout: Optional[float],
heartbeat: Optional[float] = None
) -> List[Msg]:
msgs = []
queue = self._sub._pending_queue
start_time = time.monotonic()
got_any_response = False
needed = batch

# Fetch as many as needed from the internal pending queue.
Expand All @@ -1029,6 +1052,10 @@ async def _fetch_n(
next_req['batch'] = needed
if expires:
next_req['expires'] = expires
if heartbeat:
next_req['idle_heartbeat'] = int(
heartbeat * 1_000_000_000
) # to nanoseconds
next_req['no_wait'] = True
await self._nc.publish(
self._nms,
Expand All @@ -1040,12 +1067,20 @@ async def _fetch_n(
try:
msg = await self._sub.next_msg(timeout)
except asyncio.TimeoutError:
# Return any message that was already available in the internal queue.
if msgs:
return msgs
raise

got_any_response = False

status = JetStreamContext.is_status_msg(msg)
if JetStreamContext._is_processable_msg(status, msg):
if JetStreamContext._is_heartbeat(status):
# Mark that we got any response from the server so this is not
# a possible i/o timeout error or due to a disconnection.
got_any_response = True
pass
elif JetStreamContext._is_processable_msg(status, msg):
# First processable message received, do not raise error from now.
msgs.append(msg)
needed -= 1
Expand All @@ -1061,6 +1096,10 @@ async def _fetch_n(
# No more messages after this so fallthrough
# after receiving the rest.
break
elif JetStreamContext._is_heartbeat(status):
# Skip heartbeats.
got_any_response = True
continue
elif JetStreamContext._is_processable_msg(status, msg):
needed -= 1
msgs.append(msg)
Expand All @@ -1079,6 +1118,11 @@ async def _fetch_n(
next_req['batch'] = needed
if expires:
next_req['expires'] = expires
if heartbeat:
next_req['idle_heartbeat'] = int(
heartbeat * 1_000_000_000
) # to nanoseconds

await self._nc.publish(
self._nms,
json.dumps(next_req).encode(),
Expand All @@ -1099,7 +1143,12 @@ async def _fetch_n(
if len(msgs) == 0:
# Not a single processable message has been received so far,
# if this timed out then let the error be raised.
msg = await self._sub.next_msg(timeout=deadline)
try:
msg = await self._sub.next_msg(timeout=deadline)
except asyncio.TimeoutError:
if got_any_response:
raise FetchTimeoutError
raise
else:
try:
msg = await self._sub.next_msg(timeout=deadline)
Expand All @@ -1109,6 +1158,10 @@ async def _fetch_n(

if msg:
status = JetStreamContext.is_status_msg(msg)
if JetStreamContext._is_heartbeat(status):
got_any_response = True
continue

if not status:
needed -= 1
msgs.append(msg)
Expand All @@ -1132,6 +1185,9 @@ async def _fetch_n(

msg = await self._sub.next_msg(timeout=deadline)
status = JetStreamContext.is_status_msg(msg)
if JetStreamContext._is_heartbeat(status):
got_any_response = True
continue
if JetStreamContext._is_processable_msg(status, msg):
needed -= 1
msgs.append(msg)
Expand All @@ -1140,6 +1196,9 @@ async def _fetch_n(
# at least one message has already arrived.
pass

if len(msgs) == 0 and got_any_response:
raise FetchTimeoutError

return msgs

######################
Expand Down
11 changes: 10 additions & 1 deletion nats/js/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2016-2022 The NATS Authors
# Copyright 2016-2024 The NATS Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -133,6 +133,15 @@ def __str__(self) -> str:
return "nats: no response from stream"


class FetchTimeoutError(nats.errors.TimeoutError):
"""
Raised if the consumer timed out waiting for messages.
"""

def __str__(self) -> str:
return "nats: fetch timeout"


class ConsumerSequenceMismatchError(Error):
"""
Async error raised by the client with idle_heartbeat mode enabled
Expand Down
98 changes: 97 additions & 1 deletion tests/test_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ async def test_consumer_with_multiple_filters(self):
ok = await msgs[0].ack_sync()
assert ok

@async_debug_test
@async_long_test
async def test_add_consumer_with_backoff(self):
nc = NATS()
await nc.connect()
Expand Down Expand Up @@ -901,6 +901,102 @@ async def cb(msg):
assert info.config.backoff == [1, 2]
await nc.close()

@async_long_test
async def test_fetch_heartbeats(self):
nc = NATS()
await nc.connect()

js = nc.jetstream()

await js.add_stream(name="events", subjects=["events.>"])
await js.add_consumer(
"events",
durable_name="a",
max_deliver=2,
max_waiting=5,
ack_wait=30,
max_ack_pending=5,
filter_subject="events.>",
)
sub = await js.pull_subscribe_bind("a", stream="events")

with pytest.raises(nats.js.errors.FetchTimeoutError):
await sub.fetch(1, timeout=1, heartbeat=0.1)

with pytest.raises(asyncio.TimeoutError):
await sub.fetch(1, timeout=1, heartbeat=0.1)

with pytest.raises(nats.errors.TimeoutError):
await sub.fetch(1, timeout=1, heartbeat=0.1)

for i in range(0, 15):
await js.publish("events.%d" % i, b'i:%d' % i)

# Fetch(n)
msgs = await sub.fetch(5, timeout=5, heartbeat=0.1)
assert len(msgs) == 5
for msg in msgs:
await msg.ack_sync()
info = await js.consumer_info("events", "a")
assert info.num_pending == 10

# Fetch(1)
msgs = await sub.fetch(1, timeout=1, heartbeat=0.1)
assert len(msgs) == 1
for msg in msgs:
await msg.ack_sync()

# Receive some messages.
msgs = await sub.fetch(20, timeout=2, heartbeat=0.1)
for msg in msgs:
await msg.ack_sync()
msgs = await sub.fetch(4, timeout=2, heartbeat=0.1)
for msg in msgs:
await msg.ack_sync()

# Check that messages were removed from being pending.
info = await js.consumer_info("events", "a")
assert info.num_pending == 0

# Ask for more messages but there aren't any.
with pytest.raises(nats.js.errors.FetchTimeoutError):
await sub.fetch(4, timeout=1, heartbeat=0.1)

with pytest.raises(asyncio.TimeoutError):
msgs = await sub.fetch(4, timeout=1, heartbeat=0.1)

with pytest.raises(nats.errors.TimeoutError):
msgs = await sub.fetch(4, timeout=1, heartbeat=0.1)

with pytest.raises(nats.js.errors.APIError) as err:
await sub.fetch(1, timeout=1, heartbeat=0.5)
assert err.value.description == 'Bad Request - heartbeat value too large'

# Example of catching fetch timeout instead first.
got_fetch_timeout = False
got_io_timeout = False
try:
await sub.fetch(1, timeout=1, heartbeat=0.2)
except nats.js.errors.FetchTimeoutError:
got_fetch_timeout = True
except nats.errors.TimeoutError:
got_io_timeout = True
assert got_fetch_timeout == True
assert got_io_timeout == False

got_fetch_timeout = False
got_io_timeout = False
try:
await sub.fetch(1, timeout=1, heartbeat=0.2)
except nats.errors.TimeoutError:
got_io_timeout = True
except nats.js.errors.FetchTimeoutError:
got_fetch_timeout = True
assert got_fetch_timeout == False
assert got_io_timeout == True

await nc.close()


class JSMTest(SingleJetStreamServerTestCase):

Expand Down

0 comments on commit 10d2cc6

Please sign in to comment.