mirror of
https://github.com/juanfont/headscale
synced 2026-04-25 17:15:33 +02:00
state: drain pending pings on Close
Blocked callers waiting on a pingTracker response channel would hang forever if the server Close()d mid-probe. Drain the pending map on Close so those goroutines unblock and exit cleanly. Updates #3157
This commit is contained in:
@@ -10,8 +10,10 @@ import (
|
||||
|
||||
const pingIDLength = 16
|
||||
|
||||
// pingTracker manages pending ping requests and their response channels.
|
||||
// It correlates outgoing PingRequests with incoming HEAD callbacks.
|
||||
// pingTracker correlates outgoing PingRequests with incoming HEAD
|
||||
// callbacks. Entries have no server-side TTL: callers are responsible
|
||||
// for cleaning up via CancelPing or by reading from the response channel
|
||||
// within their own timeout.
|
||||
type pingTracker struct {
|
||||
mu sync.Mutex
|
||||
pending map[string]*pendingPing
|
||||
@@ -29,9 +31,9 @@ func newPingTracker() *pingTracker {
|
||||
}
|
||||
}
|
||||
|
||||
// register creates a new pending ping and returns a unique ping ID
|
||||
// and a channel that will receive the round-trip latency when the
|
||||
// ping response arrives.
|
||||
// register creates a pending ping and returns a unique ping ID and a
|
||||
// channel that receives the round-trip latency once the response
|
||||
// arrives.
|
||||
func (pt *pingTracker) register(nodeID types.NodeID) (string, <-chan time.Duration) {
|
||||
pingID, _ := util.GenerateRandomStringDNSSafe(pingIDLength)
|
||||
ch := make(chan time.Duration, 1)
|
||||
@@ -47,9 +49,9 @@ func (pt *pingTracker) register(nodeID types.NodeID) (string, <-chan time.Durati
|
||||
return pingID, ch
|
||||
}
|
||||
|
||||
// complete signals that a ping response was received.
|
||||
// It sends the measured latency on the response channel and returns true.
|
||||
// Returns false if the pingID is unknown (already completed, cancelled, or expired).
|
||||
// complete sends the measured latency on the response channel and
|
||||
// returns true. Returns false if the pingID is unknown (already
|
||||
// completed or cancelled).
|
||||
func (pt *pingTracker) complete(pingID string) bool {
|
||||
pt.mu.Lock()
|
||||
|
||||
@@ -70,28 +72,40 @@ func (pt *pingTracker) complete(pingID string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// cancel removes a pending ping without completing it.
|
||||
// Used for cleanup when the caller times out or disconnects.
|
||||
// cancel removes a pending ping without completing it. Idempotent.
|
||||
func (pt *pingTracker) cancel(pingID string) {
|
||||
pt.mu.Lock()
|
||||
delete(pt.pending, pingID)
|
||||
pt.mu.Unlock()
|
||||
}
|
||||
|
||||
// RegisterPing creates a pending ping for the given node and returns
|
||||
// a unique ping ID and a channel that receives the round-trip latency
|
||||
// when the response arrives.
|
||||
// drain closes every outstanding response channel and clears the map.
|
||||
// Called from State.Close to unblock any caller still waiting on a
|
||||
// channel that will never receive.
|
||||
func (pt *pingTracker) drain() {
|
||||
pt.mu.Lock()
|
||||
defer pt.mu.Unlock()
|
||||
|
||||
for id, pp := range pt.pending {
|
||||
close(pp.responseCh)
|
||||
delete(pt.pending, id)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterPing tracks a pending ping and returns its ID and a channel
|
||||
// for the latency. Callers must defer CancelPing or read the channel
|
||||
// within their own timeout; there is no server-side TTL.
|
||||
func (s *State) RegisterPing(nodeID types.NodeID) (string, <-chan time.Duration) {
|
||||
return s.pings.register(nodeID)
|
||||
}
|
||||
|
||||
// CompletePing signals that a ping response was received for the given ID.
|
||||
// Returns true if the ping was found and completed, false otherwise.
|
||||
// CompletePing signals that a ping response arrived. Returns true if
|
||||
// the ID was known, false otherwise.
|
||||
func (s *State) CompletePing(pingID string) bool {
|
||||
return s.pings.complete(pingID)
|
||||
}
|
||||
|
||||
// CancelPing removes a pending ping without completing it.
|
||||
// CancelPing removes a pending ping. Idempotent.
|
||||
func (s *State) CancelPing(pingID string) {
|
||||
s.pings.cancel(pingID)
|
||||
}
|
||||
|
||||
@@ -140,6 +140,27 @@ func TestPingTracker_TwoToSameNode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPingTracker_Drain(t *testing.T) {
|
||||
pt := newPingTracker()
|
||||
|
||||
_, ch1 := pt.register(types.NodeID(1))
|
||||
_, ch2 := pt.register(types.NodeID(2))
|
||||
|
||||
pt.drain()
|
||||
|
||||
// Drained channels must be closed so blocked readers unblock.
|
||||
for i, ch := range []<-chan time.Duration{ch1, ch2} {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
assert.False(t, ok, "channel %d should be closed, got value", i)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("channel %d not closed by drain", i)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Empty(t, pt.pending, "pending map should be empty after drain")
|
||||
}
|
||||
|
||||
func TestPingTracker_LatencyNonNegative(t *testing.T) {
|
||||
pt := newPingTracker()
|
||||
|
||||
|
||||
@@ -267,6 +267,7 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
|
||||
// Close gracefully shuts down the State instance and releases all resources.
|
||||
func (s *State) Close() error {
|
||||
s.pings.drain()
|
||||
s.nodeStore.Stop()
|
||||
|
||||
err := s.db.Close()
|
||||
|
||||
Reference in New Issue
Block a user