Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom DNS resolver #49

Merged
merged 11 commits into from
Sep 26, 2024
42 changes: 34 additions & 8 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ type HTTPClientSettings struct {
RotatorSettings *RotatorSettings
Proxy string
TempDir string
DNSServer string
SkipHTTPStatusCodes []int
DNSServers []string
DedupeOptions DedupeOptions
DialTimeout time.Duration
ResponseHeaderTimeout time.Duration
DNSResolutionTimeout time.Duration
DNSRecordsTTL time.Duration
TLSHandshakeTimeout time.Duration
MaxReadBeforeTruncate int
TCPTimeout time.Duration
MaxReadBeforeTruncate int
DecompressBody bool
FollowRedirects bool
FullOnDisk bool
Expand All @@ -33,17 +37,19 @@ type HTTPClientSettings struct {
}

type CustomHTTPClient struct {
WARCWriter chan *RecordBatch
WaitGroup *WaitGroupWithCount
dedupeHashTable *sync.Map
ErrChan chan *Error
interfacesWatcherStop chan bool
WaitGroup *WaitGroupWithCount
dedupeHashTable *sync.Map
ErrChan chan *Error
WARCWriter chan *RecordBatch
interfacesWatcherStarted chan bool
http.Client
TempDir string
WARCWriterDoneChannels []chan bool
skipHTTPStatusCodes []int
dedupeOptions DedupeOptions
MaxReadBeforeTruncate int
TLSHandshakeTimeout time.Duration
MaxReadBeforeTruncate int
verifyCerts bool
FullOnDisk bool
randomLocalIP bool
Expand All @@ -67,6 +73,11 @@ func (c *CustomHTTPClient) Close() error {
wg.Wait()
close(c.ErrChan)

if c.randomLocalIP {
c.interfacesWatcherStop <- true
close(c.interfacesWatcherStop)
}

return nil
}

Expand All @@ -76,7 +87,10 @@ func NewWARCWritingHTTPClient(HTTPClientSettings HTTPClientSettings) (httpClient
// Configure random local IP
httpClient.randomLocalIP = HTTPClientSettings.RandomLocalIP
if httpClient.randomLocalIP {
go getAvailableIPs()
httpClient.interfacesWatcherStop = make(chan bool)
httpClient.interfacesWatcherStarted = make(chan bool)
go httpClient.getAvailableIPs()
<-httpClient.interfacesWatcherStarted
}

// Toggle deduplication options and create map for deduplication records.
Expand Down Expand Up @@ -146,10 +160,22 @@ func NewWARCWritingHTTPClient(HTTPClientSettings HTTPClientSettings) (httpClient
HTTPClientSettings.TLSHandshakeTimeout = 10 * time.Second
}

if HTTPClientSettings.TCPTimeout == 0 {
HTTPClientSettings.TCPTimeout = 10 * time.Second
}

if HTTPClientSettings.DNSResolutionTimeout == 0 {
HTTPClientSettings.DNSResolutionTimeout = 5 * time.Second
}

if HTTPClientSettings.DNSRecordsTTL == 0 {
HTTPClientSettings.DNSRecordsTTL = 5 * time.Minute
}

httpClient.TLSHandshakeTimeout = HTTPClientSettings.TLSHandshakeTimeout

// Configure custom dialer / transport
customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout, HTTPClientSettings.DisableIPv4, HTTPClientSettings.DisableIPv6)
customDialer, err := newCustomDialer(httpClient, HTTPClientSettings.Proxy, HTTPClientSettings.DialTimeout, HTTPClientSettings.DNSRecordsTTL, HTTPClientSettings.DNSResolutionTimeout, HTTPClientSettings.DNSServers, HTTPClientSettings.DisableIPv4, HTTPClientSettings.DisableIPv6)
if err != nil {
return nil, err
}
Expand Down
49 changes: 40 additions & 9 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/miekg/dns"
tls "github.com/refraction-networking/utls"
"golang.org/x/net/proxy"
"golang.org/x/sync/errgroup"
Expand All @@ -24,27 +25,48 @@ import (
type customDialer struct {
proxyDialer proxy.Dialer
client *CustomHTTPClient
disableIPv4 bool
disableIPv6 bool
DNSConfig *dns.ClientConfig
DNSClient *dns.Client
DNSRecords *sync.Map
net.Dialer
DNSServer string
DNSRecordsTTL time.Duration
disableIPv4 bool
disableIPv6 bool
}

func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout time.Duration, disableIPv4, disableIPv6 bool) (d *customDialer, err error) {
func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout, DNSRecordsTTL, DNSResolutionTimeout time.Duration, DNSServers []string, disableIPv4, disableIPv6 bool) (d *customDialer, err error) {
d = new(customDialer)

d.Timeout = DialTimeout
d.client = httpClient
d.disableIPv4 = disableIPv4
d.disableIPv6 = disableIPv6

d.DNSRecordsTTL = DNSRecordsTTL
d.DNSRecords = new(sync.Map)
d.DNSConfig, err = dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return nil, err
}

if len(DNSServers) > 0 {
d.DNSConfig.Servers = DNSServers
}

d.DNSClient = &dns.Client{
Net: "udp",
Timeout: DNSResolutionTimeout,
}

if proxyURL != "" {
u, err := url.Parse(proxyURL)
if err != nil {
panic(err.Error())
return nil, err
}

if d.proxyDialer, err = proxy.FromURL(u, d); err != nil {
panic(err.Error())
return nil, err
}
}

Expand Down Expand Up @@ -97,11 +119,16 @@ func (d *customDialer) CustomDial(network, address string) (conn net.Conn, err e
return nil, errors.New("no supported network type available")
}

IP, err := d.archiveDNS(address)
if err != nil {
return nil, err
}

if d.proxyDialer != nil {
conn, err = d.proxyDialer.Dial(network, address)
} else {
if d.client.randomLocalIP {
localAddr := getLocalAddr(network, address)
localAddr := getLocalAddr(network, IP.String())
if localAddr != nil {
if network == "tcp" || network == "tcp4" || network == "tcp6" {
d.LocalAddr = localAddr.(*net.TCPAddr)
Expand All @@ -111,7 +138,7 @@ func (d *customDialer) CustomDial(network, address string) (conn net.Conn, err e
}
}

conn, err = d.Dial(network, address)
conn, err = d.DialContext(context.Background(), network, address)
}

if err != nil {
Expand All @@ -128,14 +155,18 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error)
return nil, errors.New("no supported network type available")
}

IP, err := d.archiveDNS(address)
if err != nil {
return nil, err
}

var plainConn net.Conn
var err error

if d.proxyDialer != nil {
plainConn, err = d.proxyDialer.Dial(network, address)
} else {
if d.client.randomLocalIP {
localAddr := getLocalAddr(network, address)
localAddr := getLocalAddr(network, IP.String())
if localAddr != nil {
if network == "tcp" || network == "tcp4" || network == "tcp6" {
d.LocalAddr = localAddr.(*net.TCPAddr)
Expand Down
112 changes: 112 additions & 0 deletions dns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package warc

import (
"fmt"
"net"
"sync"
"time"

"github.com/miekg/dns"
)

type cachedIP struct {
expiresAt time.Time
ip net.IP
}

func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) {
// Get the address without the port if there is one
address, _, err = net.SplitHostPort(address)
if err != nil {
return resolvedIP, err
}

// Check if the address is already an IP
resolvedIP = net.ParseIP(address)
if resolvedIP != nil {
return resolvedIP, nil
}

// Check cache first
if cached, ok := d.DNSRecords.Load(address); ok {
cachedEntry := cached.(cachedIP)
if time.Now().Before(cachedEntry.expiresAt) {
return cachedEntry.ip, nil
}
// Cache entry expired, remove it
d.DNSRecords.Delete(address)
}

var wg sync.WaitGroup
var ipv4, ipv6 net.IP
var errA, errAAAA error

wg.Add(2)

go func() {
defer wg.Done()
ipv4, errA = d.lookupIP(address, dns.TypeA)
}()

go func() {
defer wg.Done()
ipv6, errAAAA = d.lookupIP(address, dns.TypeAAAA)
}()

wg.Wait()

if errA != nil && errAAAA != nil {
return nil, fmt.Errorf("failed to resolve DNS: A error: %v, AAAA error: %v", errA, errAAAA)
}

// Prioritize IPv6 if both are available and enabled
if ipv6 != nil && !d.disableIPv6 {
resolvedIP = ipv6
} else if ipv4 != nil && !d.disableIPv4 {
resolvedIP = ipv4
}

if resolvedIP != nil {
// Cache the result
d.DNSRecords.Store(address, cachedIP{
ip: resolvedIP,
expiresAt: time.Now().Add(d.DNSRecordsTTL),
})
return resolvedIP, nil
}

return nil, fmt.Errorf("no suitable IP address found for %s", address)
}

func (d *customDialer) lookupIP(address string, recordType uint16) (net.IP, error) {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(address), recordType)

r, _, err := d.DNSClient.Exchange(m, net.JoinHostPort(d.DNSConfig.Servers[0], d.DNSConfig.Port))
if err != nil {
return nil, err
}

// Record the DNS response
recordTypeStr := "A"
if recordType == dns.TypeAAAA {
recordTypeStr = "AAAA"
}

d.client.WriteRecord(fmt.Sprintf("dns:%s:%s", address, recordTypeStr), "resource", "text/dns", r.String())

for _, answer := range r.Answer {
switch recordType {
case dns.TypeA:
if a, ok := answer.(*dns.A); ok {
return a.A, nil
}
case dns.TypeAAAA:
if aaaa, ok := answer.(*dns.AAAA); ok {
return aaaa.AAAA, nil
}
}
}

return nil, fmt.Errorf("no %s record found", recordTypeStr)
}
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
module github.com/CorentinB/warc

go 1.22
go 1.22.0

require (
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
github.com/google/uuid v1.6.0
github.com/klauspost/compress v1.17.9
github.com/klauspost/compress v1.17.10
github.com/miekg/dns v1.1.62
github.com/paulbellamy/ratecounter v0.2.0
github.com/refraction-networking/utls v1.6.7
github.com/remeh/sizedwaitgroup v1.0.0
Expand All @@ -22,5 +23,7 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/tools v0.25.0 // indirect
)
10 changes: 8 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/klauspost/compress v1.17.10 h1:oXAz+Vh0PMUvJczoi+flxpnBEPxoER1IaAnU/NMPtT0=
github.com/klauspost/compress v1.17.10/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/paulbellamy/ratecounter v0.2.0 h1:2L/RhJq+HA8gBQImDXtLPrDXK5qAj6ozWVK/zFXVJGs=
github.com/paulbellamy/ratecounter v0.2.0/go.mod h1:Hfx1hDpSGoqxkVVpBi/IlYD7kChlfo5C6hzIHwPqfFE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand All @@ -34,12 +36,16 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE=
golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Loading
Loading