diff --git a/httpx/resilient_client_test.go b/httpx/resilient_client_test.go index 8a90118b..9f90e738 100644 --- a/httpx/resilient_client_test.go +++ b/httpx/resilient_client_test.go @@ -5,59 +5,127 @@ package httpx import ( "context" - "net" + "fmt" "net/http" - "net/http/httptest" "net/http/httptrace" "net/netip" - "net/url" "sync/atomic" "testing" + "time" + + "code.dny.dev/ssrf" "github.com/hashicorp/go-retryablehttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestNoPrivateIPs(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, _ = w.Write([]byte("Hello, world!")) - })) - t.Cleanup(ts.Close) - - target, err := url.ParseRequestURI(ts.URL) - require.NoError(t, err) - - _, port, err := net.SplitHostPort(target.Host) - require.NoError(t, err) - - allowedURL := "http://localhost:" + port + "/foobar" - allowedGlob := "http://localhost:" + port + "/glob/*" - - c := NewResilientClient( - ResilientClientWithMaxRetry(1), - ResilientClientDisallowInternalIPs(), - ResilientClientAllowInternalIPRequestsTo(allowedURL, allowedGlob), - ) +func TestPrivateIPs(t *testing.T) { + testCases := []struct { + url string + disallowInternalIPs bool + allowedIP bool + }{ + { + url: "http://127.0.0.1/foobar", + disallowInternalIPs: true, + allowedIP: false, + }, + { + url: "http://localhost/foobar", + disallowInternalIPs: true, + allowedIP: false, + }, + { + url: "http://127.0.0.1:56789/test", + disallowInternalIPs: true, + allowedIP: false, + }, + { + url: "http://192.168.178.5:56789", + disallowInternalIPs: true, + allowedIP: false, + }, + { + url: "http://127.0.0.1:56789/foobar", + disallowInternalIPs: true, + allowedIP: true, + }, + { + url: "http://127.0.0.1:56789/glob/bar", + disallowInternalIPs: true, + allowedIP: true, + }, + { + url: "http://127.0.0.1:56789/glob/bar/baz", + disallowInternalIPs: true, + allowedIP: false, + }, + { + url: "http://127.0.0.1:56789/FOOBAR", + disallowInternalIPs: true, + allowedIP: false, + }, + { + url: "http://100.64.1.1:80/private", + disallowInternalIPs: true, + allowedIP: true, + }, + { + url: "http://100.64.1.1:80/route", + disallowInternalIPs: true, + allowedIP: false, + }, + { + url: "http://127.0.0.1", + disallowInternalIPs: false, + allowedIP: true, + }, + { + url: "http://localhost", + disallowInternalIPs: false, + allowedIP: true, + }, + { + url: "http://192.168.178.5", + disallowInternalIPs: false, + allowedIP: true, + }, + { + url: "http://127.0.0.1:80/glob/bar", + disallowInternalIPs: false, + allowedIP: true, + }, + { + url: "http://100.64.1.1:80/route", + disallowInternalIPs: false, + allowedIP: true, + }, + } + for _, tt := range testCases { + t.Run( + fmt.Sprintf("%s should be allowed %v when disallowed internal IPs is %v", tt.url, tt.allowedIP, tt.disallowInternalIPs), + func(t *testing.T) { + options := []ResilientOptions{ + ResilientClientWithMaxRetry(0), + ResilientClientWithConnectionTimeout(50 * time.Millisecond), + } + if tt.disallowInternalIPs { + options = append(options, ResilientClientDisallowInternalIPs()) + options = append(options, ResilientClientAllowInternalIPRequestsTo( + "http://127.0.0.1:56789/foobar", + "http://127.0.0.1:56789/glob/*", + "http://100.64.1.1:80/private")) + } - for i := 0; i < 10; i++ { - for destination, passes := range map[string]bool{ - "http://127.0.0.1:" + port: false, - "http://localhost:" + port: false, - "http://192.168.178.5:" + port: false, - allowedURL: true, - "http://localhost:" + port + "/glob/bar": true, - "http://localhost:" + port + "/glob/bar/baz": false, - "http://localhost:" + port + "/FOOBAR": false, - } { - _, err := c.Get(destination) - if !passes { - require.Errorf(t, err, "dest = %s", destination) - assert.Containsf(t, err.Error(), "is not a permitted destination", "dest = %s", destination) - } else { - require.NoErrorf(t, err, "dest = %s", destination) - } - } + c := NewResilientClient(options...) + _, err := c.Get(tt.url) + if tt.allowedIP { + assert.NotErrorIs(t, err, ssrf.ErrProhibitedIP) + } else { + assert.ErrorIs(t, err, ssrf.ErrProhibitedIP) + } + }) } } diff --git a/httpx/ssrf.go b/httpx/ssrf.go index 99b16e9e..a52f6ab3 100644 --- a/httpx/ssrf.go +++ b/httpx/ssrf.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "net/http/httptrace" - "net/netip" "time" "code.dny.dev/ssrf" @@ -88,15 +87,10 @@ func init() { ssrf.WithAnyPort(), ssrf.WithNetworks("tcp4", "tcp6"), ssrf.WithAllowedV4Prefixes( - netip.MustParsePrefix("10.0.0.0/8"), // Private-Use (RFC 1918) - netip.MustParsePrefix("127.0.0.0/8"), // Loopback (RFC 1122, Section 3.2.1.3)) - netip.MustParsePrefix("169.254.0.0/16"), // Link Local (RFC 3927) - netip.MustParsePrefix("172.16.0.0/12"), // Private-Use (RFC 1918) - netip.MustParsePrefix("192.168.0.0/16"), // Private-Use (RFC 1918) + ssrf.IPv4DeniedPrefixes..., ), ssrf.WithAllowedV6Prefixes( - netip.MustParsePrefix("::1/128"), // Loopback (RFC 4193) - netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193) + ssrf.IPv6DeniedPrefixes..., ), ).Safe allowInternalAllowIPv6 = otelTransport(t) @@ -108,15 +102,10 @@ func init() { ssrf.WithAnyPort(), ssrf.WithNetworks("tcp4"), ssrf.WithAllowedV4Prefixes( - netip.MustParsePrefix("10.0.0.0/8"), // Private-Use (RFC 1918) - netip.MustParsePrefix("127.0.0.0/8"), // Loopback (RFC 1122, Section 3.2.1.3)) - netip.MustParsePrefix("169.254.0.0/16"), // Link Local (RFC 3927) - netip.MustParsePrefix("172.16.0.0/12"), // Private-Use (RFC 1918) - netip.MustParsePrefix("192.168.0.0/16"), // Private-Use (RFC 1918) + ssrf.IPv4DeniedPrefixes..., ), ssrf.WithAllowedV6Prefixes( - netip.MustParsePrefix("::1/128"), // Loopback (RFC 4193) - netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193) + ssrf.IPv6DeniedPrefixes..., ), ).Safe t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {