From 4a5d176aacd1246cb2a52f8974eecb3d7ee7b1e2 Mon Sep 17 00:00:00 2001 From: Corentin Barreau Date: Fri, 31 Jan 2025 15:43:43 +0100 Subject: [PATCH 1/3] add: context support in dialers & DNS resolve --- dialer.go | 391 +++++++++++++++++++++++++++------------------------ dns.go | 11 +- dns_test.go | 11 +- transport.go | 6 +- 4 files changed, 227 insertions(+), 192 deletions(-) diff --git a/dialer.go b/dialer.go index 9cd1a8b..39ccee1 100644 --- a/dialer.go +++ b/dialer.go @@ -24,7 +24,7 @@ import ( ) type customDialer struct { - proxyDialer proxy.Dialer + proxyDialer proxy.ContextDialer client *CustomHTTPClient DNSConfig *dns.ClientConfig DNSClient *dns.Client @@ -66,9 +66,12 @@ func newCustomDialer(httpClient *CustomHTTPClient, proxyURL string, DialTimeout, return nil, err } - if d.proxyDialer, err = proxy.FromURL(u, d); err != nil { + var proxyDialer proxy.Dialer + if proxyDialer, err = proxy.FromURL(u, d); err != nil { return nil, err } + + d.proxyDialer = proxyDialer.(proxy.ContextDialer) } return d, nil @@ -101,12 +104,12 @@ func (cc *customConnection) Close() error { return cc.Conn.Close() } -func (d *customDialer) wrapConnection(c net.Conn, scheme string) net.Conn { +func (d *customDialer) wrapConnection(ctx context.Context, c net.Conn, scheme string) net.Conn { reqReader, reqWriter := io.Pipe() respReader, respWriter := io.Pipe() d.client.WaitGroup.Add(1) - go d.writeWARCFromConnection(reqReader, respReader, scheme, c) + go d.writeWARCFromConnection(ctx, reqReader, respReader, scheme, c) return &customConnection{ Conn: c, @@ -116,20 +119,20 @@ func (d *customDialer) wrapConnection(c net.Conn, scheme string) net.Conn { } } -func (d *customDialer) CustomDial(network, address string) (conn net.Conn, err error) { +func (d *customDialer) CustomDialContext(ctx context.Context, network, address string) (conn net.Conn, err error) { // Determine the network based on IPv4/IPv6 settings network = d.getNetworkType(network) if network == "" { return nil, errors.New("no supported network type available") } - IP, err := d.archiveDNS(address) + IP, err := d.archiveDNS(ctx, address) if err != nil { return nil, err } if d.proxyDialer != nil { - conn, err = d.proxyDialer.Dial(network, address) + conn, err = d.proxyDialer.DialContext(ctx, network, address) } else { if d.client.randomLocalIP { localAddr := getLocalAddr(network, IP.String()) @@ -142,24 +145,28 @@ func (d *customDialer) CustomDial(network, address string) (conn net.Conn, err e } } - conn, err = d.DialContext(context.Background(), network, address) + conn, err = d.DialContext(ctx, network, address) } if err != nil { return nil, err } - return d.wrapConnection(conn, "http"), nil + return d.wrapConnection(ctx, conn, "http"), nil } -func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) { +func (d *customDialer) CustomDial(network, address string) (net.Conn, error) { + return d.CustomDialContext(context.Background(), network, address) +} + +func (d *customDialer) CustomDialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { // Determine the network based on IPv4/IPv6 settings network = d.getNetworkType(network) if network == "" { return nil, errors.New("no supported network type available") } - IP, err := d.archiveDNS(address) + IP, err := d.archiveDNS(ctx, address) if err != nil { return nil, err } @@ -167,7 +174,7 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) var plainConn net.Conn if d.proxyDialer != nil { - plainConn, err = d.proxyDialer.Dial(network, address) + plainConn, err = d.proxyDialer.DialContext(ctx, network, address) } else { if d.client.randomLocalIP { localAddr := getLocalAddr(network, IP.String()) @@ -180,7 +187,7 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) } } - plainConn, err = d.Dial(network, address) + plainConn, err = d.DialContext(ctx, network, address) } if err != nil { @@ -204,7 +211,7 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) }) go func() { - err := tlsConn.Handshake() + err := tlsConn.HandshakeContext(ctx) timer.Stop() errc <- err }() @@ -217,7 +224,11 @@ func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) return nil, err } - return d.wrapConnection(tlsConn, "https"), nil + return d.wrapConnection(ctx, tlsConn, "https"), nil +} + +func (d *customDialer) CustomDialTLS(network, address string) (net.Conn, error) { + return d.CustomDialTLSContext(context.Background(), network, address) } func (d *customDialer) getNetworkType(network string) string { @@ -245,7 +256,7 @@ func (d *customDialer) getNetworkType(network string) string { } } -func (d *customDialer) writeWARCFromConnection(reqPipe, respPipe *io.PipeReader, scheme string, conn net.Conn) { +func (d *customDialer) writeWARCFromConnection(ctx context.Context, reqPipe, respPipe *io.PipeReader, scheme string, conn net.Conn) { defer d.client.WaitGroup.Done() var ( @@ -255,33 +266,37 @@ func (d *customDialer) writeWARCFromConnection(reqPipe, respPipe *io.PipeReader, recordIDs []string target string host string - errs, _ = errgroup.WithContext(context.Background()) err = new(Error) ) + errs, ctx := errgroup.WithContext(ctx) + + // Run request and response readers in parallel, respecting context errs.Go(func() error { - return d.readRequest(scheme, reqPipe, target, host, warcTargetURIChannel, recordChan) + return d.readRequest(ctx, scheme, reqPipe, target, host, warcTargetURIChannel, recordChan) }) errs.Go(func() error { - return d.readResponse(respPipe, warcTargetURIChannel, recordChan) + return d.readResponse(ctx, respPipe, warcTargetURIChannel, recordChan) }) + // Wait for both goroutines to finish readErr := errs.Wait() close(recordChan) + + // Handle context cancellation + if ctx.Err() != nil { + return + } + if readErr != nil { - // Add the error to the err structure err.Err = readErr - d.client.ErrChan <- err - // Make sure we close the WARC content buffers for record := range recordChan { - // CHeck if there's an error when closing the content and send to channel if so. - err := record.Content.Close() - if err != nil { + if closeErr := record.Content.Close(); closeErr != nil { d.client.ErrChan <- &Error{ - Err: err, + Err: closeErr, Func: "writeWARCFromConnection", } } @@ -291,21 +306,23 @@ func (d *customDialer) writeWARCFromConnection(reqPipe, respPipe *io.PipeReader, } for record := range recordChan { - recordIDs = append(recordIDs, uuid.NewString()) - batch.Records = append(batch.Records, record) + select { + case <-ctx.Done(): + return + default: + recordIDs = append(recordIDs, uuid.NewString()) + batch.Records = append(batch.Records, record) + } } if len(batch.Records) != 2 { err.Err = errors.New("warc: there was an unspecified problem creating one of the WARC records") - d.client.ErrChan <- err - // Make sure we close the WARC content buffers for _, record := range batch.Records { - err := record.Content.Close() - if err != nil { + if closeErr := record.Content.Close(); closeErr != nil { d.client.ErrChan <- &Error{ - Err: err, + Err: closeErr, Func: "writeWARCFromConnection", } } @@ -314,71 +331,165 @@ func (d *customDialer) writeWARCFromConnection(reqPipe, respPipe *io.PipeReader, return } - // Most Internet Archive tools expect requests to be AFTER responses - // in the WARC file. So we make sure that's the case. if batch.Records[0].Header.Get("WARC-Type") != "response" { slices.Reverse(batch.Records) } - // Get the WARC-Target-URI value - var warcTargetURI = <-warcTargetURIChannel + var warcTargetURI string + select { + case warcTargetURI = <-warcTargetURIChannel: + case <-ctx.Done(): + return + } - // Add headers for i, r := range batch.Records { - // Generate WARC-IP-Address if we aren't using a proxy. If we are using a proxy, the real host IP cannot be determined. - if d.proxyDialer == nil { - switch addr := conn.RemoteAddr().(type) { - case *net.UDPAddr: - case *net.TCPAddr: - IP := addr.IP.String() - r.Header.Set("WARC-IP-Address", IP) + select { + case <-ctx.Done(): + return + default: + if d.proxyDialer == nil { + switch addr := conn.RemoteAddr().(type) { + case *net.TCPAddr: + IP := addr.IP.String() + r.Header.Set("WARC-IP-Address", IP) + } } - } - // Set WARC-Record-ID and WARC-Concurrent-To - r.Header.Set("WARC-Record-ID", "") + r.Header.Set("WARC-Record-ID", "") - if i == len(recordIDs)-1 { - r.Header.Set("WARC-Concurrent-To", "") - } else { - r.Header.Set("WARC-Concurrent-To", "") - } + if i == len(recordIDs)-1 { + r.Header.Set("WARC-Concurrent-To", "") + } else { + r.Header.Set("WARC-Concurrent-To", "") + } + + r.Header.Set("WARC-Target-URI", warcTargetURI) - // Add WARC-Target-URI - r.Header.Set("WARC-Target-URI", warcTargetURI) - - // Calculate WARC-Block-Digest and Content-Length - // Those 2 steps are done at this stage of the process ON PURPOSE, to take - // advantage of the parallelization context in which this function is called. - // That way, we reduce I/O bottleneck later when the record is at the "writing" step, - // because the actual WARC writing sequential, not parallel. - _, seekError := r.Content.Seek(0, 0) - if seekError != nil { - d.client.ErrChan <- &Error{ - Err: seekError, - Func: "writeWARCFromConnection", + if _, seekErr := r.Content.Seek(0, 0); seekErr != nil { + d.client.ErrChan <- &Error{ + Err: seekErr, + Func: "writeWARCFromConnection", + } + return + } + + r.Header.Set("WARC-Block-Digest", "sha1:"+GetSHA1(r.Content)) + r.Header.Set("Content-Length", strconv.Itoa(getContentLength(r.Content))) + + if d.client.dedupeOptions.LocalDedupe { + if r.Header.Get("WARC-Type") == "response" && r.Header.Get("WARC-Payload-Digest")[5:] != "3I42H3S6NNFQ2MSVX7XZKYAYSCX5QBYJ" { + d.client.dedupeHashTable.Store(r.Header.Get("WARC-Payload-Digest")[5:], revisitRecord{ + responseUUID: recordIDs[i], + size: getContentLength(r.Content), + targetURI: warcTargetURI, + date: batch.CaptureTime, + }) + } } } + } - r.Header.Set("WARC-Block-Digest", "sha1:"+GetSHA1(r.Content)) - r.Header.Set("Content-Length", strconv.Itoa(getContentLength(r.Content))) + select { + case d.client.WARCWriter <- batch: + case <-ctx.Done(): + return + } +} - if d.client.dedupeOptions.LocalDedupe { - if r.Header.Get("WARC-Type") == "response" && r.Header.Get("WARC-Payload-Digest")[5:] != "3I42H3S6NNFQ2MSVX7XZKYAYSCX5QBYJ" { - d.client.dedupeHashTable.Store(r.Header.Get("WARC-Payload-Digest")[5:], revisitRecord{ - responseUUID: recordIDs[i], - size: getContentLength(r.Content), - targetURI: warcTargetURI, - date: batch.CaptureTime, - }) +func (d *customDialer) readRequest(ctx context.Context, scheme string, reqPipe *io.PipeReader, target string, host string, warcTargetURIChannel chan string, recordChan chan *Record) error { + var ( + warcTargetURI = scheme + "://" + requestRecord = NewRecord(d.client.TempDir, d.client.FullOnDisk) + ) + + // Initialize the request record + requestRecord.Header.Set("WARC-Type", "request") + requestRecord.Header.Set("Content-Type", "application/http; msgtype=request") + + // Copy the content from the pipe + _, err := io.Copy(requestRecord.Content, reqPipe) + if err != nil { + return fmt.Errorf("readRequest: io.Copy failed: %s", err.Error()) + } + + // Parse data for WARC-Target-URI + var ( + block = make([]byte, 1) + line string + ) + +loop: + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + n, err := requestRecord.Content.Read(block) + if n > 0 { + if string(block) == "\n" { + if isHTTPRequest(line) { + target = strings.Split(line, " ")[1] + + if host != "" && target != "" { + break loop + } else { + line = "" + continue + } + } + + if strings.HasPrefix(line, "Host: ") { + host = strings.TrimPrefix(line, "Host: ") + host = strings.TrimSuffix(host, "\r") + + if host != "" && target != "" { + break loop + } else { + line = "" + continue + } + } + + line = "" + } else { + line += string(block) + } + } else { + break } + + if err == io.EOF { + break + } + + if err != nil { + return fmt.Errorf("readRequest: could not read from request content: %s", err.Error()) + } + } + } + + // Check that we achieved to parse all the necessary data + if host != "" && target != "" { + // HTTP's request first line can include a complete path, we check that + if strings.HasPrefix(target, scheme+"://"+host) { + warcTargetURI = target + } else { + warcTargetURI += host + target } + } else { + return errors.New("unable to parse data necessary for WARC-Target-URI") } - d.client.WARCWriter <- batch + // Send the WARC-Target-URI to a channel so that it can be picked-up + // by the goroutine responsible for writing the response + warcTargetURIChannel <- warcTargetURI + + recordChan <- requestRecord + + return nil } -func (d *customDialer) readResponse(respPipe *io.PipeReader, warcTargetURIChannel chan string, recordChan chan *Record) error { +func (d *customDialer) readResponse(ctx context.Context, respPipe *io.PipeReader, warcTargetURIChannel chan string, recordChan chan *Record) error { // Initialize the response record var responseRecord = NewRecord(d.client.TempDir, d.client.FullOnDisk) responseRecord.Header.Set("WARC-Type", "response") @@ -538,27 +649,34 @@ func (d *customDialer) readResponse(respPipe *io.PipeReader, warcTargetURIChanne block = make([]byte, 1) wrote := 0 responseRecord.Content.Seek(0, 0) + + loop: for { - n, err := responseRecord.Content.Read(block) - if n > 0 { - _, err = tempBuffer.Write(block) - if err != nil { - return fmt.Errorf("readResponse: could not write to temporary buffer: %s", err.Error()) + select { + case <-ctx.Done(): + return ctx.Err() + default: + n, err := responseRecord.Content.Read(block) + if n > 0 { + _, err = tempBuffer.Write(block) + if err != nil { + return fmt.Errorf("readResponse: could not write to temporary buffer: %s", err.Error()) + } } - } - if err == io.EOF { - break - } + if err == io.EOF { + break + } - if err != nil { - return fmt.Errorf("readResponse: could not read from response content: %s", err.Error()) - } + if err != nil { + return fmt.Errorf("readResponse: could not read from response content: %s", err.Error()) + } - wrote++ + wrote++ - if wrote == endOfHeadersOffset { - break + if wrote == endOfHeadersOffset { + break loop + } } } @@ -574,90 +692,3 @@ func (d *customDialer) readResponse(respPipe *io.PipeReader, warcTargetURIChanne return nil } - -func (d *customDialer) readRequest(scheme string, reqPipe *io.PipeReader, target string, host string, warcTargetURIChannel chan string, recordChan chan *Record) error { - var ( - warcTargetURI = scheme + "://" - requestRecord = NewRecord(d.client.TempDir, d.client.FullOnDisk) - ) - - // Initialize the request record - requestRecord.Header.Set("WARC-Type", "request") - requestRecord.Header.Set("Content-Type", "application/http; msgtype=request") - - // Copy the content from the pipe - _, err := io.Copy(requestRecord.Content, reqPipe) - if err != nil { - return fmt.Errorf("readRequest: io.Copy failed: %s", err.Error()) - } - - // Parse data for WARC-Target-URI - var ( - block = make([]byte, 1) - line string - ) - - for { - n, err := requestRecord.Content.Read(block) - if n > 0 { - if string(block) == "\n" { - if isHTTPRequest(line) { - target = strings.Split(line, " ")[1] - - if host != "" && target != "" { - break - } else { - line = "" - continue - } - } - - if strings.HasPrefix(line, "Host: ") { - host = strings.TrimPrefix(line, "Host: ") - host = strings.TrimSuffix(host, "\r") - - if host != "" && target != "" { - break - } else { - line = "" - continue - } - } - - line = "" - } else { - line += string(block) - } - } else { - break - } - - if err == io.EOF { - break - } - - if err != nil { - return fmt.Errorf("readRequest: could not read from request content: %s", err.Error()) - } - } - - // Check that we achieved to parse all the necessary data - if host != "" && target != "" { - // HTTP's request first line can include a complete path, we check that - if strings.HasPrefix(target, scheme+"://"+host) { - warcTargetURI = target - } else { - warcTargetURI += host + target - } - } else { - return errors.New("unable to parse data necessary for WARC-Target-URI") - } - - // Send the WARC-Target-URI to a channel so that it can be picked-up - // by the goroutine responsible for writing the response - warcTargetURIChannel <- warcTargetURI - - recordChan <- requestRecord - - return nil -} diff --git a/dns.go b/dns.go index 0a7b615..d715012 100644 --- a/dns.go +++ b/dns.go @@ -1,6 +1,7 @@ package warc import ( + "context" "fmt" "net" "sync" @@ -16,7 +17,7 @@ type cachedIP struct { const maxFallbackDNSServers = 3 -func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) { +func (d *customDialer) archiveDNS(ctx context.Context, address string) (resolvedIP net.IP, err error) { // Get the address without the port if there is one address, _, err = net.SplitHostPort(address) if err != nil { @@ -54,7 +55,7 @@ func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) wg.Add(1) go func() { defer wg.Done() - ipv4, errA = d.lookupIP(address, dns.TypeA, DNSServer) + ipv4, errA = d.lookupIP(ctx, address, dns.TypeA, DNSServer) }() } @@ -62,7 +63,7 @@ func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) wg.Add(1) go func() { defer wg.Done() - ipv6, errAAAA = d.lookupIP(address, dns.TypeAAAA, DNSServer) + ipv6, errAAAA = d.lookupIP(ctx, address, dns.TypeAAAA, DNSServer) }() } @@ -95,11 +96,11 @@ func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) return nil, fmt.Errorf("no suitable IP address found for %s", address) } -func (d *customDialer) lookupIP(address string, recordType uint16, DNSServer int) (net.IP, error) { +func (d *customDialer) lookupIP(ctx context.Context, address string, recordType uint16, DNSServer int) (net.IP, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(address), recordType) - r, _, err := d.DNSClient.Exchange(m, net.JoinHostPort(d.DNSConfig.Servers[DNSServer], d.DNSConfig.Port)) + r, _, err := d.DNSClient.ExchangeContext(ctx, m, net.JoinHostPort(d.DNSConfig.Servers[DNSServer], d.DNSConfig.Port)) if err != nil { return nil, err } diff --git a/dns_test.go b/dns_test.go index 45d6bac..710b2f4 100644 --- a/dns_test.go +++ b/dns_test.go @@ -1,6 +1,7 @@ package warc import ( + "context" "errors" "os" "sync" @@ -76,7 +77,7 @@ func TestNoDNSServersConfigured(t *testing.T) { wantErr := errors.New("no DNS servers configured") d.DNSConfig.Servers = []string{} - _, err := d.archiveDNS(target) + _, err := d.archiveDNS(context.Background(), target) if err.Error() != wantErr.Error() { t.Errorf("Want error %s, got %s", wantErr, err) } @@ -89,7 +90,7 @@ func TestNormalDNSResolution(t *testing.T) { defer cleanup() d.DNSConfig.Servers = []string{publicDNS} - IP, err := d.archiveDNS(target) + IP, err := d.archiveDNS(context.Background(), target) if err != nil { t.Fatal(err) } @@ -114,7 +115,7 @@ func TestIPv6Only(t *testing.T) { d.disableIPv6 = false d.DNSConfig.Servers = []string{publicDNS} - IP, err := d.archiveDNS(target) + IP, err := d.archiveDNS(context.Background(), target) if err != nil { t.Fatal(err) } @@ -127,7 +128,7 @@ func TestNXDOMAIN(t *testing.T) { d, _, cleanup := setup(t) defer cleanup() - IP, err := d.archiveDNS(nxdomain) + IP, err := d.archiveDNS(context.Background(), nxdomain) if err == nil { t.Error("Want failure,", "got resolved IP", IP) } @@ -141,7 +142,7 @@ func TestDNSFallback(t *testing.T) { d.DNSRecords.Delete(targetHost) d.DNSConfig.Servers = []string{invalidDNS, publicDNS} - IP, err := d.archiveDNS(target) + IP, err := d.archiveDNS(context.Background(), target) if err != nil { t.Fatal(err) } diff --git a/transport.go b/transport.go index d0e527c..2f4c02f 100644 --- a/transport.go +++ b/transport.go @@ -39,8 +39,10 @@ func newCustomTransport(dialer *customDialer, decompressBody bool, TLSHandshakeT t.t = http.Transport{ // configure HTTP transport - Dial: dialer.CustomDial, - DialTLS: dialer.CustomDialTLS, + Dial: dialer.CustomDial, + DialContext: dialer.CustomDialContext, + DialTLS: dialer.CustomDialTLS, + DialTLSContext: dialer.CustomDialTLSContext, // disable keep alive MaxConnsPerHost: 0, From 5ba230e4566b42ea916bf69cff95564780f4f48f Mon Sep 17 00:00:00 2001 From: Corentin Barreau Date: Fri, 31 Jan 2025 16:08:26 +0100 Subject: [PATCH 2/3] add: context support in dialers & DNS resolve --- dialer.go | 245 +++++++++++++++++++++++++++--------------------------- 1 file changed, 123 insertions(+), 122 deletions(-) diff --git a/dialer.go b/dialer.go index 39ccee1..ebc2d01 100644 --- a/dialer.go +++ b/dialer.go @@ -267,10 +267,9 @@ func (d *customDialer) writeWARCFromConnection(ctx context.Context, reqPipe, res target string host string err = new(Error) + errs = errgroup.Group{} ) - errs, ctx := errgroup.WithContext(ctx) - // Run request and response readers in parallel, respecting context errs.Go(func() error { return d.readRequest(ctx, scheme, reqPipe, target, host, warcTargetURIChannel, recordChan) @@ -284,11 +283,6 @@ func (d *customDialer) writeWARCFromConnection(ctx context.Context, reqPipe, res readErr := errs.Wait() close(recordChan) - // Handle context cancellation - if ctx.Err() != nil { - return - } - if readErr != nil { err.Err = readErr d.client.ErrChan <- err @@ -396,99 +390,6 @@ func (d *customDialer) writeWARCFromConnection(ctx context.Context, reqPipe, res } } -func (d *customDialer) readRequest(ctx context.Context, scheme string, reqPipe *io.PipeReader, target string, host string, warcTargetURIChannel chan string, recordChan chan *Record) error { - var ( - warcTargetURI = scheme + "://" - requestRecord = NewRecord(d.client.TempDir, d.client.FullOnDisk) - ) - - // Initialize the request record - requestRecord.Header.Set("WARC-Type", "request") - requestRecord.Header.Set("Content-Type", "application/http; msgtype=request") - - // Copy the content from the pipe - _, err := io.Copy(requestRecord.Content, reqPipe) - if err != nil { - return fmt.Errorf("readRequest: io.Copy failed: %s", err.Error()) - } - - // Parse data for WARC-Target-URI - var ( - block = make([]byte, 1) - line string - ) - -loop: - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - n, err := requestRecord.Content.Read(block) - if n > 0 { - if string(block) == "\n" { - if isHTTPRequest(line) { - target = strings.Split(line, " ")[1] - - if host != "" && target != "" { - break loop - } else { - line = "" - continue - } - } - - if strings.HasPrefix(line, "Host: ") { - host = strings.TrimPrefix(line, "Host: ") - host = strings.TrimSuffix(host, "\r") - - if host != "" && target != "" { - break loop - } else { - line = "" - continue - } - } - - line = "" - } else { - line += string(block) - } - } else { - break - } - - if err == io.EOF { - break - } - - if err != nil { - return fmt.Errorf("readRequest: could not read from request content: %s", err.Error()) - } - } - } - - // Check that we achieved to parse all the necessary data - if host != "" && target != "" { - // HTTP's request first line can include a complete path, we check that - if strings.HasPrefix(target, scheme+"://"+host) { - warcTargetURI = target - } else { - warcTargetURI += host + target - } - } else { - return errors.New("unable to parse data necessary for WARC-Target-URI") - } - - // Send the WARC-Target-URI to a channel so that it can be picked-up - // by the goroutine responsible for writing the response - warcTargetURIChannel <- warcTargetURI - - recordChan <- requestRecord - - return nil -} - func (d *customDialer) readResponse(ctx context.Context, respPipe *io.PipeReader, warcTargetURIChannel chan string, recordChan chan *Record) error { // Initialize the response record var responseRecord = NewRecord(d.client.TempDir, d.client.FullOnDisk) @@ -506,6 +407,12 @@ func (d *customDialer) readResponse(ctx context.Context, respPipe *io.PipeReader return fmt.Errorf("readResponse: io.Copy failed: %s", err.Error()) } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + resp, err := http.ReadResponse(bufio.NewReader(responseRecord.Content), nil) if err != nil { closeErr := responseRecord.Content.Close() @@ -649,34 +556,27 @@ func (d *customDialer) readResponse(ctx context.Context, respPipe *io.PipeReader block = make([]byte, 1) wrote := 0 responseRecord.Content.Seek(0, 0) - - loop: for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - n, err := responseRecord.Content.Read(block) - if n > 0 { - _, err = tempBuffer.Write(block) - if err != nil { - return fmt.Errorf("readResponse: could not write to temporary buffer: %s", err.Error()) - } + n, err := responseRecord.Content.Read(block) + if n > 0 { + _, err = tempBuffer.Write(block) + if err != nil { + return fmt.Errorf("readResponse: could not write to temporary buffer: %s", err.Error()) } + } - if err == io.EOF { - break - } + if err == io.EOF { + break + } - if err != nil { - return fmt.Errorf("readResponse: could not read from response content: %s", err.Error()) - } + if err != nil { + return fmt.Errorf("readResponse: could not read from response content: %s", err.Error()) + } - wrote++ + wrote++ - if wrote == endOfHeadersOffset { - break loop - } + if wrote == endOfHeadersOffset { + break } } @@ -692,3 +592,104 @@ func (d *customDialer) readResponse(ctx context.Context, respPipe *io.PipeReader return nil } + +func (d *customDialer) readRequest(ctx context.Context, scheme string, reqPipe *io.PipeReader, target string, host string, warcTargetURIChannel chan string, recordChan chan *Record) error { + var ( + warcTargetURI = scheme + "://" + requestRecord = NewRecord(d.client.TempDir, d.client.FullOnDisk) + ) + + // Initialize the request record + requestRecord.Header.Set("WARC-Type", "request") + requestRecord.Header.Set("Content-Type", "application/http; msgtype=request") + + // Copy the content from the pipe + _, err := io.Copy(requestRecord.Content, reqPipe) + if err != nil { + return fmt.Errorf("readRequest: io.Copy failed: %s", err.Error()) + } + + // Parse data for WARC-Target-URI + var ( + block = make([]byte, 1) + line string + ) + +loop: + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + n, err := requestRecord.Content.Read(block) + if n > 0 { + if string(block) == "\n" { + if isHTTPRequest(line) { + target = strings.Split(line, " ")[1] + + if host != "" && target != "" { + break loop + } else { + line = "" + continue + } + } + + if strings.HasPrefix(line, "Host: ") { + host = strings.TrimPrefix(line, "Host: ") + host = strings.TrimSuffix(host, "\r") + + if host != "" && target != "" { + break loop + } else { + line = "" + continue + } + } + + line = "" + } else { + line += string(block) + } + } else { + break + } + + if err == io.EOF { + break + } + + if err != nil { + return fmt.Errorf("readRequest: could not read from request content: %s", err.Error()) + } + } + } + + // Check that we achieved to parse all the necessary data + if host != "" && target != "" { + // HTTP's request first line can include a complete path, we check that + if strings.HasPrefix(target, scheme+"://"+host) { + warcTargetURI = target + } else { + warcTargetURI += host + target + } + } else { + return errors.New("unable to parse data necessary for WARC-Target-URI") + } + + // Send the WARC-Target-URI to a channel so that it can be picked-up + // by the goroutine responsible for writing the response + select { + case <-ctx.Done(): + return ctx.Err() + case warcTargetURIChannel <- warcTargetURI: + } + + select { + case <-ctx.Done(): + return ctx.Err() + case recordChan <- requestRecord: + } + + return nil +} From 57cb0d1d356e1beb0e22ba5687d0d362f0de3d33 Mon Sep 17 00:00:00 2001 From: Corentin Barreau Date: Fri, 31 Jan 2025 16:23:51 +0100 Subject: [PATCH 3/3] cosmetic fix --- dialer.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dialer.go b/dialer.go index ebc2d01..f877c39 100644 --- a/dialer.go +++ b/dialer.go @@ -284,8 +284,10 @@ func (d *customDialer) writeWARCFromConnection(ctx context.Context, reqPipe, res close(recordChan) if readErr != nil { - err.Err = readErr - d.client.ErrChan <- err + d.client.ErrChan <- &Error{ + Err: readErr, + Func: "writeWARCFromConnection", + } for record := range recordChan { if closeErr := record.Content.Close(); closeErr != nil {