Skip to content

Commit

Permalink
feat(firewall): check valid gossip and stream messages (#1402)
Browse files Browse the repository at this point in the history
  • Loading branch information
b00f authored Jul 11, 2024
1 parent 2115d85 commit 6318fad
Show file tree
Hide file tree
Showing 48 changed files with 642 additions and 405 deletions.
52 changes: 10 additions & 42 deletions consensus/mock.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package consensus

import (
"sync"

"github.com/pactus-project/pactus/crypto/bls"
"github.com/pactus-project/pactus/crypto/hash"
"github.com/pactus-project/pactus/state"
"github.com/pactus-project/pactus/types/proposal"
"github.com/pactus-project/pactus/types/vote"
"github.com/pactus-project/pactus/util/testsuite"
Expand All @@ -13,10 +12,9 @@ import (
var _ Consensus = &MockConsensus{}

type MockConsensus struct {
// This locks prevents the Data Race in tests
lk sync.RWMutex
ts *testsuite.TestSuite

State *state.MockState
ValKey *bls.ValidatorKey
Votes []*vote.Vote
CurProposal *proposal.Proposal
Expand All @@ -26,11 +24,13 @@ type MockConsensus struct {
Round int16
}

func MockingManager(ts *testsuite.TestSuite, valKeys []*bls.ValidatorKey) (Manager, []*MockConsensus) {
func MockingManager(ts *testsuite.TestSuite, st *state.MockState,
valKeys []*bls.ValidatorKey,
) (Manager, []*MockConsensus) {
mocks := make([]*MockConsensus, len(valKeys))
instances := make([]Consensus, len(valKeys))
for i, s := range valKeys {
cons := MockingConsensus(ts, s)
for i, key := range valKeys {
cons := MockingConsensus(ts, st, key)
mocks[i] = cons
instances[i] = cons
}
Expand All @@ -42,9 +42,10 @@ func MockingManager(ts *testsuite.TestSuite, valKeys []*bls.ValidatorKey) (Manag
}, mocks
}

func MockingConsensus(ts *testsuite.TestSuite, valKey *bls.ValidatorKey) *MockConsensus {
func MockingConsensus(ts *testsuite.TestSuite, st *state.MockState, valKey *bls.ValidatorKey) *MockConsensus {
return &MockConsensus{
ts: ts,
State: st,
ValKey: valKey,
}
}
Expand All @@ -54,39 +55,24 @@ func (m *MockConsensus) ConsensusKey() *bls.PublicKey {
}

func (m *MockConsensus) MoveToNewHeight() {
m.lk.Lock()
defer m.lk.Unlock()

m.Height++
m.Height = m.State.LastBlockHeight() + 1
}

func (*MockConsensus) Start() {}

func (m *MockConsensus) AddVote(v *vote.Vote) {
m.lk.Lock()
defer m.lk.Unlock()

m.Votes = append(m.Votes, v)
}

func (m *MockConsensus) AllVotes() []*vote.Vote {
m.lk.Lock()
defer m.lk.Unlock()

return m.Votes
}

func (m *MockConsensus) SetProposal(p *proposal.Proposal) {
m.lk.Lock()
defer m.lk.Unlock()

m.CurProposal = p
}

func (m *MockConsensus) HasVote(h hash.Hash) bool {
m.lk.Lock()
defer m.lk.Unlock()

for _, v := range m.Votes {
if v.Hash() == h {
return true
Expand All @@ -97,16 +83,10 @@ func (m *MockConsensus) HasVote(h hash.Hash) bool {
}

func (m *MockConsensus) Proposal() *proposal.Proposal {
m.lk.Lock()
defer m.lk.Unlock()

return m.CurProposal
}

func (m *MockConsensus) HeightRound() (uint32, int16) {
m.lk.Lock()
defer m.lk.Unlock()

return m.Height, m.Round
}

Expand All @@ -115,9 +95,6 @@ func (*MockConsensus) String() string {
}

func (m *MockConsensus) PickRandomVote(_ int16) *vote.Vote {
m.lk.Lock()
defer m.lk.Unlock()

if len(m.Votes) == 0 {
return nil
}
Expand All @@ -127,22 +104,13 @@ func (m *MockConsensus) PickRandomVote(_ int16) *vote.Vote {
}

func (m *MockConsensus) IsActive() bool {
m.lk.Lock()
defer m.lk.Unlock()

return m.Active
}

func (m *MockConsensus) IsProposer() bool {
m.lk.Lock()
defer m.lk.Unlock()

return m.Proposer
}

func (m *MockConsensus) SetActive(active bool) {
m.lk.Lock()
defer m.lk.Unlock()

m.Active = active
}
8 changes: 7 additions & 1 deletion network/gossip.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ func (g *gossipService) Broadcast(msg []byte, topicID TopicID) error {
g.logger.Debug("publishing new message", "topic", topicID)

switch topicID {
case TopicIDUnspecified:
return InvalidTopicError{TopicID: topicID}

case TopicIDBlock:
if g.topicBlock == nil {
return NotSubscribedError{TopicID: topicID}
Expand Down Expand Up @@ -114,6 +117,9 @@ func (g *gossipService) publish(msg []byte, topic *lp2pps.Topic) error {
// JoinTopic joins to the topic with the given name and subscribes to receive topic messages.
func (g *gossipService) JoinTopic(topicID TopicID, sp ShouldPropagate) error {
switch topicID {
case TopicIDUnspecified:
return InvalidTopicError{TopicID: topicID}

case TopicIDBlock:
if g.topicBlock != nil {
g.logger.Warn("already subscribed to block topic")
Expand Down Expand Up @@ -247,7 +253,7 @@ func (g *gossipService) onReceiveMessage(m *lp2pps.Message) {
return
}

g.logger.Debug("receiving new gossip message", "source", m.GetFrom())
g.logger.Debug("receiving new gossip message", "from", m.ReceivedFrom)
event := &GossipMessage{
From: m.ReceivedFrom,
Data: m.Data,
Expand Down
29 changes: 24 additions & 5 deletions network/gossip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,47 @@ package network
import (
"testing"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)

func TestJoinConsensusTopic(t *testing.T) {
net := makeTestNetwork(t, testConfig(), nil)

msg := []byte("test-consensus-topic")

require.ErrorIs(t, net.gossip.Broadcast(msg, TopicIDConsensus),
assert.ErrorIs(t, net.gossip.Broadcast(msg, TopicIDConsensus),
NotSubscribedError{
TopicID: TopicIDConsensus,
})
require.NoError(t, net.JoinTopic(TopicIDConsensus, alwaysPropagate))
require.NoError(t, net.gossip.Broadcast(msg, TopicIDConsensus))
assert.NoError(t, net.JoinTopic(TopicIDConsensus, alwaysPropagate))
assert.NoError(t, net.gossip.Broadcast(msg, TopicIDConsensus))
}

func TestJoinInvalidTopic(t *testing.T) {
net := makeTestNetwork(t, testConfig(), nil)

assert.ErrorIs(t, net.JoinTopic(TopicIDUnspecified, alwaysPropagate),
InvalidTopicError{
TopicID: TopicIDUnspecified,
})

assert.ErrorIs(t, net.JoinTopic(TopicID(-1), alwaysPropagate),
InvalidTopicError{
TopicID: TopicID(-1),
})
}

func TestInvalidTopic(t *testing.T) {
net := makeTestNetwork(t, testConfig(), nil)

msg := []byte("test-invalid-topic")

require.ErrorIs(t, net.gossip.Broadcast(msg, -1),
assert.ErrorIs(t, net.gossip.Broadcast(msg, TopicIDUnspecified),
InvalidTopicError{
TopicID: TopicIDUnspecified,
})

assert.ErrorIs(t, net.gossip.Broadcast(msg, -1),
InvalidTopicError{
TopicID: TopicID(-1),
})
Expand Down
4 changes: 4 additions & 0 deletions network/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ import (
type TopicID int

const (
TopicIDUnspecified TopicID = 0
TopicIDBlock TopicID = 1
TopicIDTransaction TopicID = 2
TopicIDConsensus TopicID = 3
)

func (t TopicID) String() string {
switch t {
case TopicIDUnspecified:
return "unspecified"

case TopicIDBlock:
return "block"

Expand Down
43 changes: 39 additions & 4 deletions network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func testConfig() *Config {
func shouldReceiveEvent(t *testing.T, net *network, eventType EventType) Event {
t.Helper()

timeout := time.NewTimer(8 * time.Second)
timeout := time.NewTimer(10 * time.Second)

for {
select {
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestNetwork(t *testing.T) {
confM.EnableRelay = true
confM.BootstrapAddrStrings = bootstrapAddresses
confM.ListenAddrStrings = []string{
"/ip4/127.0.0.1/tcp/9987",
"/ip4/127.0.0.1/tcp/0",
}
fmt.Println("Starting Private node M")
networkM := makeTestNetwork(t, confM, []lp2p.Option{
Expand All @@ -179,7 +179,7 @@ func TestNetwork(t *testing.T) {
confN.EnableRelay = true
confN.BootstrapAddrStrings = bootstrapAddresses
confN.ListenAddrStrings = []string{
"/ip4/127.0.0.1/tcp/5678",
"/ip4/127.0.0.1/tcp/0",
}
fmt.Println("Starting Private node N")
networkN := makeTestNetwork(t, confN, []lp2p.Option{
Expand Down Expand Up @@ -232,6 +232,41 @@ func TestNetwork(t *testing.T) {
assert.NotContains(t, protos, lp2pproto.ProtoIDv2Stop)
assert.Contains(t, protos, lp2pproto.ProtoIDv2Hop)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
protos := networkX.Protocols()
assert.NotContains(t, protos, lp2pproto.ProtoIDv2Stop)
assert.NotContains(t, protos, lp2pproto.ProtoIDv2Hop)
}, time.Second, 100*time.Millisecond)
})

t.Run("Reachability", func(t *testing.T) {
fmt.Printf("Running %s\n", t.Name())

require.EventuallyWithT(t, func(_ *assert.CollectT) {
reachability := networkB.ReachabilityStatus()
assert.Equal(t, "Public", reachability)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
reachability := networkM.ReachabilityStatus()
assert.Equal(t, "Private", reachability)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
reachability := networkN.ReachabilityStatus()
assert.Equal(t, "Private", reachability)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
reachability := networkP.ReachabilityStatus()
assert.Equal(t, "Public", reachability)
}, time.Second, 100*time.Millisecond)

require.EventuallyWithT(t, func(_ *assert.CollectT) {
reachability := networkP.ReachabilityStatus()
assert.Equal(t, "Public", reachability)
}, time.Second, 100*time.Millisecond)
})

t.Run("all nodes have at least one connection to the bootstrap node B", func(t *testing.T) {
Expand Down Expand Up @@ -355,7 +390,7 @@ func TestNetwork(t *testing.T) {
// TODO: How to test this?
// t.Run("nodes M and N (private, connected via relay) can communicate using the relay node R", func(t *testing.T) {
// msgM := ts.RandBytes(64)
// require.NoError(t, networkM.SendTo(msgM, networkN.SelfID()))
// networkM.SendTo(msgM, networkN.SelfID())
// eM := shouldReceiveEvent(t, networkN, EventTypeStream).(*StreamMessage)
// assert.Equal(t, readData(t, eM.Reader, len(msgM)), msgM)
// })
Expand Down
2 changes: 1 addition & 1 deletion state/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type MockState struct {

func MockingState(ts *testsuite.TestSuite) *MockState {
cmt, valKeys := ts.GenerateTestCommittee(21)
genDoc := genesis.TestnetGenesis()
genDoc := genesis.MainnetGenesis()

return &MockState{
ts: ts,
Expand Down
9 changes: 9 additions & 0 deletions sync/bundle/message/block_announce.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package message
import (
"fmt"

"github.com/pactus-project/pactus/network"
"github.com/pactus-project/pactus/types/block"
"github.com/pactus-project/pactus/types/certificate"
)
Expand Down Expand Up @@ -35,6 +36,14 @@ func (*BlockAnnounceMessage) Type() Type {
return TypeBlockAnnounce
}

func (*BlockAnnounceMessage) TopicID() network.TopicID {
return network.TopicIDBlock
}

func (*BlockAnnounceMessage) ShouldBroadcast() bool {
return true
}

func (m *BlockAnnounceMessage) String() string {
return fmt.Sprintf("{⌘ %d %v}",
m.Certificate.Height(),
Expand Down
9 changes: 9 additions & 0 deletions sync/bundle/message/blocks_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package message
import (
"fmt"

"github.com/pactus-project/pactus/network"
"github.com/pactus-project/pactus/util/errors"
)

Expand Down Expand Up @@ -39,6 +40,14 @@ func (*BlocksRequestMessage) Type() Type {
return TypeBlocksRequest
}

func (*BlocksRequestMessage) TopicID() network.TopicID {
return network.TopicIDUnspecified
}

func (*BlocksRequestMessage) ShouldBroadcast() bool {
return false
}

func (m *BlocksRequestMessage) String() string {
return fmt.Sprintf("{⚓ %d %v:%v}", m.SessionID, m.From, m.To())
}
9 changes: 9 additions & 0 deletions sync/bundle/message/blocks_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package message
import (
"fmt"

"github.com/pactus-project/pactus/network"
"github.com/pactus-project/pactus/types/certificate"
)

Expand Down Expand Up @@ -42,6 +43,14 @@ func (*BlocksResponseMessage) Type() Type {
return TypeBlocksResponse
}

func (*BlocksResponseMessage) TopicID() network.TopicID {
return network.TopicIDUnspecified
}

func (*BlocksResponseMessage) ShouldBroadcast() bool {
return false
}

func (m *BlocksResponseMessage) Count() uint32 {
return uint32(len(m.CommittedBlocksData))
}
Expand Down
Loading

0 comments on commit 6318fad

Please sign in to comment.