diff --git a/common/bufio/copy.go b/common/bufio/copy.go index e48677650..914865a65 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -45,7 +45,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) { dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) if srcIsSyscall && dstIsSyscall { var handled bool - handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) + handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) if handled { return } diff --git a/common/bufio/copy_direct.go b/common/bufio/copy_direct.go index 1648c03be..f34d3844a 100644 --- a/common/bufio/copy_direct.go +++ b/common/bufio/copy_direct.go @@ -1,12 +1,16 @@ package bufio import ( + "errors" + "io" "syscall" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { +func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) { rawSource, err := source.SyscallConn() if err != nil { return @@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N. handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters) return } + +func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { + handled = true + var ( + buffer *buf.Buffer + notFirstTime bool + ) + for { + buffer, err = source.WaitReadBuffer() + if err != nil { + if errors.Is(err, io.EOF) { + err = nil + return + } + return + } + dataLen := buffer.Len() + err = destination.WriteBuffer(buffer) + if err != nil { + buffer.Leak() + if !notFirstTime { + err = N.ReportHandshakeFailure(originSource, err) + } + return + } + n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } + notFirstTime = true + } +} + +func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { + handled = true + var ( + buffer *buf.Buffer + destination M.Socksaddr + ) + for { + buffer, destination, err = source.WaitReadPacket() + if err != nil { + return + } + dataLen := buffer.Len() + err = destinationConn.WritePacket(buffer, destination) + if err != nil { + buffer.Leak() + if !notFirstTime { + err = N.ReportHandshakeFailure(originSource, err) + } + return + } + n += int64(dataLen) + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } + notFirstTime = true + } +} diff --git a/common/bufio/copy_direct_posix.go b/common/bufio/copy_direct_posix.go index ce3d3c3a8..f24d53b9e 100644 --- a/common/bufio/copy_direct_posix.go +++ b/common/bufio/copy_direct_posix.go @@ -3,7 +3,6 @@ package bufio import ( - "errors" "io" "net/netip" "os" @@ -15,72 +14,6 @@ import ( N "github.com/sagernet/sing/common/network" ) -func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { - handled = true - var ( - buffer *buf.Buffer - notFirstTime bool - ) - for { - buffer, err = source.WaitReadBuffer() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - return - } - return - } - dataLen := buffer.Len() - err = destination.WriteBuffer(buffer) - if err != nil { - buffer.Leak() - if !notFirstTime { - err = N.ReportHandshakeFailure(originSource, err) - } - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true - } -} - -func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { - handled = true - var ( - buffer *buf.Buffer - destination M.Socksaddr - ) - for { - buffer, destination, err = source.WaitReadPacket() - if err != nil { - return - } - dataLen := buffer.Len() - err = destinationConn.WritePacket(buffer, destination) - if err != nil { - buffer.Leak() - if !notFirstTime { - err = N.ReportHandshakeFailure(originSource, err) - } - return - } - n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - notFirstTime = true - } -} - var _ N.ReadWaiter = (*syscallReadWaiter)(nil) type syscallReadWaiter struct { diff --git a/common/bufio/copy_direct_windows.go b/common/bufio/copy_direct_windows.go index 22a2de095..b6317aad5 100644 --- a/common/bufio/copy_direct_windows.go +++ b/common/bufio/copy_direct_windows.go @@ -1,19 +1,9 @@ package bufio import ( - "io" - N "github.com/sagernet/sing/common/network" ) -func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { - return -} - -func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { - return -} - func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) { return nil, false }