diff --git a/firewall.go b/firewall.go index f863566a1..06b8e8589 100644 --- a/firewall.go +++ b/firewall.go @@ -23,7 +23,8 @@ import ( ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error + //TODO: name these better addr, localAddr. Are they vpnAddrs? + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error } type conn struct { @@ -52,9 +53,12 @@ type Firewall struct { UDPTimeout time.Duration //linux: 180s max DefaultTimeout time.Duration //linux: 600s - // Used to ensure we don't emit local packets for ips we don't own - localIps *bart.Table[struct{}] - assignedCIDR netip.Prefix + // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. + // The vpn addresses are a full bit match while the unsafe networks only match the prefix + routableNetworks *bart.Table[struct{}] + + // assignedNetworks is a list of vpn networks assigned to us in the certificate. + assignedNetworks []netip.Prefix hasUnsafeNetworks bool rules string @@ -68,9 +72,9 @@ type Firewall struct { } type firewallMetrics struct { - droppedLocalIP metrics.Counter - droppedRemoteIP metrics.Counter - droppedNoRule metrics.Counter + droppedLocalAddr metrics.Counter + droppedRemoteAddr metrics.Counter + droppedNoRule metrics.Counter } type FirewallConntrack struct { @@ -127,84 +131,87 @@ type firewallLocalCIDR struct { } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. +// The certificate provided should be the highest version loaded in memory. func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration - var min, max time.Duration + var tmin, tmax time.Duration if tcpTimeout < UDPTimeout { - min = tcpTimeout - max = UDPTimeout + tmin = tcpTimeout + tmax = UDPTimeout } else { - min = UDPTimeout - max = tcpTimeout + tmin = UDPTimeout + tmax = tcpTimeout } - if defaultTimeout < min { - min = defaultTimeout - } else if defaultTimeout > max { - max = defaultTimeout + if defaultTimeout < tmin { + tmin = defaultTimeout + } else if defaultTimeout > tmax { + tmax = defaultTimeout } - localIps := new(bart.Table[struct{}]) - var assignedCIDR netip.Prefix - var assignedSet bool + routableNetworks := new(bart.Table[struct{}]) + var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) - localIps.Insert(nprefix, struct{}{}) - - if !assignedSet { - // Only grabbing the first one in the cert since any more than that currently has undefined behavior - assignedCIDR = nprefix - assignedSet = true - } + routableNetworks.Insert(nprefix, struct{}{}) + assignedNetworks = append(assignedNetworks, network) } hasUnsafeNetworks := false for _, n := range c.UnsafeNetworks() { - localIps.Insert(n, struct{}{}) + routableNetworks.Insert(n, struct{}{}) hasUnsafeNetworks = true } return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), - TimerWheel: NewTimerWheel[firewall.Packet](min, max), + TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, InRules: newFirewallTable(), OutRules: newFirewallTable(), TCPTimeout: tcpTimeout, UDPTimeout: UDPTimeout, DefaultTimeout: defaultTimeout, - localIps: localIps, - assignedCIDR: assignedCIDR, + routableNetworks: routableNetworks, + assignedNetworks: assignedNetworks, hasUnsafeNetworks: hasUnsafeNetworks, l: l, incomingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), }, outgoingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), }, } } -func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { + certificate := cs.getCertificate(cert.Version2) + if certificate == nil { + certificate = cs.getCertificate(cert.Version1) + } + + if certificate == nil { + panic("No certificate available to reconfigure the firewall") + } + fw := NewFirewall( l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), - nc, + certificate, //TODO: max_connections ) - //TODO: Flip to false after v1.9 release - fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true) + fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) inboundAction := c.GetString("firewall.inbound_action", "drop") switch inboundAction { @@ -426,26 +433,24 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * // Make sure remote address matches nebula certificate if remoteCidr := h.remoteCidr; remoteCidr != nil { - //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different - _, ok := remoteCidr.Lookup(fp.RemoteIP) + _, ok := remoteCidr.Lookup(fp.RemoteAddr) if !ok { - f.metrics(incoming).droppedRemoteIP.Inc(1) + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } else { // Simple case: Certificate has one IP and no subnets //TODO: we can make this more performant - if !slices.Contains(h.vpnAddrs, fp.RemoteIP) { - f.metrics(incoming).droppedRemoteIP.Inc(1) + if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) { + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } // Make sure we are supposed to be handling this local ip address - //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different - _, ok := f.localIps.Lookup(fp.LocalIP) + _, ok := f.routableNetworks.Lookup(fp.LocalAddr) if !ok { - f.metrics(incoming).droppedLocalIP.Inc(1) + f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } @@ -861,9 +866,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool } matched := false - prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) + prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen()) fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { - if prefix.Contains(p.RemoteIP) && val.match(p, c) { + if prefix.Contains(p.RemoteAddr) && val.match(p, c) { matched = true return false } @@ -879,9 +884,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { return nil } - localIp = f.assignedCIDR + for _, network := range f.assignedNetworks { + flc.LocalCIDR.Insert(network, struct{}{}) + } + return nil + } else if localIp.Bits() == 0 { flc.Any = true + return nil } flc.LocalCIDR.Insert(localIp, struct{}{}) @@ -897,7 +907,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate return true } - _, ok := flc.LocalCIDR.Lookup(p.LocalIP) + _, ok := flc.LocalCIDR.Lookup(p.LocalAddr) return ok } diff --git a/firewall/packet.go b/firewall/packet.go index b3cf1fb98..1d8f12a0c 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -20,8 +20,8 @@ const ( ) type Packet struct { - LocalIP netip.Addr - RemoteIP netip.Addr + LocalAddr netip.Addr + RemoteAddr netip.Addr LocalPort uint16 RemotePort uint16 Protocol uint8 @@ -30,8 +30,8 @@ type Packet struct { func (fp *Packet) Copy() *Packet { return &Packet{ - LocalIP: fp.LocalIP, - RemoteIP: fp.RemoteIP, + LocalAddr: fp.LocalAddr, + RemoteAddr: fp.RemoteAddr, LocalPort: fp.LocalPort, RemotePort: fp.RemotePort, Protocol: fp.Protocol, @@ -52,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) { proto = fmt.Sprintf("unknown %v", fp.Protocol) } return json.Marshal(m{ - "LocalIP": fp.LocalIP.String(), - "RemoteIP": fp.RemoteIP.String(), + "LocalAddr": fp.LocalAddr.String(), + "RemoteAddr": fp.RemoteAddr.String(), "LocalPort": fp.LocalPort, "RemotePort": fp.RemotePort, "Protocol": proto, diff --git a/firewall_test.go b/firewall_test.go index 79e90b692..1bdfe6f93 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewFirewall(t *testing.T) { @@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -166,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) { assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch - oldRemote := p.RemoteIP - p.RemoteIP = netip.MustParseAddr("1.2.3.10") + oldRemote := p.RemoteAddr + p.RemoteAddr = netip.MustParseAddr("1.2.3.10") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) - p.RemoteIP = oldRemote + p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) @@ -235,7 +236,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { } ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) } }) @@ -261,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) @@ -285,7 +286,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { - assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) @@ -308,8 +309,8 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -364,8 +365,8 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, @@ -446,8 +447,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -622,55 +623,58 @@ func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} + cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + require.NoError(t, err) + conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} - _, err := NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } diff --git a/hostmap.go b/hostmap.go index ff7d1e6fa..dc16af208 100644 --- a/hostmap.go +++ b/hostmap.go @@ -191,7 +191,7 @@ type HostInfo struct { localIndexId uint32 vpnAddrs []netip.Addr recvError atomic.Uint32 - remoteCidr *bart.Table[struct{}] + remoteCidr *bart.Table[struct{}] //TODO: rename `vpnNetworks` relayState RelayState // HandshakePacket records the packets used to create this hostinfo diff --git a/inside.go b/inside.go index 149210ade..d882f516e 100644 --- a/inside.go +++ b/inside.go @@ -21,18 +21,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // Ignore local broadcast packets if f.dropLocalBroadcast { - _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteIP) + _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr) if found { return } } //TODO: seems like a huge bummer - _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteIP) + _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr) if found { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which - // routes packets from the Nebula IP to the Nebula IP through the Nebula + // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) @@ -41,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } // Otherwise, drop. On linux, we should never see these packets - Linux - // routes packets from the nebula IP to the nebula IP through the loopback device. + // routes packets from the nebula addr to the nebula addr through the loopback device. return } // Ignore multicast packets - if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { + hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) if hostinfo == nil { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", fwPacket.RemoteIP). + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). - Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") } return } @@ -122,22 +122,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } -func (f *Interface) Handshake(vpnIp netip.Addr) { - f.getOrHandshake(vpnIp, nil) +func (f *Interface) Handshake(vpnAddr netip.Addr) { + f.getOrHandshake(vpnAddr, nil) } -// getOrHandshake returns nil if the vpnIp is not routable. +// getOrHandshake returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - _, found := f.myVpnNetworksTable.Lookup(vpnIp) +func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + _, found := f.myVpnNetworksTable.Lookup(vpnAddr) if !found { - vpnIp = f.inside.RouteFor(vpnIp) - if !vpnIp.IsValid() { + vpnAddr = f.inside.RouteFor(vpnAddr) + if !vpnAddr.IsValid() { return nil, false } } - return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) + return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -162,7 +162,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp +// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) @@ -291,7 +291,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType f.connectionManager.Out(hostinfo.localIndexId) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against - // all our IPs and enable a faster roaming. + // all our addrs and enable a faster roaming. if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. diff --git a/interface.go b/interface.go index d88dbbff1..f8f2b0960 100644 --- a/interface.go +++ b/interface.go @@ -14,7 +14,6 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -328,17 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - cs := f.pki.getCertState() - certificate := cs.getCertificate(cert.Version2) - if certificate == nil { - certificate = cs.getCertificate(cert.Version1) - } - - if certificate == nil { - panic("No certificate available to reconfigure the firewall") - } - - fw, err := NewFirewallFromConfig(f.l, certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return diff --git a/main.go b/main.go index 894305ede..1e77a4767 100644 --- a/main.go +++ b/main.go @@ -7,8 +7,6 @@ import ( "net/netip" "time" - "github.com/slackhq/nebula/cert" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" @@ -62,16 +60,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - cs := pki.getCertState() - certificate := cs.getCertificate(cert.Version2) - if certificate == nil { - certificate = cs.getCertificate(cert.Version1) - } - - if certificate == nil { - panic("No certificates available to configure the firewall") - } - fw, err := NewFirewallFromConfig(l, certificate, c) + fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } diff --git a/outside.go b/outside.go index 91af76712..fe5fc5317 100644 --- a/outside.go +++ b/outside.go @@ -302,11 +302,11 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { } if incoming { - fp.RemoteIP, _ = netip.AddrFromSlice(data[8:24]) - fp.LocalIP, _ = netip.AddrFromSlice(data[24:40]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40]) } else { - fp.LocalIP, _ = netip.AddrFromSlice(data[8:24]) - fp.RemoteIP, _ = netip.AddrFromSlice(data[24:40]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40]) } //TODO: whats a reasonable number of extension headers to attempt to parse? @@ -417,8 +417,8 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { // Firewall packets are locally oriented if incoming { - fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) - fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -427,8 +427,8 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) - fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 diff --git a/outside_test.go b/outside_test.go index 05537a474..cbe622345 100644 --- a/outside_test.go +++ b/outside_test.go @@ -64,8 +64,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) + assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.2")) + assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.LocalPort, uint16(4)) @@ -85,8 +85,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) + assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.1")) + assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.LocalPort, uint16(5)) } @@ -127,8 +127,8 @@ func Test_newPacket_v6(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP)) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("ff02::2")) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("ff02::1")) + assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::2")) + assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::1")) assert.Equal(t, p.RemotePort, uint16(36123)) assert.Equal(t, p.LocalPort, uint16(22)) @@ -137,8 +137,8 @@ func Test_newPacket_v6(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP)) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("ff02::2")) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("ff02::1")) + assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::2")) + assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::1")) assert.Equal(t, p.LocalPort, uint16(36123)) assert.Equal(t, p.RemotePort, uint16(22)) } diff --git a/timeout_test.go b/timeout_test.go index 4c6364ef5..db36fec73 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) { assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ - {LocalIP: netip.MustParseAddr("0.0.0.1")}, - {LocalIP: netip.MustParseAddr("0.0.0.2")}, - {LocalIP: netip.MustParseAddr("0.0.0.3")}, - {LocalIP: netip.MustParseAddr("0.0.0.4")}, + {LocalAddr: netip.MustParseAddr("0.0.0.1")}, + {LocalAddr: netip.MustParseAddr("0.0.0.2")}, + {LocalAddr: netip.MustParseAddr("0.0.0.3")}, + {LocalAddr: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1)