From bf64ccd2d5e25aaa9f0f8d6f1359595c1506827a Mon Sep 17 00:00:00 2001 From: dprotaso Date: Mon, 24 Jun 2024 13:21:23 -0400 Subject: [PATCH] fix unit test --- test/spoof/error_checks_test.go | 6 +++--- test/spoof/spoof_test.go | 12 ++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/test/spoof/error_checks_test.go b/test/spoof/error_checks_test.go index d516d8cae5..55bdc5568e 100644 --- a/test/spoof/error_checks_test.go +++ b/test/spoof/error_checks_test.go @@ -51,7 +51,7 @@ func TestDNSError(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, tt.url, nil) resp, err := client.Do(req) - if err != nil { + if resp != nil { defer resp.Body.Close() } if dnsError := isDNSError(err); tt.dnsError != dnsError { @@ -84,7 +84,7 @@ func TestConnectionRefused(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, tt.url, nil) resp, err := client.Do(req) - if err != nil { + if resp != nil { defer resp.Body.Close() } if connRefused := isConnectionRefused(err); tt.connRefused != connRefused { @@ -144,7 +144,7 @@ func TestTCPTimeout(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, tt.url, nil) resp, err := client.Do(req) - if err != nil { + if resp != nil { defer resp.Body.Close() } if tcpTimeout := isTCPTimeout(err); tt.tcpTimeout != tcpTimeout { diff --git a/test/spoof/spoof_test.go b/test/spoof/spoof_test.go index 5a43a78149..355d1c6b9f 100644 --- a/test/spoof/spoof_test.go +++ b/test/spoof/spoof_test.go @@ -47,12 +47,16 @@ type fakeTransport struct { func (ft *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { call := ft.calls.Add(1) - if ft.response != nil && call == 2 { - // If both a response and an error is defined, we return just the response on - // the second call to simulate a retry that passes eventually. + if ft.response != nil && ft.err != nil { + if call == 2 { + return ft.response, nil + } + return nil, ft.err + } else if ft.response != nil { return ft.response, nil } - return ft.response, ft.err + + return nil, ft.err } func TestSpoofingClient_CheckEndpointState(t *testing.T) {