diff --git a/main_test.go b/main_test.go index eb6df9f49..78b91f1bb 100644 --- a/main_test.go +++ b/main_test.go @@ -141,13 +141,21 @@ var _ = Describe("Router Integration", func() { go func() { defer GinkgoRecover() + + //Open a connection that never goes active + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", localIP, proxyPort)) + Expect(err).NotTo(HaveOccurred()) + err = conn.Close() + Expect(err).NotTo(HaveOccurred()) + + //Open a connection that goes active resp, err := http.Get(longApp.Endpoint()) - Ω(err).ShouldNot(HaveOccurred()) - Ω(resp.StatusCode).Should(Equal(http.StatusOK)) + Expect(err).ShouldNot(HaveOccurred()) + Expect(resp.StatusCode).Should(Equal(http.StatusOK)) bytes, err := ioutil.ReadAll(resp.Body) resp.Body.Close() - Ω(err).ShouldNot(HaveOccurred()) - Ω(bytes).Should(Equal([]byte{'b'})) + Expect(err).ShouldNot(HaveOccurred()) + Expect(bytes).Should(Equal([]byte{'b'})) responseRead <- true }() diff --git a/router/router.go b/router/router.go index 348d5be49..457a003e5 100644 --- a/router/router.go +++ b/router/router.go @@ -39,9 +39,9 @@ type Router struct { listener net.Listener tlsListener net.Listener closeConnections bool - activeConns uint32 connLock sync.Mutex idleConns map[net.Conn]struct{} + activeConns map[net.Conn]struct{} drainDone chan struct{} serveDone chan struct{} tlsServeDone chan struct{} @@ -89,6 +89,7 @@ func NewRouter(cfg *config.Config, p proxy.Proxy, mbusClient yagnats.NATSConn, r serveDone: make(chan struct{}), tlsServeDone: make(chan struct{}), idleConns: make(map[net.Conn]struct{}), + activeConns: make(map[net.Conn]struct{}), logger: steno.NewLogger("router"), } @@ -198,7 +199,7 @@ func (r *Router) Drain(drainTimeout time.Duration) error { r.connLock.Lock() r.closeIdleConns() - if r.activeConns == 0 { + if len(r.activeConns) == 0 { close(drained) } else { r.drainDone = drained @@ -318,12 +319,12 @@ func (r *Router) HandleConnState(conn net.Conn, state http.ConnState) { switch state { case http.StateActive: - r.activeConns++ + r.activeConns[conn] = struct{}{} delete(r.idleConns, conn) conn.SetDeadline(time.Time{}) case http.StateIdle: - r.activeConns-- + delete(r.activeConns, conn) r.idleConns[conn] = struct{}{} if r.closeConnections { @@ -340,11 +341,11 @@ func (r *Router) HandleConnState(conn net.Conn, state http.ConnState) { i := len(r.idleConns) delete(r.idleConns, conn) if i == len(r.idleConns) { - r.activeConns-- + delete(r.activeConns, conn) } } - if r.drainDone != nil && r.activeConns == 0 { + if r.drainDone != nil && len(r.activeConns) == 0 { close(r.drainDone) r.drainDone = nil }