From c1af21333e560c87d908e4b29a201500e48e3c6a Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Tue, 14 Nov 2023 13:42:42 +0100 Subject: [PATCH] fix --- httpx/private_ip_validator.go | 24 +++++++++++- httpx/private_ip_validator_test.go | 61 ++++++++++++++---------------- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/httpx/private_ip_validator.go b/httpx/private_ip_validator.go index 4999f119..f644d4c4 100644 --- a/httpx/private_ip_validator.go +++ b/httpx/private_ip_validator.go @@ -6,8 +6,10 @@ package httpx import ( "fmt" "net" + "net/netip" "net/url" + "code.dny.dev/ssrf" "github.com/pkg/errors" ) @@ -65,8 +67,26 @@ func DisallowIPPrivateAddresses(ipOrHostnameOrURL string) error { } for _, ip := range ips { - if ip.IsPrivate() || ip.IsLoopback() || ip.IsUnspecified() { - return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) + ip, err := netip.ParseAddr(ip.String()) + if err != nil { + return ErrPrivateIPAddressDisallowed(errors.WithStack(err)) // should be unreacheable + } + + if ip.Is4() { + for _, deny := range ssrf.IPv4DeniedPrefixes { + if deny.Contains(ip) { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) + } + } + } else { + if !ssrf.IPv6GlobalUnicast.Contains(ip) { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) + } + for _, net := range ssrf.IPv6DeniedPrefixes { + if net.Contains(ip) { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) + } + } } } diff --git a/httpx/private_ip_validator_test.go b/httpx/private_ip_validator_test.go index a9c5e4a3..e3520ffc 100644 --- a/httpx/private_ip_validator_test.go +++ b/httpx/private_ip_validator_test.go @@ -4,13 +4,11 @@ package httpx import ( - "net" "net/http" - "net/url" "testing" "github.com/pkg/errors" - + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -28,9 +26,12 @@ func TestIsAssociatedIPAllowed(t *testing.T) { "0.0.0.0", "10.255.255.255", "::1", + "100::1", + "fe80::1", + "169.254.169.254", // AWS instance metadata service } { t.Run("case="+disallowed, func(t *testing.T) { - require.Error(t, DisallowIPPrivateAddresses(disallowed)) + assert.Error(t, DisallowIPPrivateAddresses(disallowed)) }) } } @@ -51,8 +52,8 @@ var _ http.RoundTripper = new(noOpRoundTripper) type errRoundTripper struct{ err error } -var err1 = errors.New("error 1") -var err2 = errors.New("error 2") +var errNotOnWhitelist = errors.New("OK") +var errOnWhitelist = errors.New("OK (on whitelist)") func (n errRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { return nil, n.err @@ -60,10 +61,12 @@ func (n errRoundTripper) RoundTrip(request *http.Request) (*http.Response, error var _ http.RoundTripper = new(errRoundTripper) +// TestInternalRespectsRoundTripper tests if the RoundTripper picks the correct +// underlying transport for two allowed requests. func TestInternalRespectsRoundTripper(t *testing.T) { rt := &noInternalIPRoundTripper{ - onWhitelist: &errRoundTripper{err1}, - notOnWhitelist: &errRoundTripper{err2}, + onWhitelist: &errRoundTripper{errOnWhitelist}, + notOnWhitelist: &errRoundTripper{errNotOnWhitelist}, internalIPExceptions: []string{ "https://127.0.0.1/foo", }} @@ -71,42 +74,34 @@ func TestInternalRespectsRoundTripper(t *testing.T) { req, err := http.NewRequest("GET", "https://google.com/foo", nil) require.NoError(t, err) _, err = rt.RoundTrip(req) - require.ErrorIs(t, err, err2) + require.ErrorIs(t, err, errNotOnWhitelist) req, err = http.NewRequest("GET", "https://127.0.0.1/foo", nil) require.NoError(t, err) _, err = rt.RoundTrip(req) - require.ErrorIs(t, err, err1) + require.ErrorIs(t, err, errOnWhitelist) } func TestAllowExceptions(t *testing.T) { rt := noInternalIPRoundTripper{ - onWhitelist: &errRoundTripper{}, - notOnWhitelist: &errRoundTripper{}, + onWhitelist: &errRoundTripper{errOnWhitelist}, + notOnWhitelist: &errRoundTripper{errNotOnWhitelist}, internalIPExceptions: []string{ "http://localhost/asdf", }} - _, err := rt.RoundTrip(&http.Request{ - Host: "localhost", - URL: &url.URL{Scheme: "http", Path: "/asdf", Host: "localhost"}, - Header: http.Header{ - "Host": []string{"localhost"}, - }, - }) - // assert that the error is eiher nil or a dial error. - if err != nil { - opErr := new(net.OpError) - require.ErrorAs(t, err, &opErr) - require.Equal(t, "dial", opErr.Op) - } + req, err := http.NewRequest("GET", "http://localhost/asdf", nil) + require.NoError(t, err) + _, err = rt.RoundTrip(req) + require.ErrorIs(t, err, errOnWhitelist) - _, err = rt.RoundTrip(&http.Request{ - Host: "localhost", - URL: &url.URL{Scheme: "http", Path: "/not-asdf", Host: "localhost"}, - Header: http.Header{ - "Host": []string{"localhost"}, - }, - }) - require.Error(t, err) + req, err = http.NewRequest("GET", "http://localhost/not-asdf", nil) + require.NoError(t, err) + _, err = rt.RoundTrip(req) + require.ErrorIs(t, err, errNotOnWhitelist) + + req, err = http.NewRequest("GET", "http://127.0.0.1", nil) + require.NoError(t, err) + _, err = rt.RoundTrip(req) + require.ErrorIs(t, err, errNotOnWhitelist) }