Skip to content

Commit

Permalink
feat: add http client option to disable IPv6
Browse files Browse the repository at this point in the history
* improve connection re-use
* add detailed trace events (DNS, TLS, etc)
  • Loading branch information
alnr committed Feb 7, 2024
1 parent a6f7cce commit 041443c
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 49 deletions.
31 changes: 28 additions & 3 deletions httpx/resilient_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import (
"io"
"log"
"net/http"
"net/http/httptrace"
"time"

"go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
Expand All @@ -29,17 +31,19 @@ type resilientOptions struct {
retryMax int
noInternalIPs bool
internalIPExceptions []string
ipV6 bool
tracer trace.Tracer
}

func newResilientOptions() *resilientOptions {
connTimeout := time.Minute
return &resilientOptions{
c: &http.Client{Timeout: connTimeout, Transport: http.DefaultTransport},
c: &http.Client{Timeout: connTimeout},
retryWaitMin: 1 * time.Second,
retryWaitMax: 30 * time.Second,
retryMax: 4,
l: log.New(io.Discard, "", log.LstdFlags),
ipV6: true,
}
}

Expand Down Expand Up @@ -103,6 +107,12 @@ func ResilientClientAllowInternalIPRequestsTo(urlGlobs ...string) ResilientOptio
}
}

func ResilientClientNoIPv6() ResilientOptions {
return func(o *resilientOptions) {
o.ipV6 = false
}
}

// NewResilientClient creates a new ResilientClient.
func NewResilientClient(opts ...ResilientOptions) *retryablehttp.Client {
o := newResilientOptions()
Expand All @@ -111,11 +121,19 @@ func NewResilientClient(opts ...ResilientOptions) *retryablehttp.Client {
}

if o.noInternalIPs {
o.c.Transport = NewNoInternalIPRoundTripper(o.internalIPExceptions)
o.c.Transport = &noInternalIPRoundTripper{
onWhitelist: ifelse(o.ipV6, allowInternalAllowIPv6, allowInternalProhibitIPv6),
notOnWhitelist: ifelse(o.ipV6, prohibitInternalAllowIPv6, prohibitInternalProhibitIPv6),
internalIPExceptions: o.internalIPExceptions,
}
} else {
o.c.Transport = ifelse(o.ipV6, allowInternalAllowIPv6, allowInternalProhibitIPv6)
}

if o.tracer != nil {
o.c.Transport = otelhttp.NewTransport(o.c.Transport)
o.c.Transport = otelhttp.NewTransport(o.c.Transport, otelhttp.WithClientTrace(func(ctx context.Context) *httptrace.ClientTrace {
return otelhttptrace.NewClientTrace(ctx, otelhttptrace.WithoutHeaders(), otelhttptrace.WithoutSubSpans())
}))
}

cl := retryablehttp.NewClient()
Expand Down Expand Up @@ -146,3 +164,10 @@ func SetOAuth2(ctx context.Context, cl *retryablehttp.Client, c OAuth2Config, t
type OAuth2Config interface {
Client(context.Context, *oauth2.Token) *http.Client
}

func ifelse[A any](b bool, x, y A) A {
if b {
return x
}
return y
}
73 changes: 73 additions & 0 deletions httpx/resilient_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
package httpx

import (
"context"
"net"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"net/netip"
"net/url"
"sync/atomic"
"testing"

"github.com/hashicorp/go-retryablehttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -55,3 +60,71 @@ func TestNoPrivateIPs(t *testing.T) {
}
}
}

func TestNoIPV6(t *testing.T) {
for _, tc := range []struct {
name string
c *retryablehttp.Client
}{
{
"internal IPs allowed",
NewResilientClient(
ResilientClientWithMaxRetry(1),
ResilientClientNoIPv6(),
),
}, {
"internal IPs disallowed",
NewResilientClient(
ResilientClientWithMaxRetry(1),
ResilientClientDisallowInternalIPs(),
ResilientClientNoIPv6(),
),
},
} {
t.Run(tc.name, func(t *testing.T) {
var connectDone int32
ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
DNSDone: func(dnsInfo httptrace.DNSDoneInfo) {
for _, ip := range dnsInfo.Addrs {
netIP, ok := netip.AddrFromSlice(ip.IP)
assert.True(t, ok)
assert.Truef(t, netIP.Is4(), "ip = %s", ip)
}
},
ConnectDone: func(network, addr string, err error) {
atomic.AddInt32(&connectDone, 1)
assert.NoError(t, err)
assert.Equalf(t, "tcp4", network, "network = %s addr = %s", network, addr)
},
})

// Dual stack
req, err := retryablehttp.NewRequestWithContext(ctx, "GET", "http://dual.tlund.se/", nil)
require.NoError(t, err)
atomic.StoreInt32(&connectDone, 0)
res, err := tc.c.Do(req)
require.GreaterOrEqual(t, int32(1), atomic.LoadInt32(&connectDone))
require.NoError(t, err)
t.Cleanup(func() { _ = res.Body.Close() })
require.EqualValues(t, http.StatusOK, res.StatusCode)

// IPv4 only
req, err = retryablehttp.NewRequestWithContext(ctx, "GET", "http://ipv4.tlund.se/", nil)
require.NoError(t, err)
atomic.StoreInt32(&connectDone, 0)
res, err = tc.c.Do(req)
require.EqualValues(t, 1, atomic.LoadInt32(&connectDone))
require.NoError(t, err)
t.Cleanup(func() { _ = res.Body.Close() })
require.EqualValues(t, http.StatusOK, res.StatusCode)

// IPv6 only
req, err = retryablehttp.NewRequestWithContext(ctx, "GET", "http://ipv6.tlund.se/", nil)
require.NoError(t, err)
atomic.StoreInt32(&connectDone, 0)
_, err = tc.c.Do(req)
require.EqualValues(t, 0, atomic.LoadInt32(&connectDone))
require.ErrorContains(t, err, "no such host")
})
}
}
128 changes: 83 additions & 45 deletions httpx/ssrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package httpx

