From 2c9988d54c2a301f7e68a612d0768059cba073f0 Mon Sep 17 00:00:00 2001 From: Gabriel Levcovitz Date: Tue, 15 Oct 2024 17:32:06 -0300 Subject: [PATCH] refactor(p2p): async on_new_vertex --- hathor/p2p/dependencies/p2p_dependencies.py | 2 +- .../single_process_p2p_dependencies.py | 2 +- hathor/p2p/states/base.py | 13 ++++--- hathor/p2p/states/ready.py | 2 +- hathor/p2p/sync_agent.py | 4 +-- hathor/p2p/sync_v1/agent.py | 15 ++++---- hathor/p2p/sync_v2/agent.py | 22 ++++++------ .../sync_v2/blockchain_streaming_client.py | 4 +-- hathor/p2p/sync_v2/mempool.py | 32 ++++++++--------- .../sync_v2/transaction_streaming_client.py | 24 ++++++------- hathor/utils/twisted.py | 34 +++++++++++++++++++ tests/p2p/test_sync.py | 5 +-- 12 files changed, 97 insertions(+), 62 deletions(-) create mode 100644 hathor/utils/twisted.py diff --git a/hathor/p2p/dependencies/p2p_dependencies.py b/hathor/p2p/dependencies/p2p_dependencies.py index d9e05ac37..66dbd9394 100644 --- a/hathor/p2p/dependencies/p2p_dependencies.py +++ b/hathor/p2p/dependencies/p2p_dependencies.py @@ -47,7 +47,7 @@ def __init__( self.vertex_parser = vertex_parser @abstractmethod - def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool = True) -> bool: + async def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool = True) -> bool: raise NotImplementedError @abstractmethod diff --git a/hathor/p2p/dependencies/single_process_p2p_dependencies.py b/hathor/p2p/dependencies/single_process_p2p_dependencies.py index 01eb1281f..677ff10df 100644 --- a/hathor/p2p/dependencies/single_process_p2p_dependencies.py +++ b/hathor/p2p/dependencies/single_process_p2p_dependencies.py @@ -60,7 +60,7 @@ def __init__( self._indexes = not_none(tx_storage.indexes) @override - def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool = True) -> bool: + async def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool = True) -> bool: return self._vertex_handler.on_new_vertex(vertex=vertex, fails_silently=fails_silently) @override diff --git a/hathor/p2p/states/base.py b/hathor/p2p/states/base.py index 75a69140e..677ea5049 100644 --- a/hathor/p2p/states/base.py +++ b/hathor/p2p/states/base.py @@ -13,7 +13,7 @@ # limitations under the License. from collections.abc import Coroutine -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeAlias from structlog import get_logger from twisted.internet.defer import Deferred @@ -27,13 +27,16 @@ logger = get_logger() +CommandHandler: TypeAlias = ( + Callable[[str], None] | + Callable[[str], Deferred[None]] | + Callable[[str], Coroutine[Deferred[None], Any, None]] +) + class BaseState: protocol: 'HathorProtocol' - cmd_map: dict[ - ProtocolMessages, - Callable[[str], None] | Callable[[str], Deferred[None]] | Callable[[str], Coroutine[Deferred[None], Any, None]] - ] + cmd_map: dict[ProtocolMessages, CommandHandler] def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies): self.log = logger.new(**protocol.get_logger_context()) diff --git a/hathor/p2p/states/ready.py b/hathor/p2p/states/ready.py index 3c11e7fff..d0b0584f8 100644 --- a/hathor/p2p/states/ready.py +++ b/hathor/p2p/states/ready.py @@ -23,7 +23,6 @@ from hathor.p2p.messages import ProtocolMessages from hathor.p2p.peer import PublicPeer, UnverifiedPeer from hathor.p2p.states.base import BaseState -from hathor.p2p.sync_agent import SyncAgent from hathor.p2p.utils import to_height_info, to_serializable_best_blockchain from hathor.transaction import BaseTransaction from hathor.util import json_dumps, json_loads @@ -106,6 +105,7 @@ def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies) self.log.debug(f'loading {sync_version}') sync_factory = connections.get_sync_factory(sync_version) + from hathor.p2p.sync_agent import SyncAgent self.sync_agent: SyncAgent = sync_factory.create_sync_agent(self.protocol) self.cmd_map.update(self.sync_agent.get_cmd_dict()) diff --git a/hathor/p2p/sync_agent.py b/hathor/p2p/sync_agent.py index a700335ed..8127f999a 100644 --- a/hathor/p2p/sync_agent.py +++ b/hathor/p2p/sync_agent.py @@ -13,9 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Callable from hathor.p2p.messages import ProtocolMessages +from hathor.p2p.states.base import CommandHandler from hathor.transaction import BaseTransaction @@ -36,7 +36,7 @@ def stop(self) -> None: raise NotImplementedError @abstractmethod - def get_cmd_dict(self) -> dict[ProtocolMessages, Callable[[str], None]]: + def get_cmd_dict(self) -> dict[ProtocolMessages, CommandHandler]: """Command dict to add to the protocol handler""" raise NotImplementedError diff --git a/hathor/p2p/sync_v1/agent.py b/hathor/p2p/sync_v1/agent.py index c1df77d16..70135e307 100644 --- a/hathor/p2p/sync_v1/agent.py +++ b/hathor/p2p/sync_v1/agent.py @@ -15,7 +15,7 @@ import base64 import struct from math import inf -from typing import TYPE_CHECKING, Any, Callable, Generator, Iterator, Optional +from typing import TYPE_CHECKING, Any, Generator, Iterator, Optional from weakref import WeakSet from structlog import get_logger @@ -24,6 +24,7 @@ from hathor.p2p import P2PDependencies from hathor.p2p.messages import GetNextPayload, GetTipsPayload, NextPayload, ProtocolMessages, TipsPayload +from hathor.p2p.states.base import CommandHandler from hathor.p2p.sync_agent import SyncAgent from hathor.p2p.sync_v1.downloader import Downloader from hathor.transaction import BaseTransaction @@ -128,7 +129,7 @@ def get_status(self): 'synced_timestamp': self.synced_timestamp, } - def get_cmd_dict(self) -> dict[ProtocolMessages, Callable[[str], None]]: + def get_cmd_dict(self) -> dict[ProtocolMessages, CommandHandler]: """ Return a dict of messages. """ return { @@ -249,7 +250,7 @@ def get_data(self, hash_bytes: bytes) -> Deferred: :rtype: Deferred """ d = self.downloader.get_tx(hash_bytes, self) - d.addCallback(self.on_tx_success) + d.addCallback(lambda tx: Deferred.fromCoroutine(self.on_tx_success(tx))) d.addErrback(self.on_get_data_failed, hash_bytes) return d @@ -591,7 +592,7 @@ def send_data(self, tx: BaseTransaction) -> None: payload = base64.b64encode(tx.get_struct()).decode('ascii') self.send_message(ProtocolMessages.DATA, payload) - def handle_data(self, payload: str) -> None: + async def handle_data(self, payload: str) -> None: """ Handle a received DATA message. """ if not payload: @@ -629,7 +630,7 @@ def handle_data(self, payload: str) -> None: self.log.info('tx received in real time from peer', tx=tx.hash_hex, peer=self.protocol.get_peer_id()) # If we have not requested the data, it is a new transaction being propagated # in the network, thus, we propagate it as well. - result = self.dependencies.on_new_vertex(tx) + result = await self.dependencies.on_new_vertex(tx) if result: self.protocol.connections.send_tx_to_peers(tx) self.update_received_stats(tx, result) @@ -668,7 +669,7 @@ def remove_deferred(self, reason: 'Failure', hash_bytes: bytes) -> None: key = self.get_data_key(hash_bytes) self.deferred_by_key.pop(key, None) - def on_tx_success(self, tx: 'BaseTransaction') -> 'BaseTransaction': + async def on_tx_success(self, tx: 'BaseTransaction') -> 'BaseTransaction': """ Callback for the deferred when we add a new tx to the DAG """ # When we have multiple callbacks in a deferred @@ -680,7 +681,7 @@ def on_tx_success(self, tx: 'BaseTransaction') -> 'BaseTransaction': success = True else: # Add tx to the DAG. - success = self.dependencies.on_new_vertex(tx) + success = await self.dependencies.on_new_vertex(tx) if success: self.protocol.connections.send_tx_to_peers(tx) # Updating stats data diff --git a/hathor/p2p/sync_v2/agent.py b/hathor/p2p/sync_v2/agent.py index 6397eb17b..19adbe721 100644 --- a/hathor/p2p/sync_v2/agent.py +++ b/hathor/p2p/sync_v2/agent.py @@ -18,7 +18,7 @@ import struct from collections import OrderedDict from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Generator, NamedTuple, Optional +from typing import TYPE_CHECKING, Any, Generator, NamedTuple, Optional from structlog import get_logger from twisted.internet.defer import Deferred, inlineCallbacks @@ -27,6 +27,7 @@ from hathor.exception import InvalidNewTransaction from hathor.p2p import P2PDependencies from hathor.p2p.messages import ProtocolMessages +from hathor.p2p.states.base import CommandHandler from hathor.p2p.sync_agent import SyncAgent from hathor.p2p.sync_v2.blockchain_streaming_client import BlockchainStreamingClient, StreamingError from hathor.p2p.sync_v2.mempool import SyncMempoolManager @@ -228,7 +229,7 @@ def stop(self) -> None: if self._lc_run.running: self._lc_run.stop() - def get_cmd_dict(self) -> dict[ProtocolMessages, Callable[[str], None]]: + def get_cmd_dict(self) -> dict[ProtocolMessages, CommandHandler]: """ Return a dict of messages of the plugin. For further information about each message, see the RFC. @@ -596,18 +597,17 @@ def find_best_common_block(self, self.log.debug('find_best_common_block n-ary search finished', lo=lo, hi=hi) return lo - @inlineCallbacks - def on_block_complete(self, blk: Block, vertex_list: list[BaseTransaction]) -> Generator[Any, Any, None]: + async def on_block_complete(self, blk: Block, vertex_list: list[BaseTransaction]) -> None: """This method is called when a block and its transactions are downloaded.""" # Note: Any vertex and block could have already been added by another concurrent syncing peer. try: for tx in vertex_list: if not self.dependencies.vertex_exists(tx.hash): - self.dependencies.on_new_vertex(tx, fails_silently=False) - yield deferLater(self.reactor, 0, lambda: None) + await self.dependencies.on_new_vertex(tx, fails_silently=False) + await deferLater(self.reactor, 0, lambda: None) if not self.dependencies.vertex_exists(blk.hash): - self.dependencies.on_new_vertex(blk, fails_silently=False) + await self.dependencies.on_new_vertex(blk, fails_silently=False) except InvalidNewTransaction: self.protocol.send_error_and_close_connection('invalid vertex received') @@ -752,7 +752,7 @@ def handle_blocks_end(self, payload: str) -> None: self._blk_streaming_client.handle_blocks_end(response_code) self.log.debug('block streaming ended', reason=str(response_code)) - def handle_blocks(self, payload: str) -> None: + async def handle_blocks(self, payload: str) -> None: """ Handle a BLOCKS message. """ if self.state is not PeerState.SYNCING_BLOCKS: @@ -769,7 +769,7 @@ def handle_blocks(self, payload: str) -> None: return assert self._blk_streaming_client is not None - self._blk_streaming_client.handle_blocks(blk) + await self._blk_streaming_client.handle_blocks(blk) def send_stop_block_streaming(self) -> None: """ Send a STOP-BLOCK-STREAMING message. @@ -1109,7 +1109,7 @@ def handle_get_data(self, payload: str) -> None: # In case the tx does not exist we send a NOT-FOUND message self.send_message(ProtocolMessages.NOT_FOUND, txid_hex) - def handle_data(self, payload: str) -> None: + async def handle_data(self, payload: str) -> None: """ Handle a DATA message. """ if not payload: @@ -1154,7 +1154,7 @@ def handle_data(self, payload: str) -> None: if self.dependencies.can_validate_full(tx): self.log.debug('tx received in real time from peer', tx=tx.hash_hex, peer=self.protocol.get_peer_id()) try: - result = self.dependencies.on_new_vertex(tx, fails_silently=False) + result = await self.dependencies.on_new_vertex(tx, fails_silently=False) if result: self.protocol.connections.send_tx_to_peers(tx) except InvalidNewTransaction: diff --git a/hathor/p2p/sync_v2/blockchain_streaming_client.py b/hathor/p2p/sync_v2/blockchain_streaming_client.py index 866b860a2..0f6fda711 100644 --- a/hathor/p2p/sync_v2/blockchain_streaming_client.py +++ b/hathor/p2p/sync_v2/blockchain_streaming_client.py @@ -81,7 +81,7 @@ def fails(self, reason: 'StreamingError') -> None: """Fail the execution by resolving the deferred with an error.""" self._deferred.errback(reason) - def handle_blocks(self, blk: Block) -> None: + async def handle_blocks(self, blk: Block) -> None: """This method is called by the sync agent when a BLOCKS message is received.""" if self._deferred.called: return @@ -133,7 +133,7 @@ def handle_blocks(self, blk: Block) -> None: if self.dependencies.can_validate_full(blk): try: - self.dependencies.on_new_vertex(blk, fails_silently=False) + await self.dependencies.on_new_vertex(blk, fails_silently=False) except HathorError: self.fails(InvalidVertexError(blk.hash.hex())) return diff --git a/hathor/p2p/sync_v2/mempool.py b/hathor/p2p/sync_v2/mempool.py index eddd563b3..eecea6e5b 100644 --- a/hathor/p2p/sync_v2/mempool.py +++ b/hathor/p2p/sync_v2/mempool.py @@ -13,14 +13,15 @@ # limitations under the License. from collections import deque -from typing import TYPE_CHECKING, Any, Generator, Optional +from typing import TYPE_CHECKING, Optional from structlog import get_logger -from twisted.internet.defer import Deferred, inlineCallbacks +from twisted.internet.defer import Deferred from hathor.exception import InvalidNewTransaction from hathor.p2p import P2PDependencies from hathor.transaction import BaseTransaction +from hathor.utils.twisted import call_coro_later if TYPE_CHECKING: from hathor.p2p.sync_v2.agent import NodeBlockSync @@ -62,7 +63,7 @@ def run(self) -> Deferred[bool]: assert self._deferred is not None return self._deferred self._is_running = True - self.reactor.callLater(0, self._run) + call_coro_later(self.reactor, 0, self._run) # TODO Implement a stop() and call it after N minutes. @@ -70,11 +71,10 @@ def run(self) -> Deferred[bool]: self._deferred = Deferred() return self._deferred - @inlineCallbacks - def _run(self) -> Generator[Deferred, Any, None]: + async def _run(self) -> None: is_synced = False try: - is_synced = yield self._unsafe_run() + is_synced = await self._unsafe_run() except InvalidNewTransaction: return finally: @@ -84,29 +84,27 @@ def _run(self) -> Generator[Deferred, Any, None]: self._deferred.callback(is_synced) self._deferred = None - @inlineCallbacks - def _unsafe_run(self) -> Generator[Deferred, Any, bool]: + async def _unsafe_run(self) -> bool: """Run a single loop of the sync-v2 mempool.""" if not self.missing_tips: # No missing tips? Let's get them! - tx_hashes: list[bytes] = yield self.sync_agent.get_tips() + tx_hashes: list[bytes] = await self.sync_agent.get_tips() self.missing_tips.update(h for h in tx_hashes if not self.dependencies.vertex_exists(h)) while self.missing_tips: self.log.debug('We have missing tips! Let\'s start!', missing_tips=[x.hex() for x in self.missing_tips]) tx_id = next(iter(self.missing_tips)) - tx: BaseTransaction = yield self.sync_agent.get_tx(tx_id) + tx: BaseTransaction = await self.sync_agent.get_tx(tx_id) # Stack used by the DFS in the dependencies. # We use a deque for performance reasons. self.log.debug('start mempool DSF', tx=tx.hash_hex) - yield self._dfs(deque([tx])) + await self._dfs(deque([tx])) if not self.missing_tips: return True return False - @inlineCallbacks - def _dfs(self, stack: deque[BaseTransaction]) -> Generator[Deferred, Any, None]: + async def _dfs(self, stack: deque[BaseTransaction]) -> None: """DFS method.""" while stack: tx = stack[-1] @@ -114,11 +112,11 @@ def _dfs(self, stack: deque[BaseTransaction]) -> Generator[Deferred, Any, None]: missing_dep = self._next_missing_dep(tx) if missing_dep is None: self.log.debug(r'No dependencies missing! \o/') - self._add_tx(tx) + await self._add_tx(tx) assert tx == stack.pop() else: self.log.debug('Iterate in the DFS.', missing_dep=missing_dep.hex()) - tx_dep = yield self.sync_agent.get_tx(missing_dep) + tx_dep = await self.sync_agent.get_tx(missing_dep) stack.append(tx_dep) if len(stack) > self.MAX_STACK_LENGTH: stack.popleft() @@ -134,13 +132,13 @@ def _next_missing_dep(self, tx: BaseTransaction) -> Optional[bytes]: return parent return None - def _add_tx(self, tx: BaseTransaction) -> None: + async def _add_tx(self, tx: BaseTransaction) -> None: """Add tx to the DAG.""" self.missing_tips.discard(tx.hash) if self.dependencies.vertex_exists(tx.hash): return try: - result = self.dependencies.on_new_vertex(tx, fails_silently=False) + result = await self.dependencies.on_new_vertex(tx, fails_silently=False) if result: self.sync_agent.protocol.connections.send_tx_to_peers(tx) except InvalidNewTransaction: diff --git a/hathor/p2p/sync_v2/transaction_streaming_client.py b/hathor/p2p/sync_v2/transaction_streaming_client.py index 700680514..600b65e0a 100644 --- a/hathor/p2p/sync_v2/transaction_streaming_client.py +++ b/hathor/p2p/sync_v2/transaction_streaming_client.py @@ -13,10 +13,10 @@ # limitations under the License. from collections import deque -from typing import TYPE_CHECKING, Any, Generator, Optional +from typing import TYPE_CHECKING, Optional from structlog import get_logger -from twisted.internet.defer import Deferred, inlineCallbacks +from twisted.internet.defer import Deferred from hathor.p2p import P2PDependencies from hathor.p2p.sync_v2.exception import ( @@ -29,6 +29,7 @@ from hathor.transaction import BaseTransaction from hathor.transaction.exceptions import HathorError, TxValidationError from hathor.types import VertexId +from hathor.utils.twisted import call_coro_later if TYPE_CHECKING: from hathor.p2p.sync_v2.agent import NodeBlockSync @@ -124,10 +125,9 @@ def handle_transaction(self, tx: BaseTransaction) -> None: assert len(self._queue) <= self._tx_max_quantity if not self._is_processing: - self.reactor.callLater(0, self.process_queue) + call_coro_later(self.reactor, 0, self.process_queue) - @inlineCallbacks - def process_queue(self) -> Generator[Any, Any, None]: + async def process_queue(self) -> None: """Process next transaction in the queue.""" if self._deferred.called: return @@ -143,14 +143,13 @@ def process_queue(self) -> Generator[Any, Any, None]: try: tx = self._queue.popleft() self.log.debug('processing tx', tx_id=tx.hash.hex()) - yield self._process_transaction(tx) + await self._process_transaction(tx) finally: self._is_processing = False - self.reactor.callLater(0, self.process_queue) + call_coro_later(self.reactor, 0, self.process_queue) - @inlineCallbacks - def _process_transaction(self, tx: BaseTransaction) -> Generator[Any, Any, None]: + async def _process_transaction(self, tx: BaseTransaction) -> None: """Process transaction.""" # Run basic verification. @@ -185,7 +184,7 @@ def _process_transaction(self, tx: BaseTransaction) -> Generator[Any, Any, None] if not self._waiting_for: self.log.debug('no pending dependencies, processing buffer') while not self._waiting_for: - result = yield self._execute_and_prepare_next() + result = await self._execute_and_prepare_next() if not result: break else: @@ -221,8 +220,7 @@ def check_end(self) -> None: self.log.info('transactions streaming ended', reason=self._response_code, waiting_for=len(self._waiting_for)) self._deferred.callback(self._response_code) - @inlineCallbacks - def _execute_and_prepare_next(self) -> Generator[Any, Any, bool]: + async def _execute_and_prepare_next(self) -> bool: """Add the block and its vertices to the DAG.""" assert not self._waiting_for @@ -231,7 +229,7 @@ def _execute_and_prepare_next(self) -> Generator[Any, Any, bool]: vertex_list.sort(key=lambda v: v.timestamp) try: - yield self.sync_agent.on_block_complete(blk, vertex_list) + await self.sync_agent.on_block_complete(blk, vertex_list) except HathorError as e: self.fails(InvalidVertexError(repr(e))) return False diff --git a/hathor/utils/twisted.py b/hathor/utils/twisted.py new file mode 100644 index 000000000..f3e924b69 --- /dev/null +++ b/hathor/utils/twisted.py @@ -0,0 +1,34 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Coroutine, ParamSpec + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IDelayedCall + +from hathor.reactor import ReactorProtocol + +P = ParamSpec('P') + + +def call_coro_later( + reactor: ReactorProtocol, + delay: float, + callable: Callable[P, Coroutine[Deferred[None], Any, None]], + *args: P.args, + **kwargs: P.kwargs, +) -> IDelayedCall: + """Utility function for performing twisted's `reactor.callLater` on coroutines (async functions).""" + coro = callable(*args, **kwargs) + return reactor.callLater(delay, lambda: Deferred.fromCoroutine(coro)) diff --git a/tests/p2p/test_sync.py b/tests/p2p/test_sync.py index 33fdaefc5..0303d6180 100644 --- a/tests/p2p/test_sync.py +++ b/tests/p2p/test_sync.py @@ -1,3 +1,4 @@ +from twisted.internet.defer import Deferred from twisted.python.failure import Failure from hathor.checkpoint import Checkpoint as cp @@ -296,7 +297,7 @@ def test_downloader(self) -> None: self.assertTrue(isinstance(conn.proto2.state, PeerIdState)) deferred1 = downloader.get_tx(blocks[0].hash, node_sync1) - deferred1.addCallback(node_sync1.on_tx_success) + deferred1.addCallback(lambda tx: Deferred.fromCoroutine(node_sync1.on_tx_success(tx))) self.assertEqual(len(downloader.pending_transactions), 1) @@ -305,7 +306,7 @@ def test_downloader(self) -> None: self.assertEqual(len(downloader.downloading_deque), 1) deferred2 = downloader.get_tx(blocks[0].hash, node_sync2) - deferred2.addCallback(node_sync2.on_tx_success) + deferred2.addCallback(lambda tx: Deferred.fromCoroutine(node_sync2.on_tx_success(tx))) self.assertEqual(len(downloader.pending_transactions), 1) self.assertEqual(len(downloader.pending_transactions[blocks[0].hash].connections), 2)