diff --git a/msg.go b/msg.go index 667bfbe..cdd02d5 100644 --- a/msg.go +++ b/msg.go @@ -9,6 +9,18 @@ import ( "golang.org/x/exp/slices" ) +var ( + RPCReplyName = xml.Name{ + Space: "urn:ietf:params:xml:ns:netconf:base:1.0", + Local: "rpc-reply", + } + + NofificationName = xml.Name{ + Space: "urn:ietf:params:xml:ns:netconf:notification:1.0", + Local: "notification", + } +) + // RawXML captures the raw xml for the given element. Used to process certain // elements later. type RawXML []byte @@ -69,13 +81,38 @@ type Reply struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:base:1.0 rpc-reply"` MessageID uint64 `xml:"message-id,attr"` Errors RPCErrors `xml:"rpc-error,omitempty"` - Body []byte `xml:",innerxml"` + + raw []byte `xml:"-"` +} + +func ParseReply(data []byte) (*Reply, error) { + reply := Reply{ + raw: data, + } + if err := xml.Unmarshal(data, &reply); err != nil { + return nil, fmt.Errorf("couldn't parse reply: %v", err) + } + + return &reply, nil } -// Decode will decode the body of a reply into a value pointed to by v. This is -// a simple wrapper around xml.Unmarshal. +// Decode will decode the entire `rpc-reply` into a value pointed to by v. This +// is a simple wrapper around xml.Unmarshal. func (r Reply) Decode(v interface{}) error { - return xml.Unmarshal(r.Body, v) + if r.raw == nil { + return fmt.Errorf("empty reply") + } + return xml.Unmarshal(r.raw, v) +} + +// Raw returns the native message as it came from the server +func (r Reply) Raw() []byte { + return r.raw +} + +// String returns the message as string. +func (r Reply) String() string { + return string(r.raw) } // Err will return go error(s) from a Reply that are of the given severities. If @@ -121,13 +158,38 @@ func (r Reply) Err(severity ...ErrSeverity) error { type Notification struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:netconf:notification:1.0 notification"` EventTime time.Time `xml:"eventTime"` - Body []byte `xml:",innerxml"` + + raw []byte `xml:"-"` } -// Decode will decode the body of a noticiation into a value pointed to by v. +func ParseNotification(data []byte) (*Notification, error) { + notif := Notification{ + raw: data, + } + if err := xml.Unmarshal(data, ¬if); err != nil { + return nil, fmt.Errorf("couldn't parse reply: %v", err) + } + + return ¬if, nil +} + +// Decode will decode the entire `noticiation` into a value pointed to by v. // This is a simple wrapper around xml.Unmarshal. -func (r Notification) Decode(v interface{}) error { - return xml.Unmarshal(r.Body, v) +func (n Notification) Decode(v interface{}) error { + if n.raw == nil { + return fmt.Errorf("empty reply") + } + return xml.Unmarshal(n.raw, v) +} + +// Raw returns the native message as it came from the server +func (n Notification) Raw() []byte { + return n.raw +} + +// String returns the message as string. +func (n Notification) String() string { + return string(n.raw) } type ErrSeverity string diff --git a/msg_test.go b/msg_test.go index 279dfb4..ddcd067 100644 --- a/msg_test.go +++ b/msg_test.go @@ -251,17 +251,6 @@ func TestUnmarshalRPCReply(t *testing.T) { `), }, }, - Body: []byte(` - -protocol -operation-failed -error -syntax error, expecting <candidate/> or <running/> - -non-exist - - -`), }, }, } diff --git a/ops.go b/ops.go index 6fdf571..da27191 100644 --- a/ops.go +++ b/ops.go @@ -33,7 +33,7 @@ func (b *ExtantBool) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error return nil } -type OKResp struct { +type OKReply struct { OK ExtantBool `xml:"ok"` } @@ -95,8 +95,7 @@ type GetConfigReq struct { } type GetConfigReply struct { - XMLName xml.Name `xml:"data"` - Config []byte `xml:",innerxml"` + Data []byte `xml:"data"` } // GetConfig implements the rpc operation defined in [RFC6241 7.1]. @@ -113,7 +112,7 @@ func (s *Session) GetConfig(ctx context.Context, source Datastore) ([]byte, erro return nil, err } - return resp.Config, nil + return resp.Data, nil } // MergeStrategy defines the strategies for merging configuration in a @@ -272,7 +271,7 @@ func (s *Session) EditConfig(ctx context.Context, target Datastore, config any, opt.apply(&req) } - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } @@ -297,7 +296,7 @@ func (s *Session) CopyConfig(ctx context.Context, source, target any) error { Target: target, } - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } @@ -311,7 +310,7 @@ func (s *Session) DeleteConfig(ctx context.Context, target Datastore) error { Target: target, } - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } @@ -326,7 +325,7 @@ func (s *Session) Lock(ctx context.Context, target Datastore) error { Target: target, } - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } @@ -336,7 +335,7 @@ func (s *Session) Unlock(ctx context.Context, target Datastore) error { Target: target, } - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } @@ -356,7 +355,7 @@ func (s *Session) KillSession(ctx context.Context, sessionID uint32) error { SessionID: sessionID, } - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } @@ -370,7 +369,7 @@ func (s *Session) Validate(ctx context.Context, source any) error { Source: source, } - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } @@ -444,7 +443,7 @@ func (s *Session) Commit(ctx context.Context, opts ...CommitOption) error { return fmt.Errorf("PersistID cannot be used with Confirmed/ConfirmedTimeout or Persist options") } - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } @@ -466,7 +465,7 @@ func (s *Session) CancelCommit(ctx context.Context, opts ...CancelCommitOption) opt.applyCancelCommit(&req) } - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } @@ -509,6 +508,6 @@ func (s *Session) CreateSubscription(ctx context.Context, opts ...CreateSubscrip } // TODO: eventual custom notifications rpc logic, e.g. create subscription only if notification capability is present - var resp OKResp + var resp OKReply return s.Call(ctx, &req, &resp) } diff --git a/session.go b/session.go index 5c38ee8..c574573 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package netconf import ( + "bytes" "context" "encoding/xml" "errors" @@ -189,30 +190,26 @@ func (s *Session) recvMsg() error { return err } defer r.Close() - dec := xml.NewDecoder(r) - root, err := startElement(dec) + msg, err := io.ReadAll(r) if err != nil { return err } - const ( - ncNamespace = "urn:ietf:params:xml:ns:netconf:base:1.0" - notifNamespace = "urn:ietf:params:xml:ns:netconf:notification:1.0" - ) + return s.parseMsg(msg) +} + +func (s *Session) parseMsg(msg []byte) error { + dec := xml.NewDecoder(bytes.NewReader(msg)) + + root, err := startElement(dec) + if err != nil { + return err + } switch root.Name { - case xml.Name{Space: notifNamespace, Local: "notification"}: - if s.notificationHandler == nil { - return nil - } - var notif Notification - if err := dec.DecodeElement(¬if, root); err != nil { - return fmt.Errorf("failed to decode notification message: %w", err) - } - s.notificationHandler(notif) - case xml.Name{Space: ncNamespace, Local: "rpc-reply"}: - var reply Reply + case RPCReplyName: + reply := Reply{raw: msg} if err := dec.DecodeElement(&reply, root); err != nil { // What should we do here? Kill the connection? return fmt.Errorf("failed to decode rpc-reply message: %w", err) @@ -228,6 +225,17 @@ func (s *Session) recvMsg() error { case <-req.ctx.Done(): return fmt.Errorf("message %d context canceled: %s", reply.MessageID, req.ctx.Err().Error()) } + + case NofificationName: + if s.notificationHandler == nil { + return nil + } + notif := Notification{raw: msg} + if err := dec.DecodeElement(¬if, root); err != nil { + return fmt.Errorf("failed to decode notification message: %w", err) + } + s.notificationHandler(notif) + default: return fmt.Errorf("unknown message type: %q", root.Name.Local) } @@ -342,6 +350,8 @@ func (s *Session) Call(ctx context.Context, req any, resp any) error { return err } + // Return any . This defaults to a severity of `error` (warning + // are omitted). if err := reply.Err(); err != nil { return err } @@ -377,7 +387,9 @@ func (s *Session) Close(ctx context.Context) error { } } - if callErr != io.EOF { + // it's ok if we are already closed + if !errors.Is(callErr, io.EOF) && + !errors.Is(callErr, ErrClosed) { return callErr }