diff --git a/diameter.go b/diameter.go index a65dc82..f741b46 100644 --- a/diameter.go +++ b/diameter.go @@ -101,7 +101,11 @@ func (c *DiameterClient) Connect(address string) error { return nil } -func (d *Diameter) Send(client *DiameterClient, msg *DiameterMessage) (uint32, error) { +func (d *Diameter) Send( + client *DiameterClient, + msg *DiameterMessage, + requestTimeoutMillis int, +) (uint32, error) { if client.conn == nil { return 0, errors.New("Not connected") @@ -113,6 +117,14 @@ func (d *Diameter) Send(client *DiameterClient, msg *DiameterMessage) (uint32, e hopByHopID := req.Header.HopByHopID client.hopIds[hopByHopID] = make(chan *diam.Message) + // Timeout settings + var timeout <-chan time.Time + if requestTimeoutMillis == 0 { + timeout = time.After(60 * time.Second) + } else { + timeout = time.After(time.Duration(requestTimeoutMillis) * time.Millisecond) + } + // Send CCR _, err := req.WriteTo(client.conn) if err != nil { @@ -120,7 +132,12 @@ func (d *Diameter) Send(client *DiameterClient, msg *DiameterMessage) (uint32, e } // Wait for CCA - resp := <-client.hopIds[hopByHopID] + var resp *diam.Message + select { + case resp = <-client.hopIds[hopByHopID]: + case <-timeout: + return uint32(5012), errors.New("Response timeout") + } //log.Infof("Received CCA \n%s", resp) delete(client.hopIds, hopByHopID)