import (
"context"
"net"
"net/http"
"net/netip"
Expand All @@ -23,41 +24,12 @@ type noInternalIPRoundTripper struct {
// NewNoInternalIPRoundTripper creates a RoundTripper that disallows
// non-publicly routable IP addresses, except for URLs matching the given
// exception globs.
// Deprecated: Use ResilientClientDisallowInternalIPs instead.
func NewNoInternalIPRoundTripper(exceptions []string) http.RoundTripper {
if len(exceptions) > 0 {
prohibitInternal := newSSRFTransport(ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4", "tcp6"),
))

allowInternal := newSSRFTransport(ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4", "tcp6"),
ssrf.WithAllowedV4Prefixes(
netip.MustParsePrefix("10.0.0.0/8"), // Private-Use (RFC 1918)
netip.MustParsePrefix("127.0.0.0/8"), // Loopback (RFC 1122, Section 3.2.1.3))
netip.MustParsePrefix("169.254.0.0/16"), // Link Local (RFC 3927)
netip.MustParsePrefix("172.16.0.0/12"), // Private-Use (RFC 1918)
netip.MustParsePrefix("192.168.0.0/16"), // Private-Use (RFC 1918)
),
ssrf.WithAllowedV6Prefixes(
netip.MustParsePrefix("::1/128"), // Loopback (RFC 4193)
netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193)
),
))
return noInternalIPRoundTripper{
onWhitelist: allowInternal,
notOnWhitelist: prohibitInternal,
internalIPExceptions: exceptions,
}
}
prohibitInternal := newSSRFTransport(ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4", "tcp6"),
))
return noInternalIPRoundTripper{
onWhitelist: prohibitInternal,
notOnWhitelist: prohibitInternal,
return &noInternalIPRoundTripper{
onWhitelist: allowInternalAllowIPv6,
notOnWhitelist: prohibitInternalAllowIPv6,
internalIPExceptions: exceptions,
}
}

Expand All @@ -79,23 +51,89 @@ func (n noInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Respon
return n.notOnWhitelist.RoundTrip(request)
}

func newSSRFTransport(g *ssrf.Guardian) http.RoundTripper {
t := newDefaultTransport()
t.DialContext = (&net.Dialer{Control: g.Safe}).DialContext
return t
var (
prohibitInternalAllowIPv6 http.RoundTripper
prohibitInternalProhibitIPv6 http.RoundTripper
allowInternalAllowIPv6 http.RoundTripper
allowInternalProhibitIPv6 http.RoundTripper
)

func init() {
t, d := newDefaultTransport()
d.Control = ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4", "tcp6"),
).Safe
prohibitInternalAllowIPv6 = t
}

func newDefaultTransport() *http.Transport {
func init() {
t, d := newDefaultTransport()
d.Control = ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4"),
).Safe
t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.DialContext(ctx, "tcp4", addr)
}
prohibitInternalProhibitIPv6 = t
}

func init() {
t, d := newDefaultTransport()
d.Control = ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4", "tcp6"),
ssrf.WithAllowedV4Prefixes(
netip.MustParsePrefix("10.0.0.0/8"), // Private-Use (RFC 1918)
netip.MustParsePrefix("127.0.0.0/8"), // Loopback (RFC 1122, Section 3.2.1.3))
netip.MustParsePrefix("169.254.0.0/16"), // Link Local (RFC 3927)
netip.MustParsePrefix("172.16.0.0/12"), // Private-Use (RFC 1918)
netip.MustParsePrefix("192.168.0.0/16"), // Private-Use (RFC 1918)
),
ssrf.WithAllowedV6Prefixes(
netip.MustParsePrefix("::1/128"), // Loopback (RFC 4193)
netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193)
),
).Safe
allowInternalAllowIPv6 = t
}

func init() {
t, d := newDefaultTransport()
d.Control = ssrf.New(
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4"),
ssrf.WithAllowedV4Prefixes(
netip.MustParsePrefix("10.0.0.0/8"), // Private-Use (RFC 1918)
netip.MustParsePrefix("127.0.0.0/8"), // Loopback (RFC 1122, Section 3.2.1.3))
netip.MustParsePrefix("169.254.0.0/16"), // Link Local (RFC 3927)
netip.MustParsePrefix("172.16.0.0/12"), // Private-Use (RFC 1918)
netip.MustParsePrefix("192.168.0.0/16"), // Private-Use (RFC 1918)
),
ssrf.WithAllowedV6Prefixes(
netip.MustParsePrefix("::1/128"), // Loopback (RFC 4193)
netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193)
),
).Safe
t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.DialContext(ctx, "tcp4", addr)
}
allowInternalProhibitIPv6 = t
}

func newDefaultTransport() (*http.Transport, *net.Dialer) {
dialer := net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}, &dialer
}
3 changes: 2 additions & 1 deletion jsonnetsecure/jsonnet_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ func (w worker) destroy() {

func (w worker) eval(ctx context.Context, processParams []byte) (output string, err error) {
tracer := trace.SpanFromContext(ctx).TracerProvider().Tracer("")
ctx, span := tracer.Start(ctx, "jsonnetsecure.worker.eval")
ctx, span := tracer.Start(ctx, "jsonnetsecure.worker.eval",
trace.WithAttributes(attribute.Int("cmd.Process.Pid", w.cmd.Process.Pid)))
defer otelx.End(span, &err)

select {
Expand Down

0 comments on commit 041443c

Please sign in to comment.