diff --git a/netconf/transport_ssh.go b/netconf/transport_ssh.go index 1e40b55..eb6a1b9 100644 --- a/netconf/transport_ssh.go +++ b/netconf/transport_ssh.go @@ -34,6 +34,9 @@ type TransportSSH struct { TransportBasicIO sshClient *ssh.Client sshSession *ssh.Session + + // SSH Client connection is managed externally + externClient bool } // Close closes an existing SSH session and socket if they exist. @@ -48,13 +51,15 @@ func (t *TransportSSH) Close() error { if err := t.sshSession.Close(); err != nil { // If we receive an error when trying to close the session, then // lets try to close the socket, otherwise it will be left open - t.sshClient.Close() + if !t.externClient { + t.sshClient.Close() + } return err } } // Close the socket - if t.sshClient != nil { + if !t.externClient && t.sshClient != nil { return t.sshClient.Close() } return fmt.Errorf("No connection to close") @@ -117,6 +122,17 @@ func NewSSHSession(conn net.Conn, config *ssh.ClientConfig) (*Session, error) { return NewSession(t), nil } +// NewSSHSessionForExternalClient creates a new NETCONF session using an existing ssh.Client +// initiated and managed externally. +func NewSSHSessionForExternalClient(client *ssh.Client) (*Session, error) { + t, err := clientToTransport(client) + if err != nil { + return nil, err + } + + return NewSession(t), nil +} + // DialSSH creates a new NETCONF session using a SSH Transport. // See TransportSSH.Dial for arguments. func DialSSH(target string, config *ssh.ClientConfig) (*Session, error) { @@ -244,6 +260,20 @@ func connToTransport(conn net.Conn, config *ssh.ClientConfig) (*TransportSSH, er return t, nil } +func clientToTransport(client *ssh.Client) (*TransportSSH, error) { + t := &TransportSSH{ + sshClient: client, + externClient: true, + } + + err := t.setupSession() + if err != nil { + return nil, err + } + + return t, nil +} + type deadlineConn struct { net.Conn timeout time.Duration