mirror of
https://github.com/juanfont/headscale
synced 2026-04-25 17:15:33 +02:00
all: fix golangci-lint issues (#3064)
This commit is contained in:
@@ -18,6 +18,7 @@ linters:
|
||||
- lll
|
||||
- maintidx
|
||||
- makezero
|
||||
- mnd
|
||||
- musttag
|
||||
- nestif
|
||||
- nolintlint
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// 90 days.
|
||||
// DefaultAPIKeyExpiry is 90 days.
|
||||
DefaultAPIKeyExpiry = "90d"
|
||||
)
|
||||
|
||||
|
||||
@@ -19,10 +19,12 @@ func init() {
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
createNodeCmd.Flags().StringP("name", "", "", "Name")
|
||||
|
||||
err := createNodeCmd.MarkFlagRequired("name")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
createNodeCmd.Flags().StringP("user", "u", "", "User")
|
||||
|
||||
createNodeCmd.Flags().StringP("namespace", "n", "", "User")
|
||||
@@ -34,11 +36,14 @@ func init() {
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
createNodeCmd.Flags().StringP("key", "k", "", "Key")
|
||||
|
||||
err = createNodeCmd.MarkFlagRequired("key")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
createNodeCmd.Flags().
|
||||
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -20,6 +20,7 @@ const (
|
||||
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
|
||||
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
|
||||
errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined")
|
||||
errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined")
|
||||
refreshTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
@@ -47,33 +48,39 @@ func mockOIDC() error {
|
||||
if clientID == "" {
|
||||
return errMockOidcClientIDNotDefined
|
||||
}
|
||||
|
||||
clientSecret := os.Getenv("MOCKOIDC_CLIENT_SECRET")
|
||||
if clientSecret == "" {
|
||||
return errMockOidcClientSecretNotDefined
|
||||
}
|
||||
|
||||
addrStr := os.Getenv("MOCKOIDC_ADDR")
|
||||
if addrStr == "" {
|
||||
return errMockOidcPortNotDefined
|
||||
}
|
||||
|
||||
portStr := os.Getenv("MOCKOIDC_PORT")
|
||||
if portStr == "" {
|
||||
return errMockOidcPortNotDefined
|
||||
}
|
||||
|
||||
accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
|
||||
if accessTTLOverride != "" {
|
||||
newTTL, err := time.ParseDuration(accessTTLOverride)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accessTTL = newTTL
|
||||
}
|
||||
|
||||
userStr := os.Getenv("MOCKOIDC_USERS")
|
||||
if userStr == "" {
|
||||
return errors.New("MOCKOIDC_USERS not defined")
|
||||
return errMockOidcUsersNotDefined
|
||||
}
|
||||
|
||||
var users []mockoidc.MockUser
|
||||
|
||||
err := json.Unmarshal([]byte(userStr), &users)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshalling users: %w", err)
|
||||
@@ -93,7 +100,7 @@ func mockOIDC() error {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addrStr, port))
|
||||
listener, err := new(net.ListenConfig).Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", addrStr, port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -105,6 +112,7 @@ func mockOIDC() error {
|
||||
|
||||
log.Info().Msgf("mock OIDC server listening on %s", listener.Addr().String())
|
||||
log.Info().Msgf("issuer: %s", mock.Issuer())
|
||||
|
||||
c := make(chan struct{})
|
||||
<-c
|
||||
|
||||
@@ -135,10 +143,11 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser
|
||||
ErrorQueue: &mockoidc.ErrorQueue{},
|
||||
}
|
||||
|
||||
mock.AddMiddleware(func(h http.Handler) http.Handler {
|
||||
_ = mock.AddMiddleware(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Info().Msgf("request: %+v", r)
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if r.Response != nil {
|
||||
log.Info().Msgf("response: %+v", r.Response)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ func init() {
|
||||
listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace")
|
||||
listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage
|
||||
listNodesNamespaceFlag.Hidden = true
|
||||
|
||||
nodeCmd.AddCommand(listNodesCmd)
|
||||
|
||||
listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
@@ -42,42 +43,51 @@ func init() {
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
registerNodeCmd.Flags().StringP("key", "k", "", "Key")
|
||||
|
||||
err = registerNodeCmd.MarkFlagRequired("key")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
nodeCmd.AddCommand(registerNodeCmd)
|
||||
|
||||
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.")
|
||||
|
||||
err = expireNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
nodeCmd.AddCommand(expireNodeCmd)
|
||||
|
||||
renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
|
||||
err = renameNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
nodeCmd.AddCommand(renameNodeCmd)
|
||||
|
||||
deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
|
||||
err = deleteNodeCmd.MarkFlagRequired("identifier")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
nodeCmd.AddCommand(deleteNodeCmd)
|
||||
|
||||
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
tagCmd.MarkFlagRequired("identifier")
|
||||
_ = tagCmd.MarkFlagRequired("identifier")
|
||||
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
|
||||
nodeCmd.AddCommand(tagCmd)
|
||||
|
||||
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
approveRoutesCmd.MarkFlagRequired("identifier")
|
||||
_ = approveRoutesCmd.MarkFlagRequired("identifier")
|
||||
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
|
||||
nodeCmd.AddCommand(approveRoutesCmd)
|
||||
|
||||
@@ -233,10 +243,7 @@ var listNodeRoutesCmd = &cobra.Command{
|
||||
return
|
||||
}
|
||||
|
||||
tableData, err := nodeRoutesToPtables(nodes)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
}
|
||||
tableData := nodeRoutesToPtables(nodes)
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
@@ -506,15 +513,21 @@ func nodesToPtables(
|
||||
ephemeral = true
|
||||
}
|
||||
|
||||
var lastSeen time.Time
|
||||
var lastSeenTime string
|
||||
var (
|
||||
lastSeen time.Time
|
||||
lastSeenTime string
|
||||
)
|
||||
|
||||
if node.GetLastSeen() != nil {
|
||||
lastSeen = node.GetLastSeen().AsTime()
|
||||
lastSeenTime = lastSeen.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
var expiry time.Time
|
||||
var expiryTime string
|
||||
var (
|
||||
expiry time.Time
|
||||
expiryTime string
|
||||
)
|
||||
|
||||
if node.GetExpiry() != nil {
|
||||
expiry = node.GetExpiry().AsTime()
|
||||
expiryTime = expiry.Format("2006-01-02 15:04:05")
|
||||
@@ -523,6 +536,7 @@ func nodesToPtables(
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
|
||||
err := machineKey.UnmarshalText(
|
||||
[]byte(node.GetMachineKey()),
|
||||
)
|
||||
@@ -531,6 +545,7 @@ func nodesToPtables(
|
||||
}
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
|
||||
err = nodeKey.UnmarshalText(
|
||||
[]byte(node.GetNodeKey()),
|
||||
)
|
||||
@@ -572,8 +587,11 @@ func nodesToPtables(
|
||||
user = pterm.LightYellow(node.GetUser().GetName())
|
||||
}
|
||||
|
||||
var IPV4Address string
|
||||
var IPV6Address string
|
||||
var (
|
||||
IPV4Address string
|
||||
IPV6Address string
|
||||
)
|
||||
|
||||
for _, addr := range node.GetIpAddresses() {
|
||||
if netip.MustParseAddr(addr).Is4() {
|
||||
IPV4Address = addr
|
||||
@@ -608,7 +626,7 @@ func nodesToPtables(
|
||||
|
||||
func nodeRoutesToPtables(
|
||||
nodes []*v1.Node,
|
||||
) (pterm.TableData, error) {
|
||||
) pterm.TableData {
|
||||
tableHeader := []string{
|
||||
"ID",
|
||||
"Hostname",
|
||||
@@ -632,7 +650,7 @@ func nodeRoutesToPtables(
|
||||
)
|
||||
}
|
||||
|
||||
return tableData, nil
|
||||
return tableData
|
||||
}
|
||||
|
||||
var tagCmd = &cobra.Command{
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
bypassFlag = "bypass-grpc-and-access-database-directly"
|
||||
bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // not a credential
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -26,16 +26,22 @@ func init() {
|
||||
policyCmd.AddCommand(getPolicy)
|
||||
|
||||
setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
||||
if err := setPolicy.MarkFlagRequired("file"); err != nil {
|
||||
|
||||
err := setPolicy.MarkFlagRequired("file")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
|
||||
policyCmd.AddCommand(setPolicy)
|
||||
|
||||
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
||||
if err := checkPolicy.MarkFlagRequired("file"); err != nil {
|
||||
|
||||
err = checkPolicy.MarkFlagRequired("file")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
policyCmd.AddCommand(checkPolicy)
|
||||
}
|
||||
|
||||
@@ -173,7 +179,7 @@ var setPolicy = &cobra.Command{
|
||||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
if _, err := client.SetPolicy(ctx, request); err != nil {
|
||||
if _, err := client.SetPolicy(ctx, request); err != nil { //nolint:noinlineerr
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ func initConfig() {
|
||||
if cfgFile == "" {
|
||||
cfgFile = os.Getenv("HEADSCALE_CONFIG")
|
||||
}
|
||||
|
||||
if cfgFile != "" {
|
||||
err := types.LoadConfig(cfgFile, true)
|
||||
if err != nil {
|
||||
@@ -80,6 +81,7 @@ func initConfig() {
|
||||
Repository: "headscale",
|
||||
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
|
||||
}
|
||||
|
||||
res, err := latest.Check(githubTag, versionInfo.Version)
|
||||
if err == nil && res.Outdated {
|
||||
//nolint
|
||||
@@ -101,6 +103,7 @@ func isPreReleaseVersion(version string) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -140,7 +143,8 @@ https://github.com/juanfont/headscale`,
|
||||
}
|
||||
|
||||
func Execute() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// CLI user errors.
|
||||
var (
|
||||
errFlagRequired = errors.New("--name or --identifier flag is required")
|
||||
errMultipleUsersMatch = errors.New("multiple users match query, specify an ID")
|
||||
)
|
||||
|
||||
func usernameAndIDFlag(cmd *cobra.Command) {
|
||||
cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)")
|
||||
cmd.Flags().StringP("name", "n", "", "Username")
|
||||
@@ -24,12 +30,12 @@ func usernameAndIDFlag(cmd *cobra.Command) {
|
||||
// If both are empty, it will exit the program with an error.
|
||||
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
|
||||
username, _ := cmd.Flags().GetString("name")
|
||||
|
||||
identifier, _ := cmd.Flags().GetInt64("identifier")
|
||||
if username == "" && identifier < 0 {
|
||||
err := errors.New("--name or --identifier flag is required")
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot rename user: "+status.Convert(err).Message(),
|
||||
errFlagRequired,
|
||||
"Cannot rename user: "+status.Convert(errFlagRequired).Message(),
|
||||
"",
|
||||
)
|
||||
}
|
||||
@@ -51,7 +57,8 @@ func init() {
|
||||
userCmd.AddCommand(renameUserCmd)
|
||||
usernameAndIDFlag(renameUserCmd)
|
||||
renameUserCmd.Flags().StringP("new-name", "r", "", "New username")
|
||||
renameNodeCmd.MarkFlagRequired("new-name")
|
||||
|
||||
_ = renameNodeCmd.MarkFlagRequired("new-name")
|
||||
}
|
||||
|
||||
var errMissingParameter = errors.New("missing parameters")
|
||||
@@ -95,7 +102,7 @@ var createUserCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
|
||||
if _, err := url.Parse(pictureURL); err != nil {
|
||||
if _, err := url.Parse(pictureURL); err != nil { //nolint:noinlineerr
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
@@ -149,7 +156,7 @@ var destroyUserCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("multiple users match query, specify an ID")
|
||||
err := errMultipleUsersMatch
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
@@ -277,7 +284,7 @@ var renameUserCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("multiple users match query, specify an ID")
|
||||
err := errMultipleUsersMatch
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
|
||||
@@ -58,7 +58,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout)
|
||||
|
||||
grpcOptions := []grpc.DialOption{
|
||||
grpc.WithBlock(),
|
||||
grpc.WithBlock(), //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
||||
}
|
||||
|
||||
address := cfg.CLI.Address
|
||||
@@ -82,6 +82,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
Msgf("Unable to read/write to headscale socket, do you have the correct permissions?")
|
||||
}
|
||||
}
|
||||
|
||||
socket.Close()
|
||||
|
||||
grpcOptions = append(
|
||||
@@ -95,6 +96,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
if apiKey == "" {
|
||||
log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set")
|
||||
}
|
||||
|
||||
grpcOptions = append(grpcOptions,
|
||||
grpc.WithPerRPCCredentials(tokenAuth{
|
||||
token: apiKey,
|
||||
@@ -120,7 +122,8 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
}
|
||||
|
||||
log.Trace().Caller().Str(zf.Address, address).Msg("connecting via gRPC")
|
||||
conn, err := grpc.DialContext(ctx, address, grpcOptions...)
|
||||
|
||||
conn, err := grpc.DialContext(ctx, address, grpcOptions...) //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
||||
if err != nil {
|
||||
log.Fatal().Caller().Err(err).Msgf("could not connect: %v", err)
|
||||
os.Exit(-1) // we get here if logging is suppressed (i.e., json output)
|
||||
@@ -132,8 +135,11 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
|
||||
}
|
||||
|
||||
func output(result any, override string, outputFormat string) string {
|
||||
var jsonBytes []byte
|
||||
var err error
|
||||
var (
|
||||
jsonBytes []byte
|
||||
err error
|
||||
)
|
||||
|
||||
switch outputFormat {
|
||||
case "json":
|
||||
jsonBytes, err = json.MarshalIndent(result, "", "\t")
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
func main() {
|
||||
var colors bool
|
||||
|
||||
switch l := termcolor.SupportLevel(os.Stderr); l {
|
||||
case termcolor.Level16M:
|
||||
colors = true
|
||||
|
||||
@@ -14,9 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func TestConfigFileLoading(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
path, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
@@ -48,9 +46,7 @@ func TestConfigFileLoading(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigLoading(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
path, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -25,7 +25,7 @@ func cleanupBeforeTest(ctx context.Context) error {
|
||||
return fmt.Errorf("cleaning stale test containers: %w", err)
|
||||
}
|
||||
|
||||
if err := pruneDockerNetworks(ctx); err != nil {
|
||||
if err := pruneDockerNetworks(ctx); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("pruning networks: %w", err)
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI
|
||||
|
||||
// killTestContainers terminates and removes all test containers.
|
||||
func killTestContainers(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
@@ -69,8 +69,10 @@ func killTestContainers(ctx context.Context) error {
|
||||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
for _, cont := range containers {
|
||||
shouldRemove := false
|
||||
|
||||
for _, name := range cont.Names {
|
||||
if strings.Contains(name, "headscale-test-suite") ||
|
||||
strings.Contains(name, "hs-") ||
|
||||
@@ -107,7 +109,7 @@ func killTestContainers(ctx context.Context) error {
|
||||
// This function filters containers by the hi.run-id label to only affect containers
|
||||
// belonging to the specified test run, leaving other concurrent test runs untouched.
|
||||
func killTestContainersByRunID(ctx context.Context, runID string) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
@@ -149,7 +151,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error {
|
||||
// This is useful for cleaning up leftover containers from previous crashed or interrupted test runs
|
||||
// without interfering with currently running concurrent tests.
|
||||
func cleanupStaleTestContainers(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
@@ -223,7 +225,7 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container
|
||||
|
||||
// pruneDockerNetworks removes unused Docker networks.
|
||||
func pruneDockerNetworks(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
@@ -245,7 +247,7 @@ func pruneDockerNetworks(ctx context.Context) error {
|
||||
|
||||
// cleanOldImages removes test-related and old dangling Docker images.
|
||||
func cleanOldImages(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
@@ -259,8 +261,10 @@ func cleanOldImages(ctx context.Context) error {
|
||||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
for _, img := range images {
|
||||
shouldRemove := false
|
||||
|
||||
for _, tag := range img.RepoTags {
|
||||
if strings.Contains(tag, "hs-") ||
|
||||
strings.Contains(tag, "headscale-integration") ||
|
||||
@@ -295,18 +299,19 @@ func cleanOldImages(ctx context.Context) error {
|
||||
|
||||
// cleanCacheVolume removes the Docker volume used for Go module cache.
|
||||
func cleanCacheVolume(ctx context.Context) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
volumeName := "hs-integration-go-cache"
|
||||
|
||||
err = cli.VolumeRemove(ctx, volumeName, true)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
if errdefs.IsNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional
|
||||
fmt.Printf("Go module cache volume not found: %s\n", volumeName)
|
||||
} else if errdefs.IsConflict(err) {
|
||||
} else if errdefs.IsConflict(err) { //nolint:staticcheck // SA1019: deprecated but functional
|
||||
fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName)
|
||||
} else {
|
||||
fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err)
|
||||
|
||||
@@ -22,15 +22,20 @@ import (
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
)
|
||||
|
||||
const defaultDirPerm = 0o755
|
||||
|
||||
var (
|
||||
ErrTestFailed = errors.New("test failed")
|
||||
ErrUnexpectedContainerWait = errors.New("unexpected end of container wait")
|
||||
ErrNoDockerContext = errors.New("no docker context found")
|
||||
ErrMemoryLimitViolations = errors.New("container(s) exceeded memory limits")
|
||||
)
|
||||
|
||||
// runTestContainer executes integration tests in a Docker container.
|
||||
//
|
||||
//nolint:gocyclo // complex test orchestration function
|
||||
func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
@@ -52,7 +57,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
}
|
||||
|
||||
const dirPerm = 0o755
|
||||
if err := os.MkdirAll(absLogsDir, dirPerm); err != nil {
|
||||
if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("creating logs directory: %w", err)
|
||||
}
|
||||
|
||||
@@ -60,7 +65,9 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
if config.Verbose {
|
||||
log.Printf("Running pre-test cleanup...")
|
||||
}
|
||||
if err := cleanupBeforeTest(ctx); err != nil && config.Verbose {
|
||||
|
||||
err := cleanupBeforeTest(ctx)
|
||||
if err != nil && config.Verbose {
|
||||
log.Printf("Warning: pre-test cleanup failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -71,7 +78,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
}
|
||||
|
||||
imageName := "golang:" + config.GoVersion
|
||||
if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil {
|
||||
if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("ensuring image availability: %w", err)
|
||||
}
|
||||
|
||||
@@ -84,7 +91,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
log.Printf("Created container: %s", resp.ID)
|
||||
}
|
||||
|
||||
if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
|
||||
if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("starting container: %w", err)
|
||||
}
|
||||
|
||||
@@ -95,13 +102,16 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
|
||||
// Start stats collection for container resource monitoring (if enabled)
|
||||
var statsCollector *StatsCollector
|
||||
|
||||
if config.Stats {
|
||||
var err error
|
||||
statsCollector, err = NewStatsCollector()
|
||||
|
||||
statsCollector, err = NewStatsCollector(ctx)
|
||||
if err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to create stats collector: %v", err)
|
||||
}
|
||||
|
||||
statsCollector = nil
|
||||
}
|
||||
|
||||
@@ -110,7 +120,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
|
||||
// Start stats collection immediately - no need for complex retry logic
|
||||
// The new implementation monitors Docker events and will catch containers as they start
|
||||
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
|
||||
err := statsCollector.StartCollection(ctx, runID, config.Verbose)
|
||||
if err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to start stats collection: %v", err)
|
||||
}
|
||||
@@ -122,12 +133,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
exitCode, err := streamAndWait(ctx, cli, resp.ID)
|
||||
|
||||
// Ensure all containers have finished and logs are flushed before extracting artifacts
|
||||
if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose {
|
||||
waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose)
|
||||
if waitErr != nil && config.Verbose {
|
||||
log.Printf("Warning: failed to wait for container finalization: %v", waitErr)
|
||||
}
|
||||
|
||||
// Extract artifacts from test containers before cleanup
|
||||
if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose {
|
||||
if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { //nolint:noinlineerr
|
||||
log.Printf("Warning: failed to extract artifacts from containers: %v", err)
|
||||
}
|
||||
|
||||
@@ -140,12 +152,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
if len(violations) > 0 {
|
||||
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
|
||||
log.Printf("=================================")
|
||||
|
||||
for _, violation := range violations {
|
||||
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
|
||||
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
|
||||
}
|
||||
|
||||
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
|
||||
return fmt.Errorf("test failed: %d %w", len(violations), ErrMemoryLimitViolations)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,6 +360,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
maxWaitTime := 10 * time.Second
|
||||
checkInterval := 500 * time.Millisecond
|
||||
timeout := time.After(maxWaitTime)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -356,6 +370,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
if verbose {
|
||||
log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction")
|
||||
}
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
allFinalized := true
|
||||
@@ -366,12 +381,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
if verbose {
|
||||
log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if container is in a final state
|
||||
if !isContainerFinalized(inspect.State) {
|
||||
allFinalized = false
|
||||
|
||||
if verbose {
|
||||
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
|
||||
}
|
||||
@@ -384,6 +401,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
if verbose {
|
||||
log.Printf("All test containers finalized, ready for artifact extraction")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -400,13 +418,15 @@ func isContainerFinalized(state *container.State) bool {
|
||||
func findProjectRoot(startPath string) string {
|
||||
current := startPath
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil {
|
||||
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { //nolint:noinlineerr
|
||||
return current
|
||||
}
|
||||
|
||||
parent := filepath.Dir(current)
|
||||
if parent == current {
|
||||
return startPath
|
||||
}
|
||||
|
||||
current = parent
|
||||
}
|
||||
}
|
||||
@@ -416,6 +436,7 @@ func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -428,13 +449,14 @@ type DockerContext struct {
|
||||
}
|
||||
|
||||
// createDockerClient creates a Docker client with context detection.
|
||||
func createDockerClient() (*client.Client, error) {
|
||||
contextInfo, err := getCurrentDockerContext()
|
||||
func createDockerClient(ctx context.Context) (*client.Client, error) {
|
||||
contextInfo, err := getCurrentDockerContext(ctx)
|
||||
if err != nil {
|
||||
return client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||
}
|
||||
|
||||
var clientOpts []client.Opt
|
||||
|
||||
clientOpts = append(clientOpts, client.WithAPIVersionNegotiation())
|
||||
|
||||
if contextInfo != nil {
|
||||
@@ -444,6 +466,7 @@ func createDockerClient() (*client.Client, error) {
|
||||
if runConfig.Verbose {
|
||||
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
|
||||
}
|
||||
|
||||
clientOpts = append(clientOpts, client.WithHost(host))
|
||||
}
|
||||
}
|
||||
@@ -458,15 +481,16 @@ func createDockerClient() (*client.Client, error) {
|
||||
}
|
||||
|
||||
// getCurrentDockerContext retrieves the current Docker context information.
|
||||
func getCurrentDockerContext() (*DockerContext, error) {
|
||||
cmd := exec.Command("docker", "context", "inspect")
|
||||
func getCurrentDockerContext(ctx context.Context) (*DockerContext, error) {
|
||||
cmd := exec.CommandContext(ctx, "docker", "context", "inspect")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting docker context: %w", err)
|
||||
}
|
||||
|
||||
var contexts []DockerContext
|
||||
if err := json.Unmarshal(output, &contexts); err != nil {
|
||||
if err := json.Unmarshal(output, &contexts); err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("parsing docker context: %w", err)
|
||||
}
|
||||
|
||||
@@ -486,11 +510,12 @@ func getDockerSocketPath() string {
|
||||
|
||||
// checkImageAvailableLocally checks if the specified Docker image is available locally.
|
||||
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
|
||||
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
|
||||
_, _, err := cli.ImageInspectWithRaw(ctx, imageName) //nolint:staticcheck // SA1019: deprecated but functional
|
||||
if err != nil {
|
||||
if client.IsErrNotFound(err) {
|
||||
if client.IsErrNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("inspecting image %s: %w", imageName, err)
|
||||
}
|
||||
|
||||
@@ -509,6 +534,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
||||
if verbose {
|
||||
log.Printf("Image %s is available locally", imageName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -533,6 +559,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading pull output: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Image %s pulled successfully", imageName)
|
||||
}
|
||||
|
||||
@@ -547,9 +574,11 @@ func listControlFiles(logsDir string) {
|
||||
return
|
||||
}
|
||||
|
||||
var logFiles []string
|
||||
var dataFiles []string
|
||||
var dataDirs []string
|
||||
var (
|
||||
logFiles []string
|
||||
dataFiles []string
|
||||
dataDirs []string
|
||||
)
|
||||
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
@@ -578,6 +607,7 @@ func listControlFiles(logsDir string) {
|
||||
|
||||
if len(logFiles) > 0 {
|
||||
log.Printf("Headscale logs:")
|
||||
|
||||
for _, file := range logFiles {
|
||||
log.Printf(" %s", file)
|
||||
}
|
||||
@@ -585,9 +615,11 @@ func listControlFiles(logsDir string) {
|
||||
|
||||
if len(dataFiles) > 0 || len(dataDirs) > 0 {
|
||||
log.Printf("Headscale data:")
|
||||
|
||||
for _, file := range dataFiles {
|
||||
log.Printf(" %s", file)
|
||||
}
|
||||
|
||||
for _, dir := range dataDirs {
|
||||
log.Printf(" %s/", dir)
|
||||
}
|
||||
@@ -596,7 +628,7 @@ func listControlFiles(logsDir string) {
|
||||
|
||||
// extractArtifactsFromContainers collects container logs and files from the specific test run.
|
||||
func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
@@ -612,9 +644,11 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
|
||||
currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
||||
|
||||
extractedCount := 0
|
||||
|
||||
for _, cont := range currentTestContainers {
|
||||
// Extract container logs and tar files
|
||||
if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil {
|
||||
err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose)
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err)
|
||||
}
|
||||
@@ -622,6 +656,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
|
||||
if verbose {
|
||||
log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12])
|
||||
}
|
||||
|
||||
extractedCount++
|
||||
}
|
||||
}
|
||||
@@ -645,11 +680,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
||||
|
||||
// Find the test container to get its run ID label
|
||||
var runID string
|
||||
|
||||
for _, cont := range containers {
|
||||
if cont.ID == testContainerID {
|
||||
if cont.Labels != nil {
|
||||
runID = cont.Labels["hi.run-id"]
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -690,18 +727,21 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
||||
// extractContainerArtifacts saves logs and tar files from a container.
|
||||
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
|
||||
// Ensure the logs directory exists
|
||||
if err := os.MkdirAll(logsDir, 0o755); err != nil {
|
||||
err := os.MkdirAll(logsDir, defaultDirPerm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating logs directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract container logs
|
||||
if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
|
||||
err = extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose)
|
||||
if err != nil {
|
||||
return fmt.Errorf("extracting logs: %w", err)
|
||||
}
|
||||
|
||||
// Extract tar files for headscale containers only
|
||||
if strings.HasPrefix(containerName, "hs-") {
|
||||
if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
|
||||
err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose)
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Warning: failed to extract files from %s: %v", containerName, err)
|
||||
}
|
||||
@@ -741,12 +781,12 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
|
||||
}
|
||||
|
||||
// Write stdout logs
|
||||
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil {
|
||||
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable
|
||||
return fmt.Errorf("writing stdout log: %w", err)
|
||||
}
|
||||
|
||||
// Write stderr logs
|
||||
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil {
|
||||
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable
|
||||
return fmt.Errorf("writing stderr log: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -38,13 +38,13 @@ func runDoctorCheck(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Check 3: Go installation
|
||||
results = append(results, checkGoInstallation())
|
||||
results = append(results, checkGoInstallation(ctx))
|
||||
|
||||
// Check 4: Git repository
|
||||
results = append(results, checkGitRepository())
|
||||
results = append(results, checkGitRepository(ctx))
|
||||
|
||||
// Check 5: Required files
|
||||
results = append(results, checkRequiredFiles())
|
||||
results = append(results, checkRequiredFiles(ctx))
|
||||
|
||||
// Display results
|
||||
displayDoctorResults(results)
|
||||
@@ -86,7 +86,7 @@ func checkDockerBinary() DoctorResult {
|
||||
|
||||
// checkDockerDaemon verifies Docker daemon is running and accessible.
|
||||
func checkDockerDaemon(ctx context.Context) DoctorResult {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Docker Daemon",
|
||||
@@ -124,8 +124,8 @@ func checkDockerDaemon(ctx context.Context) DoctorResult {
|
||||
}
|
||||
|
||||
// checkDockerContext verifies Docker context configuration.
|
||||
func checkDockerContext(_ context.Context) DoctorResult {
|
||||
contextInfo, err := getCurrentDockerContext()
|
||||
func checkDockerContext(ctx context.Context) DoctorResult {
|
||||
contextInfo, err := getCurrentDockerContext(ctx)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Docker Context",
|
||||
@@ -155,7 +155,7 @@ func checkDockerContext(_ context.Context) DoctorResult {
|
||||
|
||||
// checkDockerSocket verifies Docker socket accessibility.
|
||||
func checkDockerSocket(ctx context.Context) DoctorResult {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Docker Socket",
|
||||
@@ -192,7 +192,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
|
||||
|
||||
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
|
||||
func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
cli, err := createDockerClient()
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
@@ -251,7 +251,7 @@ func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
}
|
||||
|
||||
// checkGoInstallation verifies Go is installed and working.
|
||||
func checkGoInstallation() DoctorResult {
|
||||
func checkGoInstallation(ctx context.Context) DoctorResult {
|
||||
_, err := exec.LookPath("go")
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
@@ -265,7 +265,8 @@ func checkGoInstallation() DoctorResult {
|
||||
}
|
||||
}
|
||||
|
||||
cmd := exec.Command("go", "version")
|
||||
cmd := exec.CommandContext(ctx, "go", "version")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
@@ -285,8 +286,9 @@ func checkGoInstallation() DoctorResult {
|
||||
}
|
||||
|
||||
// checkGitRepository verifies we're in a git repository.
|
||||
func checkGitRepository() DoctorResult {
|
||||
cmd := exec.Command("git", "rev-parse", "--git-dir")
|
||||
func checkGitRepository(ctx context.Context) DoctorResult {
|
||||
cmd := exec.CommandContext(ctx, "git", "rev-parse", "--git-dir")
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
@@ -308,7 +310,7 @@ func checkGitRepository() DoctorResult {
|
||||
}
|
||||
|
||||
// checkRequiredFiles verifies required files exist.
|
||||
func checkRequiredFiles() DoctorResult {
|
||||
func checkRequiredFiles(ctx context.Context) DoctorResult {
|
||||
requiredFiles := []string{
|
||||
"go.mod",
|
||||
"integration/",
|
||||
@@ -316,9 +318,12 @@ func checkRequiredFiles() DoctorResult {
|
||||
}
|
||||
|
||||
var missingFiles []string
|
||||
|
||||
for _, file := range requiredFiles {
|
||||
cmd := exec.Command("test", "-e", file)
|
||||
if err := cmd.Run(); err != nil {
|
||||
cmd := exec.CommandContext(ctx, "test", "-e", file)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
missingFiles = append(missingFiles, file)
|
||||
}
|
||||
}
|
||||
@@ -350,6 +355,7 @@ func displayDoctorResults(results []DoctorResult) {
|
||||
|
||||
for _, result := range results {
|
||||
var icon string
|
||||
|
||||
switch result.Status {
|
||||
case "PASS":
|
||||
icon = "✅"
|
||||
|
||||
@@ -79,13 +79,18 @@ func main() {
|
||||
}
|
||||
|
||||
func cleanAll(ctx context.Context) error {
|
||||
if err := killTestContainers(ctx); err != nil {
|
||||
err := killTestContainers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pruneDockerNetworks(ctx); err != nil {
|
||||
|
||||
err = pruneDockerNetworks(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cleanOldImages(ctx); err != nil {
|
||||
|
||||
err = cleanOldImages(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,9 @@ func runIntegrationTest(env *command.Env) error {
|
||||
if runConfig.Verbose {
|
||||
log.Printf("Running pre-flight system checks...")
|
||||
}
|
||||
if err := runDoctorCheck(env.Context()); err != nil {
|
||||
|
||||
err := runDoctorCheck(env.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("pre-flight checks failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -66,9 +68,9 @@ func runIntegrationTest(env *command.Env) error {
|
||||
func detectGoVersion() string {
|
||||
goModPath := filepath.Join("..", "..", "go.mod")
|
||||
|
||||
if _, err := os.Stat("go.mod"); err == nil {
|
||||
if _, err := os.Stat("go.mod"); err == nil { //nolint:noinlineerr
|
||||
goModPath = "go.mod"
|
||||
} else if _, err := os.Stat("../../go.mod"); err == nil {
|
||||
} else if _, err := os.Stat("../../go.mod"); err == nil { //nolint:noinlineerr
|
||||
goModPath = "../../go.mod"
|
||||
}
|
||||
|
||||
@@ -94,8 +96,10 @@ func detectGoVersion() string {
|
||||
|
||||
// splitLines splits a string into lines without using strings.Split.
|
||||
func splitLines(s string) []string {
|
||||
var lines []string
|
||||
var current string
|
||||
var (
|
||||
lines []string
|
||||
current string
|
||||
)
|
||||
|
||||
for _, char := range s {
|
||||
if char == '\n' {
|
||||
|
||||
@@ -18,6 +18,9 @@ import (
|
||||
"github.com/docker/docker/client"
|
||||
)
|
||||
|
||||
// ErrStatsCollectionAlreadyStarted is returned when trying to start stats collection that is already running.
|
||||
var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started")
|
||||
|
||||
// ContainerStats represents statistics for a single container.
|
||||
type ContainerStats struct {
|
||||
ContainerID string
|
||||
@@ -44,8 +47,8 @@ type StatsCollector struct {
|
||||
}
|
||||
|
||||
// NewStatsCollector creates a new stats collector instance.
|
||||
func NewStatsCollector() (*StatsCollector, error) {
|
||||
cli, err := createDockerClient()
|
||||
func NewStatsCollector(ctx context.Context) (*StatsCollector, error) {
|
||||
cli, err := createDockerClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating Docker client: %w", err)
|
||||
}
|
||||
@@ -63,17 +66,19 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
||||
defer sc.mutex.Unlock()
|
||||
|
||||
if sc.collectionStarted {
|
||||
return errors.New("stats collection already started")
|
||||
return ErrStatsCollectionAlreadyStarted
|
||||
}
|
||||
|
||||
sc.collectionStarted = true
|
||||
|
||||
// Start monitoring existing containers
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.monitorExistingContainers(ctx, runID, verbose)
|
||||
|
||||
// Start Docker events monitoring for new containers
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.monitorDockerEvents(ctx, runID, verbose)
|
||||
|
||||
if verbose {
|
||||
@@ -87,10 +92,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
||||
func (sc *StatsCollector) StopCollection() {
|
||||
// Check if already stopped without holding lock
|
||||
sc.mutex.RLock()
|
||||
|
||||
if !sc.collectionStarted {
|
||||
sc.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
// Signal stop to all goroutines
|
||||
@@ -114,6 +121,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s
|
||||
if verbose {
|
||||
log.Printf("Failed to list existing containers: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -147,13 +155,13 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
|
||||
case event := <-events:
|
||||
if event.Type == "container" && event.Action == "start" {
|
||||
// Get container details
|
||||
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
|
||||
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) //nolint:staticcheck // SA1019: use Actor.ID
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert to types.Container format for consistency
|
||||
cont := types.Container{
|
||||
cont := types.Container{ //nolint:staticcheck // SA1019: use container.Summary
|
||||
ID: containerInfo.ID,
|
||||
Names: []string{containerInfo.Name},
|
||||
Labels: containerInfo.Config.Labels,
|
||||
@@ -167,13 +175,14 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
|
||||
if verbose {
|
||||
log.Printf("Error in Docker events stream: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldMonitorContainer determines if a container should be monitored.
|
||||
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
|
||||
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { //nolint:staticcheck // SA1019: use container.Summary
|
||||
// Check if it has the correct run ID label
|
||||
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
|
||||
return false
|
||||
@@ -213,6 +222,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI
|
||||
}
|
||||
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.collectStatsForContainer(ctx, containerID, verbose)
|
||||
}
|
||||
|
||||
@@ -226,12 +236,14 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
||||
if verbose {
|
||||
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
defer statsResponse.Body.Close()
|
||||
|
||||
decoder := json.NewDecoder(statsResponse.Body)
|
||||
var prevStats *container.Stats
|
||||
|
||||
var prevStats *container.Stats //nolint:staticcheck // SA1019: use StatsResponse
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -240,12 +252,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
var stats container.Stats
|
||||
if err := decoder.Decode(&stats); err != nil {
|
||||
var stats container.Stats //nolint:staticcheck // SA1019: use StatsResponse
|
||||
|
||||
err := decoder.Decode(&stats)
|
||||
if err != nil {
|
||||
// EOF is expected when container stops or stream ends
|
||||
if err.Error() != "EOF" && verbose {
|
||||
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -261,8 +276,10 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
||||
// Store the sample (skip first sample since CPU calculation needs previous stats)
|
||||
if prevStats != nil {
|
||||
// Get container stats reference without holding the main mutex
|
||||
var containerStats *ContainerStats
|
||||
var exists bool
|
||||
var (
|
||||
containerStats *ContainerStats
|
||||
exists bool
|
||||
)
|
||||
|
||||
sc.mutex.RLock()
|
||||
containerStats, exists = sc.containers[containerID]
|
||||
@@ -286,7 +303,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
||||
}
|
||||
|
||||
// calculateCPUPercent calculates CPU usage percentage from Docker stats.
|
||||
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
|
||||
func calculateCPUPercent(prevStats, stats *container.Stats) float64 { //nolint:staticcheck // SA1019: use StatsResponse
|
||||
// CPU calculation based on Docker's implementation
|
||||
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
|
||||
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
|
||||
@@ -331,10 +348,12 @@ type StatsSummary struct {
|
||||
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
||||
// Take snapshot of container references without holding main lock long
|
||||
sc.mutex.RLock()
|
||||
|
||||
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
|
||||
for _, containerStats := range sc.containers {
|
||||
containerRefs = append(containerRefs, containerStats)
|
||||
}
|
||||
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
|
||||
@@ -384,23 +403,25 @@ func calculateStatsSummary(values []float64) StatsSummary {
|
||||
return StatsSummary{}
|
||||
}
|
||||
|
||||
min := values[0]
|
||||
max := values[0]
|
||||
minVal := values[0]
|
||||
maxVal := values[0]
|
||||
sum := 0.0
|
||||
|
||||
for _, value := range values {
|
||||
if value < min {
|
||||
min = value
|
||||
if value < minVal {
|
||||
minVal = value
|
||||
}
|
||||
if value > max {
|
||||
max = value
|
||||
|
||||
if value > maxVal {
|
||||
maxVal = value
|
||||
}
|
||||
|
||||
sum += value
|
||||
}
|
||||
|
||||
return StatsSummary{
|
||||
Min: min,
|
||||
Max: max,
|
||||
Min: minVal,
|
||||
Max: maxVal,
|
||||
Average: sum / float64(len(values)),
|
||||
}
|
||||
}
|
||||
@@ -434,6 +455,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo
|
||||
}
|
||||
|
||||
summaries := sc.GetSummary()
|
||||
|
||||
var violations []MemoryViolation
|
||||
|
||||
for _, summary := range summaries {
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -15,7 +16,10 @@ type MapConfig struct {
|
||||
Directory string `flag:"directory,Directory to read map responses from"`
|
||||
}
|
||||
|
||||
var mapConfig MapConfig
|
||||
var (
|
||||
mapConfig MapConfig
|
||||
errDirectoryRequired = errors.New("directory is required")
|
||||
)
|
||||
|
||||
func main() {
|
||||
root := command.C{
|
||||
@@ -40,7 +44,7 @@ func main() {
|
||||
// runIntegrationTest executes the integration test workflow.
|
||||
func runOnline(env *command.Env) error {
|
||||
if mapConfig.Directory == "" {
|
||||
return fmt.Errorf("directory is required")
|
||||
return errDirectoryRequired
|
||||
}
|
||||
|
||||
resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory)
|
||||
@@ -57,5 +61,6 @@ func runOnline(env *command.Env) error {
|
||||
|
||||
os.Stderr.Write(out)
|
||||
os.Stderr.Write([]byte("\n"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
let
|
||||
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
|
||||
buildGo = pkgs.buildGo125Module;
|
||||
vendorHash = "sha256-jkeB9XUTEGt58fPOMpE4/e3+JQoMQTgf0RlthVBmfG0=";
|
||||
vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0=";
|
||||
in
|
||||
{
|
||||
headscale = buildGo {
|
||||
|
||||
@@ -115,6 +115,7 @@ var (
|
||||
|
||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
var err error
|
||||
|
||||
if profilingEnabled {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
}
|
||||
@@ -142,6 +143,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
if !ok {
|
||||
log.Error().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed")
|
||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed because node not found in NodeStore")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -157,10 +159,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
app.ephemeralGC = ephemeralGC
|
||||
|
||||
var authProvider AuthProvider
|
||||
|
||||
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
oidcProvider, err := NewAuthProviderOIDC(
|
||||
ctx,
|
||||
&app,
|
||||
@@ -177,17 +181,18 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
authProvider = oidcProvider
|
||||
}
|
||||
}
|
||||
|
||||
app.authProvider = authProvider
|
||||
|
||||
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
|
||||
// TODO(kradalby): revisit why this takes a list.
|
||||
|
||||
var magicDNSDomains []dnsname.FQDN
|
||||
if cfg.PrefixV4 != nil {
|
||||
magicDNSDomains = append(
|
||||
magicDNSDomains,
|
||||
util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
||||
}
|
||||
|
||||
if cfg.PrefixV6 != nil {
|
||||
magicDNSDomains = append(
|
||||
magicDNSDomains,
|
||||
@@ -198,6 +203,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
if app.cfg.TailcfgDNSConfig.Routes == nil {
|
||||
app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver)
|
||||
}
|
||||
|
||||
for _, d := range magicDNSDomains {
|
||||
app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil
|
||||
}
|
||||
@@ -232,6 +238,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app.DERPServer = embeddedDERPServer
|
||||
}
|
||||
|
||||
@@ -251,9 +258,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
lastExpiryCheck := time.Unix(0, 0)
|
||||
|
||||
derpTickerChan := make(<-chan time.Time)
|
||||
|
||||
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
|
||||
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
|
||||
defer derpTicker.Stop()
|
||||
|
||||
derpTickerChan = derpTicker.C
|
||||
}
|
||||
|
||||
@@ -271,8 +280,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
return
|
||||
|
||||
case <-expireTicker.C:
|
||||
var expiredNodeChanges []change.Change
|
||||
var changed bool
|
||||
var (
|
||||
expiredNodeChanges []change.Change
|
||||
changed bool
|
||||
)
|
||||
|
||||
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
|
||||
@@ -287,11 +298,13 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
|
||||
case <-derpTickerChan:
|
||||
log.Info().Msg("fetching DERPMap updates")
|
||||
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) {
|
||||
|
||||
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { //nolint:contextcheck
|
||||
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||
region, _ := h.DERPServer.GenerateRegion()
|
||||
derpMap.Regions[region.RegionID] = ®ion
|
||||
@@ -303,6 +316,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
|
||||
continue
|
||||
}
|
||||
|
||||
h.state.SetDERPMap(derpMap)
|
||||
|
||||
h.Change(change.DERPMap())
|
||||
@@ -311,6 +325,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||
|
||||
h.Change(change.ExtraRecords())
|
||||
@@ -390,7 +405,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
|
||||
writeUnauthorized := func(statusCode int) {
|
||||
writer.WriteHeader(statusCode)
|
||||
if _, err := writer.Write([]byte("Unauthorized")); err != nil {
|
||||
|
||||
if _, err := writer.Write([]byte("Unauthorized")); err != nil { //nolint:noinlineerr
|
||||
log.Error().Err(err).Msg("writing HTTP response failed")
|
||||
}
|
||||
}
|
||||
@@ -401,6 +417,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg(`missing "Bearer " prefix in "Authorization" header`)
|
||||
writeUnauthorized(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -412,6 +429,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("failed to validate token")
|
||||
writeUnauthorized(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -420,6 +438,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
Str("client_address", req.RemoteAddr).
|
||||
Msg("invalid token")
|
||||
writeUnauthorized(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -431,7 +450,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
||||
// and will remove it if it is not.
|
||||
func (h *Headscale) ensureUnixSocketIsAbsent() error {
|
||||
// File does not exist, all fine
|
||||
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
|
||||
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { //nolint:noinlineerr
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -455,6 +474,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
|
||||
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
|
||||
Methods(http.MethodGet)
|
||||
@@ -484,8 +504,11 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
}
|
||||
|
||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
//
|
||||
//nolint:gocyclo // complex server startup function
|
||||
func (h *Headscale) Serve() error {
|
||||
var err error
|
||||
|
||||
capver.CanOldCodeBeCleanedUp()
|
||||
|
||||
if profilingEnabled {
|
||||
@@ -512,6 +535,7 @@ func (h *Headscale) Serve() error {
|
||||
Msg("Clients with a lower minimum version will be rejected")
|
||||
|
||||
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||
|
||||
h.mapBatcher.Start()
|
||||
defer h.mapBatcher.Close()
|
||||
|
||||
@@ -545,6 +569,7 @@ func (h *Headscale) Serve() error {
|
||||
// around between restarts, they will reconnect and the GC will
|
||||
// be cancelled.
|
||||
go h.ephemeralGC.Start()
|
||||
|
||||
ephmNodes := h.state.ListEphemeralNodes()
|
||||
for _, node := range ephmNodes.All() {
|
||||
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||
@@ -555,7 +580,9 @@ func (h *Headscale) Serve() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up extrarecord manager: %w", err)
|
||||
}
|
||||
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
|
||||
|
||||
go h.extraRecordMan.Run()
|
||||
defer h.extraRecordMan.Close()
|
||||
}
|
||||
@@ -564,6 +591,7 @@ func (h *Headscale) Serve() error {
|
||||
// records updates
|
||||
scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
|
||||
defer scheduleCancel()
|
||||
|
||||
go h.scheduledTasks(scheduleCtx)
|
||||
|
||||
if zl.GlobalLevel() == zl.TraceLevel {
|
||||
@@ -576,6 +604,7 @@ func (h *Headscale) Serve() error {
|
||||
errorGroup := new(errgroup.Group)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -590,25 +619,26 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
|
||||
socketDir := filepath.Dir(h.cfg.UnixSocket)
|
||||
|
||||
err = util.EnsureDir(socketDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up unix socket: %w", err)
|
||||
}
|
||||
|
||||
socketListener, err := net.Listen("unix", h.cfg.UnixSocket)
|
||||
socketListener, err := new(net.ListenConfig).Listen(context.Background(), "unix", h.cfg.UnixSocket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting up gRPC socket: %w", err)
|
||||
}
|
||||
|
||||
// Change socket permissions
|
||||
if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil {
|
||||
if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("changing gRPC socket permission: %w", err)
|
||||
}
|
||||
|
||||
grpcGatewayMux := grpcRuntime.NewServeMux()
|
||||
|
||||
// Make the grpc-gateway connect to grpc over socket
|
||||
grpcGatewayConn, err := grpc.Dial(
|
||||
grpcGatewayConn, err := grpc.Dial( //nolint:staticcheck // SA1019: deprecated but supported in 1.x
|
||||
h.cfg.UnixSocket,
|
||||
[]grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
@@ -659,8 +689,11 @@ func (h *Headscale) Serve() error {
|
||||
// https://github.com/soheilhy/cmux/issues/68
|
||||
// https://github.com/soheilhy/cmux/issues/91
|
||||
|
||||
var grpcServer *grpc.Server
|
||||
var grpcListener net.Listener
|
||||
var (
|
||||
grpcServer *grpc.Server
|
||||
grpcListener net.Listener
|
||||
)
|
||||
|
||||
if tlsConfig != nil || h.cfg.GRPCAllowInsecure {
|
||||
log.Info().Msgf("enabling remote gRPC at %s", h.cfg.GRPCAddr)
|
||||
|
||||
@@ -685,7 +718,7 @@ func (h *Headscale) Serve() error {
|
||||
v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
|
||||
reflection.Register(grpcServer)
|
||||
|
||||
grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr)
|
||||
grpcListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.GRPCAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("binding to TCP address: %w", err)
|
||||
}
|
||||
@@ -715,12 +748,14 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
|
||||
var httpListener net.Listener
|
||||
|
||||
if tlsConfig != nil {
|
||||
httpServer.TLSConfig = tlsConfig
|
||||
httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig)
|
||||
} else {
|
||||
httpListener, err = net.Listen("tcp", h.cfg.Addr)
|
||||
httpListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.Addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("binding to TCP address: %w", err)
|
||||
}
|
||||
@@ -751,19 +786,24 @@ func (h *Headscale) Serve() error {
|
||||
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
|
||||
}
|
||||
|
||||
|
||||
var tailsqlContext context.Context
|
||||
|
||||
if tailsqlEnabled {
|
||||
if h.cfg.Database.Type != types.DatabaseSqlite {
|
||||
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
|
||||
log.Fatal().
|
||||
Str("type", h.cfg.Database.Type).
|
||||
Msgf("tailsql only support %q", types.DatabaseSqlite)
|
||||
}
|
||||
|
||||
if tailsqlTSKey == "" {
|
||||
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
|
||||
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
|
||||
}
|
||||
|
||||
tailsqlContext = context.Background()
|
||||
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path)
|
||||
|
||||
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) //nolint:errcheck
|
||||
}
|
||||
|
||||
// Handle common process-killing signals so we can gracefully shut down:
|
||||
@@ -774,6 +814,7 @@ func (h *Headscale) Serve() error {
|
||||
syscall.SIGTERM,
|
||||
syscall.SIGQUIT,
|
||||
syscall.SIGHUP)
|
||||
|
||||
sigFunc := func(c chan os.Signal) {
|
||||
// Wait for a SIGINT or SIGKILL:
|
||||
for {
|
||||
@@ -798,6 +839,7 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
default:
|
||||
info := func(msg string) { log.Info().Msg(msg) }
|
||||
|
||||
log.Info().
|
||||
Str("signal", sig.String()).
|
||||
Msg("Received signal to stop, shutting down gracefully")
|
||||
@@ -854,6 +896,7 @@ func (h *Headscale) Serve() error {
|
||||
if debugHTTPListener != nil {
|
||||
debugHTTPListener.Close()
|
||||
}
|
||||
|
||||
httpListener.Close()
|
||||
grpcGatewayConn.Close()
|
||||
|
||||
@@ -863,6 +906,7 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
// Close state connections
|
||||
info("closing state and database")
|
||||
|
||||
err = h.state.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to close state")
|
||||
@@ -875,6 +919,7 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errorGroup.Go(func() error {
|
||||
sigFunc(sigc)
|
||||
|
||||
@@ -886,6 +931,7 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
var err error
|
||||
|
||||
if h.cfg.TLS.LetsEncrypt.Hostname != "" {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||
log.Warn().
|
||||
@@ -918,7 +964,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
// Configuration via autocert with HTTP-01. This requires listening on
|
||||
// port 80 for the certificate validation in addition to the headscale
|
||||
// service, which can be configured to run on any other port.
|
||||
|
||||
server := &http.Server{
|
||||
Addr: h.cfg.TLS.LetsEncrypt.Listen,
|
||||
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
|
||||
@@ -963,6 +1008,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
|
||||
func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
dir := filepath.Dir(path)
|
||||
|
||||
err := util.EnsureDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ensuring private key directory: %w", err)
|
||||
@@ -981,6 +1027,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
err = os.WriteFile(path, machineKeyStr, privateKeyFileMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
@@ -998,7 +1045,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
trimmedPrivateKey := strings.TrimSpace(string(privateKey))
|
||||
|
||||
var machineKey key.MachinePrivate
|
||||
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil {
|
||||
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("parsing private key: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
)
|
||||
|
||||
type AuthProvider interface {
|
||||
RegisterHandler(http.ResponseWriter, *http.Request)
|
||||
AuthURL(types.RegistrationID) string
|
||||
RegisterHandler(w http.ResponseWriter, r *http.Request)
|
||||
AuthURL(regID types.RegistrationID) string
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegister(
|
||||
@@ -51,6 +51,7 @@ func (h *Headscale) handleRegister(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling logout: %w", err)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
@@ -132,7 +133,7 @@ func (h *Headscale) handleRegister(
|
||||
}
|
||||
|
||||
// handleLogout checks if the [tailcfg.RegisterRequest] is a
|
||||
// logout attempt from a node. If the node is not attempting to
|
||||
// logout attempt from a node. If the node is not attempting to.
|
||||
func (h *Headscale) handleLogout(
|
||||
node types.NodeView,
|
||||
req tailcfg.RegisterRequest,
|
||||
@@ -159,6 +160,7 @@ func (h *Headscale) handleLogout(
|
||||
Interface("reg.req", req).
|
||||
Bool("unexpected", true).
|
||||
Msg("Node key expired, forcing re-authentication")
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
NodeKeyExpired: true,
|
||||
MachineAuthorized: false,
|
||||
@@ -275,6 +277,7 @@ func (h *Headscale) waitForFollowup(
|
||||
// registration is expired in the cache, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(req, machineKey)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node.View()), nil
|
||||
}
|
||||
}
|
||||
@@ -340,6 +343,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
||||
}
|
||||
|
||||
var perr types.PAKError
|
||||
if errors.As(err, &perr) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
|
||||
@@ -351,7 +355,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// If node is not valid, it means an ephemeral node was deleted during logout
|
||||
if !node.Valid() {
|
||||
h.Change(changed)
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: no node to return when ephemeral deleted
|
||||
}
|
||||
|
||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||
@@ -430,6 +434,7 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
Str("generated.hostname", hostname).
|
||||
Msg("Received registration request with empty hostname, generated default")
|
||||
}
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -77,7 +77,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
if err := hsdb.DB.Save(&key).Error; err != nil {
|
||||
if err := hsdb.DB.Save(&key).Error; err != nil { //nolint:noinlineerr
|
||||
return "", nil, fmt.Errorf("saving API key to database: %w", err)
|
||||
}
|
||||
|
||||
@@ -87,7 +87,9 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||
keys := []types.APIKey{}
|
||||
if err := hsdb.DB.Find(&keys).Error; err != nil {
|
||||
|
||||
err := hsdb.DB.Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -126,7 +128,8 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
||||
|
||||
// ExpireAPIKey marks a ApiKey as expired.
|
||||
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||
err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -53,6 +53,8 @@ type HSDatabase struct {
|
||||
|
||||
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
||||
// It accepts the full configuration to allow migrations access to policy settings.
|
||||
//
|
||||
//nolint:gocyclo // complex database initialization with many migrations
|
||||
func NewHeadscaleDatabase(
|
||||
cfg *types.Config,
|
||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
|
||||
@@ -76,7 +78,7 @@ func NewHeadscaleDatabase(
|
||||
ID: "202501221827",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// Remove any invalid routes associated with a node that does not exist.
|
||||
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) {
|
||||
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -84,14 +86,14 @@ func NewHeadscaleDatabase(
|
||||
}
|
||||
|
||||
// Remove any invalid routes without a node_id.
|
||||
if tx.Migrator().HasTable(&types.Route{}) {
|
||||
if tx.Migrator().HasTable(&types.Route{}) { //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
err := tx.Exec("delete from routes where node_id is null").Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := tx.AutoMigrate(&types.Route{})
|
||||
err := tx.AutoMigrate(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.Route: %w", err)
|
||||
}
|
||||
@@ -109,6 +111,7 @@ func NewHeadscaleDatabase(
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.PreAuthKey: %w", err)
|
||||
}
|
||||
|
||||
err = tx.AutoMigrate(&types.Node{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.Node: %w", err)
|
||||
@@ -155,7 +158,8 @@ AND auth_key_id NOT IN (
|
||||
|
||||
nodeRoutes := map[uint64][]netip.Prefix{}
|
||||
|
||||
var routes []types.Route
|
||||
var routes []types.Route //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
|
||||
err = tx.Find(&routes).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching routes: %w", err)
|
||||
@@ -171,7 +175,7 @@ AND auth_key_id NOT IN (
|
||||
tsaddr.SortPrefixes(routes)
|
||||
routes = slices.Compact(routes)
|
||||
|
||||
data, err := json.Marshal(routes)
|
||||
data, _ := json.Marshal(routes)
|
||||
|
||||
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
|
||||
if err != nil {
|
||||
@@ -180,7 +184,7 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
// Drop the old table.
|
||||
_ = tx.Migrator().DropTable(&types.Route{})
|
||||
_ = tx.Migrator().DropTable(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
|
||||
return nil
|
||||
},
|
||||
@@ -256,10 +260,13 @@ AND auth_key_id NOT IN (
|
||||
|
||||
// Check if routes table exists and drop it (should have been migrated already)
|
||||
var routesExists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
|
||||
if err == nil && routesExists {
|
||||
log.Info().Msg("dropping leftover routes table")
|
||||
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
|
||||
|
||||
err := tx.Exec("DROP TABLE routes").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("dropping routes table: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -281,6 +288,7 @@ AND auth_key_id NOT IN (
|
||||
for _, table := range tablesToRename {
|
||||
// Check if table exists before renaming
|
||||
var exists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking if table %s exists: %w", table, err)
|
||||
@@ -291,7 +299,8 @@ AND auth_key_id NOT IN (
|
||||
_ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||
|
||||
// Rename current table to _old
|
||||
if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil {
|
||||
err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err)
|
||||
}
|
||||
}
|
||||
@@ -365,7 +374,8 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for _, createSQL := range tableCreationSQL {
|
||||
if err := tx.Exec(createSQL).Error; err != nil {
|
||||
err := tx.Exec(createSQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating new table: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -394,7 +404,8 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for _, copySQL := range dataCopySQL {
|
||||
if err := tx.Exec(copySQL).Error; err != nil {
|
||||
err := tx.Exec(copySQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("copying data: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -417,14 +428,16 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for _, indexSQL := range indexes {
|
||||
if err := tx.Exec(indexSQL).Error; err != nil {
|
||||
err := tx.Exec(indexSQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Drop old tables only after everything succeeds
|
||||
for _, table := range tablesToRename {
|
||||
if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil {
|
||||
err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||
if err != nil {
|
||||
log.Warn().Str("table", table+"_old").Err(err).Msg("failed to drop old table, but migration succeeded")
|
||||
}
|
||||
}
|
||||
@@ -760,6 +773,7 @@ AND auth_key_id NOT IN (
|
||||
|
||||
// or else it blocks...
|
||||
sqlConn.SetMaxIdleConns(maxIdleConns)
|
||||
|
||||
sqlConn.SetMaxOpenConns(maxOpenConns)
|
||||
defer sqlConn.SetMaxIdleConns(1)
|
||||
defer sqlConn.SetMaxOpenConns(1)
|
||||
@@ -777,7 +791,7 @@ AND auth_key_id NOT IN (
|
||||
},
|
||||
}
|
||||
|
||||
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil {
|
||||
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("validating schema: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -803,6 +817,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||
switch cfg.Type {
|
||||
case types.DatabaseSqlite:
|
||||
dir := filepath.Dir(cfg.Sqlite.Path)
|
||||
|
||||
err := util.EnsureDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
|
||||
@@ -856,7 +871,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||
Str("path", dbString).
|
||||
Msg("Opening database")
|
||||
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { //nolint:noinlineerr
|
||||
if !sslEnabled {
|
||||
dbString += " sslmode=disable"
|
||||
}
|
||||
@@ -911,7 +926,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
|
||||
// Get the current foreign key status
|
||||
var fkOriginallyEnabled int
|
||||
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil {
|
||||
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("checking foreign key status: %w", err)
|
||||
}
|
||||
|
||||
@@ -940,28 +955,31 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
|
||||
if needsFKDisabled {
|
||||
// Disable foreign keys for this migration
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil {
|
||||
err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err)
|
||||
}
|
||||
} else {
|
||||
// Ensure foreign keys are enabled for this migration
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
||||
err := dbConn.Exec("PRAGMA foreign_keys = ON").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run up to this specific migration (will only run the next pending migration)
|
||||
if err := migrations.MigrateTo(migrationID); err != nil {
|
||||
err := migrations.MigrateTo(migrationID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("running migration %s: %w", migrationID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("restoring foreign keys: %w", err)
|
||||
}
|
||||
|
||||
// Run the rest of the migrations
|
||||
if err := migrations.Migrate(); err != nil {
|
||||
if err := migrations.Migrate(); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -979,16 +997,22 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var violation constraintViolation
|
||||
if err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex); err != nil {
|
||||
|
||||
err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
violatedConstraints = append(violatedConstraints, violation)
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
if err := rows.Err(); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
if len(violatedConstraints) > 0 {
|
||||
for _, violation := range violatedConstraints {
|
||||
@@ -1003,7 +1027,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
}
|
||||
} else {
|
||||
// PostgreSQL can run all migrations in one block - no foreign key issues
|
||||
if err := migrations.Migrate(); err != nil {
|
||||
err := migrations.Migrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1014,6 +1039,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
sqlDB, err := hsdb.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1029,7 +1055,7 @@ func (hsdb *HSDatabase) Close() error {
|
||||
}
|
||||
|
||||
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
|
||||
db.Exec("VACUUM")
|
||||
db.Exec("VACUUM") //nolint:errcheck,noctx
|
||||
}
|
||||
|
||||
return db.Close()
|
||||
@@ -1038,12 +1064,14 @@ func (hsdb *HSDatabase) Close() error {
|
||||
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
|
||||
rx := hsdb.DB.Begin()
|
||||
defer rx.Rollback()
|
||||
|
||||
return fn(rx)
|
||||
}
|
||||
|
||||
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
||||
rx := db.Begin()
|
||||
defer rx.Rollback()
|
||||
|
||||
ret, err := fn(rx)
|
||||
if err != nil {
|
||||
var no T
|
||||
@@ -1056,7 +1084,9 @@ func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
||||
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
||||
tx := hsdb.DB.Begin()
|
||||
defer tx.Rollback()
|
||||
if err := fn(tx); err != nil {
|
||||
|
||||
err := fn(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1066,6 +1096,7 @@ func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
||||
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
|
||||
tx := db.Begin()
|
||||
defer tx.Rollback()
|
||||
|
||||
ret, err := fn(tx)
|
||||
if err != nil {
|
||||
var no T
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
||||
|
||||
// Verify api_keys data preservation
|
||||
var apiKeyCount int
|
||||
|
||||
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
|
||||
@@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(string(schemaContent))
|
||||
_, err = db.ExecContext(context.Background(), string(schemaContent))
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||
func requireConstraintFailed(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
|
||||
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
|
||||
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
|
||||
}
|
||||
@@ -198,7 +201,7 @@ func TestConstraints(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "no-duplicate-username-if-no-oidc",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
_, err := CreateUser(db, types.User{Name: "user1"})
|
||||
require.NoError(t, err)
|
||||
_, err = CreateUser(db, types.User{Name: "user1"})
|
||||
@@ -207,7 +210,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no-oidc-duplicate-username-and-id",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
@@ -229,7 +232,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no-oidc-duplicate-id",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
@@ -251,7 +254,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "allow-duplicate-username-cli-then-oidc",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
_, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -266,7 +269,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "allow-duplicate-username-oidc-then-cli",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
user := types.User{
|
||||
Name: "user1",
|
||||
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
|
||||
@@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Construct the pg_restore command
|
||||
cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
|
||||
// Set the output streams
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
|
||||
// skip already-applied migrations and only run new ones.
|
||||
func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemas, err := os.ReadDir("testdata/sqlite")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Basic deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var deletionWg sync.WaitGroup
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
deletionWg sync.WaitGroup
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
deletionWg.Done()
|
||||
}
|
||||
@@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
go gc.Start()
|
||||
|
||||
// Schedule several nodes for deletion with short expiry
|
||||
const expiry = fifty
|
||||
const numNodes = 100
|
||||
const (
|
||||
expiry = fifty
|
||||
numNodes = 100
|
||||
)
|
||||
|
||||
// Set up wait group for expected deletions
|
||||
|
||||
deletionWg.Add(numNodes)
|
||||
|
||||
for i := 1; i <= numNodes; i++ {
|
||||
gc.Schedule(types.NodeID(i), expiry)
|
||||
gc.Schedule(types.NodeID(i), expiry) //nolint:gosec // safe conversion in test
|
||||
}
|
||||
|
||||
// Wait for all scheduled deletions to complete
|
||||
@@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
|
||||
// Schedule and immediately cancel to test that part of the code
|
||||
for i := numNodes + 1; i <= numNodes*2; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
nodeID := types.NodeID(i) //nolint:gosec // safe conversion in test
|
||||
gc.Schedule(nodeID, time.Hour)
|
||||
gc.Cancel(nodeID)
|
||||
}
|
||||
@@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
@@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
|
||||
// Start GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
const shortExpiry = fifty
|
||||
const longExpiry = 1 * time.Hour
|
||||
const (
|
||||
shortExpiry = fifty
|
||||
longExpiry = 1 * time.Hour
|
||||
)
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
@@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
// and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
}
|
||||
|
||||
// Start the GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
const expiry = fifty
|
||||
|
||||
// Schedule node for deletion
|
||||
@@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
||||
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
|
||||
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
@@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
nodeDeleted := make(chan struct{})
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
close(nodeDeleted) // Signal that deletion happened
|
||||
}
|
||||
@@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
||||
// Use a WaitGroup to ensure the GC has started
|
||||
var startWg sync.WaitGroup
|
||||
startWg.Add(1)
|
||||
|
||||
go func() {
|
||||
startWg.Done() // Signal that the goroutine has started
|
||||
gc.Start()
|
||||
}()
|
||||
|
||||
startWg.Wait() // Wait for the GC to start
|
||||
|
||||
// Close GC right away
|
||||
@@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
||||
|
||||
// Check no node was deleted
|
||||
deleteMutex.Lock()
|
||||
|
||||
nodesDeleted := len(deletedIDs)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
|
||||
|
||||
@@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
||||
go gc.Start()
|
||||
|
||||
// Number of concurrent scheduling goroutines
|
||||
const numSchedulers = 10
|
||||
const nodesPerScheduler = 50
|
||||
const (
|
||||
numSchedulers = 10
|
||||
nodesPerScheduler = 50
|
||||
)
|
||||
|
||||
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
|
||||
|
||||
@@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
||||
case <-stopScheduling:
|
||||
return
|
||||
default:
|
||||
nodeID := types.NodeID(baseNodeID + j + 1)
|
||||
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
||||
nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec // safe conversion in test
|
||||
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
||||
atomic.AddInt64(&scheduledCount, 1)
|
||||
|
||||
// Yield to other goroutines to introduce variability
|
||||
|
||||
@@ -17,7 +17,11 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
var errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
|
||||
var (
|
||||
errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
|
||||
errGeneratedIPNotInPrefix = errors.New("generated ip not in prefix")
|
||||
errIPAllocatorNil = errors.New("ip allocator was nil")
|
||||
)
|
||||
|
||||
// IPAllocator is a singleton responsible for allocating
|
||||
// IP addresses for nodes and making sure the same
|
||||
@@ -62,8 +66,10 @@ func NewIPAllocator(
|
||||
strategy: strategy,
|
||||
}
|
||||
|
||||
var v4s []sql.NullString
|
||||
var v6s []sql.NullString
|
||||
var (
|
||||
v4s []sql.NullString
|
||||
v6s []sql.NullString
|
||||
)
|
||||
|
||||
if db != nil {
|
||||
err := db.Read(func(rx *gorm.DB) error {
|
||||
@@ -135,15 +141,18 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var ret4 *netip.Addr
|
||||
var ret6 *netip.Addr
|
||||
var (
|
||||
err error
|
||||
ret4 *netip.Addr
|
||||
ret6 *netip.Addr
|
||||
)
|
||||
|
||||
if i.prefix4 != nil {
|
||||
ret4, err = i.next(i.prev4, i.prefix4)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
||||
}
|
||||
|
||||
i.prev4 = *ret4
|
||||
}
|
||||
|
||||
@@ -152,6 +161,7 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
||||
}
|
||||
|
||||
i.prev6 = *ret6
|
||||
}
|
||||
|
||||
@@ -168,8 +178,10 @@ func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.
|
||||
}
|
||||
|
||||
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
|
||||
var err error
|
||||
var ip netip.Addr
|
||||
var (
|
||||
err error
|
||||
ip netip.Addr
|
||||
)
|
||||
|
||||
switch i.strategy {
|
||||
case types.IPAllocationStrategySequential:
|
||||
@@ -243,7 +255,8 @@ func randomNext(pfx netip.Prefix) (netip.Addr, error) {
|
||||
|
||||
if !pfx.Contains(ip) {
|
||||
return netip.Addr{}, fmt.Errorf(
|
||||
"generated ip(%s) not in prefix(%s)",
|
||||
"%w: ip(%s) not in prefix(%s)",
|
||||
errGeneratedIPNotInPrefix,
|
||||
ip.String(),
|
||||
pfx.String(),
|
||||
)
|
||||
@@ -268,11 +281,14 @@ func isTailscaleReservedIP(ip netip.Addr) bool {
|
||||
// If a prefix type has been removed (IPv4 or IPv6), it
|
||||
// will remove the IPs in that family from the node.
|
||||
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
var err error
|
||||
var ret []string
|
||||
var (
|
||||
err error
|
||||
ret []string
|
||||
)
|
||||
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
if i == nil {
|
||||
return errors.New("backfilling IPs: ip allocator was nil")
|
||||
return fmt.Errorf("backfilling IPs: %w", errIPAllocatorNil)
|
||||
}
|
||||
|
||||
log.Trace().Caller().Msgf("starting to backfill IPs")
|
||||
@@ -295,6 +311,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
|
||||
node.IPv4 = ret4
|
||||
changed = true
|
||||
|
||||
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
|
||||
}
|
||||
|
||||
@@ -307,6 +324,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
|
||||
node.IPv6 = ret6
|
||||
changed = true
|
||||
|
||||
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
|
||||
}
|
||||
|
||||
|
||||
@@ -21,9 +21,7 @@ var mpp = func(pref string) *netip.Prefix {
|
||||
return &p
|
||||
}
|
||||
|
||||
var na = func(pref string) netip.Addr {
|
||||
return netip.MustParseAddr(pref)
|
||||
}
|
||||
var na = netip.MustParseAddr
|
||||
|
||||
var nap = func(pref string) *netip.Addr {
|
||||
n := na(pref)
|
||||
@@ -158,8 +156,10 @@ func TestIPAllocatorSequential(t *testing.T) {
|
||||
types.IPAllocationStrategySequential,
|
||||
)
|
||||
|
||||
var got4s []netip.Addr
|
||||
var got6s []netip.Addr
|
||||
var (
|
||||
got4s []netip.Addr
|
||||
got6s []netip.Addr
|
||||
)
|
||||
|
||||
for range tt.getCount {
|
||||
got4, got6, err := alloc.Next()
|
||||
@@ -175,6 +175,7 @@ func TestIPAllocatorSequential(t *testing.T) {
|
||||
got6s = append(got6s, *got6)
|
||||
}
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
|
||||
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -288,6 +289,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||
fullNodeP := func(i int) *types.Node {
|
||||
v4 := fmt.Sprintf("100.64.0.%d", i)
|
||||
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
||||
|
||||
return &types.Node{
|
||||
IPv4: nap(v4),
|
||||
IPv6: nap(v6),
|
||||
@@ -484,6 +486,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||
db, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer db.Close()
|
||||
|
||||
alloc, err := NewIPAllocator(
|
||||
|
||||
@@ -27,8 +27,14 @@ import (
|
||||
const (
|
||||
NodeGivenNameHashLength = 8
|
||||
NodeGivenNameTrimSize = 2
|
||||
|
||||
// defaultTestNodePrefix is the default hostname prefix for nodes created in tests.
|
||||
defaultTestNodePrefix = "testnode"
|
||||
)
|
||||
|
||||
// ErrNodeNameNotUnique is returned when a node name is not unique.
|
||||
var ErrNodeNameNotUnique = errors.New("node name is not unique")
|
||||
|
||||
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
|
||||
var (
|
||||
@@ -52,12 +58,14 @@ func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID)
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
|
||||
err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Where("id <> ?", nodeID).
|
||||
Where(peerIDs).Find(&nodes).Error; err != nil {
|
||||
Where(peerIDs).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return types.Nodes{}, err
|
||||
}
|
||||
|
||||
@@ -76,11 +84,13 @@ func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error)
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
|
||||
err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Where(nodeIDs).Find(&nodes).Error; err != nil {
|
||||
Where(nodeIDs).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -90,7 +100,9 @@ func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil {
|
||||
|
||||
err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -222,7 +234,7 @@ func SetTags(
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTags takes a Node struct pointer and update the forced tags.
|
||||
// SetApprovedRoutes takes a Node struct pointer and updates the approved routes.
|
||||
func SetApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
nodeID types.NodeID,
|
||||
@@ -254,7 +266,7 @@ func SetApprovedRoutes(
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("updating approved routes: %w", err)
|
||||
}
|
||||
|
||||
@@ -294,10 +306,10 @@ func RenameNode(tx *gorm.DB,
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return errors.New("name is not unique")
|
||||
return ErrNodeNameNotUnique
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("renaming node in database: %w", err)
|
||||
}
|
||||
|
||||
@@ -329,7 +341,8 @@ func DeleteNode(tx *gorm.DB,
|
||||
node *types.Node,
|
||||
) error {
|
||||
// Unscoped causes the node to be fully removed from the database.
|
||||
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
|
||||
err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -343,9 +356,11 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
||||
nodeID types.NodeID,
|
||||
) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil {
|
||||
err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -395,7 +410,8 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
||||
// so we store the node.Expire and node.Nodekey that has been set when
|
||||
// adding it to the registrationCache
|
||||
if node.IPv4 != nil || node.IPv6 != nil {
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
err := tx.Save(&node).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registering existing node in database: %w", err)
|
||||
}
|
||||
|
||||
@@ -431,7 +447,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
||||
node.GivenName = givenName
|
||||
}
|
||||
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
if err := tx.Save(&node).Error; err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("saving node to database: %w", err)
|
||||
}
|
||||
|
||||
@@ -656,7 +672,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
|
||||
panic("CreateNodeForTest requires a valid user")
|
||||
}
|
||||
|
||||
nodeName := "testnode"
|
||||
nodeName := defaultTestNodePrefix
|
||||
if len(hostname) > 0 && hostname[0] != "" {
|
||||
nodeName = hostname[0]
|
||||
}
|
||||
@@ -728,7 +744,7 @@ func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostname
|
||||
panic("CreateNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
prefix := defaultTestNodePrefix
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
@@ -751,7 +767,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
|
||||
panic("CreateRegisteredNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
prefix := defaultTestNodePrefix
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
|
||||
@@ -187,6 +187,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
suppliedName string
|
||||
randomSuffix bool
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -467,10 +468,10 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
users, err := adb.ListUsers()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodes, err := adb.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
pm, err := pmf(users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
@@ -498,6 +499,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
if len(expectedRoutes1) == 0 {
|
||||
expectedRoutes1 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -509,6 +511,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
if len(expectedRoutes2) == 0 {
|
||||
expectedRoutes2 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -520,6 +523,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
want := []types.NodeID{1, 3}
|
||||
got := []types.NodeID{}
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
deletionCount := make(chan struct{}, 10)
|
||||
@@ -527,6 +531,7 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
got = append(got, ni)
|
||||
|
||||
deletionCount <- struct{}{}
|
||||
@@ -576,8 +581,10 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||
var got []types.NodeID
|
||||
var mu sync.Mutex
|
||||
var (
|
||||
got []types.NodeID
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
want := 1000
|
||||
|
||||
@@ -589,6 +596,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||
|
||||
// Yield to other goroutines to introduce variability
|
||||
runtime.Gosched()
|
||||
|
||||
got = append(got, ni)
|
||||
|
||||
atomic.AddInt64(&deletedCount, 1)
|
||||
@@ -616,9 +624,12 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func generateRandomNumber(t *testing.T, max int64) int64 {
|
||||
//nolint:unused
|
||||
func generateRandomNumber(t *testing.T, maxVal int64) int64 {
|
||||
t.Helper()
|
||||
maxB := big.NewInt(max)
|
||||
|
||||
maxB := big.NewInt(maxVal)
|
||||
|
||||
n, err := rand.Int(rand.Reader, maxB)
|
||||
if err != nil {
|
||||
t.Fatalf("getting random number: %s", err)
|
||||
@@ -722,7 +733,7 @@ func TestNodeNaming(t *testing.T) {
|
||||
nodeInvalidHostname := types.Node{
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "我的电脑",
|
||||
Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||
UserID: &user2.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
}
|
||||
@@ -746,12 +757,15 @@ func TestNodeNaming(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
|
||||
|
||||
_, _ = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
|
||||
_, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil)
|
||||
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -810,25 +824,25 @@ func TestNodeNaming(t *testing.T) {
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "test")
|
||||
})
|
||||
assert.ErrorContains(t, err, "name is not unique")
|
||||
require.ErrorContains(t, err, "name is not unique")
|
||||
|
||||
// Rename invalid chars
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[2].ID, "我的电脑")
|
||||
return RenameNode(tx, nodes[2].ID, "我的电脑") //nolint:gosmopolitan // intentional i18n test data
|
||||
})
|
||||
assert.ErrorContains(t, err, "invalid characters")
|
||||
require.ErrorContains(t, err, "invalid characters")
|
||||
|
||||
// Rename too short
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[3].ID, "a")
|
||||
})
|
||||
assert.ErrorContains(t, err, "at least 2 characters")
|
||||
require.ErrorContains(t, err, "at least 2 characters")
|
||||
|
||||
// Rename with emoji
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
|
||||
})
|
||||
assert.ErrorContains(t, err, "invalid characters")
|
||||
require.ErrorContains(t, err, "invalid characters")
|
||||
|
||||
// Rename with only emoji
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
@@ -896,12 +910,12 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "chinese_chars_with_dash_rejected",
|
||||
newName: "server-北京-01",
|
||||
newName: "server-北京-01", //nolint:gosmopolitan // intentional i18n test data
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "chinese_only_rejected",
|
||||
newName: "我的电脑",
|
||||
newName: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
@@ -911,7 +925,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "mixed_chinese_emoji_rejected",
|
||||
newName: "测试💻机器",
|
||||
newName: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
@@ -1000,6 +1014,7 @@ func TestListPeers(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
@@ -1085,6 +1100,7 @@ func TestListNodes(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
|
||||
@@ -17,7 +17,8 @@ func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) {
|
||||
Data: policy,
|
||||
}
|
||||
|
||||
if err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error; err != nil {
|
||||
err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ func CreatePreAuthKey(
|
||||
Hash: hash, // Store hash
|
||||
}
|
||||
|
||||
if err := tx.Save(&key).Error; err != nil {
|
||||
if err := tx.Save(&key).Error; err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("creating key in database: %w", err)
|
||||
}
|
||||
|
||||
@@ -155,9 +155,7 @@ func CreatePreAuthKey(
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||
return ListPreAuthKeys(rx)
|
||||
})
|
||||
return Read(hsdb.DB, ListPreAuthKeys)
|
||||
}
|
||||
|
||||
// ListPreAuthKeys returns all PreAuthKeys in the database.
|
||||
@@ -329,10 +327,11 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
}
|
||||
|
||||
k.Used = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
// ExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
|
||||
|
||||
@@ -362,7 +362,8 @@ func (c *Config) Validate() error {
|
||||
// ToURL builds a properly encoded SQLite connection string using _pragma parameters
|
||||
// compatible with modernc.org/sqlite driver.
|
||||
func (c *Config) ToURL() (string, error) {
|
||||
if err := c.Validate(); err != nil {
|
||||
err := c.Validate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
@@ -372,18 +373,23 @@ func (c *Config) ToURL() (string, error) {
|
||||
if c.BusyTimeout > 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
|
||||
}
|
||||
|
||||
if c.JournalMode != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
|
||||
}
|
||||
|
||||
if c.AutoVacuum != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
|
||||
}
|
||||
|
||||
if c.WALAutocheckpoint >= 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
|
||||
}
|
||||
|
||||
if c.Synchronous != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
|
||||
}
|
||||
|
||||
if c.ForeignKeys {
|
||||
pragmas = append(pragmas, "foreign_keys=ON")
|
||||
}
|
||||
|
||||
@@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
|
||||
t.Errorf("Config.ToURL() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
|
||||
}
|
||||
@@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
|
||||
Path: "",
|
||||
BusyTimeout: -1,
|
||||
}
|
||||
|
||||
_, err := config.ToURL()
|
||||
if err == nil {
|
||||
t.Error("Config.ToURL() with invalid config should return error")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqliteconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -101,7 +102,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
||||
defer db.Close()
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
ctx := context.Background()
|
||||
|
||||
err = db.PingContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to ping database: %v", err)
|
||||
}
|
||||
|
||||
@@ -109,8 +113,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
||||
for pragma, expectedValue := range tt.expected {
|
||||
t.Run("pragma_"+pragma, func(t *testing.T) {
|
||||
var actualValue any
|
||||
|
||||
query := "PRAGMA " + pragma
|
||||
err := db.QueryRow(query).Scan(&actualValue)
|
||||
|
||||
err := db.QueryRowContext(ctx, query).Scan(&actualValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query %s: %v", query, err)
|
||||
}
|
||||
@@ -163,6 +169,8 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test tables with foreign key relationship
|
||||
schema := `
|
||||
CREATE TABLE parent (
|
||||
@@ -178,23 +186,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
||||
);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
_, err = db.ExecContext(ctx, schema)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create schema: %v", err)
|
||||
}
|
||||
|
||||
// Insert parent record
|
||||
if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert parent: %v", err)
|
||||
}
|
||||
|
||||
// Test 1: Valid foreign key should work
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
if err != nil {
|
||||
t.Fatalf("Valid foreign key insert failed: %v", err)
|
||||
}
|
||||
|
||||
// Test 2: Invalid foreign key should fail
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation, but insert succeeded")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
@@ -204,7 +214,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test 3: Deleting referenced parent should fail
|
||||
_, err = db.Exec("DELETE FROM parent WHERE id = 1")
|
||||
_, err = db.ExecContext(ctx, "DELETE FROM parent WHERE id = 1")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation when deleting referenced parent")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
@@ -249,7 +259,8 @@ func TestJournalModeValidation(t *testing.T) {
|
||||
defer db.Close()
|
||||
|
||||
var actualMode string
|
||||
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
|
||||
|
||||
err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query journal_mode: %v", err)
|
||||
}
|
||||
|
||||
@@ -53,16 +53,19 @@ func newPostgresDBForTest(t *testing.T) *url.URL {
|
||||
t.Helper()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
srv, err := postgrestest.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(srv.Cleanup)
|
||||
|
||||
u, err := srv.CreateDatabase(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("created local postgres: %s", u)
|
||||
pu, _ := url.Parse(u)
|
||||
|
||||
|
||||
@@ -3,12 +3,19 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnmarshalTextValue = errors.New("unmarshalling text value")
|
||||
errUnsupportedType = errors.New("unsupported type")
|
||||
errTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported")
|
||||
)
|
||||
|
||||
// Got from https://github.com/xdg-go/strum/blob/main/types.go
|
||||
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
|
||||
|
||||
@@ -42,22 +49,26 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
||||
|
||||
if dbValue != nil {
|
||||
var bytes []byte
|
||||
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("unmarshalling text value: %#v", dbValue)
|
||||
return fmt.Errorf("%w: %#v", errUnmarshalTextValue, dbValue)
|
||||
}
|
||||
|
||||
if isTextUnmarshaler(fieldValue) {
|
||||
maybeInstantiatePtr(fieldValue)
|
||||
f := fieldValue.MethodByName("UnmarshalText")
|
||||
args := []reflect.Value{reflect.ValueOf(bytes)}
|
||||
|
||||
ret := f.Call(args)
|
||||
if !ret[0].IsNil() {
|
||||
return decodingError(field.Name, ret[0].Interface().(error))
|
||||
if err, ok := ret[0].Interface().(error); ok {
|
||||
return decodingError(field.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If the underlying field is to a pointer type, we need to
|
||||
@@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
||||
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
|
||||
return fmt.Errorf("%w: %T", errUnsupportedType, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,8 +98,9 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
||||
// always comparable, particularly when reflection is involved:
|
||||
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
|
||||
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: nil value for GORM serializer
|
||||
}
|
||||
|
||||
b, err := v.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
||||
|
||||
return string(b), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
|
||||
return nil, fmt.Errorf("%w, got %T", errTextMarshalerOnly, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
ErrUserWhereInvalidCount = errors.New("expect 0 or 1 where User structs")
|
||||
ErrUserNotUnique = errors.New("expected exactly one user")
|
||||
)
|
||||
|
||||
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
||||
@@ -26,10 +28,13 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
||||
// CreateUser creates a new User. Returns error if could not be created
|
||||
// or another user already exists.
|
||||
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
|
||||
if err := util.ValidateHostname(user.Name); err != nil {
|
||||
err := util.ValidateHostname(user.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
|
||||
err = tx.Create(&user).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
|
||||
@@ -54,6 +59,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(nodes) > 0 {
|
||||
return ErrUserStillHasNodes
|
||||
}
|
||||
@@ -62,6 +68,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
err = DestroyPreAuthKey(tx, key.ID)
|
||||
if err != nil {
|
||||
@@ -88,11 +95,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
|
||||
// not exist or if another User exists with the new name.
|
||||
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||
var err error
|
||||
|
||||
oldUser, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = util.ValidateHostname(newName); err != nil {
|
||||
|
||||
if err = util.ValidateHostname(newName); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -151,7 +160,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
||||
// ListUsers gets all the existing users.
|
||||
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||
if len(where) > 1 {
|
||||
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
|
||||
return nil, fmt.Errorf("%w, got %d", ErrUserWhereInvalidCount, len(where))
|
||||
}
|
||||
|
||||
var user *types.User
|
||||
@@ -160,7 +169,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||
}
|
||||
|
||||
users := []types.User{}
|
||||
if err := tx.Where(user).Find(&users).Error; err != nil {
|
||||
|
||||
err := tx.Where(user).Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -180,7 +191,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
||||
}
|
||||
|
||||
if len(users) != 1 {
|
||||
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
||||
return nil, fmt.Errorf("%w, found %d", ErrUserNotUnique, len(users))
|
||||
}
|
||||
|
||||
return &users[0], nil
|
||||
|
||||
@@ -25,34 +25,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
overview := h.state.DebugOverviewJSON()
|
||||
|
||||
overviewJSON, err := json.MarshalIndent(overview, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(overviewJSON)
|
||||
_, _ = w.Write(overviewJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
overview := h.state.DebugOverview()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(overview))
|
||||
_, _ = w.Write([]byte(overview))
|
||||
}
|
||||
}))
|
||||
|
||||
// Configuration endpoint
|
||||
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
config := h.state.DebugConfig()
|
||||
|
||||
configJSON, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(configJSON)
|
||||
_, _ = w.Write(configJSON)
|
||||
}))
|
||||
|
||||
// Policy endpoint
|
||||
@@ -70,8 +75,9 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(policy))
|
||||
_, _ = w.Write([]byte(policy))
|
||||
}))
|
||||
|
||||
// Filter rules endpoint
|
||||
@@ -81,27 +87,31 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
filterJSON, err := json.MarshalIndent(filter, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(filterJSON)
|
||||
_, _ = w.Write(filterJSON)
|
||||
}))
|
||||
|
||||
// SSH policies endpoint
|
||||
debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sshPolicies := h.state.DebugSSHPolicies()
|
||||
|
||||
sshJSON, err := json.MarshalIndent(sshPolicies, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(sshJSON)
|
||||
_, _ = w.Write(sshJSON)
|
||||
}))
|
||||
|
||||
// DERP map endpoint
|
||||
@@ -112,20 +122,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
derpInfo := h.state.DebugDERPJSON()
|
||||
|
||||
derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(derpJSON)
|
||||
_, _ = w.Write(derpJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
derpInfo := h.state.DebugDERPMap()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(derpInfo))
|
||||
_, _ = w.Write([]byte(derpInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -137,34 +150,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
nodeStoreNodes := h.state.DebugNodeStoreJSON()
|
||||
|
||||
nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(nodeStoreJSON)
|
||||
_, _ = w.Write(nodeStoreJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
nodeStoreInfo := h.state.DebugNodeStore()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(nodeStoreInfo))
|
||||
_, _ = w.Write([]byte(nodeStoreInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
// Registration cache endpoint
|
||||
debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cacheInfo := h.state.DebugRegistrationCache()
|
||||
|
||||
cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(cacheJSON)
|
||||
_, _ = w.Write(cacheJSON)
|
||||
}))
|
||||
|
||||
// Routes endpoint
|
||||
@@ -175,20 +193,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
routes := h.state.DebugRoutes()
|
||||
|
||||
routesJSON, err := json.MarshalIndent(routes, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(routesJSON)
|
||||
_, _ = w.Write(routesJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
routes := h.state.DebugRoutesString()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(routes))
|
||||
_, _ = w.Write([]byte(routes))
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -200,20 +221,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if wantsJSON {
|
||||
policyManagerInfo := h.state.DebugPolicyManagerJSON()
|
||||
|
||||
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(policyManagerJSON)
|
||||
_, _ = w.Write(policyManagerJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
policyManagerInfo := h.state.DebugPolicyManager()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(policyManagerInfo))
|
||||
_, _ = w.Write([]byte(policyManagerInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -226,7 +250,8 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
if res == nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
|
||||
_, _ = w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -235,9 +260,10 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(resJSON)
|
||||
_, _ = w.Write(resJSON)
|
||||
}))
|
||||
|
||||
// Batcher endpoint
|
||||
@@ -257,14 +283,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(batcherJSON)
|
||||
_, _ = w.Write(batcherJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
batcherInfo := h.debugBatcher()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(batcherInfo))
|
||||
_, _ = w.Write([]byte(batcherInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string {
|
||||
activeConnections: info.ActiveConnections,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if info.Connected {
|
||||
connectedCount++
|
||||
}
|
||||
@@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string {
|
||||
activeConnections: 0,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if connected {
|
||||
connectedCount++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
info.TotalNodes++
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
@@ -28,11 +28,14 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
|
||||
return nil, err
|
||||
}
|
||||
defer derpFile.Close()
|
||||
|
||||
var derpMap tailcfg.DERPMap
|
||||
|
||||
b, err := io.ReadAll(derpFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(b, &derpMap)
|
||||
|
||||
return &derpMap, err
|
||||
@@ -57,12 +60,14 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var derpMap tailcfg.DERPMap
|
||||
|
||||
err = json.Unmarshal(body, &derpMap)
|
||||
|
||||
return &derpMap, err
|
||||
@@ -134,6 +139,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) {
|
||||
for id := range dm.Regions {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
|
||||
for _, id := range ids {
|
||||
@@ -160,16 +166,18 @@ func derpRandom() *rand.Rand {
|
||||
|
||||
derpRandomOnce.Do(func() {
|
||||
seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String())
|
||||
rnd := rand.New(rand.NewSource(0))
|
||||
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table)))
|
||||
rnd := rand.New(rand.NewSource(0)) //nolint:gosec // weak random is fine for DERP scrambling
|
||||
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) //nolint:gosec // safe conversion
|
||||
derpRandomInst = rnd
|
||||
})
|
||||
|
||||
return derpRandomInst
|
||||
}
|
||||
|
||||
func resetDerpRandomForTesting() {
|
||||
derpRandomMu.Lock()
|
||||
defer derpRandomMu.Unlock()
|
||||
|
||||
derpRandomOnce = sync.Once{}
|
||||
derpRandomInst = nil
|
||||
}
|
||||
|
||||
@@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
viper.Set("dns.base_domain", tt.baseDomain)
|
||||
|
||||
defer viper.Reset()
|
||||
|
||||
resetDerpRandomForTesting()
|
||||
|
||||
testMap := tt.derpMap.View().AsStruct()
|
||||
|
||||
@@ -75,9 +75,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||
if err != nil {
|
||||
return tailcfg.DERPRegion{}, err
|
||||
}
|
||||
var host string
|
||||
var port int
|
||||
var portStr string
|
||||
|
||||
var (
|
||||
host string
|
||||
port int
|
||||
portStr string
|
||||
)
|
||||
|
||||
// Extract hostname and port from URL
|
||||
host, portStr, err = net.SplitHostPort(serverURL.Host)
|
||||
@@ -98,12 +101,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||
|
||||
// If debug flag is set, resolve hostname to IP address
|
||||
if debugUseDERPIP {
|
||||
ips, err := net.LookupIP(host)
|
||||
ips, err := new(net.Resolver).LookupIPAddr(context.Background(), host)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msgf("failed to resolve DERP hostname %s to IP, using hostname", host)
|
||||
} else if len(ips) > 0 {
|
||||
// Use the first IP address
|
||||
ipStr := ips[0].String()
|
||||
ipStr := ips[0].IP.String()
|
||||
log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: resolved %s to %s", host, ipStr)
|
||||
host = ipStr
|
||||
}
|
||||
@@ -130,10 +133,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||
if err != nil {
|
||||
return tailcfg.DERPRegion{}, err
|
||||
}
|
||||
|
||||
portSTUN, err := strconv.Atoi(portSTUNStr)
|
||||
if err != nil {
|
||||
return tailcfg.DERPRegion{}, err
|
||||
}
|
||||
|
||||
localDERPregion.Nodes[0].STUNPort = portSTUN
|
||||
|
||||
log.Info().Caller().Msgf("derp region: %+v", localDERPregion)
|
||||
@@ -155,8 +160,10 @@ func (d *DERPServer) DERPHandler(
|
||||
Caller().
|
||||
Msg("No Upgrade header in DERP server request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain")
|
||||
writer.WriteHeader(http.StatusUpgradeRequired)
|
||||
|
||||
_, err := writer.Write([]byte("DERP requires connection upgrade"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -206,6 +213,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
|
||||
return
|
||||
}
|
||||
defer websocketConn.Close(websocket.StatusInternalError, "closing")
|
||||
|
||||
if websocketConn.Subprotocol() != "derp" {
|
||||
websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
|
||||
|
||||
@@ -225,6 +233,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
||||
log.Error().Caller().Msg("derp requires Hijacker interface from Gin")
|
||||
writer.Header().Set("Content-Type", "text/plain")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
_, err := writer.Write([]byte("HTTP does not support general TCP support"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -241,6 +250,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
|
||||
log.Error().Caller().Err(err).Msgf("hijack failed")
|
||||
writer.Header().Set("Content-Type", "text/plain")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
_, err = writer.Write([]byte("HTTP does not support general TCP support"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -281,6 +291,7 @@ func DERPProbeHandler(
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
writer.WriteHeader(http.StatusMethodNotAllowed)
|
||||
|
||||
_, err := writer.Write([]byte("bogus probe method"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -310,9 +321,11 @@ func DERPBootstrapDNSHandler(
|
||||
|
||||
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var resolver net.Resolver
|
||||
for _, region := range derpMap.Regions().All() {
|
||||
for _, node := range region.Nodes().All() { // we don't care if we override some nodes
|
||||
|
||||
for _, region := range derpMap.Regions().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator
|
||||
for _, node := range region.Nodes().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator
|
||||
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName())
|
||||
if err != nil {
|
||||
log.Trace().
|
||||
@@ -322,11 +335,14 @@ func DERPBootstrapDNSHandler(
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
dnsEntries[node.HostName()] = addrs
|
||||
}
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
err := json.NewEncoder(writer).Encode(dnsEntries)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -339,7 +355,7 @@ func DERPBootstrapDNSHandler(
|
||||
|
||||
// ServeSTUN starts a STUN server on the configured addr.
|
||||
func (d *DERPServer) ServeSTUN() {
|
||||
packetConn, err := net.ListenPacket("udp", d.cfg.STUNAddr)
|
||||
packetConn, err := new(net.ListenConfig).ListenPacket(context.Background(), "udp", d.cfg.STUNAddr)
|
||||
if err != nil {
|
||||
log.Fatal().Msgf("failed to open STUN listener: %v", err)
|
||||
}
|
||||
@@ -350,16 +366,18 @@ func (d *DERPServer) ServeSTUN() {
|
||||
if !ok {
|
||||
log.Fatal().Msg("stun listener is not a UDP listener")
|
||||
}
|
||||
|
||||
serverSTUNListener(context.Background(), udpConn)
|
||||
}
|
||||
|
||||
func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
|
||||
var buf [64 << 10]byte
|
||||
var (
|
||||
buf [64 << 10]byte
|
||||
bytesRead int
|
||||
udpAddr *net.UDPAddr
|
||||
err error
|
||||
)
|
||||
|
||||
for {
|
||||
bytesRead, udpAddr, err = packetConn.ReadFromUDP(buf[:])
|
||||
if err != nil {
|
||||
@@ -380,12 +398,14 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
|
||||
}
|
||||
|
||||
log.Trace().Caller().Msgf("stun request from %v", udpAddr)
|
||||
|
||||
pkt := buf[:bytesRead]
|
||||
if !stun.Is(pkt) {
|
||||
log.Trace().Caller().Msgf("udp packet is not stun")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
txid, err := stun.ParseBindingRequest(pkt)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("stun parse error")
|
||||
@@ -394,7 +414,8 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
|
||||
}
|
||||
|
||||
addr, _ := netip.AddrFromSlice(udpAddr.IP)
|
||||
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port)))
|
||||
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port))) //nolint:gosec // port is always <=65535
|
||||
|
||||
_, err = packetConn.WriteTo(res, udpAddr)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("issue writing to UDP")
|
||||
@@ -416,7 +437,9 @@ type DERPVerifyTransport struct {
|
||||
|
||||
func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
if err := t.handleVerifyRequest(req, buf); err != nil {
|
||||
|
||||
err := t.handleVerifyRequest(req, buf)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msg("failed to handle client verify request")
|
||||
|
||||
return nil, err
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -15,6 +16,9 @@ import (
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
// ErrPathIsDirectory is returned when a directory path is provided where a file is expected.
|
||||
var ErrPathIsDirectory = errors.New("path is a directory, only file is supported")
|
||||
|
||||
type ExtraRecordsMan struct {
|
||||
mu sync.RWMutex
|
||||
records set.Set[tailcfg.DNSRecord]
|
||||
@@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) {
|
||||
}
|
||||
|
||||
if fi.IsDir() {
|
||||
return nil, fmt.Errorf("path is a directory, only file is supported: %s", path)
|
||||
return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path)
|
||||
}
|
||||
|
||||
records, hash, err := readExtraRecordsFromPath(path)
|
||||
@@ -85,19 +89,22 @@ func (e *ExtraRecordsMan) Run() {
|
||||
log.Error().Caller().Msgf("file watcher event channel closing")
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Op {
|
||||
case fsnotify.Create, fsnotify.Write, fsnotify.Chmod:
|
||||
log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event")
|
||||
|
||||
if event.Name != e.path {
|
||||
continue
|
||||
}
|
||||
|
||||
e.updateRecords()
|
||||
|
||||
// If a file is removed or renamed, fsnotify will loose track of it
|
||||
// and not watch it. We will therefore attempt to re-add it with a backoff.
|
||||
case fsnotify.Remove, fsnotify.Rename:
|
||||
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
|
||||
if _, err := os.Stat(e.path); err != nil {
|
||||
if _, err := os.Stat(e.path); err != nil { //nolint:noinlineerr
|
||||
return struct{}{}, err
|
||||
}
|
||||
|
||||
@@ -123,6 +130,7 @@ func (e *ExtraRecordsMan) Run() {
|
||||
log.Error().Caller().Msgf("file watcher error channel closing")
|
||||
return
|
||||
}
|
||||
|
||||
log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err)
|
||||
}
|
||||
}
|
||||
@@ -165,6 +173,7 @@ func (e *ExtraRecordsMan) updateRecords() {
|
||||
e.hashes[e.path] = newHash
|
||||
|
||||
log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len())
|
||||
|
||||
e.updateCh <- e.records.Slice()
|
||||
}
|
||||
|
||||
@@ -183,6 +192,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error
|
||||
}
|
||||
|
||||
var records []tailcfg.DNSRecord
|
||||
|
||||
err = json.Unmarshal(b, &records)
|
||||
if err != nil {
|
||||
return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err)
|
||||
|
||||
@@ -17,6 +17,7 @@ func Test_validateTag(t *testing.T) {
|
||||
type args struct {
|
||||
tag string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -45,7 +46,8 @@ func Test_validateTag(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr {
|
||||
err := validateTag(tt.args.tag)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// The CapabilityVersion is used by Tailscale clients to indicate
|
||||
// NoiseCapabilityVersion is used by Tailscale clients to indicate
|
||||
// their codebase version. Tailscale clients can communicate over TS2021
|
||||
// from CapabilityVersion 28, but we only have good support for it
|
||||
// since https://github.com/tailscale/tailscale/pull/4323 (Noise in any HTTPS port).
|
||||
@@ -56,7 +56,7 @@ type HTTPError struct {
|
||||
func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) }
|
||||
func (e HTTPError) Unwrap() error { return e.Err }
|
||||
|
||||
// Error returns an HTTPError containing the given information.
|
||||
// NewHTTPError returns an HTTPError containing the given information.
|
||||
func NewHTTPError(code int, msg string, err error) HTTPError {
|
||||
return HTTPError{Code: code, Msg: msg, Err: err}
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func (h *Headscale) handleVerifyRequest(
|
||||
}
|
||||
|
||||
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
|
||||
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil {
|
||||
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { //nolint:noinlineerr
|
||||
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("parsing DERP client request: %w", err))
|
||||
}
|
||||
|
||||
@@ -155,7 +155,11 @@ func (h *Headscale) KeyHandler(
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(writer).Encode(resp)
|
||||
|
||||
err := json.NewEncoder(writer).Encode(resp)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to encode public key response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -180,8 +184,12 @@ func (h *Headscale) HealthHandler(
|
||||
res.Status = "fail"
|
||||
}
|
||||
|
||||
json.NewEncoder(writer).Encode(res)
|
||||
encErr := json.NewEncoder(writer).Encode(res)
|
||||
if encErr != nil {
|
||||
log.Error().Err(encErr).Msg("failed to encode health response")
|
||||
}
|
||||
}
|
||||
|
||||
err := h.state.PingDB(req.Context())
|
||||
if err != nil {
|
||||
respond(err)
|
||||
@@ -218,6 +226,7 @@ func (h *Headscale) VersionHandler(
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
versionInfo := types.GetVersionInfo()
|
||||
|
||||
err := json.NewEncoder(writer).Encode(versionInfo)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -244,7 +253,7 @@ func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
|
||||
registrationId.String())
|
||||
}
|
||||
|
||||
// RegisterWebAPI shows a simple message in the browser to point to the CLI
|
||||
// RegisterHandler shows a simple message in the browser to point to the CLI
|
||||
// Listens in /register/:registration_id.
|
||||
//
|
||||
// This is not part of the Tailscale control API, as we could send whatever URL
|
||||
@@ -267,7 +276,11 @@ func (a *AuthProviderWeb) RegisterHandler(
|
||||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
|
||||
|
||||
_, err = writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to write register response")
|
||||
}
|
||||
}
|
||||
|
||||
func FaviconHandler(writer http.ResponseWriter, req *http.Request) {
|
||||
|
||||
@@ -16,6 +16,14 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// Mapper errors.
|
||||
var (
|
||||
ErrInvalidNodeID = errors.New("invalid nodeID")
|
||||
ErrMapperNil = errors.New("mapper is nil")
|
||||
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
|
||||
ErrNodeNotFoundMapper = errors.New("node not found")
|
||||
)
|
||||
|
||||
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "headscale",
|
||||
Name: "mapresponse_generated_total",
|
||||
@@ -81,11 +89,11 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
||||
}
|
||||
|
||||
if nodeID == 0 {
|
||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
|
||||
}
|
||||
|
||||
if mapper == nil {
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
|
||||
}
|
||||
|
||||
// Handle self-only responses
|
||||
@@ -136,7 +144,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
|
||||
if nc == nil {
|
||||
return errors.New("nodeConnection is nil")
|
||||
return ErrNodeConnectionNil
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
|
||||
@@ -2,6 +2,7 @@ package mapper
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
@@ -18,7 +19,13 @@ import (
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var errConnectionClosed = errors.New("connection channel already closed")
|
||||
// LockFreeBatcher errors.
|
||||
var (
|
||||
errConnectionClosed = errors.New("connection channel already closed")
|
||||
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
|
||||
ErrBatcherShuttingDown = errors.New("batcher shutting down")
|
||||
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
|
||||
)
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
@@ -81,6 +88,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
if err != nil {
|
||||
nlog.Error().Err(err).Msg("initial map generation failed")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
|
||||
return fmt.Errorf("generating initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
@@ -90,11 +98,12 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
case c <- initialMap:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second): //nolint:mnd
|
||||
nlog.Error().Err(errors.New("timeout")).Msg("initial map send timeout") //nolint:err113
|
||||
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
|
||||
Msg("initial map send timed out because channel was blocked or receiver not ready")
|
||||
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
|
||||
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
|
||||
Msg("initial map send timed out because channel was blocked or receiver not ready")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
return fmt.Errorf("sending initial map to node %d: timeout", id)
|
||||
|
||||
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
|
||||
}
|
||||
|
||||
// Update connection status
|
||||
@@ -135,6 +144,7 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
nlog.Debug().Caller().
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("node connection removed but keeping online, other connections remain")
|
||||
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
@@ -219,10 +229,12 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
var err error
|
||||
|
||||
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
|
||||
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
@@ -235,7 +247,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
nc.updateSentPeers(result.mapResponse)
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
|
||||
|
||||
b.workErrors.Add(1)
|
||||
wlog.Error().Err(result.err).
|
||||
@@ -402,6 +414,7 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -454,6 +467,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
if nodeConn.hasActiveConnections() {
|
||||
ret.Store(id, true)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -469,6 +483,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret.Store(id, false)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -488,7 +503,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.done:
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,8 +538,9 @@ type multiChannelNodeConn struct {
|
||||
// generateConnectionID generates a unique connection identifier.
|
||||
func generateConnectionID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return fmt.Sprintf("%x", bytes)
|
||||
_, _ = rand.Read(bytes)
|
||||
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
@@ -557,7 +573,9 @@ func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||
|
||||
mc.mutex.Lock()
|
||||
|
||||
mutexWaitDur := time.Since(mutexWaitStart)
|
||||
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
mc.connections = append(mc.connections, entry)
|
||||
@@ -579,9 +597,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("successfully removed connection")
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -615,6 +635,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
// This is not an error - the node will receive a full map when it reconnects
|
||||
mc.log.Debug().Caller().
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
|
||||
return nil // Return success instead of error
|
||||
}
|
||||
|
||||
@@ -623,7 +644,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
Msg("send: broadcasting to all connections")
|
||||
|
||||
var lastErr error
|
||||
|
||||
successCount := 0
|
||||
|
||||
var failedConnections []int // Track failed connections for removal
|
||||
|
||||
// Send to all connections
|
||||
@@ -632,8 +655,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
Msg("send: attempting to send to connection")
|
||||
|
||||
if err := conn.send(data); err != nil {
|
||||
err := conn.send(data)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
|
||||
failedConnections = append(failedConnections, i)
|
||||
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
@@ -695,7 +720,7 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Connection is likely stale - client isn't reading from channel
|
||||
// This catches the case where Docker containers are killed but channels remain open
|
||||
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
|
||||
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -805,6 +830,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
Connected: connected,
|
||||
ActiveConnections: activeConnCount,
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -819,6 +845,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ type batcherTestCase struct {
|
||||
// that would normally be sent by poll.go in production.
|
||||
type testBatcherWrapper struct {
|
||||
Batcher
|
||||
|
||||
state *state.State
|
||||
}
|
||||
|
||||
@@ -80,12 +81,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
|
||||
}
|
||||
|
||||
// Finally remove from the real batcher
|
||||
removed := t.Batcher.RemoveNode(id, c)
|
||||
if !removed {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return t.Batcher.RemoveNode(id, c)
|
||||
}
|
||||
|
||||
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
||||
@@ -129,8 +125,6 @@ const (
|
||||
SMALL_BUFFER_SIZE = 3
|
||||
TINY_BUFFER_SIZE = 1 // For maximum contention
|
||||
LARGE_BUFFER_SIZE = 200
|
||||
|
||||
reservedResponseHeaderSize = 4
|
||||
)
|
||||
|
||||
// TestData contains all test entities created for a test scenario.
|
||||
@@ -241,8 +235,8 @@ func setupBatcherWithTestData(
|
||||
}
|
||||
|
||||
derpMap, err := derp.GetDERPMap(cfg.DERP)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, derpMap)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, derpMap)
|
||||
|
||||
state.SetDERPMap(derpMap)
|
||||
|
||||
@@ -319,6 +313,8 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) {
|
||||
}
|
||||
|
||||
// getStats returns a copy of the statistics for a node.
|
||||
//
|
||||
//nolint:unused
|
||||
func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
@@ -386,16 +382,14 @@ type UpdateInfo struct {
|
||||
}
|
||||
|
||||
// parseUpdateAndAnalyze parses an update and returns detailed information.
|
||||
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) {
|
||||
info := UpdateInfo{
|
||||
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo {
|
||||
return UpdateInfo{
|
||||
PeerCount: len(resp.Peers),
|
||||
PatchCount: len(resp.PeersChangedPatch),
|
||||
IsFull: len(resp.Peers) > 0,
|
||||
IsPatch: len(resp.PeersChangedPatch) > 0,
|
||||
IsDERP: resp.DERPMap != nil,
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// start begins consuming updates from the node's channel and tracking stats.
|
||||
@@ -417,7 +411,8 @@ func (n *node) start() {
|
||||
atomic.AddInt64(&n.updateCount, 1)
|
||||
|
||||
// Parse update and track detailed stats
|
||||
if info, err := parseUpdateAndAnalyze(data); err == nil {
|
||||
info := parseUpdateAndAnalyze(data)
|
||||
{
|
||||
// Track update types
|
||||
if info.IsFull {
|
||||
atomic.AddInt64(&n.fullCount, 1)
|
||||
@@ -548,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
||||
testNode.start()
|
||||
|
||||
// Connect the node to the batcher
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Wait for connection to be established
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
@@ -657,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Issue full update after each join to ensure connectivity
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
@@ -676,6 +671,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
connectedCount := 0
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
@@ -693,6 +689,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
|
||||
|
||||
t.Logf("✅ All nodes achieved full connectivity!")
|
||||
|
||||
totalTime := time.Since(startTime)
|
||||
|
||||
// Disconnect all nodes
|
||||
@@ -820,11 +817,11 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
tn := testData.Nodes[0]
|
||||
tn2 := testData.Nodes[1]
|
||||
tn := &testData.Nodes[0]
|
||||
tn2 := &testData.Nodes[1]
|
||||
|
||||
// Test AddNode with real node ID
|
||||
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
_ = batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
|
||||
if !batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be connected after AddNode")
|
||||
@@ -842,10 +839,10 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
// Drain any initial messages from first node
|
||||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
||||
drainChannelTimeout(tn.ch, 100*time.Millisecond)
|
||||
|
||||
// Add the second node and verify update message
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
_ = batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||
|
||||
// First node should get an update that second node has connected.
|
||||
@@ -911,18 +908,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
|
||||
count := 0
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, timeout time.Duration) {
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case data := <-ch:
|
||||
count++
|
||||
// Optional: add debug output if needed
|
||||
_ = data
|
||||
case <-ch:
|
||||
// Drain message
|
||||
case <-timer.C:
|
||||
return
|
||||
}
|
||||
@@ -1050,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
testNodes := testData.Nodes
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, 10)
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Track update content for validation
|
||||
var receivedUpdates []*tailcfg.MapResponse
|
||||
@@ -1131,6 +1124,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
// even when real node updates are being processed, ensuring no race conditions
|
||||
// occur during channel replacement with actual workload.
|
||||
func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
// Create test environment with real database and nodes
|
||||
@@ -1138,7 +1133,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
testNode := &testData.Nodes[0]
|
||||
|
||||
var (
|
||||
channelIssues int
|
||||
@@ -1154,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Go(func() {
|
||||
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
})
|
||||
|
||||
// Add real work during connection chaos
|
||||
@@ -1167,7 +1162,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
wg.Go(func() {
|
||||
runtime.Gosched() // Yield to introduce timing variability
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
|
||||
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
})
|
||||
|
||||
// Remove second connection
|
||||
@@ -1231,7 +1227,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
testNode := &testData.Nodes[0]
|
||||
|
||||
var (
|
||||
panics int
|
||||
@@ -1258,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 5)
|
||||
|
||||
// Add node and immediately queue real work
|
||||
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddWork(change.DERPMap())
|
||||
|
||||
// Consumer goroutine to validate data and detect channel issues
|
||||
@@ -1308,6 +1304,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
for range i % 3 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
batcher.RemoveNode(testNode.n.ID, ch)
|
||||
|
||||
// Yield to allow workers to process and close channels
|
||||
@@ -1350,6 +1347,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
// real node data. The test validates that stable clients continue to function
|
||||
// normally and receive proper updates despite the connection churn from other clients,
|
||||
// ensuring system stability under concurrent load.
|
||||
//
|
||||
//nolint:gocyclo // complex concurrent test scenario
|
||||
func TestBatcherConcurrentClients(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrent client test in short mode")
|
||||
@@ -1377,10 +1376,11 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable
|
||||
stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
|
||||
|
||||
for _, node := range stableNodes {
|
||||
for i := range stableNodes {
|
||||
node := &stableNodes[i]
|
||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||
stableChannels[node.n.ID] = ch
|
||||
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Monitor updates for each stable client
|
||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||
@@ -1391,6 +1391,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, reason := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
@@ -1427,7 +1428,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
|
||||
// Connection churn cycles - rapidly connect/disconnect to test concurrency safety
|
||||
for i := range numCycles {
|
||||
for _, node := range churningNodes {
|
||||
for j := range churningNodes {
|
||||
node := &churningNodes[j]
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
// Connect churning node
|
||||
@@ -1448,10 +1451,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
churningChannels[nodeID] = ch
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Consume updates to prevent blocking
|
||||
go func() {
|
||||
@@ -1462,6 +1467,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, _ := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
@@ -1494,6 +1500,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
for range i % 5 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
ch, exists := churningChannels[nodeID]
|
||||
@@ -1519,7 +1526,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
|
||||
if i%7 == 0 && len(allNodes) > 0 {
|
||||
// Node-specific changes using real nodes
|
||||
node := allNodes[i%len(allNodes)]
|
||||
node := &allNodes[i%len(allNodes)]
|
||||
// Use a valid expiry time for testing since test nodes don't have expiry set
|
||||
testExpiry := time.Now().Add(24 * time.Hour)
|
||||
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
|
||||
@@ -1567,7 +1574,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls",
|
||||
expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork)
|
||||
|
||||
for _, node := range stableNodes {
|
||||
for i := range stableNodes {
|
||||
node := &stableNodes[i]
|
||||
if stats, exists := allStats[node.n.ID]; exists {
|
||||
stableUpdateCount += stats.TotalUpdates
|
||||
t.Logf("Stable node %d: %d updates",
|
||||
@@ -1580,7 +1588,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, node := range churningNodes {
|
||||
for i := range churningNodes {
|
||||
node := &churningNodes[i]
|
||||
if stats, exists := allStats[node.n.ID]; exists {
|
||||
churningUpdateCount += stats.TotalUpdates
|
||||
}
|
||||
@@ -1605,7 +1614,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify all stable clients are still functional
|
||||
for _, node := range stableNodes {
|
||||
for i := range stableNodes {
|
||||
node := &stableNodes[i]
|
||||
if !batcher.IsConnected(node.n.ID) {
|
||||
t.Errorf("Stable node %d lost connection during racing", node.n.ID)
|
||||
}
|
||||
@@ -1623,6 +1633,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// It validates that the system remains stable with no deadlocks, panics, or
|
||||
// missed updates under sustained high load. The test uses real node data to
|
||||
// generate authentic update scenarios and tracks comprehensive statistics.
|
||||
//
|
||||
//nolint:gocyclo,thelper // complex scalability test scenario
|
||||
func XTestBatcherScalability(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping scalability test in short mode")
|
||||
@@ -1651,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
description string
|
||||
}
|
||||
|
||||
var testCases []testCase
|
||||
testCases := make([]testCase, 0, len(chaosTypes)*len(bufferSizes)*len(cycles)*len(nodes))
|
||||
|
||||
// Generate all combinations of the test matrix
|
||||
for _, nodeCount := range nodes {
|
||||
@@ -1762,7 +1774,8 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[node.n.ID] = true
|
||||
@@ -1824,7 +1837,8 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
}
|
||||
|
||||
// Connection/disconnection cycles for subset of nodes
|
||||
for i, node := range chaosNodes {
|
||||
for i := range chaosNodes {
|
||||
node := &chaosNodes[i]
|
||||
// Only add work if this is connection chaos or mixed
|
||||
if tc.chaosType == "connection" || tc.chaosType == "mixed" {
|
||||
wg.Add(2)
|
||||
@@ -1878,6 +1892,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
channel,
|
||||
tailcfg.CapabilityVersion(100),
|
||||
)
|
||||
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = true
|
||||
@@ -2138,8 +2153,9 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
t.Logf("Created %d nodes in database", len(allNodes))
|
||||
|
||||
// Connect nodes one at a time and wait for each to be connected
|
||||
for i, node := range allNodes {
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||
|
||||
// Wait for node to be connected
|
||||
@@ -2157,7 +2173,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
}, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect")
|
||||
|
||||
// Check how many peers each node should see
|
||||
for i, node := range allNodes {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
peers := testData.State.ListPeers(node.n.ID)
|
||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||
}
|
||||
@@ -2286,7 +2303,10 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
|
||||
// Phase 1: Connect all nodes initially
|
||||
t.Logf("Phase 1: Connecting all nodes...")
|
||||
for i, node := range allNodes {
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node %d: %v", i, err)
|
||||
@@ -2302,16 +2322,21 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
|
||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||
for i, node := range allNodes {
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||
}
|
||||
|
||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||
|
||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||
for i, node := range allNodes {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
||||
@@ -2334,7 +2359,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
disconnectedCount := 0
|
||||
|
||||
for i, node := range allNodes {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
if info, exists := debugInfo[node.n.ID]; exists {
|
||||
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
||||
|
||||
@@ -2342,11 +2368,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
}
|
||||
|
||||
@@ -2381,6 +2409,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
case update := <-newChannels[i]:
|
||||
if update != nil {
|
||||
receivedCount++
|
||||
|
||||
t.Logf("Node %d received update successfully", i)
|
||||
}
|
||||
case <-timeout:
|
||||
@@ -2399,6 +2428,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocyclo // complex multi-connection test scenario
|
||||
func TestBatcherMultiConnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
@@ -2406,13 +2436,14 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
node1 := testData.Nodes[0]
|
||||
node2 := testData.Nodes[1]
|
||||
node1 := &testData.Nodes[0]
|
||||
node2 := &testData.Nodes[1]
|
||||
|
||||
t.Logf("=== MULTI-CONNECTION TEST ===")
|
||||
|
||||
// Phase 1: Connect first node with initial connection
|
||||
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||
|
||||
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node1: %v", err)
|
||||
@@ -2432,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||
|
||||
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add second connection for node1: %v", err)
|
||||
@@ -2443,7 +2476,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
// Phase 3: Add third connection for node1
|
||||
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||
|
||||
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add third connection for node1: %v", err)
|
||||
@@ -2454,6 +2489,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
// Phase 4: Verify debug status shows correct connection count
|
||||
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
@@ -2461,6 +2497,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
t.Logf("Node1 debug info: %+v", info)
|
||||
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 3 {
|
||||
@@ -2469,6 +2506,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||
}
|
||||
}
|
||||
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
@@ -36,6 +35,7 @@ const (
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
|
||||
return &MapResponseBuilder{
|
||||
resp: &tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
@@ -69,7 +69,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -123,6 +123,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -130,7 +131,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -149,7 +150,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -162,7 +163,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -175,7 +176,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView])
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -229,7 +230,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
return nil, errors.New("node not found")
|
||||
return nil, ErrNodeNotFoundMapper
|
||||
}
|
||||
|
||||
// Get unreduced matchers for peer relationship determination.
|
||||
@@ -276,20 +277,22 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs.
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
tailscaleIDs := make([]tailcfg.NodeID, 0, len(removedIDs))
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
// Build finalizes the response and returns marshaled bytes.
|
||||
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
||||
}
|
||||
|
||||
@@ -339,8 +339,8 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||
|
||||
// Build should return a multierr
|
||||
data, err := result.Build()
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
require.Nil(t, data)
|
||||
require.Error(t, err)
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
const (
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
)
|
||||
|
||||
@@ -50,6 +49,7 @@ type mapper struct {
|
||||
created time.Time
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
type patch struct {
|
||||
timestamp time.Time
|
||||
change *tailcfg.PeerChange
|
||||
@@ -60,7 +60,6 @@ func newMapper(
|
||||
state *state.State,
|
||||
) *mapper {
|
||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
return &mapper{
|
||||
state: state,
|
||||
cfg: cfg,
|
||||
@@ -76,6 +75,7 @@ func generateUserProfiles(
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.UserView)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
|
||||
user := node.Owner()
|
||||
if !user.Valid() {
|
||||
log.Error().
|
||||
@@ -84,14 +84,17 @@ func generateUserProfiles(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
userID := user.Model().ID
|
||||
userMap[userID] = &user
|
||||
ids = append(ids, userID)
|
||||
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.Owner()
|
||||
if !peerUser.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
peerUserID := peerUser.Model().ID
|
||||
userMap[peerUserID] = &peerUser
|
||||
ids = append(ids, peerUserID)
|
||||
@@ -99,7 +102,9 @@ func generateUserProfiles(
|
||||
|
||||
slices.Sort(ids)
|
||||
ids = slices.Compact(ids)
|
||||
|
||||
var profiles []tailcfg.UserProfile
|
||||
|
||||
for _, id := range ids {
|
||||
if userMap[id] != nil {
|
||||
profiles = append(profiles, userMap[id].TailscaleUserProfile())
|
||||
@@ -149,6 +154,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
}
|
||||
|
||||
// fullMapResponse returns a MapResponse for the given node.
|
||||
//
|
||||
//nolint:unused
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
@@ -316,6 +323,7 @@ func writeDebugMapResponse(
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
||||
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -329,6 +337,7 @@ func writeDebugMapResponse(
|
||||
)
|
||||
|
||||
log.Trace().Msgf("writing MapResponse to %s", mapResponsePath)
|
||||
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -337,7 +346,7 @@ func writeDebugMapResponse(
|
||||
|
||||
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
if debugDumpMapResponsePath == "" {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: no data when debug path not set
|
||||
}
|
||||
|
||||
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
|
||||
@@ -350,6 +359,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
||||
}
|
||||
|
||||
result := make(map[types.NodeID][]tailcfg.MapResponse)
|
||||
|
||||
for _, node := range nodes {
|
||||
if !node.IsDir() {
|
||||
continue
|
||||
@@ -385,6 +395,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
||||
}
|
||||
|
||||
var resp tailcfg.MapResponse
|
||||
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("unmarshalling file %s", file.Name())
|
||||
|
||||
@@ -3,14 +3,10 @@ package mapper
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
@@ -81,90 +77,3 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockState is a mock implementation that provides the required methods.
|
||||
type mockState struct {
|
||||
polMan policy.PolicyManager
|
||||
derpMap *tailcfg.DERPMap
|
||||
primary *routes.PrimaryRoutes
|
||||
nodes types.Nodes
|
||||
peers types.Nodes
|
||||
}
|
||||
|
||||
func (m *mockState) DERPMap() *tailcfg.DERPMap {
|
||||
return m.derpMap
|
||||
}
|
||||
|
||||
func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
if m.polMan == nil {
|
||||
return tailcfg.FilterAllowAll, nil
|
||||
}
|
||||
return m.polMan.Filter()
|
||||
}
|
||||
|
||||
func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
if m.polMan == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.polMan.SSHPolicy(node)
|
||||
}
|
||||
|
||||
func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
|
||||
if m.polMan == nil {
|
||||
return false
|
||||
}
|
||||
return m.polMan.NodeCanHaveTag(node, tag)
|
||||
}
|
||||
|
||||
func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
|
||||
if m.primary == nil {
|
||||
return nil
|
||||
}
|
||||
return m.primary.PrimaryRoutes(nodeID)
|
||||
}
|
||||
|
||||
func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
if len(peerIDs) > 0 {
|
||||
// Filter peers by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
if slices.Contains(peerIDs, peer.ID) {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
// Return all peers except the node itself
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
if peer.ID != nodeID {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
if len(nodeIDs) > 0 {
|
||||
// Filter nodes by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, node := range m.nodes {
|
||||
if slices.Contains(nodeIDs, node.ID) {
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
return m.nodes, nil
|
||||
}
|
||||
|
||||
func Test_fullMapResponse(t *testing.T) {
|
||||
t.Skip("Test needs to be refactored for new state-based architecture")
|
||||
// TODO: Refactor this test to work with the new state-based mapper
|
||||
// The test architecture needs to be updated to work with the state interface
|
||||
// instead of the old direct dependency injection pattern
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
func TestTailNode(t *testing.T) {
|
||||
mustNK := func(str string) key.NodePublic {
|
||||
var k key.NodePublic
|
||||
|
||||
_ = k.UnmarshalText([]byte(str))
|
||||
|
||||
return k
|
||||
@@ -26,6 +27,7 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
mustDK := func(str string) key.DiscoPublic {
|
||||
var k key.DiscoPublic
|
||||
|
||||
_ = k.UnmarshalText([]byte(str))
|
||||
|
||||
return k
|
||||
@@ -33,6 +35,7 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
mustMK := func(str string) key.MachinePublic {
|
||||
var k key.MachinePublic
|
||||
|
||||
_ = k.UnmarshalText([]byte(str))
|
||||
|
||||
return k
|
||||
@@ -255,7 +258,7 @@ func TestNodeExpiry(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "localtime",
|
||||
exp: tp(time.Time{}.Local()),
|
||||
exp: tp(time.Time{}.Local()), //nolint:gosmopolitan
|
||||
wantTimeZero: true,
|
||||
},
|
||||
}
|
||||
@@ -284,7 +287,9 @@ func TestNodeExpiry(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("nodeExpiry() error = %v", err)
|
||||
}
|
||||
|
||||
var deseri tailcfg.Node
|
||||
|
||||
err = json.Unmarshal(seri, &deseri)
|
||||
if err != nil {
|
||||
t.Fatalf("nodeExpiry() error = %v", err)
|
||||
|
||||
@@ -71,6 +71,7 @@ func prometheusMiddleware(next http.Handler) http.Handler {
|
||||
rw := &respWriterProm{ResponseWriter: w}
|
||||
|
||||
timer := prometheus.NewTimer(httpDuration.WithLabelValues(path))
|
||||
|
||||
next.ServeHTTP(rw, r)
|
||||
timer.ObserveDuration()
|
||||
httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc()
|
||||
@@ -79,6 +80,7 @@ func prometheusMiddleware(next http.Handler) http.Handler {
|
||||
|
||||
type respWriterProm struct {
|
||||
http.ResponseWriter
|
||||
|
||||
status int
|
||||
written int64
|
||||
wroteHeader bool
|
||||
@@ -94,6 +96,7 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
|
||||
if !r.wroteHeader {
|
||||
r.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
n, err := r.ResponseWriter.Write(b)
|
||||
r.written += int64(n)
|
||||
|
||||
|
||||
@@ -19,6 +19,9 @@ import (
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
|
||||
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
|
||||
|
||||
const (
|
||||
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
|
||||
ts2021UpgradePath = "/ts2021"
|
||||
@@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
}
|
||||
|
||||
func unsupportedClientError(version tailcfg.CapabilityVersion) error {
|
||||
return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version)
|
||||
return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version)
|
||||
}
|
||||
|
||||
func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
|
||||
@@ -137,17 +140,20 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
|
||||
// an HTTP/2 settings frame, which isn't of type 'T')
|
||||
var notH2Frame [5]byte
|
||||
copy(notH2Frame[:], earlyPayloadMagic)
|
||||
|
||||
var lenBuf [4]byte
|
||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON)))
|
||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) //nolint:gosec // JSON length is bounded
|
||||
// These writes are all buffered by caller, so fine to do them
|
||||
// separately:
|
||||
if _, err := writer.Write(notH2Frame[:]); err != nil {
|
||||
if _, err := writer.Write(notH2Frame[:]); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
if _, err := writer.Write(lenBuf[:]); err != nil {
|
||||
|
||||
if _, err := writer.Write(lenBuf[:]); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
if _, err := writer.Write(earlyJSON); err != nil {
|
||||
|
||||
if _, err := writer.Write(earlyJSON); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -199,7 +205,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
|
||||
var mapRequest tailcfg.MapRequest
|
||||
if err := json.Unmarshal(body, &mapRequest); err != nil {
|
||||
if err := json.Unmarshal(body, &mapRequest); err != nil { //nolint:noinlineerr
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
@@ -219,6 +225,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||
|
||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
|
||||
sess.log.Trace().Caller().Msg("a node sending a MapRequest with Noise protocol")
|
||||
|
||||
if !sess.isStreaming() {
|
||||
sess.serve()
|
||||
} else {
|
||||
@@ -241,14 +248,16 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
return
|
||||
}
|
||||
|
||||
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) {
|
||||
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { //nolint:contextcheck
|
||||
var resp *tailcfg.RegisterResponse
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return &tailcfg.RegisterRequest{}, regErr(err)
|
||||
}
|
||||
|
||||
var regReq tailcfg.RegisterRequest
|
||||
if err := json.Unmarshal(body, ®Req); err != nil {
|
||||
if err := json.Unmarshal(body, ®Req); err != nil { //nolint:noinlineerr
|
||||
return ®Req, regErr(err)
|
||||
}
|
||||
|
||||
@@ -261,6 +270,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
resp = &tailcfg.RegisterResponse{
|
||||
Error: httpErr.Msg,
|
||||
}
|
||||
|
||||
return ®Req, resp
|
||||
}
|
||||
|
||||
@@ -278,7 +288,8 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
||||
err := json.NewEncoder(writer).Encode(registerResponse)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msg("noise registration handler: failed to encode RegisterResponse")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func NewAuthProviderOIDC(
|
||||
) (*AuthProviderOIDC, error) {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
|
||||
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) //nolint:contextcheck
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
|
||||
}
|
||||
@@ -163,6 +163,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
for k, v := range a.cfg.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
|
||||
extras = append(extras, oidc.Nonce(nonce))
|
||||
|
||||
// Cache the registration info
|
||||
@@ -190,6 +191,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
}
|
||||
|
||||
stateCookieName := getCookieName("state", state)
|
||||
|
||||
cookieState, err := req.Cookie(stateCookieName)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
|
||||
@@ -212,17 +214,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
if idToken.Nonce == "" {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
|
||||
return
|
||||
}
|
||||
|
||||
nonceCookieName := getCookieName("nonce", idToken.Nonce)
|
||||
|
||||
nonce, err := req.Cookie(nonceCookieName)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
|
||||
return
|
||||
}
|
||||
|
||||
if idToken.Nonce != nonce.Value {
|
||||
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
|
||||
return
|
||||
@@ -231,7 +236,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
|
||||
|
||||
var claims types.OIDCClaims
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
if err := idToken.Claims(&claims); err != nil { //nolint:noinlineerr
|
||||
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
|
||||
return
|
||||
}
|
||||
@@ -239,6 +244,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// Fetch user information (email, groups, name, etc) from the userinfo endpoint
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
var userinfo *oidc.UserInfo
|
||||
|
||||
userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token))
|
||||
if err != nil {
|
||||
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
||||
@@ -255,6 +261,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified)
|
||||
claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username)
|
||||
claims.Name = cmp.Or(userinfo2.Name, claims.Name)
|
||||
|
||||
claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL)
|
||||
if userinfo2.Groups != nil {
|
||||
claims.Groups = userinfo2.Groups
|
||||
@@ -279,6 +286,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
Msgf("could not create or update user")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
_, werr := writer.Write([]byte("Could not create or update user"))
|
||||
if werr != nil {
|
||||
log.Error().
|
||||
@@ -299,6 +307,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// Register the node if it does not exist.
|
||||
if registrationId != nil {
|
||||
verb := "Reauthenticated"
|
||||
|
||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||
@@ -307,7 +316,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
httpError(writer, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -316,15 +327,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
}
|
||||
|
||||
// TODO(kradalby): replace with go-elem
|
||||
content, err := renderOIDCCallbackTemplate(user, verb)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
content := renderOIDCCallbackTemplate(user, verb)
|
||||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
if _, err := writer.Write(content.Bytes()); err != nil {
|
||||
|
||||
if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr
|
||||
util.LogErr(err, "Failed to write HTTP response")
|
||||
}
|
||||
|
||||
@@ -370,6 +378,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
||||
if !ok {
|
||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
||||
}
|
||||
|
||||
if regInfo.Verifier != nil {
|
||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
||||
}
|
||||
@@ -394,6 +403,7 @@ func (a *AuthProviderOIDC) extractIDToken(
|
||||
}
|
||||
|
||||
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
|
||||
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("verifying ID token: %w", err))
|
||||
@@ -516,6 +526,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
newUser bool
|
||||
c change.Change
|
||||
)
|
||||
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
|
||||
@@ -589,9 +600,9 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
func renderOIDCCallbackTemplate(
|
||||
user *types.User,
|
||||
verb string,
|
||||
) (*bytes.Buffer, error) {
|
||||
) *bytes.Buffer {
|
||||
html := templates.OIDCCallback(user.Display(), verb).Render()
|
||||
return bytes.NewBufferString(html), nil
|
||||
return bytes.NewBufferString(html)
|
||||
}
|
||||
|
||||
// getCookieName generates a unique cookie name based on a cookie value.
|
||||
|
||||
@@ -19,7 +19,7 @@ func (h *Headscale) WindowsConfigMessage(
|
||||
) {
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
|
||||
_, _ = writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
|
||||
}
|
||||
|
||||
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
|
||||
@@ -29,7 +29,7 @@ func (h *Headscale) AppleConfigMessage(
|
||||
) {
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
|
||||
_, _ = writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
|
||||
}
|
||||
|
||||
func (h *Headscale) ApplePlatformConfig(
|
||||
@@ -37,6 +37,7 @@ func (h *Headscale) ApplePlatformConfig(
|
||||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
|
||||
platform, ok := vars["platform"]
|
||||
if !ok {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil))
|
||||
@@ -64,17 +65,20 @@ func (h *Headscale) ApplePlatformConfig(
|
||||
|
||||
switch platform {
|
||||
case "macos-standalone":
|
||||
if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil {
|
||||
err := macosStandaloneTemplate.Execute(&payload, platformConfig)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
case "macos-app-store":
|
||||
if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil {
|
||||
err := macosAppStoreTemplate.Execute(&payload, platformConfig)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
case "ios":
|
||||
if err := iosTemplate.Execute(&payload, platformConfig); err != nil {
|
||||
err := iosTemplate.Execute(&payload, platformConfig)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
@@ -90,7 +94,7 @@ func (h *Headscale) ApplePlatformConfig(
|
||||
}
|
||||
|
||||
var content bytes.Buffer
|
||||
if err := commonTemplate.Execute(&content, config); err != nil {
|
||||
if err := commonTemplate.Execute(&content, config); err != nil { //nolint:noinlineerr
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
@@ -98,7 +102,7 @@ func (h *Headscale) ApplePlatformConfig(
|
||||
writer.Header().
|
||||
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write(content.Bytes())
|
||||
_, _ = writer.Write(content.Bytes())
|
||||
}
|
||||
|
||||
type AppleMobileConfig struct {
|
||||
|
||||
@@ -16,15 +16,18 @@ type Match struct {
|
||||
dests *netipx.IPSet
|
||||
}
|
||||
|
||||
func (m Match) DebugString() string {
|
||||
func (m *Match) DebugString() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("Match:\n")
|
||||
sb.WriteString(" Sources:\n")
|
||||
|
||||
for _, prefix := range m.srcs.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
|
||||
sb.WriteString(" Destinations:\n")
|
||||
|
||||
for _, prefix := range m.dests.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
@@ -42,7 +45,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
|
||||
}
|
||||
|
||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
||||
dests := []string{}
|
||||
dests := make([]string, 0, len(rule.DstPorts))
|
||||
for _, dest := range rule.DstPorts {
|
||||
dests = append(dests, dest.IP)
|
||||
}
|
||||
@@ -98,7 +101,7 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool {
|
||||
// cased for exit nodes.
|
||||
// This checks if dests is a superset of TheInternet(), which handles
|
||||
// merged filter rules where TheInternet is combined with other destinations.
|
||||
func (m Match) DestsIsTheInternet() bool {
|
||||
func (m *Match) DestsIsTheInternet() bool {
|
||||
if m.dests.ContainsPrefix(tsaddr.AllIPv4()) ||
|
||||
m.dests.ContainsPrefix(tsaddr.AllIPv6()) {
|
||||
return true
|
||||
|
||||
@@ -19,18 +19,18 @@ type PolicyManager interface {
|
||||
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
|
||||
// BuildPeerMap constructs peer relationship maps for the given nodes
|
||||
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
|
||||
SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||
SetPolicy(pol []byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
|
||||
// NodeCanHaveTag reports whether the given node can have the given tag.
|
||||
NodeCanHaveTag(types.NodeView, string) bool
|
||||
NodeCanHaveTag(node types.NodeView, tag string) bool
|
||||
|
||||
// TagExists reports whether the given tag is defined in the policy.
|
||||
TagExists(tag string) bool
|
||||
|
||||
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||
NodeCanApproveRoute(types.NodeView, netip.Prefix) bool
|
||||
NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool
|
||||
|
||||
Version() int
|
||||
DebugString() string
|
||||
@@ -38,8 +38,11 @@ type PolicyManager interface {
|
||||
|
||||
// NewPolicyManager returns a new policy manager.
|
||||
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
var polMan PolicyManager
|
||||
var err error
|
||||
var (
|
||||
polMan PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
polMans = append(polMans, pm)
|
||||
}
|
||||
|
||||
@@ -66,7 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
|
||||
}
|
||||
|
||||
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error)
|
||||
polmanFuncs := make([]func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error), 0, 1)
|
||||
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
return policyv2.NewPolicyManager(pol, u, n)
|
||||
|
||||
@@ -126,6 +126,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
||||
if !slices.Equal(sortedCurrent, newApproved) {
|
||||
// Log what changed
|
||||
var added, kept []netip.Prefix
|
||||
|
||||
for _, route := range newApproved {
|
||||
if !slices.Contains(sortedCurrent, route) {
|
||||
added = append(added, route)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/key"
|
||||
@@ -76,7 +77,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
||||
}`
|
||||
|
||||
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -313,11 +314,14 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
||||
nodes := types.Nodes{&node}
|
||||
|
||||
// Create policy manager or use nil if specified
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
if tt.name != "nil_policy_manager" {
|
||||
pm, err = pmf(users, nodes.ViewSlice())
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
pm = nil
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ func TestReduceNodes(t *testing.T) {
|
||||
rules []tailcfg.FilterRule
|
||||
node *types.Node
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -783,9 +784,11 @@ func TestReduceNodes(t *testing.T) {
|
||||
for _, v := range gotViews.All() {
|
||||
got = append(got, v.AsStruct())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff)
|
||||
t.Log("Matchers: ")
|
||||
|
||||
for _, m := range matchers {
|
||||
t.Log("\t+", m.DebugString())
|
||||
}
|
||||
@@ -796,7 +799,7 @@ func TestReduceNodes(t *testing.T) {
|
||||
|
||||
func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node {
|
||||
var routes []netip.Prefix
|
||||
routes := make([]netip.Prefix, 0, len(routess))
|
||||
for _, route := range routess {
|
||||
routes = append(routes, netip.MustParsePrefix(route))
|
||||
}
|
||||
@@ -1034,8 +1037,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(nil, tt.nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1053,9 +1059,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
for _, v := range gotViews.All() {
|
||||
got = append(got, v.AsStruct())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff)
|
||||
t.Log("Matchers: ")
|
||||
|
||||
for _, m := range matchers {
|
||||
t.Log("\t+", m.DebugString())
|
||||
}
|
||||
@@ -1233,7 +1241,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
]
|
||||
}`,
|
||||
expectErr: true,
|
||||
errorMessage: `invalid SSH action "invalid", must be one of: accept, check`,
|
||||
errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`,
|
||||
},
|
||||
{
|
||||
name: "invalid-check-period",
|
||||
@@ -1280,7 +1288,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
]
|
||||
}`,
|
||||
expectErr: true,
|
||||
errorMessage: "autogroup \"autogroup:invalid\" is not supported",
|
||||
errorMessage: "autogroup not supported for SSH user",
|
||||
},
|
||||
{
|
||||
name: "autogroup-nonroot-should-use-wildcard-with-root-excluded",
|
||||
@@ -1453,13 +1461,17 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMessage)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1482,6 +1494,7 @@ func TestReduceRoutes(t *testing.T) {
|
||||
routes []netip.Prefix
|
||||
rules []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -2103,6 +2116,7 @@ func TestReduceRoutes(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
||||
|
||||
got := ReduceRoutes(
|
||||
tt.args.node.View(),
|
||||
tt.args.routes,
|
||||
|
||||
@@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
||||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
|
||||
@@ -798,10 +798,14 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm policy.PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm policy.PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
got, _ := pm.Filter()
|
||||
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
|
||||
got = policyutil.ReduceFilterRules(tt.node.View(), got)
|
||||
|
||||
@@ -830,6 +830,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
if tt.name == "empty policy" {
|
||||
// We expect this one to have a valid but empty policy
|
||||
require.NoError(t, err)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -844,6 +845,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
if diff := cmp.Diff(tt.canApprove, result); diff != "" {
|
||||
t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.canApprove, result, "Unexpected route approval result")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,7 +17,10 @@ import (
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
var ErrInvalidAction = errors.New("invalid action")
|
||||
var (
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
errSelfInSources = errors.New("autogroup:self cannot be used in sources")
|
||||
)
|
||||
|
||||
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
@@ -45,9 +48,10 @@ func (pol *Policy) compileFilterRules(
|
||||
continue
|
||||
}
|
||||
|
||||
protocols, _ := acl.Protocol.parseProtocol()
|
||||
protocols := acl.Protocol.parseProtocol()
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
|
||||
for _, dest := range acl.Destinations {
|
||||
// Check if destination is a wildcard - use "*" directly instead of expanding
|
||||
if _, isWildcard := dest.Alias.(Asterix); isWildcard {
|
||||
@@ -142,14 +146,18 @@ func (pol *Policy) compileFilterRulesForNode(
|
||||
// It returns a slice of filter rules because when an ACL has both autogroup:self
|
||||
// and other destinations, they need to be split into separate rules with different
|
||||
// source filtering logic.
|
||||
//
|
||||
//nolint:gocyclo // complex ACL compilation logic
|
||||
func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
acl ACL,
|
||||
users types.Users,
|
||||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) ([]*tailcfg.FilterRule, error) {
|
||||
var autogroupSelfDests []AliasWithPorts
|
||||
var otherDests []AliasWithPorts
|
||||
var (
|
||||
autogroupSelfDests []AliasWithPorts
|
||||
otherDests []AliasWithPorts
|
||||
)
|
||||
|
||||
for _, dest := range acl.Destinations {
|
||||
if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
@@ -159,14 +167,15 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
}
|
||||
}
|
||||
|
||||
protocols, _ := acl.Protocol.parseProtocol()
|
||||
protocols := acl.Protocol.parseProtocol()
|
||||
|
||||
var rules []*tailcfg.FilterRule
|
||||
|
||||
var resolvedSrcIPs []*netipx.IPSet
|
||||
|
||||
for _, src := range acl.Sources {
|
||||
if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
return nil, fmt.Errorf("autogroup:self cannot be used in sources")
|
||||
return nil, errSelfInSources
|
||||
}
|
||||
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
@@ -188,6 +197,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
|
||||
// Pre-filter to same-user untagged devices once - reuse for both sources and destinations
|
||||
sameUserNodes := make([]types.NodeView, 0)
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !n.IsTagged() && n.User().ID() == node.User().ID() {
|
||||
sameUserNodes = append(sameUserNodes, n)
|
||||
@@ -197,6 +207,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
if len(sameUserNodes) > 0 {
|
||||
// Filter sources to only same-user untagged devices
|
||||
var srcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, ips := range resolvedSrcIPs {
|
||||
for _, n := range sameUserNodes {
|
||||
// Check if any of this node's IPs are in the source set
|
||||
@@ -213,6 +224,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
|
||||
if srcSet != nil && len(srcSet.Prefixes()) > 0 {
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
|
||||
for _, dest := range autogroupSelfDests {
|
||||
for _, n := range sameUserNodes {
|
||||
for _, port := range dest.Ports {
|
||||
@@ -318,13 +330,14 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocyclo // complex SSH policy compilation logic
|
||||
func (pol *Policy) compileSSHPolicy(
|
||||
users types.Users,
|
||||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: no SSH policy when none configured
|
||||
}
|
||||
|
||||
log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
@@ -335,8 +348,10 @@ func (pol *Policy) compileSSHPolicy(
|
||||
// Separate destinations into autogroup:self and others
|
||||
// This is needed because autogroup:self requires filtering sources to same-user only,
|
||||
// while other destinations should use all resolved sources
|
||||
var autogroupSelfDests []Alias
|
||||
var otherDests []Alias
|
||||
var (
|
||||
autogroupSelfDests []Alias
|
||||
otherDests []Alias
|
||||
)
|
||||
|
||||
for _, dst := range rule.Destinations {
|
||||
if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
@@ -359,6 +374,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
}
|
||||
|
||||
var action tailcfg.SSHAction
|
||||
|
||||
switch rule.Action {
|
||||
case SSHActionAccept:
|
||||
action = sshAction(true, 0)
|
||||
@@ -374,9 +390,11 @@ func (pol *Policy) compileSSHPolicy(
|
||||
// by default, we do not allow root unless explicitly stated
|
||||
userMap["root"] = ""
|
||||
}
|
||||
|
||||
if rule.Users.ContainsRoot() {
|
||||
userMap["root"] = "root"
|
||||
}
|
||||
|
||||
for _, u := range rule.Users.NormalUsers() {
|
||||
userMap[u.String()] = u.String()
|
||||
}
|
||||
@@ -386,6 +404,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
|
||||
// Build destination set for autogroup:self (same-user untagged devices only)
|
||||
var dest netipx.IPSetBuilder
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !n.IsTagged() && n.User().ID() == node.User().ID() {
|
||||
n.AppendToIPSet(&dest)
|
||||
@@ -402,6 +421,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
// Filter sources to only same-user untagged devices
|
||||
// Pre-filter to same-user untagged devices for efficiency
|
||||
sameUserNodes := make([]types.NodeView, 0)
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if !n.IsTagged() && n.User().ID() == node.User().ID() {
|
||||
sameUserNodes = append(sameUserNodes, n)
|
||||
@@ -409,6 +429,7 @@ func (pol *Policy) compileSSHPolicy(
|
||||
}
|
||||
|
||||
var filteredSrcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, n := range sameUserNodes {
|
||||
// Check if any of this node's IPs are in the source set
|
||||
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
|
||||
@@ -444,11 +465,13 @@ func (pol *Policy) compileSSHPolicy(
|
||||
if len(otherDests) > 0 {
|
||||
// Build destination set for other destinations
|
||||
var dest netipx.IPSetBuilder
|
||||
|
||||
for _, dst := range otherDests {
|
||||
ips, err := dst.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
|
||||
if ips != nil {
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
|
||||
@@ -623,7 +623,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
if sshPolicy == nil {
|
||||
return // Expected empty result
|
||||
}
|
||||
|
||||
assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -709,7 +711,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestSSHIntegrationReproduction reproduces the exact scenario from the integration test
|
||||
// TestSSHOneUserToAll that was failing with empty sshUsers
|
||||
// TestSSHOneUserToAll that was failing with empty sshUsers.
|
||||
func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
// Create users matching the integration test
|
||||
users := types.Users{
|
||||
@@ -775,7 +777,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestSSHJSONSerialization verifies that the SSH policy can be properly serialized
|
||||
// to JSON and that the sshUsers field is not empty
|
||||
// to JSON and that the sshUsers field is not empty.
|
||||
func TestSSHJSONSerialization(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Name: "user1", Model: gorm.Model{ID: 1}},
|
||||
@@ -815,6 +817,7 @@ func TestSSHJSONSerialization(t *testing.T) {
|
||||
|
||||
// Parse back to verify structure
|
||||
var parsed tailcfg.SSHPolicy
|
||||
|
||||
err = json.Unmarshal(jsonData, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -899,6 +902,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(rules) != 1 {
|
||||
t.Fatalf("expected 1 rule, got %d", len(rules))
|
||||
}
|
||||
@@ -915,6 +919,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
found := false
|
||||
|
||||
addr := netip.MustParseAddr(expectedIP)
|
||||
|
||||
for _, prefix := range rule.SrcIPs {
|
||||
pref := netip.MustParsePrefix(prefix)
|
||||
if pref.Contains(addr) {
|
||||
@@ -932,6 +937,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
||||
excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"}
|
||||
for _, excludedIP := range excludedSourceIPs {
|
||||
addr := netip.MustParseAddr(excludedIP)
|
||||
|
||||
for _, prefix := range rule.SrcIPs {
|
||||
pref := netip.MustParsePrefix(prefix)
|
||||
if pref.Contains(addr) {
|
||||
@@ -1144,7 +1150,8 @@ func TestAutogroupTagged(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify autogroup:tagged includes all tagged nodes
|
||||
taggedIPs, err := AutoGroupTagged.Resolve(policy, users, nodes.ViewSlice())
|
||||
ag := AutoGroupTagged
|
||||
taggedIPs, err := ag.Resolve(policy, users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, taggedIPs)
|
||||
|
||||
@@ -1366,14 +1373,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
|
||||
assert.Empty(t, rules3, "user3 should have no rules")
|
||||
}
|
||||
|
||||
// Helper function to create IP addresses for testing
|
||||
// Helper function to create IP addresses for testing.
|
||||
func createAddr(ip string) *netip.Addr {
|
||||
addr, _ := netip.ParseAddr(ip)
|
||||
return &addr
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly
|
||||
// with autogroup:self in destinations
|
||||
// with autogroup:self in destinations.
|
||||
func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
@@ -1421,6 +1428,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// Test for user2's first node
|
||||
@@ -1439,12 +1447,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
for i, p := range rule2.Principals {
|
||||
principalIPs2[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2)
|
||||
|
||||
// Test for tagged node (should have no SSH rules)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy3 != nil {
|
||||
assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self")
|
||||
}
|
||||
@@ -1452,7 +1462,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
|
||||
// TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user
|
||||
// is in the source and autogroup:self in destination, only that user's devices
|
||||
// can SSH (and only if they match the target user)
|
||||
// can SSH (and only if they match the target user).
|
||||
func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
@@ -1494,18 +1504,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
||||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// For user2's node: should have no rules (user1's devices can't match user2's self)
|
||||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations
|
||||
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations.
|
||||
func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
@@ -1552,19 +1564,21 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
||||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// For user3's node: should have no rules (not in group:admins)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices
|
||||
// are excluded from both sources and destinations when autogroup:self is used
|
||||
// are excluded from both sources and destinations when autogroup:self is used.
|
||||
func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
@@ -1609,6 +1623,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs,
|
||||
"should only include untagged devices")
|
||||
|
||||
@@ -1616,6 +1631,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self")
|
||||
}
|
||||
@@ -1664,10 +1680,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
||||
// Verify autogroup:self rule has filtered sources (only same-user devices)
|
||||
selfRule := sshPolicy1.Rules[0]
|
||||
require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices")
|
||||
|
||||
selfPrincipals := make([]string, len(selfRule.Principals))
|
||||
for i, p := range selfRule.Principals {
|
||||
selfPrincipals[i] = p.NodeIP
|
||||
}
|
||||
|
||||
require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals,
|
||||
"autogroup:self rule should only include same-user untagged devices")
|
||||
|
||||
@@ -1679,10 +1697,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
||||
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")
|
||||
|
||||
routerRule := sshPolicyRouter.Rules[0]
|
||||
|
||||
routerPrincipals := make([]string, len(routerRule.Principals))
|
||||
for i, p := range routerRule.Principals {
|
||||
routerPrincipals[i] = p.NodeIP
|
||||
}
|
||||
|
||||
require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)")
|
||||
require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)")
|
||||
require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)")
|
||||
|
||||
@@ -111,6 +111,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Filter: filter,
|
||||
Policy: pm.pol,
|
||||
})
|
||||
|
||||
filterChanged := filterHash != pm.filterHash
|
||||
if filterChanged {
|
||||
log.Debug().
|
||||
@@ -120,7 +121,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Int("filter.rules.new", len(filter)).
|
||||
Msg("Policy filter hash changed")
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
|
||||
pm.filterHash = filterHash
|
||||
if filterChanged {
|
||||
pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
|
||||
@@ -135,6 +138,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
}
|
||||
|
||||
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||
|
||||
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||
if tagOwnerChanged {
|
||||
log.Debug().
|
||||
@@ -144,6 +148,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Int("tagOwners.new", len(tagMap)).
|
||||
Msg("Tag owner hash changed")
|
||||
}
|
||||
|
||||
pm.tagOwnerMap = tagMap
|
||||
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||
|
||||
@@ -153,6 +158,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
}
|
||||
|
||||
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||
|
||||
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||
if autoApproveChanged {
|
||||
log.Debug().
|
||||
@@ -162,10 +168,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Int("autoApprovers.new", len(autoMap)).
|
||||
Msg("Auto-approvers hash changed")
|
||||
}
|
||||
|
||||
pm.autoApproveMap = autoMap
|
||||
pm.autoApproveMapHash = autoApproveMapHash
|
||||
|
||||
exitSetHash := deephash.Hash(&exitSet)
|
||||
|
||||
exitSetChanged := exitSetHash != pm.exitSetHash
|
||||
if exitSetChanged {
|
||||
log.Debug().
|
||||
@@ -173,6 +181,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
Str("exitSet.hash.new", exitSetHash.String()[:8]).
|
||||
Msg("Exit node set hash changed")
|
||||
}
|
||||
|
||||
pm.exitSet = exitSet
|
||||
pm.exitSetHash = exitSetHash
|
||||
|
||||
@@ -199,6 +208,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
if !needsUpdate {
|
||||
log.Trace().
|
||||
Msg("Policy evaluation detected no changes - all hashes match")
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -224,6 +234,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||
}
|
||||
|
||||
pm.sshPolicyMap[node.ID()] = sshPol
|
||||
|
||||
return sshPol, nil
|
||||
@@ -403,6 +414,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil
|
||||
reducedFilter := policyutil.ReduceFilterRules(node, pm.filter)
|
||||
|
||||
pm.filterRulesMap[node.ID()] = reducedFilter
|
||||
|
||||
return reducedFilter, nil
|
||||
}
|
||||
|
||||
@@ -447,7 +459,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul
|
||||
// This is different from FilterForNode which returns REDUCED rules for packet filtering.
|
||||
//
|
||||
// For global policies: returns the global matchers (same for all nodes)
|
||||
// For autogroup:self: returns node-specific matchers from unreduced compiled rules
|
||||
// For autogroup:self: returns node-specific matchers from unreduced compiled rules.
|
||||
func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
|
||||
if pm == nil {
|
||||
return nil, nil
|
||||
@@ -479,6 +491,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
|
||||
// Clear SSH policy map when users change to force SSH policy recomputation
|
||||
@@ -690,6 +703,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
||||
if pm.exitSet == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) {
|
||||
return true
|
||||
}
|
||||
@@ -753,8 +767,10 @@ func (pm *PolicyManager) DebugString() string {
|
||||
}
|
||||
|
||||
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
|
||||
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
|
||||
for _, iprange := range approveAddrs.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
@@ -763,14 +779,17 @@ func (pm *PolicyManager) DebugString() string {
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
|
||||
|
||||
for prefix, tagOwners := range pm.tagOwnerMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
|
||||
for _, iprange := range tagOwners.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if pm.filter != nil {
|
||||
filter, err := json.MarshalIndent(pm.filter, "", " ")
|
||||
if err == nil {
|
||||
@@ -783,6 +802,7 @@ func (pm *PolicyManager) DebugString() string {
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString("Matchers:\n")
|
||||
sb.WriteString("an internal structure used to filter nodes and routes\n")
|
||||
|
||||
for _, match := range pm.matchers {
|
||||
sb.WriteString(match.DebugString())
|
||||
sb.WriteString("\n")
|
||||
@@ -790,6 +810,7 @@ func (pm *PolicyManager) DebugString() string {
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString("Nodes:\n")
|
||||
|
||||
for _, node := range pm.nodes.All() {
|
||||
sb.WriteString(node.String())
|
||||
sb.WriteString("\n")
|
||||
@@ -867,6 +888,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
||||
|
||||
// Check if IPs changed (simple check - could be more sophisticated)
|
||||
oldIPs := oldNode.IPs()
|
||||
|
||||
newIPs := newNode.IPs()
|
||||
if len(oldIPs) != len(newIPs) {
|
||||
affectedUsers[newNode.User().ID()] = struct{}{}
|
||||
@@ -888,6 +910,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
||||
for nodeID := range pm.filterRulesMap {
|
||||
// Find the user for this cached node
|
||||
var nodeUserID uint
|
||||
|
||||
found := false
|
||||
|
||||
// Check in new nodes first
|
||||
@@ -899,8 +922,10 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
nodeUserID = node.User().ID()
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -913,8 +938,10 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
||||
found = true
|
||||
break
|
||||
}
|
||||
|
||||
nodeUserID = node.User().ID()
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
|
||||
func node(name, ipv4, ipv6 string, user types.User) *types.Node {
|
||||
return &types.Node{
|
||||
ID: 0,
|
||||
Hostname: name,
|
||||
@@ -22,7 +22,6 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo)
|
||||
IPv6: ap(ipv6),
|
||||
User: ptr.To(user),
|
||||
UserID: ptr.To(user.ID),
|
||||
Hostinfo: hostinfo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +56,7 @@ func TestPolicyManager(t *testing.T) {
|
||||
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
|
||||
t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(
|
||||
tt.wantMatchers,
|
||||
matchers,
|
||||
@@ -77,6 +77,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{Model: gorm.Model{ID: 3}, Name: "user3", Email: "user3@headscale.net"},
|
||||
}
|
||||
|
||||
//nolint:goconst // test-specific inline policy for clarity
|
||||
policy := `{
|
||||
"acls": [
|
||||
{
|
||||
@@ -88,14 +89,14 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
}`
|
||||
|
||||
initialNodes := types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
}
|
||||
|
||||
for i, n := range initialNodes {
|
||||
n.ID = types.NodeID(i + 1)
|
||||
n.ID = types.NodeID(i + 1) //nolint:gosec // safe conversion in test
|
||||
}
|
||||
|
||||
pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice())
|
||||
@@ -107,7 +108,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, len(initialNodes), len(pm.filterRulesMap))
|
||||
require.Len(t, pm.filterRulesMap, len(initialNodes))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -118,10 +119,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "no_changes",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 0,
|
||||
description: "No changes should clear no cache entries",
|
||||
@@ -129,11 +130,11 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "node_added",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil),
|
||||
node("user1-node3", "100.64.0.5", "fd7a:115c:a1e0::5", users[0], nil), // New node
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]),
|
||||
node("user1-node3", "100.64.0.5", "fd7a:115c:a1e0::5", users[0]), // New node
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 2, // user1's existing nodes should be cleared
|
||||
description: "Adding a node should clear cache for that user's existing nodes",
|
||||
@@ -141,10 +142,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "node_removed",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
// user1-node2 removed
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 2, // user1's remaining node + removed node should be cleared
|
||||
description: "Removing a node should clear cache for that user's remaining nodes",
|
||||
@@ -152,10 +153,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "user_changed",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[2], nil), // Changed to user3
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]),
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[2]), // Changed to user3
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 3, // user1's node + user2's node + user3's nodes should be cleared
|
||||
description: "Changing a node's user should clear cache for both old and new users",
|
||||
@@ -163,10 +164,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
{
|
||||
name: "ip_changed",
|
||||
newNodes: types.Nodes{
|
||||
node("user1-node1", "100.64.0.10", "fd7a:115c:a1e0::10", users[0], nil), // IP changed
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil),
|
||||
node("user1-node1", "100.64.0.10", "fd7a:115c:a1e0::10", users[0]), // IP changed
|
||||
node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]),
|
||||
node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]),
|
||||
node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]),
|
||||
},
|
||||
expectedCleared: 2, // user1's nodes should be cleared
|
||||
description: "Changing a node's IP should clear cache for that user's nodes",
|
||||
@@ -177,15 +178,18 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for i, n := range tt.newNodes {
|
||||
found := false
|
||||
|
||||
for _, origNode := range initialNodes {
|
||||
if n.Hostname == origNode.Hostname {
|
||||
n.ID = origNode.ID
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
n.ID = types.NodeID(len(initialNodes) + i + 1)
|
||||
n.ID = types.NodeID(len(initialNodes) + i + 1) //nolint:gosec // safe conversion in test
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,16 +374,16 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) {
|
||||
|
||||
// TestAutogroupSelfReducedVsUnreducedRules verifies that:
|
||||
// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships
|
||||
// 2. FilterForNode returns reduced compiled rules for packet filters
|
||||
// 2. FilterForNode returns reduced compiled rules for packet filters.
|
||||
func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
||||
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}
|
||||
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}
|
||||
users := types.Users{user1, user2}
|
||||
|
||||
// Create two nodes
|
||||
node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1, nil)
|
||||
node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1)
|
||||
node1.ID = 1
|
||||
node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2, nil)
|
||||
node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2)
|
||||
node2.ID = 2
|
||||
nodes := types.Nodes{node1, node2}
|
||||
|
||||
@@ -410,6 +414,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
||||
// FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations
|
||||
// For node1, destinations should only be node1's IPs
|
||||
node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"}
|
||||
|
||||
for _, rule := range filterNode1 {
|
||||
for _, dst := range rule.DstPorts {
|
||||
require.Contains(t, node1IPs, dst.IP,
|
||||
@@ -419,6 +424,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
||||
|
||||
// For node2, destinations should only be node2's IPs
|
||||
node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"}
|
||||
|
||||
for _, rule := range filterNode2 {
|
||||
for _, dst := range rule.DstPorts {
|
||||
require.Contains(t, node2IPs, dst.IP,
|
||||
|
||||
@@ -9655,7 +9655,7 @@ func TestTailscaleCompatErrorCases(t *testing.T) {
|
||||
{"action": "accept", "src": ["tag:nonexistent"], "dst": ["tag:server:22"]}
|
||||
]
|
||||
}`,
|
||||
wantErr: `Tag "tag:nonexistent" is not defined in the Policy`,
|
||||
wantErr: `tag not defined in policy: "tag:nonexistent"`,
|
||||
reference: "Test 6.4: tag:nonexistent → tag:server:22",
|
||||
},
|
||||
|
||||
@@ -9674,7 +9674,7 @@ func TestTailscaleCompatErrorCases(t *testing.T) {
|
||||
{"action": "accept", "src": ["autogroup:self"], "dst": ["tag:server:22"]}
|
||||
]
|
||||
}`,
|
||||
wantErr: `"autogroup:self" used in source, it can only be used in ACL destinations`,
|
||||
wantErr: `autogroup:self can only be used in ACL destinations`,
|
||||
reference: "Test 13.41: autogroup:self as SOURCE",
|
||||
},
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -82,6 +82,7 @@ func TestMarshalJSON(t *testing.T) {
|
||||
|
||||
// Unmarshal back to verify round trip
|
||||
var roundTripped Policy
|
||||
|
||||
err = json.Unmarshal(marshalled, &roundTripped)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -366,7 +367,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: "alias v2.Asterix is not supported for SSH source",
|
||||
wantErr: "alias not supported for SSH source: v2.Asterix",
|
||||
},
|
||||
{
|
||||
name: "invalid-username",
|
||||
@@ -393,7 +394,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `group must start with "group:", got: "grou:example"`,
|
||||
wantErr: `group must start with 'group:', got: "grou:example"`,
|
||||
},
|
||||
{
|
||||
name: "group-in-group",
|
||||
@@ -408,7 +409,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
}
|
||||
`,
|
||||
// wantErr: `username must contain @, got: "group:inner"`,
|
||||
wantErr: `nested groups are not allowed, found "group:inner" inside "group:example"`,
|
||||
wantErr: `nested groups are not allowed: found "group:inner" inside "group:example"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-addr",
|
||||
@@ -419,7 +420,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `hostname "derp" contains an invalid IP address: "10.0"`,
|
||||
wantErr: `hostname contains invalid IP address: hostname "derp" address "10.0"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-prefix",
|
||||
@@ -430,7 +431,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `hostname "derp" contains an invalid IP address: "10.0/42"`,
|
||||
wantErr: `hostname contains invalid IP address: hostname "derp" address "10.0/42"`,
|
||||
},
|
||||
// TODO(kradalby): Figure out why this doesn't work.
|
||||
// {
|
||||
@@ -459,7 +460,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
],
|
||||
}
|
||||
`,
|
||||
wantErr: `autogroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`,
|
||||
wantErr: `invalid autogroup: got "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`,
|
||||
},
|
||||
{
|
||||
name: "undefined-hostname-errors-2490",
|
||||
@@ -478,7 +479,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `host "user1" is not defined in the policy, please define or remove the reference to it`,
|
||||
wantErr: `host not defined in policy: "user1"`,
|
||||
},
|
||||
{
|
||||
name: "defined-hostname-does-not-err-2490",
|
||||
@@ -571,7 +572,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `"autogroup:internet" used in source, it can only be used in ACL destinations`,
|
||||
wantErr: `autogroup:internet can only be used in ACL destinations`,
|
||||
},
|
||||
{
|
||||
name: "autogroup:internet-in-ssh-src-not-allowed",
|
||||
@@ -590,7 +591,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `"autogroup:internet" used in SSH source, it can only be used in ACL destinations`,
|
||||
wantErr: `tag not defined in policy: "tag:test"`,
|
||||
},
|
||||
{
|
||||
name: "autogroup:internet-in-ssh-dst-not-allowed",
|
||||
@@ -609,7 +610,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`,
|
||||
wantErr: `autogroup:internet can only be used in ACL destinations`,
|
||||
},
|
||||
{
|
||||
name: "ssh-basic",
|
||||
@@ -762,7 +763,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-dst",
|
||||
@@ -781,7 +782,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-ssh-src",
|
||||
@@ -800,7 +801,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `user destination requires source to contain only that same user "user@"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-tagOwner",
|
||||
@@ -811,7 +812,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-autoapprover-route",
|
||||
@@ -824,7 +825,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "group-must-be-defined-acl-autoapprover-exitnode",
|
||||
@@ -835,7 +836,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`,
|
||||
wantErr: `group not defined in policy: "group:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-src",
|
||||
@@ -854,7 +855,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-dst",
|
||||
@@ -873,7 +874,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-ssh-src",
|
||||
@@ -892,7 +893,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-ssh-dst",
|
||||
@@ -914,7 +915,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-autoapprover-route",
|
||||
@@ -927,7 +928,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "tag-must-be-defined-acl-autoapprover-exitnode",
|
||||
@@ -938,7 +939,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
`,
|
||||
wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`,
|
||||
wantErr: `tag not defined in policy: "tag:notdefined"`,
|
||||
},
|
||||
{
|
||||
name: "missing-dst-port-is-err",
|
||||
@@ -957,7 +958,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `hostport must contain a colon (":")`,
|
||||
wantErr: `hostport must contain a colon`,
|
||||
},
|
||||
{
|
||||
name: "dst-port-zero-is-err",
|
||||
@@ -987,7 +988,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "rules"`,
|
||||
wantErr: `unknown field: "rules"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-nested",
|
||||
@@ -1010,7 +1011,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `group must start with "group:", got: "INVALID_GROUP_FIELD"`,
|
||||
wantErr: `group must start with 'group:', got: "INVALID_GROUP_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-group-datatype",
|
||||
@@ -1022,7 +1023,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `group "group:invalid" value must be an array of users, got string: "should fail"`,
|
||||
wantErr: `group value must be an array of users: group "group:invalid" got string: "should fail"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-group-name-and-datatype-fails-on-name-first",
|
||||
@@ -1034,7 +1035,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `group must start with "group:", got: "INVALID_GROUP_FIELD"`,
|
||||
wantErr: `group must start with 'group:', got: "INVALID_GROUP_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-hosts-level",
|
||||
@@ -1046,7 +1047,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `hostname "INVALID_HOST_FIELD" contains an invalid IP address: "should fail"`,
|
||||
wantErr: `hostname contains invalid IP address: hostname "INVALID_HOST_FIELD" address "should fail"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-tagowners-level",
|
||||
@@ -1058,7 +1059,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `tag has to start with "tag:", got: "INVALID_TAG_FIELD"`,
|
||||
wantErr: `tag must start with 'tag:', got: "INVALID_TAG_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-acls-level",
|
||||
@@ -1075,7 +1076,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "INVALID_ACL_FIELD"`,
|
||||
wantErr: `unknown field: "INVALID_ACL_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-ssh-level",
|
||||
@@ -1092,7 +1093,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "INVALID_SSH_FIELD"`,
|
||||
wantErr: `unknown field: "INVALID_SSH_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-policy-level",
|
||||
@@ -1109,7 +1110,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
"INVALID_POLICY_FIELD": "should fail at policy level"
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "INVALID_POLICY_FIELD"`,
|
||||
wantErr: `unknown field: "INVALID_POLICY_FIELD"`,
|
||||
},
|
||||
{
|
||||
name: "disallow-unsupported-fields-autoapprovers-level",
|
||||
@@ -1124,7 +1125,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "INVALID_AUTO_APPROVER_FIELD"`,
|
||||
wantErr: `unknown field: "INVALID_AUTO_APPROVER_FIELD"`,
|
||||
},
|
||||
// headscale-admin uses # in some field names to add metadata, so we will ignore
|
||||
// those to ensure it doesnt break.
|
||||
@@ -1183,7 +1184,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `unknown field "proto"`,
|
||||
wantErr: `unknown field: "proto"`,
|
||||
},
|
||||
{
|
||||
name: "protocol-wildcard-not-allowed",
|
||||
@@ -1279,7 +1280,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `leading 0 not permitted in protocol number "0"`,
|
||||
wantErr: `leading 0 not permitted in protocol number: "0"`,
|
||||
},
|
||||
{
|
||||
name: "protocol-empty-applies-to-tcp-udp-only",
|
||||
@@ -1326,7 +1327,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `protocol "icmp" does not support specific ports; only "*" is allowed`,
|
||||
wantErr: `protocol does not support specific ports: "icmp", only "*" is allowed`,
|
||||
},
|
||||
{
|
||||
name: "protocol-icmp-with-wildcard-port-allowed",
|
||||
@@ -1374,7 +1375,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`,
|
||||
wantErr: `protocol "gre" does not support specific ports; only "*" is allowed`,
|
||||
wantErr: `protocol does not support specific ports: "gre", only "*" is allowed`,
|
||||
},
|
||||
{
|
||||
name: "protocol-tcp-with-specific-port-allowed",
|
||||
@@ -2081,7 +2082,7 @@ func TestResolvePolicy(t *testing.T) {
|
||||
IPv4: ap("100.100.101.103"),
|
||||
},
|
||||
},
|
||||
wantErr: `user with token "invaliduser@" not found`,
|
||||
wantErr: `user not found: token "invaliduser@"`,
|
||||
},
|
||||
{
|
||||
name: "invalid-tag",
|
||||
@@ -2105,7 +2106,7 @@ func TestResolvePolicy(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "autogroup-member-comprehensive",
|
||||
toResolve: ptr.To(AutoGroup(AutoGroupMember)),
|
||||
toResolve: ptr.To(AutoGroupMember),
|
||||
nodes: types.Nodes{
|
||||
// Node with no tags (should be included - is a member)
|
||||
{
|
||||
@@ -2155,7 +2156,7 @@ func TestResolvePolicy(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "autogroup-tagged",
|
||||
toResolve: ptr.To(AutoGroup(AutoGroupTagged)),
|
||||
toResolve: ptr.To(AutoGroupTagged),
|
||||
nodes: types.Nodes{
|
||||
// Node with no tags (should be excluded - not tagged)
|
||||
{
|
||||
@@ -2266,6 +2267,7 @@ func TestResolvePolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
var prefs []netip.Prefix
|
||||
|
||||
if ips != nil {
|
||||
if p := ips.Prefixes(); len(p) > 0 {
|
||||
prefs = p
|
||||
@@ -2437,9 +2439,11 @@ func TestResolveAutoApprovers(t *testing.T) {
|
||||
t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmps...); diff != "" {
|
||||
t.Errorf("resolveAutoApprovers() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if tt.wantAllIPRoutes != nil {
|
||||
if gotAllIPRoutes == nil {
|
||||
t.Error("resolveAutoApprovers() expected non-nil allIPRoutes, got nil")
|
||||
@@ -2586,6 +2590,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
|
||||
for _, p := range prefixes {
|
||||
builder.AddPrefix(mp(p))
|
||||
}
|
||||
|
||||
ipSet, _ := builder.IPSet()
|
||||
|
||||
return ipSet
|
||||
@@ -2595,6 +2600,7 @@ func ipSetComparer(x, y *netipx.IPSet) bool {
|
||||
if x == nil || y == nil {
|
||||
return x == y
|
||||
}
|
||||
|
||||
return cmp.Equal(x.Prefixes(), y.Prefixes(), util.Comparers...)
|
||||
}
|
||||
|
||||
@@ -2823,6 +2829,7 @@ func TestResolveTagOwners(t *testing.T) {
|
||||
t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmps...); diff != "" {
|
||||
t.Errorf("resolveTagOwners() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -3098,6 +3105,7 @@ func TestNodeCanHaveTag(t *testing.T) {
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
got := pm.NodeCanHaveTag(tt.node.View(), tt.tag)
|
||||
@@ -3358,6 +3366,7 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var acl ACL
|
||||
|
||||
err := json.Unmarshal([]byte(tt.input), &acl)
|
||||
|
||||
if tt.wantErr {
|
||||
@@ -3368,8 +3377,8 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected.Action, acl.Action)
|
||||
assert.Equal(t, tt.expected.Protocol, acl.Protocol)
|
||||
assert.Equal(t, len(tt.expected.Sources), len(acl.Sources))
|
||||
assert.Equal(t, len(tt.expected.Destinations), len(acl.Destinations))
|
||||
assert.Len(t, acl.Sources, len(tt.expected.Sources))
|
||||
assert.Len(t, acl.Destinations, len(tt.expected.Destinations))
|
||||
|
||||
// Compare sources
|
||||
for i, expectedSrc := range tt.expected.Sources {
|
||||
@@ -3409,14 +3418,15 @@ func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) {
|
||||
|
||||
// Unmarshal back
|
||||
var unmarshaled ACL
|
||||
|
||||
err = json.Unmarshal(jsonBytes, &unmarshaled)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be equal
|
||||
assert.Equal(t, original.Action, unmarshaled.Action)
|
||||
assert.Equal(t, original.Protocol, unmarshaled.Protocol)
|
||||
assert.Equal(t, len(original.Sources), len(unmarshaled.Sources))
|
||||
assert.Equal(t, len(original.Destinations), len(unmarshaled.Destinations))
|
||||
assert.Len(t, unmarshaled.Sources, len(original.Sources))
|
||||
assert.Len(t, unmarshaled.Destinations, len(original.Destinations))
|
||||
}
|
||||
|
||||
func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) {
|
||||
@@ -3484,15 +3494,16 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) {
|
||||
|
||||
_, err := unmarshalPolicy([]byte(policyJSON))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `invalid action "deny"`)
|
||||
assert.Contains(t, err.Error(), `invalid ACL action: "deny"`)
|
||||
}
|
||||
|
||||
// Helper function to parse aliases for testing
|
||||
// Helper function to parse aliases for testing.
|
||||
func mustParseAlias(s string) Alias {
|
||||
alias, err := parseAlias(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return alias
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,18 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// Port parsing errors.
|
||||
var (
|
||||
ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port")
|
||||
ErrInputStartsWithColon = errors.New("input cannot start with a colon character")
|
||||
ErrInputEndsWithColon = errors.New("input cannot end with a colon character")
|
||||
ErrInvalidPortRangeFormat = errors.New("invalid port range format")
|
||||
ErrPortRangeInverted = errors.New("invalid port range: first port is greater than last port")
|
||||
ErrPortMustBePositive = errors.New("first port must be >0, or use '*' for wildcard")
|
||||
ErrInvalidPortNumber = errors.New("invalid port number")
|
||||
ErrPortNumberOutOfRange = errors.New("port number out of range")
|
||||
)
|
||||
|
||||
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
|
||||
func splitDestinationAndPort(input string) (string, string, error) {
|
||||
// Find the last occurrence of the colon character
|
||||
@@ -16,13 +28,15 @@ func splitDestinationAndPort(input string) (string, string, error) {
|
||||
|
||||
// Check if the colon character is present and not at the beginning or end of the string
|
||||
if lastColonIndex == -1 {
|
||||
return "", "", errors.New("input must contain a colon character separating destination and port")
|
||||
return "", "", ErrInputMissingColon
|
||||
}
|
||||
|
||||
if lastColonIndex == 0 {
|
||||
return "", "", errors.New("input cannot start with a colon character")
|
||||
return "", "", ErrInputStartsWithColon
|
||||
}
|
||||
|
||||
if lastColonIndex == len(input)-1 {
|
||||
return "", "", errors.New("input cannot end with a colon character")
|
||||
return "", "", ErrInputEndsWithColon
|
||||
}
|
||||
|
||||
// Split the string into destination and port based on the last colon
|
||||
@@ -45,11 +59,12 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
for part := range parts {
|
||||
if strings.Contains(part, "-") {
|
||||
rangeParts := strings.Split(part, "-")
|
||||
|
||||
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
||||
return e == ""
|
||||
})
|
||||
if len(rangeParts) != 2 {
|
||||
return nil, errors.New("invalid port range format")
|
||||
return nil, ErrInvalidPortRangeFormat
|
||||
}
|
||||
|
||||
first, err := parsePort(rangeParts[0])
|
||||
@@ -63,7 +78,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
}
|
||||
|
||||
if first > last {
|
||||
return nil, errors.New("invalid port range: first port is greater than last port")
|
||||
return nil, ErrPortRangeInverted
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
|
||||
@@ -74,7 +89,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
}
|
||||
|
||||
if port < 1 {
|
||||
return nil, errors.New("first port must be >0, or use '*' for wildcard")
|
||||
return nil, ErrPortMustBePositive
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
|
||||
@@ -88,11 +103,11 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
func parsePort(portStr string) (uint16, error) {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return 0, errors.New("invalid port number")
|
||||
return 0, ErrInvalidPortNumber
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return 0, errors.New("port number out of range")
|
||||
return 0, ErrPortNumberOutOfRange
|
||||
}
|
||||
|
||||
return uint16(port), nil
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -24,9 +23,9 @@ func TestParseDestinationAndPort(t *testing.T) {
|
||||
{"tag:api-server:443", "tag:api-server", "443", nil},
|
||||
{"example-host-1:*", "example-host-1", "*", nil},
|
||||
{"hostname:80-90", "hostname", "80-90", nil},
|
||||
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
|
||||
{":invalid", "", "", errors.New("input cannot start with a colon character")},
|
||||
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
|
||||
{"invalidinput", "", "", ErrInputMissingColon},
|
||||
{":invalid", "", "", ErrInputStartsWithColon},
|
||||
{"invalid:", "", "", ErrInputEndsWithColon},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
@@ -58,9 +57,11 @@ func TestParsePort(t *testing.T) {
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if result != test.expected {
|
||||
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
|
||||
}
|
||||
@@ -92,9 +93,11 @@ func TestParsePortRange(t *testing.T) {
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(result, test.expected); diff != "" {
|
||||
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ const nodeNameContextKey = contextKey("nodeName")
|
||||
type mapSession struct {
|
||||
h *Headscale
|
||||
req tailcfg.MapRequest
|
||||
ctx context.Context
|
||||
ctx context.Context //nolint:containedctx
|
||||
capVer tailcfg.CapabilityVersion
|
||||
|
||||
cancelChMu deadlock.Mutex
|
||||
@@ -54,7 +54,7 @@ func (h *Headscale) newMapSession(
|
||||
w http.ResponseWriter,
|
||||
node *types.Node,
|
||||
) *mapSession {
|
||||
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
|
||||
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) //nolint:gosec // weak random is fine for jitter
|
||||
|
||||
return &mapSession{
|
||||
h: h,
|
||||
@@ -162,6 +162,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
// This is not my favourite solution, but it kind of works in our eventually consistent world.
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
disconnected := true
|
||||
// Wait up to 10 seconds for the node to reconnect.
|
||||
// 10 seconds was arbitrary chosen as a reasonable time to reconnect.
|
||||
@@ -170,6 +171,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
disconnected = false
|
||||
break
|
||||
}
|
||||
|
||||
<-ticker.C
|
||||
}
|
||||
|
||||
@@ -222,7 +224,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
// adding this before connecting it to the state ensure that
|
||||
// it does not miss any updates that might be sent in the split
|
||||
// time between the node connecting and the batcher being ready.
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { //nolint:noinlineerr
|
||||
m.log.Error().Caller().Err(err).Msg("failed to add node to batcher")
|
||||
return
|
||||
}
|
||||
@@ -240,22 +242,26 @@ func (m *mapSession) serveLongPoll() {
|
||||
case <-m.cancelCh:
|
||||
m.log.Trace().Caller().Msg("poll cancelled received")
|
||||
mapResponseEnded.WithLabelValues("cancelled").Inc()
|
||||
|
||||
return
|
||||
|
||||
case <-ctx.Done():
|
||||
m.log.Trace().Caller().Str(zf.Chan, fmt.Sprintf("%p", m.ch)).Msg("poll context done")
|
||||
mapResponseEnded.WithLabelValues("done").Inc()
|
||||
|
||||
return
|
||||
|
||||
// Consume updates sent to node
|
||||
case update, ok := <-m.ch:
|
||||
m.log.Trace().Caller().Bool(zf.OK, ok).Msg("received update from channel")
|
||||
|
||||
if !ok {
|
||||
m.log.Trace().Caller().Msg("update channel closed, streaming session is likely being replaced")
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.writeMap(update); err != nil {
|
||||
err := m.writeMap(update)
|
||||
if err != nil {
|
||||
m.log.Error().Caller().Err(err).Msg("cannot write update to client")
|
||||
return
|
||||
}
|
||||
@@ -264,7 +270,8 @@ func (m *mapSession) serveLongPoll() {
|
||||
m.resetKeepAlive()
|
||||
|
||||
case <-m.keepAliveTicker.C:
|
||||
if err := m.writeMap(&keepAlive); err != nil {
|
||||
err := m.writeMap(&keepAlive)
|
||||
if err != nil {
|
||||
m.log.Error().Caller().Err(err).Msg("cannot write keep alive")
|
||||
return
|
||||
}
|
||||
@@ -272,6 +279,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
if debugHighCardinalityMetrics {
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
|
||||
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||
m.resetKeepAlive()
|
||||
}
|
||||
@@ -292,7 +300,7 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
|
||||
}
|
||||
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
data := make([]byte, reservedResponseHeaderSize, reservedResponseHeaderSize+len(jsonBody))
|
||||
//nolint:gosec // G115: JSON response size will not exceed uint32 max
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
|
||||
data = append(data, jsonBody...)
|
||||
|
||||
@@ -109,9 +109,11 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
|
||||
Msg("current primary no longer available")
|
||||
}
|
||||
}
|
||||
|
||||
if len(nodes) >= 1 {
|
||||
pr.primaries[prefix] = nodes[0]
|
||||
changed = true
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str(zf.Prefix, prefix.String()).
|
||||
@@ -128,6 +130,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
|
||||
Str(zf.Prefix, prefix.String()).
|
||||
Msg("cleaning up primary route that no longer has available nodes")
|
||||
delete(pr.primaries, prefix)
|
||||
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
@@ -164,14 +167,17 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix)
|
||||
// If no routes are being set, remove the node from the routes map.
|
||||
if len(prefixes) == 0 {
|
||||
wasPresent := false
|
||||
|
||||
if _, ok := pr.routes[node]; ok {
|
||||
delete(pr.routes, node)
|
||||
|
||||
wasPresent = true
|
||||
|
||||
nlog.Debug().
|
||||
Caller().
|
||||
Msg("removed node from primary routes (no prefixes)")
|
||||
}
|
||||
|
||||
changed := pr.updatePrimaryLocked()
|
||||
nlog.Debug().
|
||||
Caller().
|
||||
@@ -253,12 +259,14 @@ func (pr *PrimaryRoutes) stringLocked() string {
|
||||
|
||||
ids := types.NodeIDs(xmaps.Keys(pr.routes))
|
||||
sort.Sort(ids)
|
||||
|
||||
for _, id := range ids {
|
||||
prefixes := pr.routes[id]
|
||||
fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", "))
|
||||
}
|
||||
|
||||
fmt.Fprintln(&sb, "\n\nCurrent primary routes:")
|
||||
|
||||
for route, nodeID := range pr.primaries {
|
||||
fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID)
|
||||
}
|
||||
|
||||
@@ -130,6 +130,7 @@ func TestPrimaryRoutes(t *testing.T) {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
pr.SetRoutes(2, mp("192.168.2.0/24"))
|
||||
pr.SetRoutes(1) // Deregister by setting no routes
|
||||
|
||||
return pr.SetRoutes(1, mp("192.168.3.0/24"))
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
@@ -153,8 +154,9 @@ func TestPrimaryRoutes(t *testing.T) {
|
||||
{
|
||||
name: "multiple-nodes-register-same-route",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
|
||||
|
||||
return pr.SetRoutes(3, mp("192.168.1.0/24")) // false
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
@@ -182,7 +184,8 @@ func TestPrimaryRoutes(t *testing.T) {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
return pr.SetRoutes(1) // true, 2 primary
|
||||
|
||||
return pr.SetRoutes(1) // true, 2 primary
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
2: {
|
||||
@@ -393,6 +396,7 @@ func TestPrimaryRoutes(t *testing.T) {
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0"))
|
||||
pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0"))
|
||||
|
||||
return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0"))
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
@@ -413,15 +417,20 @@ func TestPrimaryRoutes(t *testing.T) {
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var change1, change2 bool
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
change1 = pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
change2 = pr.SetRoutes(2, mp("192.168.2.0/24"))
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return change1 || change2
|
||||
@@ -449,17 +458,21 @@ func TestPrimaryRoutes(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pr := New()
|
||||
|
||||
change := tt.operations(pr)
|
||||
if change != tt.expectedChange {
|
||||
t.Errorf("change = %v, want %v", change, tt.expectedChange)
|
||||
}
|
||||
|
||||
comps := append(util.Comparers, cmpopts.EquateEmpty())
|
||||
if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" {
|
||||
t.Errorf("routes mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" {
|
||||
t.Errorf("primaries mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" {
|
||||
t.Errorf("isPrimary mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -77,6 +77,7 @@ func (s *State) DebugOverview() string {
|
||||
ephemeralCount := 0
|
||||
|
||||
now := time.Now()
|
||||
|
||||
for _, node := range allNodes.All() {
|
||||
if node.Valid() {
|
||||
userName := node.Owner().Name()
|
||||
@@ -103,17 +104,21 @@ func (s *State) DebugOverview() string {
|
||||
|
||||
// User statistics
|
||||
sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users)))
|
||||
|
||||
for userName, nodeCount := range userNodeCounts {
|
||||
sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Policy information
|
||||
sb.WriteString("Policy:\n")
|
||||
sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode))
|
||||
|
||||
if s.cfg.Policy.Mode == types.PolicyModeFile {
|
||||
sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// DERP information
|
||||
@@ -123,6 +128,7 @@ func (s *State) DebugOverview() string {
|
||||
} else {
|
||||
sb.WriteString("DERP: not configured\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Route information
|
||||
@@ -130,6 +136,7 @@ func (s *State) DebugOverview() string {
|
||||
if s.primaryRoutes.String() == "" {
|
||||
routeCount = 0
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount))
|
||||
sb.WriteString("\n")
|
||||
|
||||
@@ -165,10 +172,12 @@ func (s *State) DebugDERPMap() string {
|
||||
for _, node := range region.Nodes {
|
||||
sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n",
|
||||
node.Name, node.HostName, node.DERPPort))
|
||||
|
||||
if node.STUNPort != 0 {
|
||||
sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
@@ -236,7 +245,7 @@ func (s *State) DebugPolicy() (string, error) {
|
||||
|
||||
return string(pol), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported policy mode: %s", s.cfg.Policy.Mode)
|
||||
return "", fmt.Errorf("%w: %s", ErrUnsupportedPolicyMode, s.cfg.Policy.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,6 +328,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo {
|
||||
if s.primaryRoutes.String() == "" {
|
||||
routeCount = 0
|
||||
}
|
||||
|
||||
info.PrimaryRoutes = routeCount
|
||||
|
||||
return info
|
||||
|
||||
@@ -21,6 +21,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
|
||||
|
||||
// Create NodeStore
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -44,20 +45,26 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
|
||||
// 6. If DELETE came after UPDATE, the returned node should be invalid
|
||||
|
||||
done := make(chan bool, 2)
|
||||
var updatedNode types.NodeView
|
||||
var updateOk bool
|
||||
|
||||
var (
|
||||
updatedNode types.NodeView
|
||||
updateOk bool
|
||||
)
|
||||
|
||||
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
|
||||
|
||||
go func() {
|
||||
updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
|
||||
go func() {
|
||||
store.DeleteNode(node.ID)
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
@@ -91,6 +98,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) {
|
||||
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -148,6 +156,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) {
|
||||
node := createTestNode(3, 1, "test-user", "test-node-3")
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -204,6 +213,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
||||
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -214,8 +224,11 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
||||
// 1. UpdateNode (from UpdateNodeFromMapRequest during polling)
|
||||
// 2. DeleteNode (from handleLogout when client sends logout request)
|
||||
|
||||
var updatedNode types.NodeView
|
||||
var updateOk bool
|
||||
var (
|
||||
updatedNode types.NodeView
|
||||
updateOk bool
|
||||
)
|
||||
|
||||
done := make(chan bool, 2)
|
||||
|
||||
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
|
||||
@@ -223,12 +236,14 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
||||
updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
})
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
|
||||
go func() {
|
||||
store.DeleteNode(ephemeralNode.ID)
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
@@ -267,7 +282,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
||||
// 5. UpdateNode and DeleteNode batch together
|
||||
// 6. UpdateNode returns a valid node (from before delete in batch)
|
||||
// 7. persistNodeToDB is called with the stale valid node
|
||||
// 8. Node gets re-inserted into database instead of staying deleted
|
||||
// 8. Node gets re-inserted into database instead of staying deleted.
|
||||
func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
|
||||
ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5")
|
||||
ephemeralNode.AuthKey = &types.PreAuthKey{
|
||||
@@ -279,6 +294,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -349,6 +365,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
|
||||
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -399,7 +416,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
|
||||
// 3. UpdateNode and DeleteNode batch together
|
||||
// 4. UpdateNode returns a valid node (from before delete in batch)
|
||||
// 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node
|
||||
// 6. persistNodeToDB must detect the node is deleted and refuse to persist
|
||||
// 6. persistNodeToDB must detect the node is deleted and refuse to persist.
|
||||
func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
|
||||
ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7")
|
||||
ephemeralNode.AuthKey = &types.PreAuthKey{
|
||||
@@ -409,6 +426,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
|
||||
}
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ func netInfoFromMapRequest(
|
||||
Uint64("node.id", nodeID.Uint64()).
|
||||
Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP).
|
||||
Msg("using NetInfo from previous Hostinfo in MapRequest")
|
||||
|
||||
return currentHostinfo.NetInfo
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestNetInfoFromMapRequest(t *testing.T) {
|
||||
@@ -136,26 +133,3 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) {
|
||||
assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node")
|
||||
})
|
||||
}
|
||||
|
||||
// Simple helper function for tests
|
||||
func createTestNodeSimple(id types.NodeID) *types.Node {
|
||||
user := types.User{
|
||||
Name: "test-user",
|
||||
}
|
||||
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
node := &types.Node{
|
||||
ID: id,
|
||||
Hostname: "test-node",
|
||||
UserID: ptr.To(uint(id)),
|
||||
User: &user,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
IPv4: &netip.Addr{},
|
||||
IPv6: &netip.Addr{},
|
||||
}
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
@@ -55,8 +55,8 @@ var (
|
||||
})
|
||||
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_nodes_total",
|
||||
Help: "Total number of nodes in the NodeStore",
|
||||
Name: "nodestore_nodes",
|
||||
Help: "Number of nodes in the NodeStore",
|
||||
})
|
||||
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
@@ -97,6 +97,7 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batc
|
||||
for _, n := range allNodes {
|
||||
nodes[n.ID] = *n
|
||||
}
|
||||
|
||||
snap := snapshotFromNodes(nodes, peersFunc)
|
||||
|
||||
store := &NodeStore{
|
||||
@@ -165,11 +166,14 @@ func (s *NodeStore) PutNode(n types.Node) types.NodeView {
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
|
||||
s.writeQueue <- work
|
||||
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
|
||||
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||
|
||||
return resultNode
|
||||
@@ -205,11 +209,14 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
|
||||
s.writeQueue <- work
|
||||
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
|
||||
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||
|
||||
// Return the node and whether it exists (is valid)
|
||||
@@ -229,7 +236,9 @@ func (s *NodeStore) DeleteNode(id types.NodeID) {
|
||||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
|
||||
s.writeQueue <- work
|
||||
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
@@ -262,8 +271,10 @@ func (s *NodeStore) processWrite() {
|
||||
if len(batch) != 0 {
|
||||
s.applyBatch(batch)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
batch = append(batch, w)
|
||||
if len(batch) >= s.batchSize {
|
||||
s.applyBatch(batch)
|
||||
@@ -321,6 +332,7 @@ func (s *NodeStore) applyBatch(batch []work) {
|
||||
w.updateFn(&n)
|
||||
nodes[w.nodeID] = n
|
||||
}
|
||||
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
@@ -349,12 +361,14 @@ func (s *NodeStore) applyBatch(batch []work) {
|
||||
nodeView := node.View()
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- nodeView
|
||||
|
||||
close(w.nodeResult)
|
||||
}
|
||||
} else {
|
||||
// Node was deleted or doesn't exist
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- types.NodeView{} // Send invalid view
|
||||
|
||||
close(w.nodeResult)
|
||||
}
|
||||
}
|
||||
@@ -400,6 +414,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
||||
peersByNode: func() map[types.NodeID][]types.NodeView {
|
||||
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
|
||||
defer peersTimer.ObserveDuration()
|
||||
|
||||
return peersFunc(allNodes)
|
||||
}(),
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
@@ -417,6 +432,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
||||
if newSnap.nodesByMachineKey[n.MachineKey] == nil {
|
||||
newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView)
|
||||
}
|
||||
|
||||
newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView
|
||||
}
|
||||
|
||||
@@ -511,10 +527,12 @@ func (s *NodeStore) DebugString() string {
|
||||
|
||||
// User distribution (shows internal UserID tracking, not display owner)
|
||||
sb.WriteString("Nodes by Internal User ID:\n")
|
||||
|
||||
for userID, nodes := range snapshot.nodesByUser {
|
||||
if len(nodes) > 0 {
|
||||
userName := "unknown"
|
||||
taggedCount := 0
|
||||
|
||||
if len(nodes) > 0 && nodes[0].Valid() {
|
||||
userName = nodes[0].User().Name()
|
||||
// Count tagged nodes (which have UserID set but are owned by "tagged-devices")
|
||||
@@ -532,23 +550,29 @@ func (s *NodeStore) DebugString() string {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Peer relationships summary
|
||||
sb.WriteString("Peer Relationships:\n")
|
||||
|
||||
totalPeers := 0
|
||||
|
||||
for nodeID, peers := range snapshot.peersByNode {
|
||||
peerCount := len(peers)
|
||||
|
||||
totalPeers += peerCount
|
||||
if node, exists := snapshot.nodesByID[nodeID]; exists {
|
||||
sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n",
|
||||
nodeID, node.Hostname, peerCount))
|
||||
}
|
||||
}
|
||||
|
||||
if len(snapshot.peersByNode) > 0 {
|
||||
avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode))
|
||||
sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Node key index
|
||||
@@ -591,6 +615,7 @@ func (s *NodeStore) RebuildPeerMaps() {
|
||||
}
|
||||
|
||||
s.writeQueue <- w
|
||||
|
||||
<-result
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
||||
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
@@ -45,9 +45,10 @@ func TestSnapshotFromNodes(t *testing.T) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
}
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
@@ -70,7 +71,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
@@ -95,7 +96,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
@@ -124,7 +125,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
||||
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
assert.Len(t, snapshot.allNodes, 4)
|
||||
assert.Len(t, snapshot.peersByNode, 4)
|
||||
@@ -193,11 +194,13 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
|
||||
for _, n := range nodes {
|
||||
if n.ID() != node.ID() {
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
@@ -208,6 +211,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
|
||||
nodeIsOdd := node.ID()%2 == 1
|
||||
|
||||
for _, n := range nodes {
|
||||
@@ -222,6 +226,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
@@ -236,7 +241,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "create empty store and add single node",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper
|
||||
return NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
},
|
||||
steps: []testStep{
|
||||
@@ -274,7 +279,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "create store with initial node and add more",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
initialNodes := types.Nodes{&node1}
|
||||
|
||||
@@ -342,7 +347,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "test node deletion",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
@@ -403,7 +408,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "test node updates",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
@@ -445,7 +450,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "test with odd-even peers filtering",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper
|
||||
return NewNodeStore(nil, oddEvenPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
},
|
||||
steps: []testStep{
|
||||
@@ -455,10 +460,13 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
// Add nodes in sequence
|
||||
n1 := store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
assert.True(t, n1.Valid())
|
||||
|
||||
n2 := store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
assert.True(t, n2.Valid())
|
||||
|
||||
n3 := store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
assert.True(t, n3.Valid())
|
||||
|
||||
n4 := store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
assert.True(t, n4.Valid())
|
||||
|
||||
@@ -501,7 +509,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "test batch modifications return correct node state",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
@@ -526,16 +534,20 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2 types.NodeView
|
||||
var newNode3 types.NodeView
|
||||
var ok1, ok2 bool
|
||||
var (
|
||||
resultNode1, resultNode2 types.NodeView
|
||||
newNode3 types.NodeView
|
||||
ok1, ok2 bool
|
||||
)
|
||||
|
||||
// These should all be processed in the same batch
|
||||
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "batch-updated-node1"
|
||||
n.GivenName = "batch-given-1"
|
||||
})
|
||||
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
@@ -544,12 +556,14 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
n.Hostname = "batch-updated-node2"
|
||||
n.GivenName = "batch-given-2"
|
||||
})
|
||||
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
node3 := createTestNode(3, 1, "user1", "node3")
|
||||
newNode3 = store.PutNode(node3)
|
||||
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
@@ -602,20 +616,23 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
// This test verifies that when multiple updates to the same node
|
||||
// are batched together, each returned node reflects ALL changes
|
||||
// in the batch, not just the individual update's changes.
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2, resultNode3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
var (
|
||||
resultNode1, resultNode2, resultNode3 types.NodeView
|
||||
ok1, ok2, ok3 bool
|
||||
)
|
||||
|
||||
// These updates all modify node 1 and should be batched together
|
||||
// The final state should have all three modifications applied
|
||||
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "multi-update-hostname"
|
||||
})
|
||||
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
@@ -623,6 +640,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "multi-update-givenname"
|
||||
})
|
||||
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
@@ -630,6 +648,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Tags = []string{"tag1", "tag2"}
|
||||
})
|
||||
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
@@ -673,7 +692,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "test UpdateNode result is immutable for database save",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
@@ -723,14 +742,18 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var result1, result2, result3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
var (
|
||||
result1, result2, result3 types.NodeView
|
||||
ok1, ok2, ok3 bool
|
||||
)
|
||||
|
||||
// Start concurrent updates
|
||||
|
||||
go func() {
|
||||
result1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "concurrent-db-hostname"
|
||||
})
|
||||
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
@@ -738,6 +761,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
result2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "concurrent-db-given"
|
||||
})
|
||||
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
@@ -745,6 +769,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Tags = []string{"concurrent-tag"}
|
||||
})
|
||||
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
@@ -828,6 +853,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := tt.setupFunc(t)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -847,10 +873,11 @@ type testStep struct {
|
||||
|
||||
// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests ---
|
||||
|
||||
// Helper for concurrent test nodes
|
||||
// Helper for concurrent test nodes.
|
||||
func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
return types.Node{
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
@@ -863,72 +890,88 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Concurrency: concurrent PutNode operations ---
|
||||
// --- Concurrency: concurrent PutNode operations ---.
|
||||
func TestNodeStoreConcurrentPutNode(t *testing.T) {
|
||||
const concurrentOps = 20
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
results := make(chan bool, concurrentOps)
|
||||
for i := range concurrentOps {
|
||||
wg.Add(1)
|
||||
|
||||
go func(nodeID int) {
|
||||
defer wg.Done()
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node")
|
||||
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") //nolint:gosec // safe conversion in test
|
||||
|
||||
resultNode := store.PutNode(node)
|
||||
results <- resultNode.Valid()
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successCount := 0
|
||||
|
||||
for success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed")
|
||||
}
|
||||
|
||||
// --- Batching: concurrent ops fit in one batch ---
|
||||
// --- Batching: concurrent ops fit in one batch ---.
|
||||
func TestNodeStoreBatchingEfficiency(t *testing.T) {
|
||||
const batchSize = 10
|
||||
const ops = 15 // more than batchSize
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
results := make(chan bool, ops)
|
||||
for i := range ops {
|
||||
wg.Add(1)
|
||||
|
||||
go func(nodeID int) {
|
||||
defer wg.Done()
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node")
|
||||
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") //nolint:gosec // test code with small integers
|
||||
|
||||
resultNode := store.PutNode(node)
|
||||
results <- resultNode.Valid()
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successCount := 0
|
||||
|
||||
for success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, ops, successCount, "All batch PutNode operations should succeed")
|
||||
}
|
||||
|
||||
// --- Race conditions: many goroutines on same node ---
|
||||
// --- Race conditions: many goroutines on same node ---.
|
||||
func TestNodeStoreRaceConditions(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -937,13 +980,18 @@ func TestNodeStoreRaceConditions(t *testing.T) {
|
||||
resultNode := store.PutNode(node)
|
||||
require.True(t, resultNode.Valid())
|
||||
|
||||
const numGoroutines = 30
|
||||
const opsPerGoroutine = 10
|
||||
const (
|
||||
numGoroutines = 30
|
||||
opsPerGoroutine = 10
|
||||
)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
errors := make(chan error, numGoroutines*opsPerGoroutine)
|
||||
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
|
||||
go func(gid int) {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -954,40 +1002,46 @@ func TestNodeStoreRaceConditions(t *testing.T) {
|
||||
n.Hostname = "race-updated"
|
||||
})
|
||||
if !resultNode.Valid() {
|
||||
errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j)
|
||||
errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j) //nolint:err113
|
||||
}
|
||||
case 1:
|
||||
retrieved, found := store.GetNode(nodeID)
|
||||
if !found || !retrieved.Valid() {
|
||||
errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j)
|
||||
errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j) //nolint:err113
|
||||
}
|
||||
case 2:
|
||||
newNode := createConcurrentTestNode(nodeID, "race-put")
|
||||
|
||||
resultNode := store.PutNode(newNode)
|
||||
if !resultNode.Valid() {
|
||||
errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j)
|
||||
errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) //nolint:err113
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
errorCount := 0
|
||||
|
||||
for err := range errors {
|
||||
t.Error(err)
|
||||
|
||||
errorCount++
|
||||
}
|
||||
|
||||
if errorCount > 0 {
|
||||
t.Fatalf("Race condition test failed with %d errors", errorCount)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Resource cleanup: goroutine leak detection ---
|
||||
// --- Resource cleanup: goroutine leak detection ---.
|
||||
func TestNodeStoreResourceCleanup(t *testing.T) {
|
||||
// initialGoroutines := runtime.NumGoroutine()
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -1001,7 +1055,7 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
|
||||
|
||||
const ops = 100
|
||||
for i := range ops {
|
||||
nodeID := types.NodeID(i + 1)
|
||||
nodeID := types.NodeID(i + 1) //nolint:gosec // test code with small integers
|
||||
node := createConcurrentTestNode(nodeID, "cleanup-node")
|
||||
resultNode := store.PutNode(node)
|
||||
assert.True(t, resultNode.Valid())
|
||||
@@ -1010,10 +1064,12 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
|
||||
})
|
||||
retrieved, found := store.GetNode(nodeID)
|
||||
assert.True(t, found && retrieved.Valid())
|
||||
|
||||
if i%10 == 9 {
|
||||
store.DeleteNode(nodeID)
|
||||
}
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
|
||||
// Wait for goroutines to settle and check for leaks
|
||||
@@ -1024,9 +1080,10 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
|
||||
}, time.Second, 10*time.Millisecond, "goroutines should not leak")
|
||||
}
|
||||
|
||||
// --- Timeout/deadlock: operations complete within reasonable time ---
|
||||
// --- Timeout/deadlock: operations complete within reasonable time ---.
|
||||
func TestNodeStoreOperationTimeout(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
@@ -1034,36 +1091,47 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
const ops = 30
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
putResults := make([]error, ops)
|
||||
updateResults := make([]error, ops)
|
||||
|
||||
// Launch all PutNode operations concurrently
|
||||
for i := 1; i <= ops; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
nodeID := types.NodeID(i) //nolint:gosec // test code with small integers
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func(idx int, id types.NodeID) {
|
||||
defer wg.Done()
|
||||
|
||||
startPut := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id)
|
||||
node := createConcurrentTestNode(id, "timeout-node")
|
||||
resultNode := store.PutNode(node)
|
||||
endPut := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut))
|
||||
|
||||
if !resultNode.Valid() {
|
||||
putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id)
|
||||
putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) //nolint:err113
|
||||
}
|
||||
}(i, nodeID)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Launch all UpdateNode operations concurrently
|
||||
wg = sync.WaitGroup{}
|
||||
|
||||
for i := 1; i <= ops; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
nodeID := types.NodeID(i) //nolint:gosec // test code with small integers
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func(idx int, id types.NodeID) {
|
||||
defer wg.Done()
|
||||
|
||||
startUpdate := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id)
|
||||
resultNode, ok := store.UpdateNode(id, func(n *types.Node) {
|
||||
@@ -1071,31 +1139,40 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
||||
})
|
||||
endUpdate := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate))
|
||||
|
||||
if !ok || !resultNode.Valid() {
|
||||
updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id)
|
||||
updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) //nolint:err113
|
||||
}
|
||||
}(i, nodeID)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
errorCount := 0
|
||||
|
||||
for _, err := range putResults {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
|
||||
for _, err := range updateResults {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
|
||||
if errorCount == 0 {
|
||||
t.Log("All concurrent operations completed successfully within timeout")
|
||||
} else {
|
||||
@@ -1107,13 +1184,15 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Edge case: update non-existent node ---
|
||||
// --- Edge case: update non-existent node ---.
|
||||
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
|
||||
for i := range 10 {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
store.Start()
|
||||
nonExistentID := types.NodeID(999 + i)
|
||||
|
||||
nonExistentID := types.NodeID(999 + i) //nolint:gosec // test code with small integers
|
||||
updateCallCount := 0
|
||||
|
||||
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID)
|
||||
resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) {
|
||||
updateCallCount++
|
||||
@@ -1127,20 +1206,22 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Allocation benchmark ---
|
||||
// --- Allocation benchmark ---.
|
||||
func BenchmarkNodeStoreAllocations(b *testing.B) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
for i := 0; b.Loop(); i++ {
|
||||
nodeID := types.NodeID(i + 1)
|
||||
nodeID := types.NodeID(i + 1) //nolint:gosec // benchmark code with small integers
|
||||
node := createConcurrentTestNode(nodeID, "bench-node")
|
||||
store.PutNode(node)
|
||||
store.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.Hostname = "bench-updated"
|
||||
})
|
||||
store.GetNode(nodeID)
|
||||
|
||||
if i%10 == 9 {
|
||||
store.DeleteNode(nodeID)
|
||||
}
|
||||
|
||||
@@ -230,6 +230,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) {
|
||||
// propagate correctly when switching between policy types.
|
||||
s.nodeStore.RebuildPeerMaps()
|
||||
|
||||
//nolint:prealloc // cs starts with one element and may grow
|
||||
cs := []change.Change{change.PolicyChange()}
|
||||
|
||||
// Always call autoApproveNodes during policy reload, regardless of whether
|
||||
@@ -260,7 +261,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) {
|
||||
// CreateUser creates a new user and updates the policy manager.
|
||||
// Returns the created user, change set, and any error.
|
||||
func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) {
|
||||
if err := s.db.DB.Save(&user).Error; err != nil {
|
||||
if err := s.db.DB.Save(&user).Error; err != nil { //nolint:noinlineerr
|
||||
return nil, change.Change{}, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
|
||||
@@ -294,7 +295,7 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := updateFn(user); err != nil {
|
||||
if err := updateFn(user); err != nil { //nolint:noinlineerr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -512,7 +513,7 @@ func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("node not found: %d", id)
|
||||
return nil, fmt.Errorf("%w: %d", ErrNodeNotFound, id)
|
||||
}
|
||||
|
||||
log.Info().EmbedObject(node).Msg("node disconnected")
|
||||
@@ -765,7 +766,7 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView,
|
||||
|
||||
// Check name uniqueness against NodeStore
|
||||
allNodes := s.nodeStore.ListNodes()
|
||||
for i := 0; i < allNodes.Len(); i++ {
|
||||
for i := range allNodes.Len() {
|
||||
node := allNodes.At(i)
|
||||
if node.ID() != nodeID && node.AsStruct().GivenName == newName {
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName)
|
||||
@@ -832,7 +833,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
|
||||
|
||||
var updates []change.Change
|
||||
|
||||
for _, node := range s.nodeStore.ListNodes().All() {
|
||||
for _, node := range s.nodeStore.ListNodes().All() { //nolint:unqueryvet // NodeStore.ListNodes not a SQL query
|
||||
if !node.Valid() {
|
||||
continue
|
||||
}
|
||||
@@ -1850,7 +1851,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: transaction success
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err)
|
||||
|
||||
@@ -13,6 +13,9 @@ import (
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// ErrNoCertDomains is returned when no cert domains are available for HTTPS.
|
||||
var ErrNoCertDomains = errors.New("no cert domains available for HTTPS")
|
||||
|
||||
func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath string) error {
|
||||
opts := tailsql.Options{
|
||||
Hostname: "tailsql-headscale",
|
||||
@@ -41,15 +44,17 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
|
||||
defer tsNode.Close()
|
||||
|
||||
logf("Starting tailscale (hostname=%q)", opts.Hostname)
|
||||
|
||||
lc, err := tsNode.LocalClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect local client: %w", err)
|
||||
}
|
||||
|
||||
opts.LocalClient = lc // for authentication
|
||||
|
||||
// Make sure the Tailscale node starts up. It might not, if it is a new node
|
||||
// and the user did not provide an auth key.
|
||||
if st, err := tsNode.Up(ctx); err != nil {
|
||||
if st, err := tsNode.Up(ctx); err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("starting tailscale: %w", err)
|
||||
} else {
|
||||
logf("tailscale started, node state %q", st.BackendState)
|
||||
@@ -71,28 +76,38 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
|
||||
// When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443.
|
||||
certDomains := tsNode.CertDomains()
|
||||
if len(certDomains) == 0 {
|
||||
return errors.New("no cert domains available for HTTPS")
|
||||
return ErrNoCertDomains
|
||||
}
|
||||
|
||||
base := "https://" + certDomains[0]
|
||||
go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
target := base + r.RequestURI
|
||||
http.Redirect(w, r, target, http.StatusPermanentRedirect)
|
||||
}))
|
||||
|
||||
go func() {
|
||||
_ = http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { //nolint:gosec
|
||||
target := base + r.RequestURI
|
||||
http.Redirect(w, r, target, http.StatusPermanentRedirect)
|
||||
}))
|
||||
}()
|
||||
// log.Printf("Redirecting HTTP to HTTPS at %q", base)
|
||||
|
||||
// For the real service, start a separate listener.
|
||||
// Note: Replaces the port 80 listener.
|
||||
var err error
|
||||
|
||||
lst, err = tsNode.ListenTLS("tcp", ":443")
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen TLS: %w", err)
|
||||
}
|
||||
|
||||
logf("enabled serving via HTTPS")
|
||||
}
|
||||
|
||||
mux := tsql.NewMux()
|
||||
tsweb.Debugger(mux)
|
||||
go http.Serve(lst, mux)
|
||||
|
||||
go func() {
|
||||
_ = http.Serve(lst, mux) //nolint:gosec
|
||||
}()
|
||||
|
||||
logf("TailSQL started")
|
||||
<-ctx.Done()
|
||||
logf("TailSQL shutting down...")
|
||||
|
||||
@@ -20,7 +20,11 @@ const (
|
||||
DatabaseSqlite = "sqlite3"
|
||||
)
|
||||
|
||||
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
// Common errors.
|
||||
var (
|
||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length")
|
||||
)
|
||||
|
||||
type StateUpdateType int
|
||||
|
||||
@@ -100,6 +104,10 @@ func (su *StateUpdate) Empty() bool {
|
||||
return len(su.ChangePatches) == 0
|
||||
case StatePeerRemoved:
|
||||
return len(su.Removed) == 0
|
||||
case StateFullUpdate, StateSelfUpdate, StateDERPUpdated:
|
||||
// These update types don't have associated data to check,
|
||||
// so they are never considered empty.
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -175,8 +183,9 @@ func MustRegistrationID() RegistrationID {
|
||||
|
||||
func RegistrationIDFromString(str string) (RegistrationID, error) {
|
||||
if len(str) != RegistrationIDLength {
|
||||
return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength)
|
||||
return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str))
|
||||
}
|
||||
|
||||
return RegistrationID(str), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -33,10 +33,12 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
|
||||
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
|
||||
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
|
||||
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
|
||||
ErrNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
|
||||
ErrInvalidAllocationStrategy = errors.New("invalid prefix allocation strategy")
|
||||
)
|
||||
|
||||
type IPAllocationStrategy string
|
||||
@@ -301,6 +303,7 @@ func validatePKCEMethod(method string) error {
|
||||
if method != PKCEMethodPlain && method != PKCEMethodS256 {
|
||||
return errInvalidPKCEMethod
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -326,6 +329,7 @@ func LoadConfig(path string, isFile bool) error {
|
||||
viper.SetConfigFile(path)
|
||||
} else {
|
||||
viper.SetConfigName("config")
|
||||
|
||||
if path == "" {
|
||||
viper.AddConfigPath("/etc/headscale/")
|
||||
viper.AddConfigPath("$HOME/.headscale")
|
||||
@@ -401,8 +405,10 @@ func LoadConfig(path string, isFile bool) error {
|
||||
|
||||
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
err := viper.ReadInConfig()
|
||||
if err != nil {
|
||||
var configFileNotFoundError viper.ConfigFileNotFoundError
|
||||
if errors.As(err, &configFileNotFoundError) {
|
||||
log.Warn().Msg("no config file found, using defaults")
|
||||
return nil
|
||||
}
|
||||
@@ -442,7 +448,8 @@ func validateServerConfig() error {
|
||||
depr.fatal("oidc.map_legacy_users")
|
||||
|
||||
if viper.GetBool("oidc.enabled") {
|
||||
if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil {
|
||||
err := validatePKCEMethod(viper.GetString("oidc.pkce.method"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -556,6 +563,7 @@ func derpConfig() DERPConfig {
|
||||
automaticallyAddEmbeddedDerpRegion := viper.GetBool(
|
||||
"derp.server.automatically_add_embedded_derp_region",
|
||||
)
|
||||
|
||||
if serverEnabled && stunAddr == "" {
|
||||
log.Fatal().
|
||||
Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
|
||||
@@ -625,13 +633,16 @@ func policyConfig() PolicyConfig {
|
||||
|
||||
func logConfig() LogConfig {
|
||||
logLevelStr := viper.GetString("log.level")
|
||||
|
||||
logLevel, err := zerolog.ParseLevel(logLevelStr)
|
||||
if err != nil {
|
||||
logLevel = zerolog.DebugLevel
|
||||
}
|
||||
|
||||
logFormatOpt := viper.GetString("log.format")
|
||||
|
||||
var logFormat string
|
||||
|
||||
switch logFormatOpt {
|
||||
case JSONLogFormat:
|
||||
logFormat = JSONLogFormat
|
||||
@@ -658,7 +669,7 @@ func databaseConfig() DatabaseConfig {
|
||||
type_ := viper.GetString("database.type")
|
||||
|
||||
skipErrRecordNotFound := viper.GetBool("database.gorm.skip_err_record_not_found")
|
||||
slowThreshold := viper.GetDuration("database.gorm.slow_threshold") * time.Millisecond
|
||||
slowThreshold := time.Duration(viper.GetInt64("database.gorm.slow_threshold")) * time.Millisecond
|
||||
parameterizedQueries := viper.GetBool("database.gorm.parameterized_queries")
|
||||
prepareStmt := viper.GetBool("database.gorm.prepare_stmt")
|
||||
|
||||
@@ -730,6 +741,7 @@ func dns() (DNSConfig, error) {
|
||||
if err != nil {
|
||||
return DNSConfig{}, fmt.Errorf("unmarshalling dns extra records: %w", err)
|
||||
}
|
||||
|
||||
dns.ExtraRecords = extraRecords
|
||||
}
|
||||
|
||||
@@ -745,30 +757,23 @@ func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
|
||||
var resolvers []*dnstype.Resolver
|
||||
|
||||
for _, nsStr := range d.Nameservers.Global {
|
||||
warn := ""
|
||||
if _, err := netip.ParseAddr(nsStr); err == nil {
|
||||
if _, err := netip.ParseAddr(nsStr); err == nil { //nolint:noinlineerr
|
||||
resolvers = append(resolvers, &dnstype.Resolver{
|
||||
Addr: nsStr,
|
||||
})
|
||||
|
||||
continue
|
||||
} else {
|
||||
warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(nsStr); err == nil {
|
||||
if _, err := url.Parse(nsStr); err == nil { //nolint:noinlineerr
|
||||
resolvers = append(resolvers, &dnstype.Resolver{
|
||||
Addr: nsStr,
|
||||
})
|
||||
|
||||
continue
|
||||
} else {
|
||||
warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err)
|
||||
}
|
||||
|
||||
if warn != "" {
|
||||
log.Warn().Msg(warn)
|
||||
}
|
||||
log.Warn().Str("nameserver", nsStr).Msg("invalid global nameserver, ignoring")
|
||||
}
|
||||
|
||||
return resolvers
|
||||
@@ -780,34 +785,30 @@ func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
|
||||
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
|
||||
func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver {
|
||||
routes := make(map[string][]*dnstype.Resolver)
|
||||
|
||||
for domain, nameservers := range d.Nameservers.Split {
|
||||
var resolvers []*dnstype.Resolver
|
||||
|
||||
for _, nsStr := range nameservers {
|
||||
warn := ""
|
||||
if _, err := netip.ParseAddr(nsStr); err == nil {
|
||||
if _, err := netip.ParseAddr(nsStr); err == nil { //nolint:noinlineerr
|
||||
resolvers = append(resolvers, &dnstype.Resolver{
|
||||
Addr: nsStr,
|
||||
})
|
||||
|
||||
continue
|
||||
} else {
|
||||
warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(nsStr); err == nil {
|
||||
if _, err := url.Parse(nsStr); err == nil { //nolint:noinlineerr
|
||||
resolvers = append(resolvers, &dnstype.Resolver{
|
||||
Addr: nsStr,
|
||||
})
|
||||
|
||||
continue
|
||||
} else {
|
||||
warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err)
|
||||
}
|
||||
|
||||
if warn != "" {
|
||||
log.Warn().Msg(warn)
|
||||
}
|
||||
log.Warn().Str("nameserver", nsStr).Str("domain", domain).Msg("invalid split dns nameserver, ignoring")
|
||||
}
|
||||
|
||||
routes[domain] = resolvers
|
||||
}
|
||||
|
||||
@@ -822,6 +823,7 @@ func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
||||
}
|
||||
|
||||
cfg.Proxied = dns.MagicDNS
|
||||
|
||||
cfg.ExtraRecords = dns.ExtraRecords
|
||||
if dns.OverrideLocalDNS {
|
||||
cfg.Resolvers = dns.globalResolvers()
|
||||
@@ -830,10 +832,12 @@ func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
||||
}
|
||||
|
||||
routes := dns.splitResolvers()
|
||||
|
||||
cfg.Routes = routes
|
||||
if dns.BaseDomain != "" {
|
||||
cfg.Domains = []string{dns.BaseDomain}
|
||||
}
|
||||
|
||||
cfg.Domains = append(cfg.Domains, dns.SearchDomains...)
|
||||
|
||||
return &cfg
|
||||
@@ -843,7 +847,7 @@ func prefixV4() (*netip.Prefix, error) {
|
||||
prefixV4Str := viper.GetString("prefixes.v4")
|
||||
|
||||
if prefixV4Str == "" {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // empty prefix is valid, not an error
|
||||
}
|
||||
|
||||
prefixV4, err := netip.ParsePrefix(prefixV4Str)
|
||||
@@ -853,6 +857,7 @@ func prefixV4() (*netip.Prefix, error) {
|
||||
|
||||
builder := netipx.IPSetBuilder{}
|
||||
builder.AddPrefix(tsaddr.CGNATRange())
|
||||
|
||||
ipSet, _ := builder.IPSet()
|
||||
if !ipSet.ContainsPrefix(prefixV4) {
|
||||
log.Warn().
|
||||
@@ -867,7 +872,7 @@ func prefixV6() (*netip.Prefix, error) {
|
||||
prefixV6Str := viper.GetString("prefixes.v6")
|
||||
|
||||
if prefixV6Str == "" {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // empty prefix is valid, not an error
|
||||
}
|
||||
|
||||
prefixV6, err := netip.ParsePrefix(prefixV6Str)
|
||||
@@ -910,7 +915,7 @@ func LoadCLIConfig() (*Config, error) {
|
||||
// LoadServerConfig returns the full Headscale configuration to
|
||||
// host a Headscale server. This is called as part of `headscale serve`.
|
||||
func LoadServerConfig() (*Config, error) {
|
||||
if err := validateServerConfig(); err != nil {
|
||||
if err := validateServerConfig(); err != nil { //nolint:noinlineerr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -928,11 +933,13 @@ func LoadServerConfig() (*Config, error) {
|
||||
}
|
||||
|
||||
if prefix4 == nil && prefix6 == nil {
|
||||
return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
|
||||
return nil, ErrNoPrefixConfigured
|
||||
}
|
||||
|
||||
allocStr := viper.GetString("prefixes.allocation")
|
||||
|
||||
var alloc IPAllocationStrategy
|
||||
|
||||
switch allocStr {
|
||||
case string(IPAllocationStrategySequential):
|
||||
alloc = IPAllocationStrategySequential
|
||||
@@ -940,7 +947,8 @@ func LoadServerConfig() (*Config, error) {
|
||||
alloc = IPAllocationStrategyRandom
|
||||
default:
|
||||
return nil, fmt.Errorf(
|
||||
"config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s",
|
||||
"%w: %q, allowed options: %s, %s",
|
||||
ErrInvalidAllocationStrategy,
|
||||
allocStr,
|
||||
IPAllocationStrategySequential,
|
||||
IPAllocationStrategyRandom,
|
||||
@@ -957,15 +965,18 @@ func LoadServerConfig() (*Config, error) {
|
||||
randomizeClientPort := viper.GetBool("randomize_client_port")
|
||||
|
||||
oidcClientSecret := viper.GetString("oidc.client_secret")
|
||||
|
||||
oidcClientSecretPath := viper.GetString("oidc.client_secret_path")
|
||||
if oidcClientSecretPath != "" && oidcClientSecret != "" {
|
||||
return nil, errOidcMutuallyExclusive
|
||||
}
|
||||
|
||||
if oidcClientSecretPath != "" {
|
||||
secretBytes, err := os.ReadFile(os.ExpandEnv(oidcClientSecretPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oidcClientSecret = strings.TrimSpace(string(secretBytes))
|
||||
}
|
||||
|
||||
@@ -979,7 +990,8 @@ func LoadServerConfig() (*Config, error) {
|
||||
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
||||
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
||||
if dnsConfig.BaseDomain != "" {
|
||||
if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil {
|
||||
err := isSafeServerURL(serverURL, dnsConfig.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -994,7 +1006,7 @@ func LoadServerConfig() (*Config, error) {
|
||||
|
||||
PrefixV4: prefix4,
|
||||
PrefixV6: prefix6,
|
||||
IPAllocation: IPAllocationStrategy(alloc),
|
||||
IPAllocation: alloc,
|
||||
|
||||
NoisePrivateKeyPath: util.AbsolutePathFromConfigPath(
|
||||
viper.GetString("noise.private_key_path"),
|
||||
@@ -1082,6 +1094,7 @@ func LoadServerConfig() (*Config, error) {
|
||||
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
|
||||
return workers
|
||||
}
|
||||
|
||||
return DefaultBatcherWorkers()
|
||||
}(),
|
||||
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
|
||||
@@ -1117,6 +1130,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
|
||||
}
|
||||
|
||||
s := len(serverDomainParts)
|
||||
|
||||
b := len(baseDomainParts)
|
||||
for i := range baseDomainParts {
|
||||
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
|
||||
@@ -1134,9 +1148,12 @@ type deprecator struct {
|
||||
|
||||
// warnWithAlias will register an alias between the newKey and the oldKey,
|
||||
// and log a deprecation warning if the oldKey is set.
|
||||
//
|
||||
//nolint:unused
|
||||
func (d *deprecator) warnWithAlias(newKey, oldKey string) {
|
||||
// NOTE: RegisterAlias is called with NEW KEY -> OLD KEY
|
||||
viper.RegisterAlias(newKey, oldKey)
|
||||
|
||||
if viper.IsSet(oldKey) {
|
||||
d.warns.Add(
|
||||
fmt.Sprintf(
|
||||
@@ -1179,6 +1196,8 @@ func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
|
||||
}
|
||||
|
||||
// warn deprecates and adds an option to log a warning if the oldKey is set.
|
||||
//
|
||||
//nolint:unused
|
||||
func (d *deprecator) warnNoAlias(newKey, oldKey string) {
|
||||
if viper.IsSet(oldKey) {
|
||||
d.warns.Add(
|
||||
@@ -1193,6 +1212,8 @@ func (d *deprecator) warnNoAlias(newKey, oldKey string) {
|
||||
}
|
||||
|
||||
// warn deprecates and adds an entry to the warn list of options if the oldKey is set.
|
||||
//
|
||||
//nolint:unused
|
||||
func (d *deprecator) warn(oldKey string) {
|
||||
if viper.IsSet(oldKey) {
|
||||
d.warns.Add(
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestReadConfig(t *testing.T) {
|
||||
{
|
||||
name: "unmarshal-dns-full-config",
|
||||
configPath: "testdata/dns_full.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper
|
||||
dns, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -61,7 +61,7 @@ func TestReadConfig(t *testing.T) {
|
||||
{
|
||||
name: "dns-to-tailcfg.DNSConfig",
|
||||
configPath: "testdata/dns_full.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper
|
||||
dns, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -92,7 +92,7 @@ func TestReadConfig(t *testing.T) {
|
||||
{
|
||||
name: "unmarshal-dns-full-no-magic",
|
||||
configPath: "testdata/dns_full_no_magic.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper
|
||||
dns, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -127,7 +127,7 @@ func TestReadConfig(t *testing.T) {
|
||||
{
|
||||
name: "dns-to-tailcfg.DNSConfig",
|
||||
configPath: "testdata/dns_full_no_magic.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper
|
||||
dns, err := dns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -158,7 +158,7 @@ func TestReadConfig(t *testing.T) {
|
||||
{
|
||||
name: "base-domain-in-server-url-err",
|
||||
configPath: "testdata/base-domain-in-server-url.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper
|
||||
return LoadServerConfig()
|
||||
},
|
||||
want: nil,
|
||||
@@ -167,7 +167,7 @@ func TestReadConfig(t *testing.T) {
|
||||
{
|
||||
name: "base-domain-not-in-server-url",
|
||||
configPath: "testdata/base-domain-not-in-server-url.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper
|
||||
cfg, err := LoadServerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -187,7 +187,7 @@ func TestReadConfig(t *testing.T) {
|
||||
{
|
||||
name: "dns-override-true-errors",
|
||||
configPath: "testdata/dns-override-true-error.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper
|
||||
return LoadServerConfig()
|
||||
},
|
||||
wantErr: "Fatal config error: dns.nameservers.global must be set when dns.override_local_dns is true",
|
||||
@@ -195,7 +195,7 @@ func TestReadConfig(t *testing.T) {
|
||||
{
|
||||
name: "dns-override-true",
|
||||
configPath: "testdata/dns-override-true.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper
|
||||
_, err := LoadServerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -221,7 +221,7 @@ func TestReadConfig(t *testing.T) {
|
||||
{
|
||||
name: "policy-path-is-loaded",
|
||||
configPath: "testdata/policy-path-is-loaded.yaml",
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper // inline test closure
|
||||
cfg, err := LoadServerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -242,6 +242,7 @@ func TestReadConfig(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
err := LoadConfig(tt.configPath, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -276,14 +277,14 @@ func TestReadConfigFromEnv(t *testing.T) {
|
||||
"HEADSCALE_DATABASE_SQLITE_WRITE_AHEAD_LOG": "false",
|
||||
"HEADSCALE_PREFIXES_V4": "100.64.0.0/10",
|
||||
},
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper // inline test closure
|
||||
t.Logf("all settings: %#v", viper.AllSettings())
|
||||
|
||||
assert.Equal(t, "trace", viper.GetString("log.level"))
|
||||
assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4"))
|
||||
assert.False(t, viper.GetBool("database.sqlite.write_ahead_log"))
|
||||
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // test setup returns nil to indicate no expected value
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
@@ -300,7 +301,7 @@ func TestReadConfigFromEnv(t *testing.T) {
|
||||
// "HEADSCALE_DNS_NAMESERVERS_SPLIT": `{foo.bar.com: ["1.1.1.1"]}`,
|
||||
// "HEADSCALE_DNS_EXTRA_RECORDS": `[{ name: "prometheus.myvpn.example.com", type: "A", value: "100.64.0.4" }]`,
|
||||
},
|
||||
setup: func(t *testing.T) (any, error) {
|
||||
setup: func(t *testing.T) (any, error) { //nolint:thelper // inline test closure
|
||||
t.Logf("all settings: %#v", viper.AllSettings())
|
||||
|
||||
dns, err := dns()
|
||||
@@ -335,6 +336,7 @@ func TestReadConfigFromEnv(t *testing.T) {
|
||||
}
|
||||
|
||||
viper.Reset()
|
||||
|
||||
err := LoadConfig("testdata/minimal.yaml", true)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -349,11 +351,10 @@ func TestReadConfigFromEnv(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTLSConfigValidation(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// defer os.RemoveAll(tmpDir)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
var err error
|
||||
|
||||
configYaml := []byte(`---
|
||||
tls_letsencrypt_hostname: example.com
|
||||
tls_letsencrypt_challenge_type: ""
|
||||
@@ -363,6 +364,7 @@ noise:
|
||||
|
||||
// Populate a custom config file
|
||||
configFilePath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
err = os.WriteFile(configFilePath, configYaml, 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't write file %s", configFilePath)
|
||||
@@ -398,10 +400,12 @@ server_url: http://127.0.0.1:8080
|
||||
tls_letsencrypt_hostname: example.com
|
||||
tls_letsencrypt_challenge_type: TLS-ALPN-01
|
||||
`)
|
||||
|
||||
err = os.WriteFile(configFilePath, configYaml, 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't write file %s", configFilePath)
|
||||
}
|
||||
|
||||
err = LoadConfig(tmpDir, false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -463,6 +467,7 @@ func TestSafeServerURL(t *testing.T) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (id NodeID) StableID() tailcfg.StableNodeID {
|
||||
}
|
||||
|
||||
func (id NodeID) NodeID() tailcfg.NodeID {
|
||||
return tailcfg.NodeID(id)
|
||||
return tailcfg.NodeID(id) //nolint:gosec // NodeID is bounded
|
||||
}
|
||||
|
||||
func (id NodeID) Uint64() uint64 {
|
||||
@@ -162,11 +162,12 @@ func (node *Node) GivenNameHasBeenChanged() bool {
|
||||
// Strip invalid DNS characters for givenName comparison
|
||||
normalised := strings.ToLower(node.Hostname)
|
||||
normalised = invalidDNSRegex.ReplaceAllString(normalised, "")
|
||||
|
||||
return node.GivenName == normalised
|
||||
}
|
||||
|
||||
// IsExpired returns whether the node registration has expired.
|
||||
func (node Node) IsExpired() bool {
|
||||
func (node *Node) IsExpired() bool {
|
||||
// If Expiry is not set, the client has not indicated that
|
||||
// it wants an expiry time, it is therefore considered
|
||||
// to mean "not expired"
|
||||
@@ -245,8 +246,14 @@ func (node *Node) RequestTags() []string {
|
||||
}
|
||||
|
||||
func (node *Node) Prefixes() []netip.Prefix {
|
||||
var addrs []netip.Prefix
|
||||
for _, nodeAddress := range node.IPs() {
|
||||
ips := node.IPs()
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
addrs := make([]netip.Prefix, 0, len(ips))
|
||||
|
||||
for _, nodeAddress := range ips {
|
||||
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
|
||||
addrs = append(addrs, ip)
|
||||
}
|
||||
@@ -274,9 +281,14 @@ func (node *Node) IsExitNode() bool {
|
||||
}
|
||||
|
||||
func (node *Node) IPsAsString() []string {
|
||||
var ret []string
|
||||
ips := node.IPs()
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ip := range node.IPs() {
|
||||
ret := make([]string, 0, len(ips))
|
||||
|
||||
for _, ip := range ips {
|
||||
ret = append(ret, ip.String())
|
||||
}
|
||||
|
||||
@@ -480,7 +492,7 @@ func (node *Node) IsSubnetRouter() bool {
|
||||
return len(node.SubnetRoutes()) > 0
|
||||
}
|
||||
|
||||
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes
|
||||
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes.
|
||||
func (node *Node) AllApprovedRoutes() []netip.Prefix {
|
||||
return append(node.SubnetRoutes(), node.ExitRoutes()...)
|
||||
}
|
||||
@@ -527,7 +539,7 @@ func (node *Node) MarshalZerologObject(e *zerolog.Event) {
|
||||
// - logTracePeerChange in poll.go.
|
||||
func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
|
||||
ret := tailcfg.PeerChange{
|
||||
NodeID: tailcfg.NodeID(node.ID),
|
||||
NodeID: tailcfg.NodeID(node.ID), //nolint:gosec // NodeID is bounded
|
||||
}
|
||||
|
||||
if node.NodeKey.String() != req.NodeKey.String() {
|
||||
@@ -553,11 +565,9 @@ func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC
|
||||
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
|
||||
} else if node.Hostinfo.NetInfo == nil {
|
||||
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
|
||||
} else {
|
||||
} else if node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP {
|
||||
// If there is a PreferredDERP check if it has changed.
|
||||
if node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP {
|
||||
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
|
||||
}
|
||||
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,13 +628,16 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
|
||||
}
|
||||
|
||||
newHostname := strings.ToLower(hostInfo.Hostname)
|
||||
if err := util.ValidateHostname(newHostname); err != nil {
|
||||
|
||||
err := util.ValidateHostname(newHostname)
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Str("node.id", node.ID.String()).
|
||||
Str("current_hostname", node.Hostname).
|
||||
Str("rejected_hostname", hostInfo.Hostname).
|
||||
Err(err).
|
||||
Msg("Rejecting invalid hostname update from hostinfo")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -716,6 +729,7 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
|
||||
func (nodes Nodes) DebugString() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Nodes:\n")
|
||||
|
||||
for _, node := range nodes {
|
||||
sb.WriteString(node.DebugString())
|
||||
sb.WriteString("\n")
|
||||
@@ -724,7 +738,7 @@ func (nodes Nodes) DebugString() string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (node Node) DebugString() string {
|
||||
func (node *Node) DebugString() string {
|
||||
var sb strings.Builder
|
||||
fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID)
|
||||
|
||||
@@ -897,7 +911,7 @@ func (nv NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.Peer
|
||||
// GetFQDN returns the fully qualified domain name for the node.
|
||||
func (nv NodeView) GetFQDN(baseDomain string) (string, error) {
|
||||
if !nv.Valid() {
|
||||
return "", errors.New("creating valid FQDN: node view is invalid")
|
||||
return "", fmt.Errorf("creating valid FQDN: %w", ErrInvalidNodeView)
|
||||
}
|
||||
|
||||
return nv.ж.GetFQDN(baseDomain)
|
||||
|
||||
@@ -407,7 +407,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
|
||||
Hostname: "valid-hostname",
|
||||
},
|
||||
change: &tailcfg.Hostinfo{
|
||||
Hostname: "我的电脑",
|
||||
Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||
},
|
||||
want: Node{
|
||||
GivenName: "valid-hostname",
|
||||
@@ -491,7 +491,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
|
||||
Hostname: "valid-hostname",
|
||||
},
|
||||
change: &tailcfg.Hostinfo{
|
||||
Hostname: "server-北京-01",
|
||||
Hostname: "server-北京-01", //nolint:gosmopolitan // intentional i18n test data
|
||||
},
|
||||
want: Node{
|
||||
GivenName: "valid-hostname",
|
||||
@@ -505,7 +505,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
|
||||
Hostname: "valid-hostname",
|
||||
},
|
||||
change: &tailcfg.Hostinfo{
|
||||
Hostname: "我的电脑",
|
||||
Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||
},
|
||||
want: Node{
|
||||
GivenName: "valid-hostname",
|
||||
@@ -533,7 +533,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
|
||||
Hostname: "valid-hostname",
|
||||
},
|
||||
change: &tailcfg.Hostinfo{
|
||||
Hostname: "测试💻机器",
|
||||
Hostname: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data
|
||||
},
|
||||
want: Node{
|
||||
GivenName: "valid-hostname",
|
||||
|
||||
@@ -116,7 +116,7 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
||||
return &protoKey
|
||||
}
|
||||
|
||||
// canUsePreAuthKey checks if a pre auth key can be used.
|
||||
// Validate checks if a pre auth key can be used.
|
||||
func (pak *PreAuthKey) Validate() error {
|
||||
if pak == nil {
|
||||
return PAKError("invalid authkey")
|
||||
|
||||
@@ -111,6 +111,7 @@ func TestCanUsePreAuthKey(t *testing.T) {
|
||||
t.Errorf("expected error but got none")
|
||||
} else {
|
||||
var httpErr PAKError
|
||||
|
||||
ok := errors.As(err, &httpErr)
|
||||
if !ok {
|
||||
t.Errorf("expected HTTPError but got %T", err)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"cmp"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
@@ -20,6 +21,9 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// ErrCannotParseBoolean is returned when a value cannot be parsed as boolean.
|
||||
var ErrCannotParseBoolean = errors.New("cannot parse value as boolean")
|
||||
|
||||
type UserID uint64
|
||||
|
||||
type Users []User
|
||||
@@ -42,9 +46,11 @@ var TaggedDevices = User{
|
||||
func (u Users) String() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[ ")
|
||||
|
||||
for _, user := range u {
|
||||
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
|
||||
}
|
||||
|
||||
sb.WriteString(" ]")
|
||||
|
||||
return sb.String()
|
||||
@@ -55,7 +61,8 @@ func (u Users) String() string {
|
||||
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
||||
// that contain our machines.
|
||||
type User struct {
|
||||
gorm.Model
|
||||
gorm.Model //nolint:embeddedstructfieldcheck
|
||||
|
||||
// The index `idx_name_provider_identifier` is to enforce uniqueness
|
||||
// between Name and ProviderIdentifier. This ensures that
|
||||
// you can have multiple users with the same name in OIDC,
|
||||
@@ -91,6 +98,7 @@ func (u *User) StringID() string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strconv.FormatUint(uint64(u.ID), 10)
|
||||
}
|
||||
|
||||
@@ -130,7 +138,7 @@ func (u *User) profilePicURL() string {
|
||||
|
||||
func (u *User) TailscaleUser() tailcfg.User {
|
||||
return tailcfg.User{
|
||||
ID: tailcfg.UserID(u.ID),
|
||||
ID: tailcfg.UserID(u.ID), //nolint:gosec // UserID is bounded
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
Created: u.CreatedAt,
|
||||
@@ -150,7 +158,7 @@ func (u UserView) ID() uint {
|
||||
|
||||
func (u *User) TailscaleLogin() tailcfg.Login {
|
||||
return tailcfg.Login{
|
||||
ID: tailcfg.LoginID(u.ID),
|
||||
ID: tailcfg.LoginID(u.ID), //nolint:gosec // safe conversion for user ID
|
||||
Provider: u.Provider,
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.Display(),
|
||||
@@ -164,7 +172,7 @@ func (u UserView) TailscaleLogin() tailcfg.Login {
|
||||
|
||||
func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
|
||||
return tailcfg.UserProfile{
|
||||
ID: tailcfg.UserID(u.ID),
|
||||
ID: tailcfg.UserID(u.ID), //nolint:gosec // UserID is bounded
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
@@ -184,6 +192,7 @@ func (u *User) Proto() *v1.User {
|
||||
if name == "" {
|
||||
name = u.Username()
|
||||
}
|
||||
|
||||
return &v1.User{
|
||||
Id: uint64(u.ID),
|
||||
Name: name,
|
||||
@@ -220,7 +229,7 @@ func (u UserView) MarshalZerologObject(e *zerolog.Event) {
|
||||
u.ж.MarshalZerologObject(e)
|
||||
}
|
||||
|
||||
// JumpCloud returns a JSON where email_verified is returned as a
|
||||
// FlexibleBoolean handles JumpCloud's JSON where email_verified is returned as a
|
||||
// string "true" or "false" instead of a boolean.
|
||||
// This maps bool to a specific type with a custom unmarshaler to
|
||||
// ensure we can decode it from a string.
|
||||
@@ -229,6 +238,7 @@ type FlexibleBoolean bool
|
||||
|
||||
func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
|
||||
var val any
|
||||
|
||||
err := json.Unmarshal(data, &val)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshalling data: %w", err)
|
||||
@@ -242,10 +252,11 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing %s as boolean: %w", v, err)
|
||||
}
|
||||
|
||||
*bit = FlexibleBoolean(pv)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("parsing %v as boolean", v)
|
||||
return fmt.Errorf("%w: %v", ErrCannotParseBoolean, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -279,9 +290,11 @@ func (c *OIDCClaims) Identifier() string {
|
||||
if c.Iss == "" && c.Sub == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if c.Iss == "" {
|
||||
return CleanIdentifier(c.Sub)
|
||||
}
|
||||
|
||||
if c.Sub == "" {
|
||||
return CleanIdentifier(c.Iss)
|
||||
}
|
||||
@@ -292,9 +305,9 @@ func (c *OIDCClaims) Identifier() string {
|
||||
|
||||
var result string
|
||||
// Try to parse as URL to handle URL joining correctly
|
||||
if u, err := url.Parse(issuer); err == nil && u.Scheme != "" {
|
||||
if u, err := url.Parse(issuer); err == nil && u.Scheme != "" { //nolint:noinlineerr
|
||||
// For URLs, use proper URL path joining
|
||||
if joined, err := url.JoinPath(issuer, subject); err == nil {
|
||||
if joined, err := url.JoinPath(issuer, subject); err == nil { //nolint:noinlineerr
|
||||
result = joined
|
||||
}
|
||||
}
|
||||
@@ -366,6 +379,7 @@ func CleanIdentifier(identifier string) string {
|
||||
cleanParts = append(cleanParts, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cleanParts) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -408,6 +422,7 @@ func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) {
|
||||
if claims.Iss == "" && !strings.HasPrefix(identifier, "/") {
|
||||
identifier = "/" + identifier
|
||||
}
|
||||
|
||||
u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true}
|
||||
u.DisplayName = claims.Name
|
||||
u.ProfilePicURL = claims.ProfilePictureURL
|
||||
|
||||
@@ -66,10 +66,13 @@ func TestUnmarshallOIDCClaims(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got OIDCClaims
|
||||
if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil {
|
||||
|
||||
err := json.Unmarshal([]byte(tt.jsonstr), &got)
|
||||
if err != nil {
|
||||
t.Errorf("UnmarshallOIDCClaims() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -190,6 +193,7 @@ func TestOIDCClaimsIdentifier(t *testing.T) {
|
||||
}
|
||||
result := claims.Identifier()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("Identifier() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -282,6 +286,7 @@ func TestCleanIdentifier(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CleanIdentifier(tt.identifier)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -479,7 +484,9 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got OIDCClaims
|
||||
if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil {
|
||||
|
||||
err := json.Unmarshal([]byte(tt.jsonstr), &got)
|
||||
if err != nil {
|
||||
t.Errorf("TestOIDCClaimsJSONToUser() error = %v", err)
|
||||
return
|
||||
}
|
||||
@@ -487,6 +494,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
|
||||
var user User
|
||||
|
||||
user.FromClaim(&got, tt.emailVerifiedRequired)
|
||||
|
||||
if diff := cmp.Diff(user, tt.want); diff != "" {
|
||||
t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -38,9 +38,7 @@ func (v *VersionInfo) String() string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
var buildInfo = sync.OnceValues(func() (*debug.BuildInfo, bool) {
|
||||
return debug.ReadBuildInfo()
|
||||
})
|
||||
var buildInfo = sync.OnceValues(debug.ReadBuildInfo)
|
||||
|
||||
var GetVersionInfo = sync.OnceValue(func() *VersionInfo {
|
||||
info := &VersionInfo{
|
||||
|
||||
@@ -91,6 +91,7 @@ func ParseIPSet(arg string, bits *int) (*netipx.IPSet, error) {
|
||||
|
||||
func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
|
||||
var network, broadcast netip.Addr
|
||||
|
||||
ipRange := netipx.RangeOfPrefix(na)
|
||||
network = ipRange.From()
|
||||
broadcast = ipRange.To()
|
||||
|
||||
@@ -29,6 +29,7 @@ func Test_parseIPSet(t *testing.T) {
|
||||
arg string
|
||||
bits *int
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -111,6 +112,7 @@ func Test_parseIPSet(t *testing.T) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("parseIPSet() = (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
@@ -18,13 +18,27 @@ const (
|
||||
ipv4AddressLength = 32
|
||||
ipv6AddressLength = 128
|
||||
|
||||
// LabelHostnameLength is the maximum length for a DNS label,
|
||||
// value related to RFC 1123 and 952.
|
||||
LabelHostnameLength = 63
|
||||
)
|
||||
|
||||
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
|
||||
var ErrInvalidHostName = errors.New("invalid hostname")
|
||||
// DNS validation errors.
|
||||
var (
|
||||
ErrInvalidHostName = errors.New("invalid hostname")
|
||||
ErrUsernameTooShort = errors.New("username must be at least 2 characters long")
|
||||
ErrUsernameMustStartLetter = errors.New("username must start with a letter")
|
||||
ErrUsernameTooManyAt = errors.New("username cannot contain more than one '@'")
|
||||
ErrUsernameInvalidChar = errors.New("username contains invalid character")
|
||||
ErrHostnameTooShort = errors.New("hostname is too short, must be at least 2 characters")
|
||||
ErrHostnameTooLong = errors.New("hostname is too long, must not exceed 63 characters")
|
||||
ErrHostnameMustBeLowercase = errors.New("hostname must be lowercase")
|
||||
ErrHostnameHyphenBoundary = errors.New("hostname cannot start or end with a hyphen")
|
||||
ErrHostnameDotBoundary = errors.New("hostname cannot start or end with a dot")
|
||||
ErrHostnameInvalidChars = errors.New("hostname contains invalid characters")
|
||||
)
|
||||
|
||||
// ValidateUsername checks if a username is valid.
|
||||
// It must be at least 2 characters long, start with a letter, and contain
|
||||
@@ -34,12 +48,12 @@ var ErrInvalidHostName = errors.New("invalid hostname")
|
||||
func ValidateUsername(username string) error {
|
||||
// Ensure the username meets the minimum length requirement
|
||||
if len(username) < 2 {
|
||||
return errors.New("username must be at least 2 characters long")
|
||||
return ErrUsernameTooShort
|
||||
}
|
||||
|
||||
// Ensure the username starts with a letter
|
||||
if !unicode.IsLetter(rune(username[0])) {
|
||||
return errors.New("username must start with a letter")
|
||||
return ErrUsernameMustStartLetter
|
||||
}
|
||||
|
||||
atCount := 0
|
||||
@@ -55,10 +69,10 @@ func ValidateUsername(username string) error {
|
||||
case char == '@':
|
||||
atCount++
|
||||
if atCount > 1 {
|
||||
return errors.New("username cannot contain more than one '@'")
|
||||
return ErrUsernameTooManyAt
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("username contains invalid character: '%c'", char)
|
||||
return fmt.Errorf("%w: '%c'", ErrUsernameInvalidChar, char)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,44 +84,27 @@ func ValidateUsername(username string) error {
|
||||
// The hostname must already be lowercase and contain only valid characters.
|
||||
func ValidateHostname(name string) error {
|
||||
if len(name) < 2 {
|
||||
return fmt.Errorf(
|
||||
"hostname %q is too short, must be at least 2 characters",
|
||||
name,
|
||||
)
|
||||
return fmt.Errorf("%w: %q", ErrHostnameTooShort, name)
|
||||
}
|
||||
|
||||
if len(name) > LabelHostnameLength {
|
||||
return fmt.Errorf(
|
||||
"hostname %q is too long, must not exceed 63 characters",
|
||||
name,
|
||||
)
|
||||
return fmt.Errorf("%w: %q", ErrHostnameTooLong, name)
|
||||
}
|
||||
|
||||
if strings.ToLower(name) != name {
|
||||
return fmt.Errorf(
|
||||
"hostname %q must be lowercase (try %q)",
|
||||
name,
|
||||
strings.ToLower(name),
|
||||
)
|
||||
return fmt.Errorf("%w: %q (try %q)", ErrHostnameMustBeLowercase, name, strings.ToLower(name))
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") {
|
||||
return fmt.Errorf(
|
||||
"hostname %q cannot start or end with a hyphen",
|
||||
name,
|
||||
)
|
||||
return fmt.Errorf("%w: %q", ErrHostnameHyphenBoundary, name)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") {
|
||||
return fmt.Errorf(
|
||||
"hostname %q cannot start or end with a dot",
|
||||
name,
|
||||
)
|
||||
return fmt.Errorf("%w: %q", ErrHostnameDotBoundary, name)
|
||||
}
|
||||
|
||||
if invalidDNSRegex.MatchString(name) {
|
||||
return fmt.Errorf(
|
||||
"hostname %q contains invalid characters, only lowercase letters, numbers, hyphens and dots are allowed",
|
||||
name,
|
||||
)
|
||||
return fmt.Errorf("%w: %q", ErrHostnameInvalidChars, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -170,6 +167,7 @@ func NormaliseHostname(name string) (string, error) {
|
||||
// and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next
|
||||
// class block only.
|
||||
|
||||
// GenerateIPv4DNSRootDomain generates the IPv4 reverse DNS root domains.
|
||||
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
|
||||
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
|
||||
func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
@@ -183,25 +181,27 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
// wildcardBits is the number of bits not under the mask in the lastOctet
|
||||
wildcardBits := ByteSize - maskBits%ByteSize
|
||||
|
||||
// min is the value in the lastOctet byte of the IP
|
||||
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
|
||||
min := uint(netRange.IP[lastOctet])
|
||||
max := (min + 1<<uint(wildcardBits)) - 1
|
||||
// minVal is the value in the lastOctet byte of the IP
|
||||
// maxVal is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
|
||||
minVal := uint(netRange.IP[lastOctet])
|
||||
maxVal := (minVal + 1<<uint(wildcardBits)) - 1 //nolint:gosec // wildcardBits is always < 8, no overflow
|
||||
|
||||
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
|
||||
rdnsSlice := []string{}
|
||||
for i := lastOctet - 1; i >= 0; i-- {
|
||||
rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10))
|
||||
}
|
||||
|
||||
rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
|
||||
rdnsBase := strings.Join(rdnsSlice, ".")
|
||||
|
||||
fqdns := make([]dnsname.FQDN, 0, max-min+1)
|
||||
for i := min; i <= max; i++ {
|
||||
fqdns := make([]dnsname.FQDN, 0, maxVal-minVal+1)
|
||||
for i := minVal; i <= maxVal; i++ {
|
||||
fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%d.%s", i, rdnsBase))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fqdns = append(fqdns, fqdn)
|
||||
}
|
||||
|
||||
@@ -226,6 +226,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
// and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next
|
||||
// class block only.
|
||||
|
||||
// GenerateIPv6DNSRootDomain generates the IPv6 reverse DNS root domains.
|
||||
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
|
||||
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
|
||||
func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
@@ -259,18 +260,22 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
}
|
||||
|
||||
var fqdns []dnsname.FQDN
|
||||
|
||||
if maskBits%4 == 0 {
|
||||
dom, _ := makeDomain()
|
||||
fqdns = append(fqdns, dom)
|
||||
} else {
|
||||
domCount := 1 << (maskBits % nibbleLen)
|
||||
|
||||
fqdns = make([]dnsname.FQDN, 0, domCount)
|
||||
for i := range domCount {
|
||||
varNibble := fmt.Sprintf("%x", i)
|
||||
|
||||
dom, err := makeDomain(varNibble)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fqdns = append(fqdns, dom)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func TestNormaliseHostname(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -90,6 +91,7 @@ func TestNormaliseHostname(t *testing.T) {
|
||||
t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && got != tt.want {
|
||||
t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want)
|
||||
}
|
||||
@@ -172,6 +174,7 @@ func TestValidateHostname(t *testing.T) {
|
||||
t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr && tt.errorContains != "" {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.errorContains) {
|
||||
t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains)
|
||||
|
||||
@@ -21,6 +21,9 @@ const (
|
||||
PermissionFallback = 0o700
|
||||
)
|
||||
|
||||
// ErrDirectoryPermission is returned when creating a directory fails due to permission issues.
|
||||
var ErrDirectoryPermission = errors.New("creating directory failed with permission error")
|
||||
|
||||
func AbsolutePathFromConfigPath(path string) string {
|
||||
// If a relative path is provided, prefix it with the directory where
|
||||
// the config file was found.
|
||||
@@ -42,18 +45,15 @@ func GetFileMode(key string) fs.FileMode {
|
||||
return PermissionFallback
|
||||
}
|
||||
|
||||
return fs.FileMode(mode)
|
||||
return fs.FileMode(mode) //nolint:gosec // file mode is bounded by ParseUint
|
||||
}
|
||||
|
||||
func EnsureDir(dir string) error {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) { //nolint:noinlineerr
|
||||
err := os.MkdirAll(dir, PermissionFallback)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
return fmt.Errorf(
|
||||
"creating directory %s, failed with permission error, is it located somewhere Headscale can write?",
|
||||
dir,
|
||||
)
|
||||
return fmt.Errorf("%w: %s", ErrDirectoryPermission, dir)
|
||||
}
|
||||
|
||||
return fmt.Errorf("creating directory %s: %w", dir, err)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user