Skip to content

Commit

Permalink
Fix firewall for assigned networks and more ip -> addr
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Oct 17, 2024
1 parent df3b751 commit 1377b84
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 144 deletions.
114 changes: 62 additions & 52 deletions firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import (
)

type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
//TODO: name these better addr, localAddr. Are they vpnAddrs?
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
}

type conn struct {
Expand Down Expand Up @@ -52,9 +53,12 @@ type Firewall struct {
UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s

// Used to ensure we don't emit local packets for ips we don't own
localIps *bart.Table[struct{}]
assignedCIDR netip.Prefix
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
routableNetworks *bart.Table[struct{}]

// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
hasUnsafeNetworks bool

rules string
Expand All @@ -68,9 +72,9 @@ type Firewall struct {
}

type firewallMetrics struct {
droppedLocalIP metrics.Counter
droppedRemoteIP metrics.Counter
droppedNoRule metrics.Counter
droppedLocalAddr metrics.Counter
droppedRemoteAddr metrics.Counter
droppedNoRule metrics.Counter
}

type FirewallConntrack struct {
Expand Down Expand Up @@ -127,84 +131,87 @@ type firewallLocalCIDR struct {
}

// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration
var min, max time.Duration
var tmin, tmax time.Duration

if tcpTimeout < UDPTimeout {
min = tcpTimeout
max = UDPTimeout
tmin = tcpTimeout
tmax = UDPTimeout
} else {
min = UDPTimeout
max = tcpTimeout
tmin = UDPTimeout
tmax = tcpTimeout
}

if defaultTimeout < min {
min = defaultTimeout
} else if defaultTimeout > max {
max = defaultTimeout
if defaultTimeout < tmin {
tmin = defaultTimeout
} else if defaultTimeout > tmax {
tmax = defaultTimeout
}

localIps := new(bart.Table[struct{}])
var assignedCIDR netip.Prefix
var assignedSet bool
routableNetworks := new(bart.Table[struct{}])
var assignedNetworks []netip.Prefix
for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
localIps.Insert(nprefix, struct{}{})

if !assignedSet {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = nprefix
assignedSet = true
}
routableNetworks.Insert(nprefix, struct{}{})
assignedNetworks = append(assignedNetworks, network)
}

hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
localIps.Insert(n, struct{}{})
routableNetworks.Insert(n, struct{}{})
hasUnsafeNetworks = true
}

return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](min, max),
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
},
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
localIps: localIps,
assignedCIDR: assignedCIDR,
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks,
l: l,

incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
},
outgoingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
},
}
}

func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
}

if certificate == nil {
panic("No certificate available to reconfigure the firewall")
}

fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
nc,
certificate,
//TODO: max_connections
)

//TODO: Flip to false after v1.9 release
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)

inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction {
Expand Down Expand Up @@ -426,26 +433,24 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *

// Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil {
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := remoteCidr.Lookup(fp.RemoteIP)
_, ok := remoteCidr.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1)
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
} else {
// Simple case: Certificate has one IP and no subnets
//TODO: we can make this more performant
if !slices.Contains(h.vpnAddrs, fp.RemoteIP) {
f.metrics(incoming).droppedRemoteIP.Inc(1)
if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
}

// Make sure we are supposed to be handling this local ip address
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := f.localIps.Lookup(fp.LocalIP)
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1)
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}

Expand Down Expand Up @@ -861,9 +866,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
}

matched := false
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
if prefix.Contains(p.RemoteIP) && val.match(p, c) {
if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
matched = true
return false
}
Expand All @@ -879,9 +884,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
return nil
}

localIp = f.assignedCIDR
for _, network := range f.assignedNetworks {
flc.LocalCIDR.Insert(network, struct{}{})
}
return nil

} else if localIp.Bits() == 0 {
flc.Any = true
return nil
}

flc.LocalCIDR.Insert(localIp, struct{}{})
Expand All @@ -897,7 +907,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
return true
}

_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
return ok
}

Expand Down
12 changes: 6 additions & 6 deletions firewall/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ const (
)

type Packet struct {
LocalIP netip.Addr
RemoteIP netip.Addr
LocalAddr netip.Addr
RemoteAddr netip.Addr
LocalPort uint16
RemotePort uint16
Protocol uint8
Expand All @@ -30,8 +30,8 @@ type Packet struct {

func (fp *Packet) Copy() *Packet {
return &Packet{
LocalIP: fp.LocalIP,
RemoteIP: fp.RemoteIP,
LocalAddr: fp.LocalAddr,
RemoteAddr: fp.RemoteAddr,
LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort,
Protocol: fp.Protocol,
Expand All @@ -52,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
proto = fmt.Sprintf("unknown %v", fp.Protocol)
}
return json.Marshal(m{
"LocalIP": fp.LocalIP.String(),
"RemoteIP": fp.RemoteIP.String(),
"LocalAddr": fp.LocalAddr.String(),
"RemoteAddr": fp.RemoteAddr.String(),
"LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort,
"Protocol": proto,
Expand Down
Loading

0 comments on commit 1377b84

Please sign in to comment.