Skip to content

Commit

Permalink
Clean up and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
rgzr committed Oct 3, 2024
1 parent bf8701c commit af05409
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 92 deletions.
81 changes: 30 additions & 51 deletions forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,48 +37,33 @@ func (s *TunneledConnState) String() string {
func (tun *SSHTun) forward(fromConn net.Conn) {
from := fromConn.RemoteAddr().String()

if tun.forwardType == Local {
tun.tunneledState(&TunneledConnState{
From: from,
Info: fmt.Sprintf("accepted %s connection", tun.local.Type()),
})
} else if tun.forwardType == Remote {
tun.tunneledState(&TunneledConnState{
From: from,
Info: fmt.Sprintf("accepted %s connection", tun.remote.Type()),
})
}
tun.tunneledState(&TunneledConnState{
From: from,
Info: fmt.Sprintf("accepted %s connection", tun.fromEndpoint().Type()),
})

var toConn net.Conn
var err error

if tun.forwardType == Local {
toConn, err = tun.sshClient.Dial(tun.remote.Type(), tun.remote.String())
if err != nil {
tun.tunneledState(&TunneledConnState{
From: from,
Error: fmt.Errorf("remote dial %s to %s failed: %w", tun.remote.Type(), tun.remote.String(), err),
})

fromConn.Close()
return
}
}
dialFunc := tun.sshClient.Dial
if tun.forwardType == Remote {
toConn, err = net.Dial(tun.local.Type(), tun.local.String())
if err != nil {
tun.tunneledState(&TunneledConnState{
From: from,
Error: fmt.Errorf("local dial %s to %s failed: %w", tun.local.Type(), tun.local.String(), err),
})
dialFunc = net.Dial
}

fromConn.Close()
return
}
toConn, err = dialFunc(tun.toEndpoint().Type(), tun.toEndpoint().String())
if err != nil {
tun.tunneledState(&TunneledConnState{
From: from,
Error: fmt.Errorf("%s dial %s to %s failed: %w", tun.forwardToName(),
tun.toEndpoint().Type(), tun.toEndpoint().String(), err),
})

fromConn.Close()
return
}

connStr := fmt.Sprintf("%s -(%s)> %s -(ssh)> %s -(%s)> %s", from, tun.local.Type(), tun.local.String(),
tun.server.String(), tun.remote.Type(), tun.remote.String())
connStr := fmt.Sprintf("%s -(%s)> %s <(ssh)> %s -(%s)> %s", from, tun.fromEndpoint().Type(),
tun.fromEndpoint().String(), tun.server.String(), tun.toEndpoint().Type(), tun.toEndpoint().String())

tun.tunneledState(&TunneledConnState{
From: from,
Expand All @@ -94,39 +79,34 @@ func (tun *SSHTun) forward(fromConn net.Conn) {
defer connCancel()
_, err = io.Copy(toConn, fromConn)
if err != nil {
if tun.forwardType == Local {
return fmt.Errorf("failed copying bytes from remote to local: %w", err)
} else if tun.forwardType == Remote {
return fmt.Errorf("failed copying bytes from local to remote: %w", err)
}
return fmt.Errorf("failed copying bytes from %s to %s: %w", tun.forwardToName(), tun.forwardFromName(), err)
}
return toConn.Close()
return nil
})

errGroup.Go(func() error {
defer connCancel()
_, err = io.Copy(fromConn, toConn)
if err != nil {
if tun.forwardType == Local {
return fmt.Errorf("failed copying bytes from local to remote: %w", err)
} else if tun.forwardType == Remote {
return fmt.Errorf("failed copying bytes from remote to local: %w", err)
}
return fmt.Errorf("failed copying bytes from %s to %s: %w", tun.forwardFromName(), tun.forwardToName(), err)
}
return fromConn.Close()
return nil
})

err = errGroup.Wait()

<-connCtx.Done()

fromConn.Close()
toConn.Close()

err = errGroup.Wait()

select {
case <-tun.ctx.Done():
default:
if err != nil {
tun.tunneledState(&TunneledConnState{
From: from,
Error: err,
From: from,
Error: err,
Closed: true,
})
}
Expand All @@ -135,7 +115,6 @@ func (tun *SSHTun) forward(fromConn net.Conn) {
tun.tunneledState(&TunneledConnState{
From: from,
Info: fmt.Sprintf("connection closed: %s", connStr),
Ready: false,
Closed: true,
})
}
Expand Down
71 changes: 48 additions & 23 deletions sshtun.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type SSHTun struct {
server *Endpoint
local *Endpoint
remote *Endpoint
forwardType ForwardType
forwardType ForwardType
timeout time.Duration
connState func(*SSHTun, ConnState)
tunneledConnState func(*SSHTun, *TunneledConnState)
Expand All @@ -42,7 +42,7 @@ type SSHTun struct {
type ForwardType int

const (
Local ForwardType = iota
Local ForwardType = iota
Remote
)

Expand Down Expand Up @@ -81,11 +81,9 @@ func New(localPort int, server string, remotePort int) *SSHTun {

// NewRemote does the same as New but for a remote port forward.
func NewRemote(localPort int, server string, remotePort int) *SSHTun {
sshTun := defaultSSHTun(server)
sshTun.local = NewTCPEndpoint("localhost", localPort)
sshTun.remote = NewTCPEndpoint("localhost", remotePort)
sshTun.forwardType = Remote
return sshTun
sshTun := New(localPort, server, remotePort)
sshTun.forwardType = Remote
return sshTun
}

// NewUnix does the same as New but using unix sockets.
Expand All @@ -98,20 +96,18 @@ func NewUnix(localUnixSocket string, server string, remoteUnixSocket string) *SS

// NewUnixRemote does the same as NewRemote but using unix sockets.
func NewUnixRemote(localUnixSocket string, server string, remoteUnixSocket string) *SSHTun {
sshTun := defaultSSHTun(server)
sshTun.local = NewUnixEndpoint(localUnixSocket)
sshTun.remote = NewUnixEndpoint(remoteUnixSocket)
sshTun := NewUnix(localUnixSocket, server, remoteUnixSocket)
sshTun.forwardType = Remote
return sshTun
}

func defaultSSHTun(server string) *SSHTun {
return &SSHTun{
mutex: &sync.Mutex{},
server: NewTCPEndpoint(server, 22),
user: "root",
authType: AuthTypeAuto,
timeout: time.Second * 15,
mutex: &sync.Mutex{},
server: NewTCPEndpoint(server, 22),
user: "root",
authType: AuthTypeAuto,
timeout: time.Second * 15,
forwardType: Local,
}
}
Expand Down Expand Up @@ -243,16 +239,15 @@ func (tun *SSHTun) Start(ctx context.Context) error {
if err != nil {
return tun.stop(fmt.Errorf("local listen %s on %s failed: %w", tun.local.Type(), tun.local.String(), err))
}
}
if tun.forwardType == Remote {
} else if tun.forwardType == Remote {
sshClient, err := ssh.Dial(tun.server.Type(), tun.server.String(), tun.sshConfig)
if err != nil {
return tun.stop(fmt.Errorf("ssh dial %s to %s failed: %w", tun.server.Type(), tun.server.String(), err))
}
listener, err = sshClient.Listen(tun.remote.Type(), tun.remote.String())
if err != nil {
return tun.stop(fmt.Errorf("remote listen %s on %s failed: %w", tun.remote.Type(), tun.remote.String(), err))
}
}
}

errChan := make(chan error)
Expand Down Expand Up @@ -306,17 +301,47 @@ func (tun *SSHTun) stop(err error) error {
return err
}

func (tun *SSHTun) fromEndpoint() *Endpoint {
if tun.forwardType == Remote {
return tun.remote
}

return tun.local
}

func (tun *SSHTun) toEndpoint() *Endpoint {
if tun.forwardType == Remote {
return tun.local
}

return tun.remote
}

func (tun *SSHTun) forwardFromName() string {
if tun.forwardType == Remote {
return "remote"
}

return "local"
}

func (tun *SSHTun) forwardToName() string {
if tun.forwardType == Remote {
return "local"
}

return "remote"
}

func (tun *SSHTun) listen(listener net.Listener) error {

errGroup, groupCtx := errgroup.WithContext(tun.ctx)
errGroup.Go(func() error {
for {
conn, err := listener.Accept()
if err != nil {
if tun.forwardType == Local {
return fmt.Errorf("local accept %s on %s failed: %w", tun.local.Type(), tun.local.String(), err)
} else if tun.forwardType == Remote {
return fmt.Errorf("remote accept %s on %s failed: %w", tun.remote.Type(), tun.remote.String(), err)
}
return fmt.Errorf("%s accept %s on %s failed: %w", tun.forwardFromName(),
tun.fromEndpoint().Type(), tun.fromEndpoint().String(), err)
}
errGroup.Go(func() error {
return tun.handle(conn)
Expand Down
Loading

0 comments on commit af05409

Please sign in to comment.