diff --git a/hscontrol/app.go b/hscontrol/app.go index ed0da82a..ac63b472 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -509,6 +509,10 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux { r.Use(h.httpAuthenticationMiddleware) r.HandleFunc("/v1/*", grpcMux.ServeHTTP) }) + // Ping response endpoint: receives HEAD from clients responding + // to a PingRequest. The unguessable ping ID serves as authentication. + r.Head("/machine/ping-response", h.PingResponseHandler) + r.Get("/favicon.ico", FaviconHandler) r.Get("/", BlankHandler) diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 64c8e0d5..096ea25b 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -1,16 +1,21 @@ package hscontrol import ( + "context" "encoding/json" "fmt" "net" "net/http" "net/netip" "strings" + "time" "github.com/arl/statsviz" + "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/prometheus/client_golang/prometheus/promhttp" + "tailscale.com/tailcfg" "tailscale.com/tsweb" ) @@ -324,6 +329,42 @@ func (h *Headscale) debugHTTPServer() *http.Server { } })) + // Ping endpoint: sends a PingRequest to a node and waits for it to respond. + // Supports POST (form submit) and GET with ?node= (clickable quick-ping links). + debug.Handle("ping", "Ping a node to check connectivity", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ( + query string + result *templates.PingResult + ) + + switch r.Method { + case http.MethodPost: + r.Body = http.MaxBytesReader(w, r.Body, 4096) //nolint:mnd + + err := r.ParseForm() + if err != nil { + http.Error(w, "bad form data", http.StatusBadRequest) + return + } + + query = r.FormValue("node") + result = h.doPing(r.Context(), query) + case http.MethodGet: + // Support ?node= for auto-ping links from other debug pages. + if q := r.URL.Query().Get("node"); q != "" { + query = q + result = h.doPing(r.Context(), query) + } + } + + nodes := h.connectedNodesList() + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + //nolint:gosec // elem-go auto-escapes all attribute values; no XSS risk. + _, _ = w.Write([]byte(templates.PingPage(query, result, nodes).Render())) + })) + // statsviz.Register would mount handlers directly on the raw mux, // bypassing the access gate. Build the server by hand and wrap // each handler with protectedDebugHandler. @@ -436,3 +477,94 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo { return info } + +// connectedNodesList returns a list of connected nodes for the ping page. +func (h *Headscale) connectedNodesList() []templates.ConnectedNode { + debugInfo := h.mapBatcher.Debug() + + var nodes []templates.ConnectedNode + + for nodeID, info := range debugInfo { + if !info.Connected { + continue + } + + nv, ok := h.state.GetNodeByID(nodeID) + if !ok { + continue + } + + cn := templates.ConnectedNode{ + ID: nodeID, + Hostname: nv.Hostname(), + } + + for _, ip := range nv.IPs() { + cn.IPs = append(cn.IPs, ip.String()) + } + + nodes = append(nodes, cn) + } + + return nodes +} + +const pingTimeout = 30 * time.Second + +// doPing sends a PingRequest to the node identified by query and waits for a response. +func (h *Headscale) doPing(ctx context.Context, query string) *templates.PingResult { + if query == "" { + return &templates.PingResult{ + Status: "error", + Message: "No node specified.", + } + } + + node, ok := h.state.ResolveNode(query) + if !ok { + return &templates.PingResult{ + Status: "error", + Message: fmt.Sprintf("Node %q not found.", query), + } + } + + nodeID := node.ID() + + if !h.mapBatcher.IsConnected(nodeID) { + return &templates.PingResult{ + Status: "error", + NodeID: nodeID, + Message: fmt.Sprintf("Node %d is not connected.", nodeID), + } + } + + pingID, responseCh := h.state.RegisterPing(nodeID) + defer h.state.CancelPing(pingID) + + callbackURL := h.cfg.ServerURL + "/machine/ping-response?id=" + pingID + h.Change(change.PingNode(nodeID, &tailcfg.PingRequest{ + URL: callbackURL, + Log: true, + })) + + select { + case latency := <-responseCh: + return &templates.PingResult{ + Status: "ok", + Latency: latency, + NodeID: nodeID, + } + case <-time.After(pingTimeout): + return &templates.PingResult{ + Status: "timeout", + NodeID: nodeID, + Message: fmt.Sprintf("No response after %s.", pingTimeout), + } + case <-ctx.Done(): + return &templates.PingResult{ + Status: "error", + NodeID: nodeID, + Message: "Request cancelled.", + } + } +} diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index 0410e16a..2a3add8e 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -277,6 +277,12 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ( return tailPeers, nil } +// WithPingRequest adds a PingRequest to the response. +func (b *MapResponseBuilder) WithPingRequest(pr *tailcfg.PingRequest) *MapResponseBuilder { + b.resp.PingRequest = pr + return b +} + // WithPeerChangedPatch adds peer change patches. func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder { b.resp.PeersChangedPatch = changes diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 0aba4175..ab2399a5 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -308,6 +308,10 @@ func (m *mapper) buildFromChange( builder.WithPeerChangedPatch(resp.PeerPatches) } + if resp.PingRequest != nil { + builder.WithPingRequest(resp.PingRequest) + } + return builder.Build() } diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 12d9adf7..0a84d96f 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -295,6 +295,26 @@ func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *ht http.Error(writer, "Not implemented yet", http.StatusNotImplemented) } +// PingResponseHandler handles HEAD requests from clients responding to a +// PingRequest. The client calls this endpoint to prove connectivity. +// The unguessable ping ID serves as authentication. +func (h *Headscale) PingResponseHandler( + writer http.ResponseWriter, + req *http.Request, +) { + pingID := req.URL.Query().Get("id") + if pingID == "" { + http.Error(writer, "missing ping ID", http.StatusBadRequest) + return + } + + if h.state.CompletePing(pingID) { + writer.WriteHeader(http.StatusOK) + } else { + http.Error(writer, "unknown or expired ping", http.StatusNotFound) + } +} + func urlParam[T any](req *http.Request, key string) (T, error) { var zero T diff --git a/hscontrol/state/ping.go b/hscontrol/state/ping.go new file mode 100644 index 00000000..a7e2df30 --- /dev/null +++ b/hscontrol/state/ping.go @@ -0,0 +1,97 @@ +package state + +import ( + "sync" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" +) + +const pingIDLength = 16 + +// pingTracker manages pending ping requests and their response channels. +// It correlates outgoing PingRequests with incoming HEAD callbacks. +type pingTracker struct { + mu sync.Mutex + pending map[string]*pendingPing +} + +type pendingPing struct { + nodeID types.NodeID + startTime time.Time + responseCh chan time.Duration +} + +func newPingTracker() *pingTracker { + return &pingTracker{ + pending: make(map[string]*pendingPing), + } +} + +// register creates a new pending ping and returns a unique ping ID +// and a channel that will receive the round-trip latency when the +// ping response arrives. +func (pt *pingTracker) register(nodeID types.NodeID) (string, <-chan time.Duration) { + pingID, _ := util.GenerateRandomStringDNSSafe(pingIDLength) + ch := make(chan time.Duration, 1) + + pt.mu.Lock() + pt.pending[pingID] = &pendingPing{ + nodeID: nodeID, + startTime: time.Now(), + responseCh: ch, + } + pt.mu.Unlock() + + return pingID, ch +} + +// complete signals that a ping response was received. +// It sends the measured latency on the response channel and returns true. +// Returns false if the pingID is unknown (already completed, cancelled, or expired). +func (pt *pingTracker) complete(pingID string) bool { + pt.mu.Lock() + + pp, ok := pt.pending[pingID] + if ok { + delete(pt.pending, pingID) + } + pt.mu.Unlock() + + if ok { + pp.responseCh <- time.Since(pp.startTime) + + close(pp.responseCh) + + return true + } + + return false +} + +// cancel removes a pending ping without completing it. +// Used for cleanup when the caller times out or disconnects. +func (pt *pingTracker) cancel(pingID string) { + pt.mu.Lock() + delete(pt.pending, pingID) + pt.mu.Unlock() +} + +// RegisterPing creates a pending ping for the given node and returns +// a unique ping ID and a channel that receives the round-trip latency +// when the response arrives. +func (s *State) RegisterPing(nodeID types.NodeID) (string, <-chan time.Duration) { + return s.pings.register(nodeID) +} + +// CompletePing signals that a ping response was received for the given ID. +// Returns true if the ping was found and completed, false otherwise. +func (s *State) CompletePing(pingID string) bool { + return s.pings.complete(pingID) +} + +// CancelPing removes a pending ping without completing it. +func (s *State) CancelPing(pingID string) { + s.pings.cancel(pingID) +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 9053735e..af9e984c 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -146,6 +146,9 @@ type State struct { // only proceeds when the generation it carries matches the latest. connectGen sync.Map // types.NodeID → *atomic.Uint64 + // pings tracks pending ping requests and their response channels. + pings *pingTracker + // sshCheckAuth tracks when source nodes last completed SSH check auth. // // For rules without explicit checkPeriod (default 12h), auth covers any @@ -256,6 +259,7 @@ func NewState(cfg *types.Config) (*State, error) { authCache: authCache, primaryRoutes: routes.New(), nodeStore: nodeStore, + pings: newPingTracker(), sshCheckAuth: make(map[sshCheckPair]time.Time), }, nil @@ -699,6 +703,37 @@ func (s *State) GetNodeByMachineKey(machineKey key.MachinePublic, userID types.U return s.nodeStore.GetNodeByMachineKey(machineKey, userID) } +// ResolveNode looks up a node by numeric ID, IPv4/IPv6 address, hostname, or given name. +// It tries ID first, then IP, then name matching. +func (s *State) ResolveNode(query string) (types.NodeView, bool) { + // Try numeric ID first. + id, idErr := types.ParseNodeID(query) + if idErr == nil { + return s.GetNodeByID(id) + } + + // Try IP address. + addr, addrErr := netip.ParseAddr(query) + if addrErr == nil { + for _, n := range s.ListNodes().All() { + if slices.Contains(n.IPs(), addr) { + return n, true + } + } + + return types.NodeView{}, false + } + + // Try hostname / given name. + for _, n := range s.ListNodes().All() { + if n.Hostname() == query || n.GivenName() == query { + return n, true + } + } + + return types.NodeView{}, false +} + // ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided. func (s *State) ListNodes(nodeIDs ...types.NodeID) views.Slice[types.NodeView] { if len(nodeIDs) == 0 { diff --git a/hscontrol/templates/ping.go b/hscontrol/templates/ping.go new file mode 100644 index 00000000..02121455 --- /dev/null +++ b/hscontrol/templates/ping.go @@ -0,0 +1,152 @@ +package templates + +import ( + "fmt" + "strings" + "time" + + elem "github.com/chasefleming/elem-go" + "github.com/chasefleming/elem-go/attrs" + "github.com/chasefleming/elem-go/styles" + "github.com/juanfont/headscale/hscontrol/types" +) + +// PingResult contains the outcome of a ping request. +type PingResult struct { + // Status is "ok", "timeout", or "error". + Status string + + // Latency is the round-trip time (only meaningful when Status is "ok"). + Latency time.Duration + + // NodeID is the ID of the pinged node. + NodeID types.NodeID + + // Message is a human-readable description of the result. + Message string +} + +// ConnectedNode is a node currently connected to the batcher, +// displayed as a quick-ping link on the debug ping page. +type ConnectedNode struct { + ID types.NodeID + Hostname string + IPs []string +} + +// PingPage renders the /debug/ping page with a form, optional result, +// and a list of connected nodes as quick-ping links. +func PingPage(query string, result *PingResult, nodes []ConnectedNode) *elem.Element { + children := []elem.Node{ + headscaleLogo(), + H1(elem.Text("Ping Node")), + P(elem.Text("Check if a connected node responds to a PingRequest.")), + pingForm(query), + } + + if result != nil { + children = append(children, pingResult(result)) + } + + if len(nodes) > 0 { + children = append(children, connectedNodeList(nodes)) + } + + children = append(children, pageFooter()) + + return HtmlStructure( + elem.Title(nil, elem.Text("Ping Node - Headscale")), + mdTypesetBody(children...), + ) +} + +func pingForm(query string) *elem.Element { + inputStyle := styles.Props{ + styles.Padding: spaceS, + styles.Border: "1px solid " + colorBorderMedium, + styles.BorderRadius: "0.25rem", + styles.FontSize: fontSizeBase, + styles.FontFamily: fontFamilySystem, + styles.Width: "280px", + } + + buttonStyle := styles.Props{ + styles.Padding: spaceS + " " + spaceM, + styles.BackgroundColor: colorPrimaryAccent, + styles.Color: "#ffffff", + styles.Border: "none", + styles.BorderRadius: "0.25rem", + styles.FontSize: fontSizeBase, + styles.FontFamily: fontFamilySystem, + "cursor": "pointer", + } + + return elem.Form(attrs.Props{ + attrs.Method: "POST", + attrs.Action: "/debug/ping", + attrs.Style: styles.Props{ + styles.Display: "flex", + styles.Gap: spaceS, + styles.AlignItems: "center", + styles.MarginTop: spaceM, + }.ToInline(), + }, + elem.Input(attrs.Props{ + attrs.Type: "text", + attrs.Name: "node", + attrs.Value: query, + attrs.Placeholder: "Node ID, IP, or hostname", + attrs.Autofocus: "true", + attrs.Style: inputStyle.ToInline(), + }), + elem.Button(attrs.Props{ + attrs.Type: "submit", + attrs.Style: buttonStyle.ToInline(), + }, elem.Text("Ping")), + ) +} + +func connectedNodeList(nodes []ConnectedNode) *elem.Element { + items := make([]elem.Node, 0, len(nodes)) + + for _, n := range nodes { + label := fmt.Sprintf("%s (ID: %d, %s)", n.Hostname, n.ID, strings.Join(n.IPs, ", ")) + href := fmt.Sprintf("/debug/ping?node=%d", n.ID) + + items = append(items, elem.Li(nil, + elem.A(attrs.Props{ + attrs.Href: href, + attrs.Style: styles.Props{ + styles.Color: colorPrimaryAccent, + }.ToInline(), + }, elem.Text(label)), + )) + } + + return elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.MarginTop: spaceL, + }.ToInline(), + }, + H2(elem.Text("Connected Nodes")), + elem.Ul(nil, items...), + ) +} + +func pingResult(result *PingResult) *elem.Element { + switch result.Status { + case "ok": + return successBox( + "Pong", + elem.Text(fmt.Sprintf("Node %d responded in %s", + result.NodeID, result.Latency.Round(time.Millisecond))), + ) + case "timeout": + return warningBox( + "Timeout", + fmt.Sprintf("Node %d did not respond. %s", result.NodeID, result.Message), + ) + default: + return warningBox("Error", result.Message) + } +} diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index 7014a6a8..a1f1f8c4 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -39,6 +39,11 @@ type Change struct { // must be computed at runtime per-node. Used for policy changes // where each node may have different peer visibility. RequiresRuntimePeerComputation bool + + // PingRequest, if non-nil, is a ping request to send to the node. + // Used by the debug ping endpoint to verify node connectivity. + // PingRequest is always targeted to a specific node via TargetNode. + PingRequest *tailcfg.PingRequest } // boolFieldNames returns all boolean field names for exhaustive testing. @@ -93,6 +98,11 @@ func (r Change) Merge(other Change) Change { merged.TargetNode = other.TargetNode } + // Preserve PingRequest (first wins). + if merged.PingRequest == nil { + merged.PingRequest = other.PingRequest + } + if r.Reason != "" && other.Reason != "" && r.Reason != other.Reason { merged.Reason = r.Reason + "; " + other.Reason } else if other.Reason != "" { @@ -112,6 +122,10 @@ func (r Change) IsEmpty() bool { return false } + if r.PingRequest != nil { + return false + } + return len(r.PeersChanged) == 0 && len(r.PeersRemoved) == 0 && len(r.PeerPatches) == 0 @@ -168,6 +182,10 @@ func (r Change) Type() string { return "config" } + if r.PingRequest != nil { + return "ping" + } + return "unknown" } @@ -454,6 +472,16 @@ func UserRemoved() Change { return c } +// PingNode creates a Change that sends a PingRequest to a specific node. +// The node will respond to the PingRequest URL to prove connectivity. +func PingNode(nodeID types.NodeID, pr *tailcfg.PingRequest) Change { + return Change{ + Reason: "ping node", + TargetNode: nodeID, + PingRequest: pr, + } +} + // ExtraRecords returns a Change for when DNS extra records change. func ExtraRecords() Change { c := DNSConfig()