Skip to content

Commit

Permalink
refactor(p2p): async on_new_vertex
Browse files Browse the repository at this point in the history
  • Loading branch information
glevco committed Oct 15, 2024
1 parent 2af3b11 commit 2c9988d
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 62 deletions.
2 changes: 1 addition & 1 deletion hathor/p2p/dependencies/p2p_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self.vertex_parser = vertex_parser

@abstractmethod
def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool = True) -> bool:
async def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool = True) -> bool:
raise NotImplementedError

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion hathor/p2p/dependencies/single_process_p2p_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self._indexes = not_none(tx_storage.indexes)

@override
def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool = True) -> bool:
async def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool = True) -> bool:
return self._vertex_handler.on_new_vertex(vertex=vertex, fails_silently=fails_silently)

@override
Expand Down
13 changes: 8 additions & 5 deletions hathor/p2p/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections.abc import Coroutine
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeAlias

from structlog import get_logger
from twisted.internet.defer import Deferred
Expand All @@ -27,13 +27,16 @@

logger = get_logger()

CommandHandler: TypeAlias = (
Callable[[str], None] |
Callable[[str], Deferred[None]] |
Callable[[str], Coroutine[Deferred[None], Any, None]]
)


class BaseState:
protocol: 'HathorProtocol'
cmd_map: dict[
ProtocolMessages,
Callable[[str], None] | Callable[[str], Deferred[None]] | Callable[[str], Coroutine[Deferred[None], Any, None]]
]
cmd_map: dict[ProtocolMessages, CommandHandler]

def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies):
self.log = logger.new(**protocol.get_logger_context())
Expand Down
2 changes: 1 addition & 1 deletion hathor/p2p/states/ready.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from hathor.p2p.messages import ProtocolMessages
from hathor.p2p.peer import PublicPeer, UnverifiedPeer
from hathor.p2p.states.base import BaseState
from hathor.p2p.sync_agent import SyncAgent
from hathor.p2p.utils import to_height_info, to_serializable_best_blockchain
from hathor.transaction import BaseTransaction
from hathor.util import json_dumps, json_loads
Expand Down Expand Up @@ -106,6 +105,7 @@ def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies)
self.log.debug(f'loading {sync_version}')
sync_factory = connections.get_sync_factory(sync_version)

from hathor.p2p.sync_agent import SyncAgent
self.sync_agent: SyncAgent = sync_factory.create_sync_agent(self.protocol)
self.cmd_map.update(self.sync_agent.get_cmd_dict())

Expand Down
4 changes: 2 additions & 2 deletions hathor/p2p/sync_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Callable

from hathor.p2p.messages import ProtocolMessages
from hathor.p2p.states.base import CommandHandler
from hathor.transaction import BaseTransaction


Expand All @@ -36,7 +36,7 @@ def stop(self) -> None:
raise NotImplementedError

@abstractmethod
def get_cmd_dict(self) -> dict[ProtocolMessages, Callable[[str], None]]:
def get_cmd_dict(self) -> dict[ProtocolMessages, CommandHandler]:
"""Command dict to add to the protocol handler"""
raise NotImplementedError

Expand Down
15 changes: 8 additions & 7 deletions hathor/p2p/sync_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import base64
import struct
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterator, Optional
from typing import TYPE_CHECKING, Any, Generator, Iterator, Optional
from weakref import WeakSet

from structlog import get_logger
Expand All @@ -24,6 +24,7 @@

from hathor.p2p import P2PDependencies
from hathor.p2p.messages import GetNextPayload, GetTipsPayload, NextPayload, ProtocolMessages, TipsPayload
from hathor.p2p.states.base import CommandHandler
from hathor.p2p.sync_agent import SyncAgent
from hathor.p2p.sync_v1.downloader import Downloader
from hathor.transaction import BaseTransaction
Expand Down Expand Up @@ -128,7 +129,7 @@ def get_status(self):
'synced_timestamp': self.synced_timestamp,
}

