diff --git a/fetcher/fetcher_test.go b/fetcher/fetcher_test.go index 32f24fab..b23ecf5c 100644 --- a/fetcher/fetcher_test.go +++ b/fetcher/fetcher_test.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "github.com/ory/x/httpx" + "github.com/hashicorp/go-retryablehttp" "github.com/gobuffalo/httptest" "github.com/julienschmidt/httprouter" @@ -35,9 +35,10 @@ func TestFetcher(t *testing.T) { _, err = file.WriteString(`{"foo":"baz"}`) require.NoError(t, err) require.NoError(t, file.Close()) - + rClient := retryablehttp.NewClient() + rClient.HTTPClient = ts.Client() for fc, fetcher := range []*Fetcher{ - NewFetcher(WithClient(httpx.NewResilientClient(httpx.ResilientClientWithClient(ts.Client())))), + NewFetcher(WithClient(rClient)), NewFetcher(), } { for k, tc := range []struct { diff --git a/go.mod b/go.mod index 91018eda..89eded99 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/ory/x go 1.20 require ( + code.dny.dev/ssrf v0.2.0 github.com/auth0/go-jwt-middleware v1.0.1 github.com/avast/retry-go/v4 v4.3.0 github.com/bmatcuk/doublestar/v2 v2.0.4 @@ -90,13 +91,14 @@ require ( go.opentelemetry.io/otel/trace v1.11.1 go.opentelemetry.io/proto/otlp v0.18.0 go.uber.org/goleak v1.2.1 - golang.org/x/crypto v0.9.0 + golang.org/x/crypto v0.15.0 golang.org/x/mod v0.8.0 - golang.org/x/net v0.10.0 + golang.org/x/net v0.18.0 + golang.org/x/oauth2 v0.14.0 golang.org/x/sync v0.1.0 gonum.org/v1/plot v0.12.0 google.golang.org/grpc v1.56.3 - google.golang.org/protobuf v1.30.0 + google.golang.org/protobuf v1.31.0 ) require ( @@ -204,12 +206,14 @@ require ( go.mongodb.org/mongo-driver v1.10.3 // indirect go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.11.1 // indirect go.opentelemetry.io/otel/metric v0.33.0 // indirect + golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 // indirect golang.org/x/image v0.5.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/sys v0.14.0 // indirect + golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.1.0 // indirect golang.org/x/tools v0.6.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect + google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 714c8592..1a7ccc44 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= +code.dny.dev/ssrf v0.2.0 h1:wCBP990rQQ1CYfRpW+YK1+8xhwUjv189AQ3WMo1jQaI= +code.dny.dev/ssrf v0.2.0/go.mod h1:B+91l25OnyaLIeCx0WRJN5qfJ/4/ZTZxRXgm0lj/2w8= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= git.sr.ht/~sbinet/gg v0.3.1 h1:LNhjNn8DerC8f9DHLz6lS0YYul/b602DUxDgGkd/Aik= git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc= @@ -1034,8 +1036,9 @@ golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= +golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -1046,7 +1049,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw= +golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 h1:yZNXmy+j/JpX19vZkVktWqAo7Gny4PBWYYK3zskGpx4= +golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -1133,8 +1137,9 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= +golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1147,6 +1152,8 @@ golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc= +golang.org/x/oauth2 v0.14.0 h1:P0Vrf/2538nmC0H+pEQ3MNFRRnVR7RlqyVw+bvm26z0= +golang.org/x/oauth2 v0.14.0/go.mod h1:lAtNWgaWfL4cm7j2OV8TxGi9Qb7ECORx8DktCY74OwM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1247,8 +1254,9 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1264,9 +1272,11 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1376,6 +1386,8 @@ google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= +google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -1456,8 +1468,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/httpx/private_ip_validator.go b/httpx/private_ip_validator.go index 09e20732..f644d4c4 100644 --- a/httpx/private_ip_validator.go +++ b/httpx/private_ip_validator.go @@ -6,12 +6,10 @@ package httpx import ( "fmt" "net" - "net/http" + "net/netip" "net/url" - "syscall" - "time" - "github.com/gobwas/glob" + "code.dny.dev/ssrf" "github.com/pkg/errors" ) @@ -69,83 +67,28 @@ func DisallowIPPrivateAddresses(ipOrHostnameOrURL string) error { } for _, ip := range ips { - if ip.IsPrivate() || ip.IsLoopback() || ip.IsUnspecified() { - return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) - } - } - - return nil -} - -var _ http.RoundTripper = (*NoInternalIPRoundTripper)(nil) - -// NoInternalIPRoundTripper is a RoundTripper that disallows internal IP addresses. -type NoInternalIPRoundTripper struct { - http.RoundTripper - internalIPExceptions []string -} - -func (n NoInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { - rt := http.DefaultTransport - if n.RoundTripper != nil { - rt = n.RoundTripper - } - - incoming := IncomingRequestURL(request) - incoming.RawQuery = "" - incoming.RawFragment = "" - for _, exception := range n.internalIPExceptions { - compiled, err := glob.Compile(exception, '.', '/') + ip, err := netip.ParseAddr(ip.String()) if err != nil { - return nil, err + return ErrPrivateIPAddressDisallowed(errors.WithStack(err)) // should be unreacheable } - if compiled.Match(incoming.String()) { - return rt.RoundTrip(request) - } - } - - if err := DisallowIPPrivateAddresses(incoming.Hostname()); err != nil { - return nil, err - } - - return rt.RoundTrip(request) -} -var NoInternalDialer = &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - Control: func(network, address string, _ syscall.RawConn) error { - if !(network == "tcp4" || network == "tcp6") { - return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a safe network type", network)) - } - - host, _, err := net.SplitHostPort(address) - if err != nil { - return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a valid host/port pair: %s", address, err)) - } - - ip := net.ParseIP(host) - if ip == nil { - return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a valid IP address", host)) - } - - if ip.IsPrivate() || ip.IsLoopback() || ip.IsUnspecified() { - return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) + if ip.Is4() { + for _, deny := range ssrf.IPv4DeniedPrefixes { + if deny.Contains(ip) { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) + } + } + } else { + if !ssrf.IPv6GlobalUnicast.Contains(ip) { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) + } + for _, net := range ssrf.IPv6DeniedPrefixes { + if net.Contains(ip) { + return ErrPrivateIPAddressDisallowed(fmt.Errorf("%s is not a public IP address", ip)) + } + } } + } - return nil - }, -} - -// NoInternalTransport -// -// DEPRECATED: do not use -var NoInternalTransport http.RoundTripper = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: NoInternalDialer.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, + return nil } diff --git a/httpx/private_ip_validator_test.go b/httpx/private_ip_validator_test.go index 055a58ff..e3520ffc 100644 --- a/httpx/private_ip_validator_test.go +++ b/httpx/private_ip_validator_test.go @@ -4,13 +4,10 @@ package httpx import ( - "net" "net/http" - "net/url" "testing" "github.com/pkg/errors" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -29,9 +26,12 @@ func TestIsAssociatedIPAllowed(t *testing.T) { "0.0.0.0", "10.255.255.255", "::1", + "100::1", + "fe80::1", + "169.254.169.254", // AWS instance metadata service } { t.Run("case="+disallowed, func(t *testing.T) { - require.Error(t, DisallowIPPrivateAddresses(disallowed)) + assert.Error(t, DisallowIPPrivateAddresses(disallowed)) }) } } @@ -50,104 +50,58 @@ func (n noOpRoundTripper) RoundTrip(request *http.Request) (*http.Response, erro var _ http.RoundTripper = new(noOpRoundTripper) -type errRoundTripper struct{} +type errRoundTripper struct{ err error } -var fakeErr = errors.New("error") +var errNotOnWhitelist = errors.New("OK") +var errOnWhitelist = errors.New("OK (on whitelist)") func (n errRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { - return nil, fakeErr + return nil, n.err } var _ http.RoundTripper = new(errRoundTripper) +// TestInternalRespectsRoundTripper tests if the RoundTripper picks the correct +// underlying transport for two allowed requests. func TestInternalRespectsRoundTripper(t *testing.T) { - rt := &NoInternalIPRoundTripper{RoundTripper: &errRoundTripper{}, internalIPExceptions: []string{ - "https://127.0.0.1/foo", - }} + rt := &noInternalIPRoundTripper{ + onWhitelist: &errRoundTripper{errOnWhitelist}, + notOnWhitelist: &errRoundTripper{errNotOnWhitelist}, + internalIPExceptions: []string{ + "https://127.0.0.1/foo", + }} req, err := http.NewRequest("GET", "https://google.com/foo", nil) require.NoError(t, err) _, err = rt.RoundTrip(req) - require.ErrorIs(t, err, fakeErr) + require.ErrorIs(t, err, errNotOnWhitelist) req, err = http.NewRequest("GET", "https://127.0.0.1/foo", nil) require.NoError(t, err) _, err = rt.RoundTrip(req) - require.ErrorIs(t, err, fakeErr) + require.ErrorIs(t, err, errOnWhitelist) } func TestAllowExceptions(t *testing.T) { - rt := &NoInternalIPRoundTripper{internalIPExceptions: []string{"http://localhost/asdf"}} - - _, err := rt.RoundTrip(&http.Request{ - Host: "localhost", - URL: &url.URL{Scheme: "http", Path: "/asdf", Host: "localhost"}, - Header: http.Header{ - "Host": []string{"localhost"}, - }, - }) - // assert that the error is eiher nil or a dial error. - if err != nil { - opErr := new(net.OpError) - require.ErrorAs(t, err, &opErr) - require.Equal(t, "dial", opErr.Op) - } - - _, err = rt.RoundTrip(&http.Request{ - Host: "localhost", - URL: &url.URL{Scheme: "http", Path: "/not-asdf", Host: "localhost"}, - Header: http.Header{ - "Host": []string{"localhost"}, - }, - }) - require.Error(t, err) -} + rt := noInternalIPRoundTripper{ + onWhitelist: &errRoundTripper{errOnWhitelist}, + notOnWhitelist: &errRoundTripper{errNotOnWhitelist}, + internalIPExceptions: []string{ + "http://localhost/asdf", + }} + + req, err := http.NewRequest("GET", "http://localhost/asdf", nil) + require.NoError(t, err) + _, err = rt.RoundTrip(req) + require.ErrorIs(t, err, errOnWhitelist) -func assertErrorContains(msg string) assert.ErrorAssertionFunc { - return func(t assert.TestingT, err error, i ...interface{}) bool { - if !assert.Error(t, err, i...) { - return false - } - return assert.Contains(t, err.Error(), msg) - } -} + req, err = http.NewRequest("GET", "http://localhost/not-asdf", nil) + require.NoError(t, err) + _, err = rt.RoundTrip(req) + require.ErrorIs(t, err, errNotOnWhitelist) -func TestNoInternalDialer(t *testing.T) { - for _, tt := range []struct { - name string - network string - address string - assertErr assert.ErrorAssertionFunc - }{{ - name: "TCP public is allowed", - network: "tcp", - address: "www.google.de:443", - assertErr: assert.NoError, - }, { - name: "TCP private is denied", - network: "tcp", - address: "localhost:443", - assertErr: assertErrorContains("is not a public IP address"), - }, { - name: "UDP public is denied", - network: "udp", - address: "www.google.de:443", - assertErr: assertErrorContains("not a safe network type"), - }, { - name: "UDP public is denied", - network: "udp", - address: "www.google.de:443", - assertErr: assertErrorContains("not a safe network type"), - }, { - name: "UNIX sockets are denied", - network: "unix", - address: "/etc/passwd", - assertErr: assertErrorContains("not a safe network type"), - }} { - - t.Run("case="+tt.name, func(t *testing.T) { - _, err := NoInternalDialer.Dial(tt.network, tt.address) - tt.assertErr(t, err) - }) - } + req, err = http.NewRequest("GET", "http://127.0.0.1", nil) + require.NoError(t, err) + _, err = rt.RoundTrip(req) + require.ErrorIs(t, err, errNotOnWhitelist) } diff --git a/httpx/resilient_client.go b/httpx/resilient_client.go index 4ab77276..76188118 100644 --- a/httpx/resilient_client.go +++ b/httpx/resilient_client.go @@ -4,6 +4,7 @@ package httpx import ( + "context" "io" "log" "net/http" @@ -11,6 +12,7 @@ import ( "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/trace" + "golang.org/x/oauth2" "github.com/hashicorp/go-retryablehttp" @@ -19,6 +21,8 @@ import ( type resilientOptions struct { c *http.Client + oauthConfig *oauth2.Config + oauthToken *oauth2.Token l interface{} retryWaitMin time.Duration retryWaitMax time.Duration @@ -31,7 +35,7 @@ type resilientOptions struct { func newResilientOptions() *resilientOptions { connTimeout := time.Minute return &resilientOptions{ - c: &http.Client{Timeout: connTimeout}, + c: &http.Client{Timeout: connTimeout, Transport: http.DefaultTransport}, retryWaitMin: 1 * time.Second, retryWaitMax: 30 * time.Second, retryMax: 4, @@ -42,10 +46,11 @@ func newResilientOptions() *resilientOptions { // ResilientOptions is a set of options for the ResilientClient. type ResilientOptions func(o *resilientOptions) -// ResilientClientWithClient sets the underlying http client to use. -func ResilientClientWithClient(c *http.Client) ResilientOptions { +// ResilientClientWithOAuth2 +func ResilientClientWithOAuth2(c *oauth2.Config, t *oauth2.Token) ResilientOptions { return func(o *resilientOptions) { - o.c = c + o.oauthConfig = c + o.oauthToken = t } } @@ -114,23 +119,30 @@ func NewResilientClient(opts ...ResilientOptions) *retryablehttp.Client { } if o.noInternalIPs { - o.c.Transport = &NoInternalIPRoundTripper{ - RoundTripper: o.c.Transport, - internalIPExceptions: o.internalIPExceptions, - } + o.c.Transport = NewNoInternalIPRoundTripper(o.internalIPExceptions) } if o.tracer != nil { o.c.Transport = otelhttp.NewTransport(o.c.Transport) } - return &retryablehttp.Client{ - HTTPClient: o.c, - Logger: o.l, - RetryWaitMin: o.retryWaitMin, - RetryWaitMax: o.retryWaitMax, - RetryMax: o.retryMax, - CheckRetry: retryablehttp.DefaultRetryPolicy, - Backoff: retryablehttp.DefaultBackoff, + if o.oauthConfig != nil { + o.c.Transport = &oauth2.Transport{ + Base: o.c.Transport, + Source: o.oauthConfig.TokenSource( + context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: o.c.Transport}), + o.oauthToken, + ), + } } + + cl := retryablehttp.NewClient() + cl.HTTPClient = o.c + cl.Logger = o.l + cl.RetryWaitMin = o.retryWaitMin + cl.RetryWaitMax = o.retryWaitMax + cl.RetryMax = o.retryMax + cl.CheckRetry = retryablehttp.DefaultRetryPolicy + cl.Backoff = retryablehttp.DefaultBackoff + return cl } diff --git a/httpx/resilient_client_test.go b/httpx/resilient_client_test.go index 8dc4fc8d..b30898eb 100644 --- a/httpx/resilient_client_test.go +++ b/httpx/resilient_client_test.go @@ -10,8 +10,6 @@ import ( "net/url" "testing" - "go.opentelemetry.io/otel" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -37,75 +35,24 @@ func TestNoPrivateIPs(t *testing.T) { ResilientClientAllowInternalIPRequestsTo(allowedURL, allowedGlob), ) - for destination, passes := range map[string]bool{ - "http://127.0.0.1:" + port: false, - "http://localhost:" + port: false, - "http://192.168.178.5:" + port: false, - allowedURL: true, - "http://localhost:" + port + "/glob/bar": true, - "http://localhost:" + port + "/glob/bar/baz": false, - "http://localhost:" + port + "/FOOBAR": false, - } { - _, err := c.Get(destination) - if !passes { - require.Error(t, err) - assert.Contains(t, err.Error(), "is not a public IP address") - } else { - require.NoError(t, err) + for i := 0; i < 10; i++ { + for destination, passes := range map[string]bool{ + "http://127.0.0.1:" + port: false, + "http://localhost:" + port: false, + "http://192.168.178.5:" + port: false, + allowedURL: true, + "http://localhost:" + port + "/glob/bar": true, + "http://localhost:" + port + "/glob/bar/baz": false, + "http://localhost:" + port + "/FOOBAR": false, + // "http://make-httpbin.org-rebind-127.0.0.1-rr.1u.ms:80": true, + } { + _, err := c.Get(destination) + if !passes { + require.Errorf(t, err, "dest = %s", destination) + assert.Containsf(t, err.Error(), "is not a permitted destination", "dest = %s", destination) + } else { + require.NoErrorf(t, err, "dest = %s", destination) + } } } } - -var errClient = &http.Client{Transport: errRoundTripper{}} - -func TestNoPrivateIPsRespectsWrappedClient(t *testing.T) { - c := NewResilientClient( - ResilientClientWithMaxRetry(1), - ResilientClientDisallowInternalIPs(), - ResilientClientWithClient(errClient), - ) - _, err := c.Get("https://google.com") - require.ErrorIs(t, err, fakeErr) -} - -func TestClientWithTracerRespectsWrappedClient(t *testing.T) { - tracer := otel.Tracer("github.com/ory/x/httpx test") - c := NewResilientClient( - ResilientClientWithMaxRetry(1), - ResilientClientWithTracer(tracer), - ResilientClientWithClient(errClient), - ) - _, err := c.Get("https://google.com") - require.ErrorIs(t, err, fakeErr) -} - -func TestClientWithMultiConfigRespectsWrapperClient(t *testing.T) { - tracer := otel.Tracer("github.com/ory/x/httpx test") - c := NewResilientClient( - ResilientClientWithMaxRetry(1), - ResilientClientWithTracer(tracer), - ResilientClientDisallowInternalIPs(), - ResilientClientWithClient(errClient), - ) - _, err := c.Get("https://google.com") - require.ErrorIs(t, err, fakeErr) -} - -func TestClientWithTracer(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, _ = w.Write([]byte("Hello, world!")) - })) - t.Cleanup(ts.Close) - - tracer := otel.Tracer("github.com/ory/x/httpx test") - c := NewResilientClient( - ResilientClientWithTracer(tracer), - ) - - target, err := url.ParseRequestURI(ts.URL) - require.NoError(t, err) - - _, err = c.Get(target.String()) - - assert.NoError(t, err) -} diff --git a/httpx/ssrf.go b/httpx/ssrf.go new file mode 100644 index 00000000..64b1901c --- /dev/null +++ b/httpx/ssrf.go @@ -0,0 +1,101 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package httpx + +import ( + "net" + "net/http" + "net/netip" + "time" + + "code.dny.dev/ssrf" + "github.com/gobwas/glob" +) + +var _ http.RoundTripper = (*noInternalIPRoundTripper)(nil) + +type noInternalIPRoundTripper struct { + onWhitelist, notOnWhitelist http.RoundTripper + internalIPExceptions []string +} + +// NewNoInternalIPRoundTripper creates a RoundTripper that disallows +// non-publicly routable IP addresses, except for URLs matching the given +// exception globs. +func NewNoInternalIPRoundTripper(exceptions []string) http.RoundTripper { + if len(exceptions) > 0 { + prohibitInternal := newSSRFTransport(ssrf.New( + ssrf.WithAnyPort(), + ssrf.WithNetworks("tcp4", "tcp6"), + )) + + allowInternal := newSSRFTransport(ssrf.New( + ssrf.WithAnyPort(), + ssrf.WithNetworks("tcp4", "tcp6"), + ssrf.WithAllowedV4Prefixes( + netip.MustParsePrefix("10.0.0.0/8"), // Private-Use (RFC 1918) + netip.MustParsePrefix("127.0.0.0/8"), // Loopback (RFC 1122, Section 3.2.1.3)) + netip.MustParsePrefix("169.254.0.0/16"), // Link Local (RFC 3927) + netip.MustParsePrefix("172.16.0.0/12"), // Private-Use (RFC 1918) + netip.MustParsePrefix("192.168.0.0/16"), // Private-Use (RFC 1918) + netip.MustParsePrefix("::1/128"), // Loopback (RFC 4193) + netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193) + ), + )) + return noInternalIPRoundTripper{ + onWhitelist: allowInternal, + notOnWhitelist: prohibitInternal, + internalIPExceptions: exceptions, + } + } + prohibitInternal := newSSRFTransport(ssrf.New( + ssrf.WithAnyPort(), + ssrf.WithNetworks("tcp4", "tcp6"), + )) + return noInternalIPRoundTripper{ + onWhitelist: prohibitInternal, + notOnWhitelist: prohibitInternal, + } +} + +// RoundTrip implements http.RoundTripper. +func (n noInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { + incoming := IncomingRequestURL(request) + incoming.RawQuery = "" + incoming.RawFragment = "" + for _, exception := range n.internalIPExceptions { + compiled, err := glob.Compile(exception, '.', '/') + if err != nil { + return nil, err + } + if compiled.Match(incoming.String()) { + println("from whitelist: ", incoming.String()) + return n.onWhitelist.RoundTrip(request) + } + } + + println("not on whitelist: ", incoming.String()) + return n.notOnWhitelist.RoundTrip(request) +} + +func newSSRFTransport(g *ssrf.Guardian) http.RoundTripper { + t := newDefaultTransport() + t.DialContext = (&net.Dialer{Control: g.Safe}).DialContext + return t +} + +func newDefaultTransport() *http.Transport { + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} diff --git a/osx/file_test.go b/osx/file_test.go index ef1de190..d7ff173f 100644 --- a/osx/file_test.go +++ b/osx/file_test.go @@ -10,10 +10,9 @@ import ( "net/http/httptest" "testing" + "github.com/hashicorp/go-retryablehttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/ory/x/httpx" ) var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { @@ -27,6 +26,9 @@ func TestReadFileFromAllSources(t *testing.T) { sslTS := httptest.NewTLSServer(handler) defer sslTS.Close() + rClient := retryablehttp.NewClient() + rClient.HTTPClient = sslTS.Client() + for k, tc := range []struct { opts []Option src string @@ -49,7 +51,7 @@ func TestReadFileFromAllSources(t *testing.T) { {src: ts.URL, expectedBody: "hello world"}, {src: sslTS.URL, expectedErrContains: "x509:"}, - {src: sslTS.URL, expectedBody: "hello world", opts: []Option{WithHTTPClient(httpx.NewResilientClient(httpx.ResilientClientWithClient(sslTS.Client())))}}, + {src: sslTS.URL, expectedBody: "hello world", opts: []Option{WithHTTPClient(rClient)}}, {src: sslTS.URL, expectedErr: "http(s) loader disabled", opts: []Option{WithDisabledHTTPLoader()}}, {src: "file://stub/text.txt", expectedErr: "file loader disabled", opts: []Option{WithDisabledFileLoader()}},