Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Nov 14, 2023
1 parent 318b4cd commit c1af213
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 35 deletions.
24 changes: 22 additions & 2 deletions httpx/private_ip_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package httpx
import (
"fmt"
"net"
"net/netip"
"net/url"

"code.dny.dev/ssrf"
"github.com/pkg/errors"
)

Expand Down Expand Up @@ -65,8 +67,26 @@ func DisallowIPPrivateAddresses(ipOrHostnameOrURL string) error {
}

for _, ip := range ips {
if ip.IsPrivate() || ip.IsLoopback() || ip.IsUnspecified() {
return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip))
ip, err := netip.ParseAddr(ip.String())
if err != nil {
return ErrPrivateIPAddressDisallowed(errors.WithStack(err)) // should be unreacheable
}

if ip.Is4() {
for _, deny := range ssrf.IPv4DeniedPrefixes {
if deny.Contains(ip) {
return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip))
}
}
} else {
if !ssrf.IPv6GlobalUnicast.Contains(ip) {
return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip))
}
for _, net := range ssrf.IPv6DeniedPrefixes {
if net.Contains(ip) {
return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip))
}
}
}
}

Expand Down
61 changes: 28 additions & 33 deletions httpx/private_ip_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
package httpx

import (
"net"
"net/http"
"net/url"
"testing"

"github.com/pkg/errors"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -28,9 +26,12 @@ func TestIsAssociatedIPAllowed(t *testing.T) {
"0.0.0.0",
"10.255.255.255",
"::1",
"100::1",
"fe80::1",
"169.254.169.254", // AWS instance metadata service
} {
t.Run("case="+disallowed, func(t *testing.T) {
require.Error(t, DisallowIPPrivateAddresses(disallowed))
assert.Error(t, DisallowIPPrivateAddresses(disallowed))
})
}
}
Expand All @@ -51,62 +52,56 @@ var _ http.RoundTripper = new(noOpRoundTripper)

type errRoundTripper struct{ err error }

var err1 = errors.New("error 1")
var err2 = errors.New("error 2")
var errNotOnWhitelist = errors.New("OK")
var errOnWhitelist = errors.New("OK (on whitelist)")

func (n errRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
return nil, n.err
}

var _ http.RoundTripper = new(errRoundTripper)

// TestInternalRespectsRoundTripper tests if the RoundTripper picks the correct
// underlying transport for two allowed requests.
func TestInternalRespectsRoundTripper(t *testing.T) {
rt := &noInternalIPRoundTripper{
onWhitelist: &errRoundTripper{err1},
notOnWhitelist: &errRoundTripper{err2},
onWhitelist: &errRoundTripper{errOnWhitelist},
notOnWhitelist: &errRoundTripper{errNotOnWhitelist},
internalIPExceptions: []string{
"https://127.0.0.1/foo",
}}

req, err := http.NewRequest("GET", "https://google.com/foo", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, err2)
require.ErrorIs(t, err, errNotOnWhitelist)

req, err = http.NewRequest("GET", "https://127.0.0.1/foo", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, err1)
require.ErrorIs(t, err, errOnWhitelist)
}

func TestAllowExceptions(t *testing.T) {
rt := noInternalIPRoundTripper{
onWhitelist: &errRoundTripper{},
notOnWhitelist: &errRoundTripper{},
onWhitelist: &errRoundTripper{errOnWhitelist},
notOnWhitelist: &errRoundTripper{errNotOnWhitelist},
internalIPExceptions: []string{
"http://localhost/asdf",
}}

_, err := rt.RoundTrip(&http.Request{
Host: "localhost",
URL: &url.URL{Scheme: "http", Path: "/asdf", Host: "localhost"},
Header: http.Header{
"Host": []string{"localhost"},
},
})
// assert that the error is eiher nil or a dial error.
if err != nil {
opErr := new(net.OpError)
require.ErrorAs(t, err, &opErr)
require.Equal(t, "dial", opErr.Op)
}
req, err := http.NewRequest("GET", "http://localhost/asdf", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, errOnWhitelist)

_, err = rt.RoundTrip(&http.Request{
Host: "localhost",
URL: &url.URL{Scheme: "http", Path: "/not-asdf", Host: "localhost"},
Header: http.Header{
"Host": []string{"localhost"},
},
})
require.Error(t, err)
req, err = http.NewRequest("GET", "http://localhost/not-asdf", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, errNotOnWhitelist)

req, err = http.NewRequest("GET", "http://127.0.0.1", nil)
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.ErrorIs(t, err, errNotOnWhitelist)
}

0 comments on commit c1af213

Please sign in to comment.