Skip to content

Commit

Permalink
XHTTP XMUX: Abandon client if client.Do(req) failed (#4253)
Browse files Browse the repository at this point in the history
  • Loading branch information
RPRX authored Jan 6, 2025
1 parent aeb12d9 commit ce6c0dc
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 14 deletions.
25 changes: 15 additions & 10 deletions transport/internet/splithttp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body i
},
})

method := "GET"
method := "GET" // stream-down
if body != nil {
method = "POST"
method = "POST" // stream-up/one
}
req, _ := http.NewRequestWithContext(ctx, method, url, body)
req, _ := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, body)
req.Header = c.transportConfig.GetRequestHeader()
if method == "POST" && !c.transportConfig.NoGRPCHeader {
req.Header.Set("Content-Type", "application/grpc")
Expand All @@ -69,17 +69,20 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body i
go func() {
resp, err := c.client.Do(req)
if err != nil {
if !uploadOnly {
c.closed = true
}
errors.LogInfoInner(ctx, err, "failed to "+method+" "+url)
gotConn.Close()
wrc.Close()
return
}
if resp.StatusCode != 200 && !uploadOnly {
// c.closed = true
errors.LogInfo(ctx, "unexpected status ", resp.StatusCode)
}
if resp.StatusCode != 200 || uploadOnly {
resp.Body.Close()
if resp.StatusCode != 200 || uploadOnly { // stream-up
io.Copy(io.Discard, resp.Body)
resp.Body.Close() // if it is called immediately, the upload will be interrupted also
wrc.Close()
return
}
Expand All @@ -91,7 +94,7 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body i
}

func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, body io.Reader, contentLength int64) error {
req, err := http.NewRequestWithContext(ctx, "POST", url, body)
req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), "POST", url, body)
if err != nil {
return err
}
Expand All @@ -101,13 +104,14 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, body i
if c.httpVersion != "1.1" {
resp, err := c.client.Do(req)
if err != nil {
c.closed = true
return err
}

io.Copy(io.Discard, resp.Body)
defer resp.Body.Close()

if resp.StatusCode != 200 {
// c.closed = true
return errors.New("bad status code:", resp.Status)
}
} else {
Expand Down Expand Up @@ -139,11 +143,12 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, body i
if h1UploadConn.UnreadedResponsesCount > 0 {
resp, err := http.ReadResponse(h1UploadConn.RespBufReader, req)
if err != nil {
c.closed = true
return fmt.Errorf("error while reading response: %s", err.Error())
}
io.Copy(io.Discard, resp.Body)
defer resp.Body.Close()
if resp.StatusCode != 200 {
// c.closed = true
// resp.Body.Close() // I'm not sure
return fmt.Errorf("got non-200 error response code: %d", resp.StatusCode)
}
}
Expand Down
5 changes: 4 additions & 1 deletion transport/internet/splithttp/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/transport/internet"
)

Expand Down Expand Up @@ -36,10 +37,11 @@ func (c *Config) GetNormalizedQuery() string {
if query != "" {
query += "&"
}
query += "x_version=" + core.Version()

paddingLen := c.GetNormalizedXPaddingBytes().rand()
if paddingLen > 0 {
query += "x_padding=" + strings.Repeat("0", int(paddingLen))
query += "&x_padding=" + strings.Repeat("0", int(paddingLen))
}

return query
Expand All @@ -58,6 +60,7 @@ func (c *Config) WriteResponseHeader(writer http.ResponseWriter) {
// CORS headers for the browser dialer
writer.Header().Set("Access-Control-Allow-Origin", "*")
writer.Header().Set("Access-Control-Allow-Methods", "GET, POST")
writer.Header().Set("X-Version", core.Version())
paddingLen := c.GetNormalizedXPaddingBytes().rand()
if paddingLen > 0 {
writer.Header().Set("X-Padding", strings.Repeat("0", int(paddingLen)))
Expand Down
6 changes: 3 additions & 3 deletions transport/internet/splithttp/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,14 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
if xmuxClient != nil {
xmuxClient.LeftRequests.Add(-1)
}
conn.reader, conn.remoteAddr, conn.localAddr, _ = httpClient.OpenStream(context.WithoutCancel(ctx), requestURL.String(), reader, false)
conn.reader, conn.remoteAddr, conn.localAddr, _ = httpClient.OpenStream(ctx, requestURL.String(), reader, false)
return stat.Connection(&conn), nil
} else { // stream-down
var err error
if xmuxClient2 != nil {
xmuxClient2.LeftRequests.Add(-1)
}
conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient2.OpenStream(context.WithoutCancel(ctx), requestURL2.String(), nil, false)
conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient2.OpenStream(ctx, requestURL2.String(), nil, false)
if err != nil { // browser dialer only
return nil, err
}
Expand Down Expand Up @@ -454,7 +454,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me

go func() {
err := httpClient.PostPacket(
context.WithoutCancel(ctx),
ctx,
url.String(),
&buf.MultiBufferContainer{MultiBuffer: chunk},
int64(chunk.Len()),
Expand Down
14 changes: 14 additions & 0 deletions transport/internet/splithttp/hub.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package splithttp

import (
"bytes"
"context"
"crypto/tls"
"io"
Expand Down Expand Up @@ -102,6 +103,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req

h.config.WriteResponseHeader(writer)

clientVer := []int{0, 0, 0}
x_version := strings.Split(request.URL.Query().Get("x_version"), ".")
for j := 0; j < 3 && len(x_version) > j; j++ {
clientVer[j], _ = strconv.Atoi(x_version[j])
}

validRange := h.config.GetNormalizedXPaddingBytes()
x_padding := int32(len(request.URL.Query().Get("x_padding")))
if validRange.To > 0 && (x_padding < validRange.From || x_padding > validRange.To) {
Expand Down Expand Up @@ -160,6 +167,13 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
writer.WriteHeader(http.StatusConflict)
} else {
writer.WriteHeader(http.StatusOK)
if request.ProtoMajor != 1 && len(clientVer) > 0 && clientVer[0] >= 25 {
paddingLen := h.config.GetNormalizedXPaddingBytes().rand()
if paddingLen > 0 {
writer.Write(bytes.Repeat([]byte{'0'}, int(paddingLen)))
}
writer.(http.Flusher).Flush()
}
<-request.Context().Done()
}
return
Expand Down

0 comments on commit ce6c0dc

Please sign in to comment.