def get_cmd_dict(self) -> dict[ProtocolMessages, Callable[[str], None]]:
def get_cmd_dict(self) -> dict[ProtocolMessages, CommandHandler]:
""" Return a dict of messages.
"""
return {
Expand Down Expand Up @@ -249,7 +250,7 @@ def get_data(self, hash_bytes: bytes) -> Deferred:
:rtype: Deferred
"""
d = self.downloader.get_tx(hash_bytes, self)
d.addCallback(self.on_tx_success)
d.addCallback(lambda tx: Deferred.fromCoroutine(self.on_tx_success(tx)))
d.addErrback(self.on_get_data_failed, hash_bytes)
return d

Expand Down Expand Up @@ -591,7 +592,7 @@ def send_data(self, tx: BaseTransaction) -> None:
payload = base64.b64encode(tx.get_struct()).decode('ascii')
self.send_message(ProtocolMessages.DATA, payload)

def handle_data(self, payload: str) -> None:
async def handle_data(self, payload: str) -> None:
""" Handle a received DATA message.
"""
if not payload:
Expand Down Expand Up @@ -629,7 +630,7 @@ def handle_data(self, payload: str) -> None:
self.log.info('tx received in real time from peer', tx=tx.hash_hex, peer=self.protocol.get_peer_id())
# If we have not requested the data, it is a new transaction being propagated
# in the network, thus, we propagate it as well.
result = self.dependencies.on_new_vertex(tx)
result = await self.dependencies.on_new_vertex(tx)
if result:
self.protocol.connections.send_tx_to_peers(tx)
self.update_received_stats(tx, result)
Expand Down Expand Up @@ -668,7 +669,7 @@ def remove_deferred(self, reason: 'Failure', hash_bytes: bytes) -> None:
key = self.get_data_key(hash_bytes)
self.deferred_by_key.pop(key, None)

def on_tx_success(self, tx: 'BaseTransaction') -> 'BaseTransaction':
async def on_tx_success(self, tx: 'BaseTransaction') -> 'BaseTransaction':
""" Callback for the deferred when we add a new tx to the DAG
"""
# When we have multiple callbacks in a deferred
Expand All @@ -680,7 +681,7 @@ def on_tx_success(self, tx: 'BaseTransaction') -> 'BaseTransaction':
success = True
else:
# Add tx to the DAG.
success = self.dependencies.on_new_vertex(tx)
success = await self.dependencies.on_new_vertex(tx)
if success:
self.protocol.connections.send_tx_to_peers(tx)
# Updating stats data
Expand Down
22 changes: 11 additions & 11 deletions hathor/p2p/sync_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import struct
from collections import OrderedDict
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Generator, NamedTuple, Optional
from typing import TYPE_CHECKING, Any, Generator, NamedTuple, Optional

from structlog import get_logger
from twisted.internet.defer import Deferred, inlineCallbacks
Expand All @@ -27,6 +27,7 @@
from hathor.exception import InvalidNewTransaction
from hathor.p2p import P2PDependencies
from hathor.p2p.messages import ProtocolMessages
from hathor.p2p.states.base import CommandHandler
from hathor.p2p.sync_agent import SyncAgent
from hathor.p2p.sync_v2.blockchain_streaming_client import BlockchainStreamingClient, StreamingError
from hathor.p2p.sync_v2.mempool import SyncMempoolManager
Expand Down Expand Up @@ -228,7 +229,7 @@ def stop(self) -> None:
if self._lc_run.running:
self._lc_run.stop()

def get_cmd_dict(self) -> dict[ProtocolMessages, Callable[[str], None]]:
def get_cmd_dict(self) -> dict[ProtocolMessages, CommandHandler]:
""" Return a dict of messages of the plugin.
For further information about each message, see the RFC.
Expand Down Expand Up @@ -596,18 +597,17 @@ def find_best_common_block(self,
self.log.debug('find_best_common_block n-ary search finished', lo=lo, hi=hi)
return lo

@inlineCallbacks
def on_block_complete(self, blk: Block, vertex_list: list[BaseTransaction]) -> Generator[Any, Any, None]:
async def on_block_complete(self, blk: Block, vertex_list: list[BaseTransaction]) -> None:
"""This method is called when a block and its transactions are downloaded."""
# Note: Any vertex and block could have already been added by another concurrent syncing peer.
try:
for tx in vertex_list:
if not self.dependencies.vertex_exists(tx.hash):
self.dependencies.on_new_vertex(tx, fails_silently=False)
yield deferLater(self.reactor, 0, lambda: None)
await self.dependencies.on_new_vertex(tx, fails_silently=False)
await deferLater(self.reactor, 0, lambda: None)

if not self.dependencies.vertex_exists(blk.hash):
self.dependencies.on_new_vertex(blk, fails_silently=False)
await self.dependencies.on_new_vertex(blk, fails_silently=False)
except InvalidNewTransaction:
self.protocol.send_error_and_close_connection('invalid vertex received')

Expand Down Expand Up @@ -752,7 +752,7 @@ def handle_blocks_end(self, payload: str) -> None:
self._blk_streaming_client.handle_blocks_end(response_code)
self.log.debug('block streaming ended', reason=str(response_code))

def handle_blocks(self, payload: str) -> None:
async def handle_blocks(self, payload: str) -> None:
""" Handle a BLOCKS message.
"""
if self.state is not PeerState.SYNCING_BLOCKS:
Expand All @@ -769,7 +769,7 @@ def handle_blocks(self, payload: str) -> None:
return

assert self._blk_streaming_client is not None
self._blk_streaming_client.handle_blocks(blk)
await self._blk_streaming_client.handle_blocks(blk)

def send_stop_block_streaming(self) -> None:
""" Send a STOP-BLOCK-STREAMING message.
Expand Down Expand Up @@ -1109,7 +1109,7 @@ def handle_get_data(self, payload: str) -> None:
# In case the tx does not exist we send a NOT-FOUND message
self.send_message(ProtocolMessages.NOT_FOUND, txid_hex)

def handle_data(self, payload: str) -> None:
async def handle_data(self, payload: str) -> None:
""" Handle a DATA message.
"""
if not payload:
Expand Down Expand Up @@ -1154,7 +1154,7 @@ def handle_data(self, payload: str) -> None:
if self.dependencies.can_validate_full(tx):
self.log.debug('tx received in real time from peer', tx=tx.hash_hex, peer=self.protocol.get_peer_id())
try:
result = self.dependencies.on_new_vertex(tx, fails_silently=False)
result = await self.dependencies.on_new_vertex(tx, fails_silently=False)
if result:
self.protocol.connections.send_tx_to_peers(tx)
except InvalidNewTransaction:
Expand Down
4 changes: 2 additions & 2 deletions hathor/p2p/sync_v2/blockchain_streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def fails(self, reason: 'StreamingError') -> None:
"""Fail the execution by resolving the deferred with an error."""
self._deferred.errback(reason)

def handle_blocks(self, blk: Block) -> None:
async def handle_blocks(self, blk: Block) -> None:
"""This method is called by the sync agent when a BLOCKS message is received."""
if self._deferred.called:
return
Expand Down Expand Up @@ -133,7 +133,7 @@ def handle_blocks(self, blk: Block) -> None:

if self.dependencies.can_validate_full(blk):
try:
self.dependencies.on_new_vertex(blk, fails_silently=False)
await self.dependencies.on_new_vertex(blk, fails_silently=False)
except HathorError:
self.fails(InvalidVertexError(blk.hash.hex()))
return
Expand Down
32 changes: 15 additions & 17 deletions hathor/p2p/sync_v2/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

from collections import deque
from typing import TYPE_CHECKING, Any, Generator, Optional
from typing import TYPE_CHECKING, Optional

from structlog import get_logger
from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.internet.defer import Deferred

from hathor.exception import InvalidNewTransaction
from hathor.p2p import P2PDependencies
from hathor.transaction import BaseTransaction
from hathor.utils.twisted import call_coro_later

if TYPE_CHECKING:
from hathor.p2p.sync_v2.agent import NodeBlockSync
Expand Down Expand Up @@ -62,19 +63,18 @@ def run(self) -> Deferred[bool]:
assert self._deferred is not None
return self._deferred
self._is_running = True
self.reactor.callLater(0, self._run)
call_coro_later(self.reactor, 0, self._run)

# TODO Implement a stop() and call it after N minutes.

assert self._deferred is None
self._deferred = Deferred()
return self._deferred

@inlineCallbacks
def _run(self) -> Generator[Deferred, Any, None]:
async def _run(self) -> None:
is_synced = False
try:
is_synced = yield self._unsafe_run()
is_synced = await self._unsafe_run()
except InvalidNewTransaction:
return
finally:
Expand All @@ -84,41 +84,39 @@ def _run(self) -> Generator[Deferred, Any, None]:
self._deferred.callback(is_synced)
self._deferred = None

@inlineCallbacks
def _unsafe_run(self) -> Generator[Deferred, Any, bool]:
async def _unsafe_run(self) -> bool:
"""Run a single loop of the sync-v2 mempool."""
if not self.missing_tips:
# No missing tips? Let's get them!
tx_hashes: list[bytes] = yield self.sync_agent.get_tips()
tx_hashes: list[bytes] = await self.sync_agent.get_tips()
self.missing_tips.update(h for h in tx_hashes if not self.dependencies.vertex_exists(h))

while self.missing_tips:
self.log.debug('We have missing tips! Let\'s start!', missing_tips=[x.hex() for x in self.missing_tips])
tx_id = next(iter(self.missing_tips))
tx: BaseTransaction = yield self.sync_agent.get_tx(tx_id)
tx: BaseTransaction = await self.sync_agent.get_tx(tx_id)
# Stack used by the DFS in the dependencies.
# We use a deque for performance reasons.
self.log.debug('start mempool DSF', tx=tx.hash_hex)
yield self._dfs(deque([tx]))
await self._dfs(deque([tx]))

if not self.missing_tips:
return True
return False

@inlineCallbacks
def _dfs(self, stack: deque[BaseTransaction]) -> Generator[Deferred, Any, None]:
async def _dfs(self, stack: deque[BaseTransaction]) -> None:
"""DFS method."""
while stack:
tx = stack[-1]
self.log.debug('step mempool DSF', tx=tx.hash_hex, stack_len=len(stack))
missing_dep = self._next_missing_dep(tx)
if missing_dep is None:
self.log.debug(r'No dependencies missing! \o/')
self._add_tx(tx)
await self._add_tx(tx)
assert tx == stack.pop()
else:
self.log.debug('Iterate in the DFS.', missing_dep=missing_dep.hex())
tx_dep = yield self.sync_agent.get_tx(missing_dep)
tx_dep = await self.sync_agent.get_tx(missing_dep)
stack.append(tx_dep)
if len(stack) > self.MAX_STACK_LENGTH:
stack.popleft()
Expand All @@ -134,13 +132,13 @@ def _next_missing_dep(self, tx: BaseTransaction) -> Optional[bytes]:
return parent
return None

def _add_tx(self, tx: BaseTransaction) -> None:
async def _add_tx(self, tx: BaseTransaction) -> None:
"""Add tx to the DAG."""
self.missing_tips.discard(tx.hash)
if self.dependencies.vertex_exists(tx.hash):
return
try:
result = self.dependencies.on_new_vertex(tx, fails_silently=False)
result = await self.dependencies.on_new_vertex(tx, fails_silently=False)
if result:
self.sync_agent.protocol.connections.send_tx_to_peers(tx)
except InvalidNewTransaction:
Expand Down
Loading

0 comments on commit 2c9988d

Please sign in to comment.