From 8332dcefb31f1983c6ffe588df12d2ad96ea2491 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 9 Feb 2024 17:46:58 +0800 Subject: [PATCH] Refactor and add support for the client-subnet option --- client.go | 4 +++ extension_edns0_subnet.go | 55 ++++++++++++++++++++++++++++++++++++++ extensions.go | 56 +++++++++++++++++++++++++++++++++++++++ loopback.go | 14 ---------- options.go | 31 ---------------------- quic/transport_http3.go | 33 +++++++++++------------ transport.go | 29 +++++++++++++++----- transport_base.go | 10 ++++--- transport_https.go | 18 +++++-------- transport_local.go | 15 +++++------ transport_rcode.go | 36 +++++++++++-------------- transport_tcp.go | 22 +++++++-------- transport_tls.go | 22 +++++++-------- transport_udp.go | 23 +++++++--------- 14 files changed, 218 insertions(+), 150 deletions(-) create mode 100644 extension_edns0_subnet.go create mode 100644 extensions.go delete mode 100644 loopback.go delete mode 100644 options.go diff --git a/client.go b/client.go index 7190b4d..e35a8ef 100644 --- a/client.go +++ b/client.go @@ -117,6 +117,10 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp return nil, E.New("DNS query loopback in transport[", contextTransport, "]") } ctx = contextWithTransportName(ctx, transport.Name()) + clientSubnet, loaded := ClientSubnetFromContext(ctx) + if loaded { + SetClientSubnet(message, clientSubnet, true) + } response, err := transport.Exchange(ctx, message) if err != nil { return nil, err diff --git a/extension_edns0_subnet.go b/extension_edns0_subnet.go new file mode 100644 index 0000000..fcbf370 --- /dev/null +++ b/extension_edns0_subnet.go @@ -0,0 +1,55 @@ +package dns + +import ( + "context" + "net/netip" + + "github.com/miekg/dns" +) + +type edns0SubnetTransportWrapper struct { + Transport + clientSubnet netip.Addr +} + +func (t *edns0SubnetTransportWrapper) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) { + SetClientSubnet(message, t.clientSubnet, false) + return t.Transport.Exchange(ctx, message) +} + +func SetClientSubnet(message *dns.Msg, clientSubnet netip.Addr, override bool) { + var subnetOption *dns.EDNS0_SUBNET +findExists: + for _, record := range message.Extra { + if optRecord, isOPTRecord := record.(*dns.OPT); isOPTRecord { + for _, option := range optRecord.Option { + var isEDNS0Subnet bool + subnetOption, isEDNS0Subnet = option.(*dns.EDNS0_SUBNET) + if isEDNS0Subnet { + if !override { + return + } + break findExists + } + } + } + } + if subnetOption == nil { + subnetOption = new(dns.EDNS0_SUBNET) + message.Extra = append(message.Extra, &dns.OPT{ + Hdr: dns.RR_Header{ + Name: ".", + Rrtype: dns.TypeOPT, + }, + Option: []dns.EDNS0{subnetOption}, + }) + } + subnetOption.Code = dns.EDNS0SUBNET + if clientSubnet.Is4() { + subnetOption.Family = 1 + } else { + subnetOption.Family = 2 + } + subnetOption.SourceNetmask = uint8(clientSubnet.BitLen()) + subnetOption.Address = clientSubnet.AsSlice() +} diff --git a/extensions.go b/extensions.go new file mode 100644 index 0000000..6b89562 --- /dev/null +++ b/extensions.go @@ -0,0 +1,56 @@ +package dns + +import ( + "context" + "net/netip" +) + +type disableCacheKey struct{} + +func ContextWithDisableCache(ctx context.Context, val bool) context.Context { + return context.WithValue(ctx, (*disableCacheKey)(nil), val) +} + +func DisableCacheFromContext(ctx context.Context) bool { + val := ctx.Value((*disableCacheKey)(nil)) + if val == nil { + return false + } + return val.(bool) +} + +type rewriteTTLKey struct{} + +func ContextWithRewriteTTL(ctx context.Context, val uint32) context.Context { + return context.WithValue(ctx, (*rewriteTTLKey)(nil), val) +} + +func RewriteTTLFromContext(ctx context.Context) (uint32, bool) { + val := ctx.Value((*rewriteTTLKey)(nil)) + if val == nil { + return 0, false + } + return val.(uint32), true +} + +type transportKey struct{} + +func contextWithTransportName(ctx context.Context, transportName string) context.Context { + return context.WithValue(ctx, transportKey{}, transportName) +} + +func transportNameFromContext(ctx context.Context) (string, bool) { + value, loaded := ctx.Value(transportKey{}).(string) + return value, loaded +} + +type clientSubnetKey struct{} + +func ContextWithClientSubnet(ctx context.Context, clientSubnet netip.Addr) context.Context { + return context.WithValue(ctx, clientSubnetKey{}, clientSubnet) +} + +func ClientSubnetFromContext(ctx context.Context) (netip.Addr, bool) { + clientSubnet, ok := ctx.Value(clientSubnetKey{}).(netip.Addr) + return clientSubnet, ok +} diff --git a/loopback.go b/loopback.go deleted file mode 100644 index ff77f4c..0000000 --- a/loopback.go +++ /dev/null @@ -1,14 +0,0 @@ -package dns - -import "context" - -type transportKey struct{} - -func contextWithTransportName(ctx context.Context, transportName string) context.Context { - return context.WithValue(ctx, transportKey{}, transportName) -} - -func transportNameFromContext(ctx context.Context) (string, bool) { - value, loaded := ctx.Value(transportKey{}).(string) - return value, loaded -} diff --git a/options.go b/options.go deleted file mode 100644 index b371f7f..0000000 --- a/options.go +++ /dev/null @@ -1,31 +0,0 @@ -package dns - -import "context" - -type disableCacheKey struct{} - -func ContextWithDisableCache(ctx context.Context, val bool) context.Context { - return context.WithValue(ctx, (*disableCacheKey)(nil), val) -} - -func DisableCacheFromContext(ctx context.Context) bool { - val := ctx.Value((*disableCacheKey)(nil)) - if val == nil { - return false - } - return val.(bool) -} - -type rewriteTTLKey struct{} - -func ContextWithRewriteTTL(ctx context.Context, val uint32) context.Context { - return context.WithValue(ctx, (*rewriteTTLKey)(nil), val) -} - -func RewriteTTLFromContext(ctx context.Context) (uint32, bool) { - val := ctx.Value((*rewriteTTLKey)(nil)) - if val == nil { - return 0, false - } - return val.(uint32), true -} diff --git a/quic/transport_http3.go b/quic/transport_http3.go index 96db2c0..d512c73 100644 --- a/quic/transport_http3.go +++ b/quic/transport_http3.go @@ -14,7 +14,6 @@ import ( "github.com/sagernet/quic-go/http3" "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common/bufio" - "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -24,16 +23,9 @@ import ( var _ dns.Transport = (*HTTP3Transport)(nil) func init() { - dns.RegisterTransport([]string{"h3"}, CreateHTTP3Transport) -} - -func CreateHTTP3Transport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { - linkURL, err := url.Parse(link) - if err != nil { - return nil, err - } - linkURL.Scheme = "https" - return NewHTTP3Transport(name, dialer, linkURL.String()), nil + dns.RegisterTransport([]string{"h3"}, func(options dns.TransportOptions) (dns.Transport, error) { + return NewHTTP3Transport(options) + }) } type HTTP3Transport struct { @@ -42,16 +34,21 @@ type HTTP3Transport struct { transport *http3.RoundTripper } -func NewHTTP3Transport(name string, dialer N.Dialer, serverURL string) *HTTP3Transport { +func NewHTTP3Transport(options dns.TransportOptions) (*HTTP3Transport, error) { + serverURL, err := url.Parse(options.Address) + if err != nil { + return nil, err + } + serverURL.Scheme = "https" return &HTTP3Transport{ - name: name, - destination: serverURL, + name: options.Name, + destination: options.Address, transport: &http3.RoundTripper{ Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { destinationAddr := M.ParseSocksaddr(addr) - conn, err := dialer.DialContext(ctx, N.NetworkUDP, destinationAddr) - if err != nil { - return nil, err + conn, dialErr := options.Dialer.DialContext(ctx, N.NetworkUDP, destinationAddr) + if dialErr != nil { + return nil, dialErr } return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg) }, @@ -59,7 +56,7 @@ func NewHTTP3Transport(name string, dialer N.Dialer, serverURL string) *HTTP3Tra NextProtos: []string{"dns"}, }, }, - } + }, nil } func (t *HTTP3Transport) Name() string { diff --git a/transport.go b/transport.go index 89a1f46..29d5c58 100644 --- a/transport.go +++ b/transport.go @@ -12,7 +12,7 @@ import ( "github.com/miekg/dns" ) -type TransportConstructor = func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error) +type TransportConstructor = func(options TransportOptions) (Transport, error) type Transport interface { Name() string @@ -24,6 +24,15 @@ type Transport interface { Lookup(ctx context.Context, domain string, strategy DomainStrategy) ([]netip.Addr, error) } +type TransportOptions struct { + Context context.Context + Logger logger.ContextLogger + Name string + Dialer N.Dialer + Address string + ClientSubnet netip.Addr +} + var transports map[string]TransportConstructor func RegisterTransport(schemes []string, constructor TransportConstructor) { @@ -35,10 +44,10 @@ func RegisterTransport(schemes []string, constructor TransportConstructor) { } } -func CreateTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, address string) (Transport, error) { - constructor := transports[address] +func CreateTransport(options TransportOptions) (Transport, error) { + constructor := transports[options.Address] if constructor == nil { - serverURL, _ := url.Parse(address) + serverURL, _ := url.Parse(options.Address) var scheme string if serverURL != nil { scheme = serverURL.Scheme @@ -46,7 +55,15 @@ func CreateTransport(name string, ctx context.Context, logger logger.ContextLogg constructor = transports[scheme] } if constructor == nil { - return nil, E.New("unknown DNS server format: " + address) + return nil, E.New("unknown DNS server format: " + options.Address) + } + options.Context = contextWithTransportName(options.Context, options.Name) + transport, err := constructor(options) + if err != nil { + return nil, err + } + if options.ClientSubnet.IsValid() { + transport = &edns0SubnetTransportWrapper{transport, options.ClientSubnet} } - return constructor(name, contextWithTransportName(ctx, name), logger, dialer, address) + return transport, nil } diff --git a/transport_base.go b/transport_base.go index a336370..ae51842 100644 --- a/transport_base.go +++ b/transport_base.go @@ -28,19 +28,21 @@ type myTransportAdapter struct { cancel context.CancelFunc dialer N.Dialer serverAddr M.Socksaddr + clientAddr netip.Addr handler myTransportHandler access sync.Mutex conn *dnsConnection } -func newAdapter(name string, ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr) myTransportAdapter { - ctx, cancel := context.WithCancel(ctx) +func newAdapter(options TransportOptions, serverAddr M.Socksaddr) myTransportAdapter { + ctx, cancel := context.WithCancel(options.Context) return myTransportAdapter{ - name: name, + name: options.Name, ctx: ctx, cancel: cancel, - dialer: dialer, + dialer: options.Dialer, serverAddr: serverAddr, + clientAddr: options.ClientSubnet, } } diff --git a/transport_https.go b/transport_https.go index 3cf3c5b..fd332f9 100644 --- a/transport_https.go +++ b/transport_https.go @@ -10,9 +10,7 @@ import ( "net/netip" "os" - "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" "github.com/miekg/dns" ) @@ -28,21 +26,19 @@ type HTTPSTransport struct { } func init() { - RegisterTransport([]string{"https"}, CreateHTTPSTransport) + RegisterTransport([]string{"https"}, func(options TransportOptions) (Transport, error) { + return NewHTTPSTransport(options), nil + }) } -func CreateHTTPSTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error) { - return NewHTTPSTransport(name, dialer, link), nil -} - -func NewHTTPSTransport(name string, dialer N.Dialer, serverURL string) *HTTPSTransport { +func NewHTTPSTransport(options TransportOptions) *HTTPSTransport { return &HTTPSTransport{ - name: name, - destination: serverURL, + name: options.Name, + destination: options.Address, transport: &http.Transport{ ForceAttemptHTTP2: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + return options.Dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) }, TLSClientConfig: &tls.Config{ NextProtos: []string{"dns"}, diff --git a/transport_local.go b/transport_local.go index ef16521..0852046 100644 --- a/transport_local.go +++ b/transport_local.go @@ -8,7 +8,6 @@ import ( "sort" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -16,11 +15,9 @@ import ( ) func init() { - RegisterTransport([]string{"local"}, CreateLocalTransport) -} - -func CreateLocalTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error) { - return NewLocalTransport(name, dialer), nil + RegisterTransport([]string{"local"}, func(options TransportOptions) (Transport, error) { + return NewLocalTransport(options), nil + }) } var _ Transport = (*LocalTransport)(nil) @@ -30,12 +27,12 @@ type LocalTransport struct { resolver net.Resolver } -func NewLocalTransport(name string, dialer N.Dialer) *LocalTransport { +func NewLocalTransport(options TransportOptions) *LocalTransport { return &LocalTransport{ - name: name, + name: options.Name, resolver: net.Resolver{ Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - return dialer.DialContext(ctx, N.NetworkName(network), M.ParseSocksaddr(address)) + return options.Dialer.DialContext(ctx, N.NetworkName(network), M.ParseSocksaddr(address)) }, }, } diff --git a/transport_rcode.go b/transport_rcode.go index 431bcf7..e26788e 100644 --- a/transport_rcode.go +++ b/transport_rcode.go @@ -7,8 +7,6 @@ import ( "os" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - N "github.com/sagernet/sing/common/network" "github.com/miekg/dns" ) @@ -16,15 +14,9 @@ import ( var _ Transport = (*RCodeTransport)(nil) func init() { - RegisterTransport([]string{"rcode"}, CreateRCodeTransport) -} - -func CreateRCodeTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error) { - serverURL, err := url.Parse(link) - if err != nil { - return nil, err - } - return NewRCodeTransport(name, serverURL.Host) + RegisterTransport([]string{"rcode"}, func(options TransportOptions) (Transport, error) { + return NewRCodeTransport(options) + }) } type RCodeTransport struct { @@ -32,22 +24,26 @@ type RCodeTransport struct { code RCodeError } -func NewRCodeTransport(name string, code string) (*RCodeTransport, error) { - switch code { +func NewRCodeTransport(options TransportOptions) (*RCodeTransport, error) { + serverURL, err := url.Parse(options.Address) + if err != nil { + return nil, err + } + switch serverURL.Host { case "success": - return &RCodeTransport{name, RCodeSuccess}, nil + return &RCodeTransport{options.Name, RCodeSuccess}, nil case "format_error": - return &RCodeTransport{name, RCodeFormatError}, nil + return &RCodeTransport{options.Name, RCodeFormatError}, nil case "server_failure": - return &RCodeTransport{name, RCodeServerFailure}, nil + return &RCodeTransport{options.Name, RCodeServerFailure}, nil case "name_error": - return &RCodeTransport{name, RCodeNameError}, nil + return &RCodeTransport{options.Name, RCodeNameError}, nil case "not_implemented": - return &RCodeTransport{name, RCodeNotImplemented}, nil + return &RCodeTransport{options.Name, RCodeNotImplemented}, nil case "refused": - return &RCodeTransport{name, RCodeRefused}, nil + return &RCodeTransport{options.Name, RCodeRefused}, nil default: - return nil, E.New("unknown rcode: " + code) + return nil, E.New("unknown rcode: " + options.Name) } } diff --git a/transport_tcp.go b/transport_tcp.go index 01feff9..a25a4c8 100644 --- a/transport_tcp.go +++ b/transport_tcp.go @@ -9,7 +9,6 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -19,22 +18,21 @@ import ( var _ Transport = (*TCPTransport)(nil) func init() { - RegisterTransport([]string{"tcp"}, CreateTCPTransport) -} - -func CreateTCPTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error) { - serverURL, err := url.Parse(link) - if err != nil { - return nil, err - } - return NewTCPTransport(name, ctx, dialer, M.ParseSocksaddr(serverURL.Host)) + RegisterTransport([]string{"tcp"}, func(options TransportOptions) (Transport, error) { + return NewTCPTransport(options) + }) } type TCPTransport struct { myTransportAdapter } -func NewTCPTransport(name string, ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr) (*TCPTransport, error) { +func NewTCPTransport(options TransportOptions) (*TCPTransport, error) { + serverURL, err := url.Parse(options.Address) + if err != nil { + return nil, err + } + serverAddr := M.ParseSocksaddr(serverURL.Host) if !serverAddr.IsValid() { return nil, E.New("invalid server address") } @@ -42,7 +40,7 @@ func NewTCPTransport(name string, ctx context.Context, dialer N.Dialer, serverAd serverAddr.Port = 53 } transport := &TCPTransport{ - newAdapter(name, ctx, dialer, serverAddr), + newAdapter(options, serverAddr), } transport.handler = transport return transport, nil diff --git a/transport_tls.go b/transport_tls.go index 9444dd3..375a529 100644 --- a/transport_tls.go +++ b/transport_tls.go @@ -10,7 +10,6 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -20,22 +19,21 @@ import ( var _ Transport = (*TLSTransport)(nil) func init() { - RegisterTransport([]string{"tls"}, CreateTLSTransport) -} - -func CreateTLSTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error) { - serverURL, err := url.Parse(link) - if err != nil { - return nil, err - } - return NewTLSTransport(name, ctx, dialer, M.ParseSocksaddr(serverURL.Host)) + RegisterTransport([]string{"tls"}, func(options TransportOptions) (Transport, error) { + return NewTLSTransport(options) + }) } type TLSTransport struct { myTransportAdapter } -func NewTLSTransport(name string, ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr) (*TLSTransport, error) { +func NewTLSTransport(options TransportOptions) (*TLSTransport, error) { + serverURL, err := url.Parse(options.Address) + if err != nil { + return nil, err + } + serverAddr := M.ParseSocksaddr(serverURL.Host) if !serverAddr.IsValid() { return nil, E.New("invalid server address") } @@ -43,7 +41,7 @@ func NewTLSTransport(name string, ctx context.Context, dialer N.Dialer, serverAd serverAddr.Port = 853 } transport := &TLSTransport{ - newAdapter(name, ctx, dialer, serverAddr), + newAdapter(options, serverAddr), } transport.handler = transport return transport, nil diff --git a/transport_udp.go b/transport_udp.go index b4006a6..dbe689f 100644 --- a/transport_udp.go +++ b/transport_udp.go @@ -8,9 +8,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" "github.com/miekg/dns" ) @@ -20,22 +18,21 @@ const FixedPacketSize = 16384 var _ Transport = (*UDPTransport)(nil) func init() { - RegisterTransport([]string{"udp", ""}, CreateUDPTransport) -} - -func CreateUDPTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error) { - serverURL, err := url.Parse(link) - if err != nil || serverURL.Scheme == "" { - return NewUDPTransport(name, ctx, dialer, M.ParseSocksaddr(link)) - } - return NewUDPTransport(name, ctx, dialer, M.ParseSocksaddr(serverURL.Host)) + RegisterTransport([]string{"udp", ""}, func(options TransportOptions) (Transport, error) { + return NewUDPTransport(options) + }) } type UDPTransport struct { myTransportAdapter } -func NewUDPTransport(name string, ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr) (*UDPTransport, error) { +func NewUDPTransport(options TransportOptions) (*UDPTransport, error) { + serverURL, err := url.Parse(options.Address) + if err != nil { + return nil, err + } + serverAddr := M.ParseSocksaddr(serverURL.Host) if !serverAddr.IsValid() { return nil, E.New("invalid server address") } @@ -43,7 +40,7 @@ func NewUDPTransport(name string, ctx context.Context, dialer N.Dialer, serverAd serverAddr.Port = 53 } transport := &UDPTransport{ - newAdapter(name, ctx, dialer, serverAddr), + newAdapter(options, serverAddr), } transport.handler = transport return transport, nil