diff --git a/hscontrol/state/ping.go b/hscontrol/state/ping.go index a7e2df30..43cf6914 100644 --- a/hscontrol/state/ping.go +++ b/hscontrol/state/ping.go @@ -10,8 +10,10 @@ import ( const pingIDLength = 16 -// pingTracker manages pending ping requests and their response channels. -// It correlates outgoing PingRequests with incoming HEAD callbacks. +// pingTracker correlates outgoing PingRequests with incoming HEAD +// callbacks. Entries have no server-side TTL: callers are responsible +// for cleaning up via CancelPing or by reading from the response channel +// within their own timeout. type pingTracker struct { mu sync.Mutex pending map[string]*pendingPing @@ -29,9 +31,9 @@ func newPingTracker() *pingTracker { } } -// register creates a new pending ping and returns a unique ping ID -// and a channel that will receive the round-trip latency when the -// ping response arrives. +// register creates a pending ping and returns a unique ping ID and a +// channel that receives the round-trip latency once the response +// arrives. func (pt *pingTracker) register(nodeID types.NodeID) (string, <-chan time.Duration) { pingID, _ := util.GenerateRandomStringDNSSafe(pingIDLength) ch := make(chan time.Duration, 1) @@ -47,9 +49,9 @@ func (pt *pingTracker) register(nodeID types.NodeID) (string, <-chan time.Durati return pingID, ch } -// complete signals that a ping response was received. -// It sends the measured latency on the response channel and returns true. -// Returns false if the pingID is unknown (already completed, cancelled, or expired). +// complete sends the measured latency on the response channel and +// returns true. Returns false if the pingID is unknown (already +// completed or cancelled). func (pt *pingTracker) complete(pingID string) bool { pt.mu.Lock() @@ -70,28 +72,40 @@ func (pt *pingTracker) complete(pingID string) bool { return false } -// cancel removes a pending ping without completing it. -// Used for cleanup when the caller times out or disconnects. +// cancel removes a pending ping without completing it. Idempotent. func (pt *pingTracker) cancel(pingID string) { pt.mu.Lock() delete(pt.pending, pingID) pt.mu.Unlock() } -// RegisterPing creates a pending ping for the given node and returns -// a unique ping ID and a channel that receives the round-trip latency -// when the response arrives. +// drain closes every outstanding response channel and clears the map. +// Called from State.Close to unblock any caller still waiting on a +// channel that will never receive. +func (pt *pingTracker) drain() { + pt.mu.Lock() + defer pt.mu.Unlock() + + for id, pp := range pt.pending { + close(pp.responseCh) + delete(pt.pending, id) + } +} + +// RegisterPing tracks a pending ping and returns its ID and a channel +// for the latency. Callers must defer CancelPing or read the channel +// within their own timeout; there is no server-side TTL. func (s *State) RegisterPing(nodeID types.NodeID) (string, <-chan time.Duration) { return s.pings.register(nodeID) } -// CompletePing signals that a ping response was received for the given ID. -// Returns true if the ping was found and completed, false otherwise. +// CompletePing signals that a ping response arrived. Returns true if +// the ID was known, false otherwise. func (s *State) CompletePing(pingID string) bool { return s.pings.complete(pingID) } -// CancelPing removes a pending ping without completing it. +// CancelPing removes a pending ping. Idempotent. func (s *State) CancelPing(pingID string) { s.pings.cancel(pingID) } diff --git a/hscontrol/state/ping_test.go b/hscontrol/state/ping_test.go index 595688fe..a51a851c 100644 --- a/hscontrol/state/ping_test.go +++ b/hscontrol/state/ping_test.go @@ -140,6 +140,27 @@ func TestPingTracker_TwoToSameNode(t *testing.T) { } } +func TestPingTracker_Drain(t *testing.T) { + pt := newPingTracker() + + _, ch1 := pt.register(types.NodeID(1)) + _, ch2 := pt.register(types.NodeID(2)) + + pt.drain() + + // Drained channels must be closed so blocked readers unblock. + for i, ch := range []<-chan time.Duration{ch1, ch2} { + select { + case _, ok := <-ch: + assert.False(t, ok, "channel %d should be closed, got value", i) + case <-time.After(time.Second): + t.Fatalf("channel %d not closed by drain", i) + } + } + + assert.Empty(t, pt.pending, "pending map should be empty after drain") +} + func TestPingTracker_LatencyNonNegative(t *testing.T) { pt := newPingTracker() diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index a40991d1..1857e057 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -267,6 +267,7 @@ func NewState(cfg *types.Config) (*State, error) { // Close gracefully shuts down the State instance and releases all resources. func (s *State) Close() error { + s.pings.drain() s.nodeStore.Stop() err := s.db.Close()