Skip to content

Commit

Permalink
feat: add http client option to disable IPv6
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Feb 5, 2024
1 parent a6f7cce commit cafcf11
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 20 deletions.
15 changes: 13 additions & 2 deletions httpx/resilient_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@ type resilientOptions struct {
retryMax int
noInternalIPs bool
internalIPExceptions []string
ipV6 bool
tracer trace.Tracer
}

func newResilientOptions() *resilientOptions {
connTimeout := time.Minute
transport, _ := newDefaultTransport(true)
return &resilientOptions{
c: &http.Client{Timeout: connTimeout, Transport: http.DefaultTransport},
c: &http.Client{Timeout: connTimeout, Transport: transport},
retryWaitMin: 1 * time.Second,
retryWaitMax: 30 * time.Second,
retryMax: 4,
l: log.New(io.Discard, "", log.LstdFlags),
ipV6: true,
}
}

Expand Down Expand Up @@ -103,6 +106,12 @@ func ResilientClientAllowInternalIPRequestsTo(urlGlobs ...string) ResilientOptio
}
}

func ResilientClientNoIPv6() ResilientOptions {
return func(o *resilientOptions) {
o.ipV6 = false
}
}

// NewResilientClient creates a new ResilientClient.
func NewResilientClient(opts ...ResilientOptions) *retryablehttp.Client {
o := newResilientOptions()
Expand All @@ -111,7 +120,9 @@ func NewResilientClient(opts ...ResilientOptions) *retryablehttp.Client {
}

if o.noInternalIPs {
o.c.Transport = NewNoInternalIPRoundTripper(o.internalIPExceptions)
o.c.Transport = newNoInternalIPRoundTripper(o.internalIPExceptions, o.ipV6)
} else {
o.c.Transport, _ = newDefaultTransport(o.ipV6)
}

if o.tracer != nil {
Expand Down
73 changes: 73 additions & 0 deletions httpx/resilient_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
package httpx

import (
"context"
"net"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"net/netip"
"net/url"
"sync/atomic"
"testing"

"github.com/hashicorp/go-retryablehttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -55,3 +60,71 @@ func TestNoPrivateIPs(t *testing.T) {
}
}
}

func TestNoIPV6(t *testing.T) {
for _, tc := range []struct {
name string
c *retryablehttp.Client
}{
{
"internal IPs allowed",
NewResilientClient(
ResilientClientWithMaxRetry(1),
ResilientClientNoIPv6(),
),
}, {
"internal IPs disallowed",
NewResilientClient(
ResilientClientWithMaxRetry(1),
ResilientClientDisallowInternalIPs(),
ResilientClientNoIPv6(),
),
},
} {
t.Run(tc.name, func(t *testing.T) {
var connectDone int32
ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
DNSDone: func(dnsInfo httptrace.DNSDoneInfo) {
for _, ip := range dnsInfo.Addrs {
netIP, ok := netip.AddrFromSlice(ip.IP)
assert.True(t, ok)
assert.Truef(t, netIP.Is4(), "ip = %s", ip)
}
},
ConnectDone: func(network, addr string, err error) {
atomic.AddInt32(&connectDone, 1)
assert.NoError(t, err)
assert.Equalf(t, "tcp4", network, "network = %s addr = %s", network, addr)
},
})

// Dual stack
req, err := retryablehttp.NewRequestWithContext(ctx, "GET", "http://dual.tlund.se/", nil)
require.NoError(t, err)
atomic.StoreInt32(&connectDone, 0)
res, err := tc.c.Do(req)
require.GreaterOrEqual(t, int32(1), atomic.LoadInt32(&connectDone))
require.NoError(t, err)
t.Cleanup(func() { _ = res.Body.Close() })
require.EqualValues(t, http.StatusOK, res.StatusCode)

// IPv4 only
req, err = retryablehttp.NewRequestWithContext(ctx, "GET", "http://ipv4.tlund.se/", nil)
require.NoError(t, err)
atomic.StoreInt32(&connectDone, 0)
res, err = tc.c.Do(req)
require.EqualValues(t, 1, atomic.LoadInt32(&connectDone))
require.NoError(t, err)
t.Cleanup(func() { _ = res.Body.Close() })
require.EqualValues(t, http.StatusOK, res.StatusCode)

// IPv6 only
req, err = retryablehttp.NewRequestWithContext(ctx, "GET", "http://ipv6.tlund.se/", nil)
require.NoError(t, err)
atomic.StoreInt32(&connectDone, 0)
_, err = tc.c.Do(req)
require.EqualValues(t, 0, atomic.LoadInt32(&connectDone))
require.ErrorContains(t, err, "no such host")
})
}
}
48 changes: 30 additions & 18 deletions httpx/ssrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package httpx

import (
"context"
"net"
"net/http"
"net/netip"
Expand All @@ -20,19 +21,23 @@ type noInternalIPRoundTripper struct {
internalIPExceptions []string
}

// NewNoInternalIPRoundTripper creates a RoundTripper that disallows
// newNoInternalIPRoundTripper creates a RoundTripper that disallows
// non-publicly routable IP addresses, except for URLs matching the given
// exception globs.
func NewNoInternalIPRoundTripper(exceptions []string) http.RoundTripper {
func newNoInternalIPRoundTripper(exceptions []string, ipV6 bool) http.RoundTripper {
networks := []string{"tcp4"}
if ipV6 {
networks = []string{"tcp4", "tcp6"}
}
if len(exceptions) > 0 {
prohibitInternal := newSSRFTransport(ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4", "tcp6"),
))
ssrf.WithNetworks(networks...),
), ipV6)

allowInternal := newSSRFTransport(ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4", "tcp6"),
ssrf.WithNetworks(networks...),
ssrf.WithAllowedV4Prefixes(
netip.MustParsePrefix("10.0.0.0/8"), // Private-Use (RFC 1918)
netip.MustParsePrefix("127.0.0.0/8"), // Loopback (RFC 1122, Section 3.2.1.3))
Expand All @@ -44,7 +49,7 @@ func NewNoInternalIPRoundTripper(exceptions []string) http.RoundTripper {
netip.MustParsePrefix("::1/128"), // Loopback (RFC 4193)
netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193)
),
))
), ipV6)
return noInternalIPRoundTripper{
onWhitelist: allowInternal,
notOnWhitelist: prohibitInternal,
Expand All @@ -53,8 +58,8 @@ func NewNoInternalIPRoundTripper(exceptions []string) http.RoundTripper {
}
prohibitInternal := newSSRFTransport(ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4", "tcp6"),
))
ssrf.WithNetworks(networks...),
), ipV6)
return noInternalIPRoundTripper{
onWhitelist: prohibitInternal,
notOnWhitelist: prohibitInternal,
Expand All @@ -79,23 +84,30 @@ func (n noInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Respon
return n.notOnWhitelist.RoundTrip(request)
}

func newSSRFTransport(g *ssrf.Guardian) http.RoundTripper {
t := newDefaultTransport()
t.DialContext = (&net.Dialer{Control: g.Safe}).DialContext
func newSSRFTransport(g *ssrf.Guardian, ipV6 bool) *http.Transport {
t, d := newDefaultTransport(ipV6)
d.Control = g.Safe
return t
}

func newDefaultTransport() *http.Transport {
func newDefaultTransport(ipV6 bool) (*http.Transport, *net.Dialer) {
dialer := net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
dialContext := dialer.DialContext
if !ipV6 {
dialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp4", address)
}
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
Proxy: http.ProxyFromEnvironment,
DialContext: dialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}, &dialer
}

0 comments on commit cafcf11

Please sign in to comment.