Compare commits

...

15 Commits

Author SHA1 Message Date
crn4
7566afd7d0 components approach - we are sending all components needed for nmap assembling on client side 2025-12-29 17:31:16 +01:00
crn4
e93d4132d3 Merge branch 'main' into nmap/compaction 2025-12-18 16:50:41 +01:00
Zoltan Papp
537151e0f3 Remove redundant lock in peer update logic to avoid deadlock with exported functions (#4953) 2025-12-17 13:55:33 +01:00
Zoltan Papp
a9c28ef723 Add stack trace for bundle (#4957) 2025-12-17 13:49:02 +01:00
Pascal Fischer
c29bb1a289 [management] use xid as request id for logging (#4955) 2025-12-16 14:02:37 +01:00
Zoltan Papp
447cd287f5 [ci] Add local lint setup with pre-push hook to catch issues early (#4925)
* Add local lint setup with pre-push hook to catch issues early

Developers can now catch lint issues before pushing, reducing CI failures
and iteration time. The setup uses golangci-lint locally with the same
configuration as CI.

Setup:
- Run `make setup-hooks` once after cloning
- Pre-push hook automatically lints changed files (~90s)
- Use `make lint` to manually check changed files
- Use `make lint-all` to run full CI-equivalent lint

The Makefile auto-installs golangci-lint to ./bin/ using go install to
match the Go version in go.mod, avoiding version compatibility issues.

---------

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2025-12-15 10:34:48 +01:00
Zoltan Papp
5748bdd64e Add health-check agent recognition to avoid error logs (#4917)
Health-check connections now send a properly formatted auth message
with a well-known peer ID instead of immediately closing. The server
recognizes this peer ID and handles the connection gracefully with a
debug log instead of error logs.
2025-12-15 10:28:25 +01:00
Diego Romar
08f31fbcb3 [iOS] Add force relay connection on iOS (#4928)
* [ios] Add a bogus test to check iOS behavior when setting environment variables

* [ios] Revert "Add a bogus test to check iOS behavior when setting environment variables"

This reverts commit 90ca01105a6b0f4471aac07a63fc95e5d4eaef9b.

* [ios] Add EnvList struct to export and import environment variables

* [ios] Add envList parameter to the iOS Client Run method

* [ios] Add some debug logging to exportEnvVarList

* Add "//go:build ios" to client/ios/NetBirdSDK files
2025-12-12 14:29:58 -03:00
Bethuel Mmbaga
932c02eaab [management] Approve all pending peers when peer approval is disabled (#4806) 2025-12-12 18:49:57 +03:00
Pascal Fischer
abcbde26f9 [management] remove context from store methods (#4940) 2025-12-11 21:45:47 +01:00
Pascal Fischer
90e3b8009f [management] Fix sync metrics (#4939) 2025-12-11 20:11:12 +01:00
Pascal Fischer
94d34dc0c5 [management] monitoring updates (#4937) 2025-12-11 18:29:15 +01:00
Pascal Fischer
44851e06fb [management] cleanup logs (#4933) 2025-12-10 19:26:51 +01:00
Viktor Liu
3f4f825ec1 [client] Fix DNS forwarder returning broken records on 4 to 6 mapped IP addresses (#4887) 2025-12-05 17:42:49 +01:00
Viktor Liu
f538e6e9ae [client] Use setsid to avoid the parent process from being killed via HUP by login (#4900) 2025-12-05 03:29:27 +01:00
48 changed files with 2806 additions and 211 deletions

11
.githooks/pre-push Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
echo "Running pre-push hook..."
if ! make lint; then
echo ""
echo "Hint: To push without verification, run:"
echo " git push --no-verify"
exit 1
fi
echo "All checks passed!"

View File

@@ -136,6 +136,14 @@ checked out and set up:
go mod tidy
```
6. Configure Git hooks for automatic linting:
```bash
make setup-hooks
```
This will configure Git to run linting automatically before each push, helping catch issues early.
### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers.

27
Makefile Normal file
View File

@@ -0,0 +1,27 @@
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)
@echo "Running lint on changed files..."
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
# Lint entire codebase (slow, matches CI)
lint-all: $(GOLANGCI_LINT)
@echo "Running lint on all files..."
@$(GOLANGCI_LINT) run --timeout=12m
# Just install the linter
lint-install: $(GOLANGCI_LINT)
# Setup git hooks for all developers
setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"

View File

@@ -56,6 +56,7 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
Anonymization Process
@@ -109,6 +110,9 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information.
Stack Trace
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
Routes
The routes.txt file contains detailed routing table information in a tabular format:
@@ -327,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
if err := g.addSyncResponse(); err != nil {
return fmt.Errorf("add sync response: %w", err)
}
@@ -522,6 +530,18 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
stackTrace := bytes.NewReader(buf[:n])
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
return fmt.Errorf("add stack trace file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {

View File

@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)

View File

@@ -20,7 +20,7 @@ type EndpointUpdater struct {
wgConfig WgConfig
initiator bool
// mu protects updateWireGuardPeer and cancelFunc
// mu protects cancelFunc
mu sync.Mutex
cancelFunc func()
updateWg sync.WaitGroup
@@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
case <-ctx.Done():
return
case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
e.mu.Unlock()
}
}

View File

@@ -1,9 +1,12 @@
//go:build ios
package NetBirdSDK
import (
"context"
"fmt"
"net/netip"
"os"
"sort"
"strings"
"sync"
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string) error {
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
exportEnvList(envList)
log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
}
return netIDs
}
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
log.Debugf("Setting env variable %s: %s", k, v)
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
} else {
log.Debugf("Env variable %s was set successfully", k)
}
}
}

View File

@@ -0,0 +1,34 @@
//go:build ios
package NetBirdSDK
import "github.com/netbirdio/netbird/client/internal/peer"
// EnvList is an exported struct to be bound by gomobile
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
func GetEnvKeyNBForceRelay() string {
return peer.EnvKeyNBForceRelay
}

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import _ "golang.org/x/mobile/bind"

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
// PeerInfo describe information about the peers. It designed for the UI usage

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection

View File

@@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// detectUtilLinuxLogin always returns false on JS/WASM
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty is not supported on JS/WASM
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
logger.Errorf("PTY command execution not supported on JS/WASM")

View File

@@ -10,6 +10,7 @@ import (
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"sync"
"syscall"
@@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool {
return supported
}
// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils).
// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent.
// See https://bugs.debian.org/1078023 for details.
func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
if runtime.GOOS != "linux" {
return false
}
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
cmd := exec.CommandContext(ctx, "login", "--version")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("login --version failed (likely shadow-utils): %v", err)
return false
}
isUtilLinux := strings.Contains(string(output), "util-linux")
log.Debugf("util-linux login detected: %v", isUtilLinux)
return isUtilLinux
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
@@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
return false
}
logger.Infof("starting interactive shell: %s", execCmd.Path)
logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " "))
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}

View File

@@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// detectUtilLinuxLogin always returns false on Windows
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()

View File

@@ -138,7 +138,8 @@ type Server struct {
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
suSupportsPty bool
suSupportsPty bool
loginIsUtilLinux bool
}
type JWTConfig struct {
@@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
}
s.suSupportsPty = s.detectSuPtySupport(ctx)
s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx)
ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil {

View File

@@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
switch runtime.GOOS {
case "linux":
// Special handling for Arch Linux without /etc/pam.d/remote
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
return loginPath, []string{"-f", username, "-p"}, nil
}
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String())
return p, a, nil
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
default:
@@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
}
}
// fileExists checks if a file exists (helper for login command logic)
// getLinuxLoginCmd returns the login command for Linux systems.
// Handles differences between util-linux and shadow-utils login implementations.
func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) {
// Special handling for Arch Linux without /etc/pam.d/remote
var loginArgs []string
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
loginArgs = []string{"-f", username, "-p"}
} else {
loginArgs = []string{"-f", username, "-h", remoteIP, "-p"}
}
// util-linux login requires setsid -c to create a new session and set the
// controlling terminal. Without this, vhangup() kills the parent process.
// See https://bugs.debian.org/1078023 for details.
// TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec()
// to avoid external setsid dependency.
if !s.loginIsUtilLinux {
return loginPath, loginArgs
}
setsidPath, err := exec.LookPath("setsid")
if err != nil {
log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err)
return loginPath, loginArgs
}
args := append([]string{"-w", "-c", loginPath}, loginArgs...)
return setsidPath, args
}
// fileExists checks if a file exists
func (s *Server) fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil

View File

@@ -60,14 +60,7 @@ func (hook ContextHook) Fire(entry *logrus.Entry) error {
entry.Data["context"] = source
switch source {
case HTTPSource:
addHTTPFields(entry)
case GRPCSource:
addGRPCFields(entry)
case SystemSource:
addSystemFields(entry)
}
addFields(entry)
return nil
}
@@ -99,7 +92,7 @@ func (hook ContextHook) parseSrc(filePath string) string {
return fmt.Sprintf("%s/%s", pkg, file)
}
func addHTTPFields(entry *logrus.Entry) {
func addFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
@@ -109,30 +102,6 @@ func addHTTPFields(entry *logrus.Entry) {
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
}
func addGRPCFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}
func addSystemFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}

2
go.mod
View File

@@ -64,7 +64,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -10,9 +10,9 @@ import (
"slices"
"time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -180,7 +180,7 @@ func unaryInterceptor(
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
@@ -194,7 +194,7 @@ func streamInterceptor(
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
reqID := xid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)

View File

@@ -134,10 +134,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
}
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
}()
// todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
@@ -194,7 +190,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
}
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
s.syncSem.Add(-1)
@@ -222,8 +218,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return err
}
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
@@ -235,7 +229,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
}
}()
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
@@ -352,7 +345,7 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
}
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
@@ -525,7 +518,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
reqStart := time.Now()
realIP := getRealIP(ctx)
sRealIP := realIP.String()
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq)
@@ -537,7 +529,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
metahashed := metaHash(peerMeta, sRealIP)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
@@ -561,16 +553,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
took := time.Since(reqStart)
if took > 7*time.Second {
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
}
}()
if loginReq.GetMeta() == nil {
@@ -604,16 +592,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, mapError(ctx, err)
}
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
key, err := s.secretsManager.GetWGKey()
if err != nil {
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
@@ -730,12 +714,10 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "error handling request")
}
sendStart := time.Now()
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
if err != nil {
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
@@ -750,10 +732,6 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
// which will be used by our clients to Login
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
@@ -813,10 +791,6 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
// which will be used by our clients to Login
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {

View File

@@ -167,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI
relayCancel := make(chan struct{}, 1)
m.relayCancelMap[peerID] = relayCancel
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID)
}
}
@@ -178,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
@@ -193,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewRelayTokens(ctx, accountID, peerID)

View File

@@ -295,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
return err
}
if oldSettings.Extra != nil && newSettings.Extra != nil &&
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to approve pending peers: %w", err)
}
if approvedCount > 0 {
log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID)
updateAccountPeers = true
}
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
@@ -372,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return newSettings, nil
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -386,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
}
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
@@ -787,6 +790,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID)
if ctx == nil {
ctx = context.Background()
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil {
return nil, nil, err

View File

@@ -2058,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) {
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
accountID := account.Id
userID := account.Users[account.CreatedBy].Id
ctx := context.Background()
newSettings := account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: true,
}
_, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
peer1.Status.RequiresApproval = true
peer2.Status.RequiresApproval = true
peer3.Status.RequiresApproval = false
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3))
newSettings = account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: false,
}
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range accountPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID)
}
}
func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct {
name string

View File

@@ -141,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
}
if userAuth.AccountId != accountId {
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}

View File

@@ -127,7 +127,7 @@ type MockIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error {
return nil
}

View File

@@ -10,7 +10,7 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)

View File

@@ -172,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
}
}
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil {
@@ -783,7 +783,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
startTransaction := time.Now()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
@@ -853,8 +852,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {

View File

@@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error {
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}
@@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}

View File

@@ -5,9 +5,6 @@ package settings
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
@@ -48,11 +45,6 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
}
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start))
}()
if userID != activity.SystemInitiator {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {

View File

@@ -27,7 +27,6 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
nbcontext "github.com/netbirdio/netbird/management/server/context"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -288,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds())
log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds())
return err
}
@@ -413,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW
return nil
}
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true).
Update("peer_status_requires_approval", false)
if result.Error != nil {
return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error)
}
return int(result.RowsAffected), nil
}
// SaveUsers saves the given list of users to the database.
func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
if len(users) == 0 {
@@ -583,16 +594,13 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
result := tx.Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -2152,16 +2160,13 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
@@ -2171,16 +2176,13 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peer nbpeer.Peer
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -2229,11 +2231,8 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var user types.User
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
@@ -2491,16 +2490,13 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKey types.SetupKey
result := tx.WithContext(ctx).
result := tx.
Take(&setupKey, GetKeyQueryCondition(s), key)
if result.Error != nil {
@@ -2514,10 +2510,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
result := s.db.Model(&types.SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
@@ -2537,11 +2530,8 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var groupID string
_ = s.db.WithContext(ctx).Model(types.Group{}).
_ = s.db.Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
@@ -2569,9 +2559,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
peer := &types.GroupPeer{
AccountID: accountID,
GroupID: groupID,
@@ -2768,10 +2755,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -2897,10 +2881,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -4022,36 +4003,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
return groupPeers, nil
}
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
}
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
}
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
}
go func() {
select {
case <-ctx.Done():
case <-grpcCtx.Done():
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
}
}()
return ctx, cancel
}
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
var info types.PrimaryAccountInfo
result := s.db.Model(&types.Account{}).
@@ -4091,7 +4042,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
Network: &types.Network{Net: ipNet},
}
result := s.db.WithContext(ctx).
result := s.db.
Model(&types.Account{}).
Where(idQueryCondition, accountID).
Updates(&patch)

View File

@@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
})
}
}
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
accountID := "test-account"
ctx := context.Background()
account := newAccountWithId(ctx, accountID, "testuser", "example.com")
err := store.SaveAccount(ctx, account)
require.NoError(t, err)
peers := []*nbpeer.Peer{
{
ID: "peer1",
AccountID: accountID,
DNSLabel: "peer1.netbird.cloud",
Key: "peer1-key",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer2",
AccountID: accountID,
DNSLabel: "peer2.netbird.cloud",
Key: "peer2-key",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer3",
AccountID: accountID,
DNSLabel: "peer3.netbird.cloud",
Key: "peer3-key",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{
RequiresApproval: false,
LastSeen: time.Now().UTC(),
},
},
}
for _, peer := range peers {
err = store.AddPeerToAccount(ctx, peer)
require.NoError(t, err)
}
t.Run("approve all pending peers", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 2, count)
allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range allPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID)
}
})
t.Run("no peers to approve", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 0, count)
})
t.Run("non-existent account", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, "non-existent")
require.NoError(t, err)
assert.Equal(t, 0, count)
})
})
}

View File

@@ -143,6 +143,7 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)

View File

@@ -16,7 +16,6 @@ type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter
syncRequestHighLatencyCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter
@@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
)
if err != nil {
return nil, err
}
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
meter: meter,
syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
@@ -175,9 +165,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
// CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
if duration > HighLatencyThreshold {
grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
}
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.

View File

@@ -7,8 +7,8 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
@@ -169,7 +169,7 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
//nolint
ctx := context.WithValue(r.Context(), hook.ExecutionContextKey, hook.HTTPSource)
reqID := uuid.New().String()
reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
@@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, r.WithContext(ctx))
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil {
if userAuth.AccountId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
}
if userAuth.UserId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
}
}
if w.Status() > 399 {
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
} else {

View File

@@ -355,6 +355,119 @@ func (a *Account) GetPeerNetworkMap(
return nm
}
// GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMapCompacted(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
if _, ok := validatedPeersMap[peerID]; !ok {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
var networkResourcesFirewallRules []*RouteFirewallRule
if isRouter {
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
}
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
})
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
type crt struct {
route *route.Route
peerIds []string
}
var routes []*route.Route
rtfilter := make(map[string]crt)
otherRoutesIDs := slices.Concat(networkResourcesRoutes, routesUpdate)
for _, route := range otherRoutesIDs {
rid, pid := splitRouteAndPeer(route)
if pid == peerID || len(pid) == 0 {
routes = append(routes, route)
continue
}
crt := rtfilter[rid]
crt.peerIds = append(crt.peerIds, pid)
crt.route = route.CopyClean()
rtfilter[rid] = crt
}
for rid, crt := range rtfilter {
crt.route.ApplicablePeerIDs = crt.peerIds
crt.route.ID = route.ID(rid)
routes = append(routes, crt.route)
}
nm := &NetworkMap{
Peers: peersToConnectIncludingRouters,
Network: a.Network.Copy(),
Routes: routes,
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
}
if metrics != nil {
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
if objectCount > 5000 {
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
}
}
return nm
}
func (a *Account) addNetworksRoutingPeers(
networkResourcesRoutes []*route.Route,
peer *nbpeer.Peer,

View File

@@ -0,0 +1,442 @@
package types
import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/route"
)
func (a *Account) GetPeerNetworkMapComponents(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
) *NetworkMapComponents {
peer := a.Peers[peerID]
if peer == nil {
return nil
}
if _, ok := validatedPeersMap[peerID]; !ok {
return nil
}
components := &NetworkMapComponents{
PeerID: peerID,
Serial: a.Network.Serial,
Network: a.Network.Copy(),
Peers: make(map[string]*nbpeer.Peer),
Groups: make(map[string]*Group),
Policies: make([]*Policy, 0),
Routes: make([]*route.Route, 0),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
CustomZoneDomain: peersCustomZone.Domain,
AllDNSRecords: peersCustomZone.Records,
ResourcePoliciesMap: make(map[string][]*Policy),
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
}
components.AccountSettings = &AccountSettingsInfo{
PeerLoginExpirationEnabled: a.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: a.Settings.PeerLoginExpiration,
PeerInactivityExpirationEnabled: a.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: a.Settings.PeerInactivityExpiration,
}
components.DNSSettings = &a.DNSSettings
relevantPeerIDsList, relevantGroupIDs := a.findRelevantPeersAndGroups(ctx, peerID, validatedPeersMap)
relevantPeerIDsMap := make(map[string]struct{})
for _, pid := range relevantPeerIDsList {
relevantPeerIDsMap[pid] = struct{}{}
}
_, _, networkResourcesSourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
for sourcePeerID := range networkResourcesSourcePeers {
relevantPeerIDsMap[sourcePeerID] = struct{}{}
}
for pid := range relevantPeerIDsMap {
if p := a.Peers[pid]; p != nil {
components.Peers[pid] = p
}
}
for gid := range relevantGroupIDs {
if g := a.Groups[gid]; g != nil {
components.Groups[gid] = g
}
}
for _, policy := range a.Policies {
if a.isPolicyRelevantForPeer(ctx, policy, peerID, relevantGroupIDs) {
components.Policies = append(components.Policies, policy)
}
}
for _, r := range a.Routes {
if a.isRouteRelevantForPeer(ctx, r, peerID, relevantGroupIDs) {
components.Routes = append(components.Routes, r)
}
}
for _, nsGroup := range a.NameServerGroups {
if nsGroup.Enabled {
for _, gID := range nsGroup.Groups {
if _, found := relevantGroupIDs[gID]; found {
components.NameServerGroups = append(components.NameServerGroups, nsGroup.Copy())
break
}
}
}
}
relevantResourceIDs := make(map[string]struct{})
relevantNetworkIDs := make(map[string]struct{})
for _, resource := range a.NetworkResources {
if !resource.Enabled {
continue
}
policies, exists := resourcePolicies[resource.ID]
if !exists {
continue
}
isRelevant := false
networkRoutingPeers, routerExists := routers[resource.NetworkID]
if routerExists {
if _, ok := networkRoutingPeers[peerID]; ok {
isRelevant = true
}
}
if !isRelevant {
for _, policy := range policies {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
}
for _, p := range peers {
if p == peerID && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
isRelevant = true
break
}
}
if isRelevant {
break
}
}
}
if isRelevant {
relevantResourceIDs[resource.ID] = struct{}{}
relevantNetworkIDs[resource.NetworkID] = struct{}{}
components.NetworkResources = append(components.NetworkResources, resource)
}
}
for resID, policies := range resourcePolicies {
if _, isRelevant := relevantResourceIDs[resID]; !isRelevant {
continue
}
for _, p := range policies {
for _, rule := range p.Rules {
for _, srcGroupID := range rule.Sources {
if g := a.Groups[srcGroupID]; g != nil {
if _, exists := components.Groups[srcGroupID]; !exists {
components.Groups[srcGroupID] = g
}
}
}
for _, dstGroupID := range rule.Destinations {
if g := a.Groups[dstGroupID]; g != nil {
if _, exists := components.Groups[dstGroupID]; !exists {
components.Groups[dstGroupID] = g
}
}
}
}
}
components.ResourcePoliciesMap[resID] = policies
}
for networkID, networkRouters := range routers {
if _, isRelevant := relevantNetworkIDs[networkID]; !isRelevant {
continue
}
components.RoutersMap[networkID] = networkRouters
for peerIDKey := range networkRouters {
if _, exists := components.Peers[peerIDKey]; !exists {
if p := a.Peers[peerIDKey]; p != nil {
components.Peers[peerIDKey] = p
}
}
}
}
for groupID, groupInfo := range components.Groups {
filteredPeers := make([]string, 0, len(groupInfo.Peers))
for _, peerID := range groupInfo.Peers {
if _, exists := components.Peers[peerID]; exists {
filteredPeers = append(filteredPeers, peerID)
}
}
if len(filteredPeers) == 0 {
delete(components.Groups, groupID)
} else {
groupInfo.Peers = filteredPeers
components.Groups[groupID] = groupInfo
}
}
return components
}
func (a *Account) findRelevantPeersAndGroups(
ctx context.Context,
peerID string,
validatedPeersMap map[string]struct{},
) ([]string, map[string]struct{}) {
relevantPeerIDs := make(map[string]struct{})
relevantGroupIDs := make(map[string]struct{})
relevantPeerIDs[peerID] = struct{}{}
for groupID, group := range a.Groups {
for _, pid := range group.Peers {
if pid == peerID {
relevantGroupIDs[groupID] = struct{}{}
break
}
}
}
for _, policy := range a.Policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
var sourcePeers, destinationPeers []string
var peerInSources, peerInDestinations bool
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers = []string{rule.SourceResource.ID}
if rule.SourceResource.ID == peerID {
peerInSources = true
}
} else {
sourcePeers, peerInSources = a.getPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
}
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
destinationPeers = []string{rule.DestinationResource.ID}
if rule.DestinationResource.ID == peerID {
peerInDestinations = true
}
} else {
destinationPeers, peerInDestinations = a.getPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
}
if rule.Bidirectional {
if peerInSources {
for _, pid := range destinationPeers {
relevantPeerIDs[pid] = struct{}{}
}
for _, dstGroupID := range rule.Destinations {
relevantGroupIDs[dstGroupID] = struct{}{}
}
}
if peerInDestinations {
for _, pid := range sourcePeers {
relevantPeerIDs[pid] = struct{}{}
}
for _, srcGroupID := range rule.Sources {
relevantGroupIDs[srcGroupID] = struct{}{}
}
}
}
if peerInSources {
for _, pid := range destinationPeers {
relevantPeerIDs[pid] = struct{}{}
}
for _, dstGroupID := range rule.Destinations {
relevantGroupIDs[dstGroupID] = struct{}{}
}
}
if peerInDestinations {
for _, pid := range sourcePeers {
relevantPeerIDs[pid] = struct{}{}
}
for _, srcGroupID := range rule.Sources {
relevantGroupIDs[srcGroupID] = struct{}{}
}
}
}
}
for _, r := range a.Routes {
isRelevant := false
for _, groupID := range r.Groups {
if _, found := relevantGroupIDs[groupID]; found {
isRelevant = true
break
}
}
if r.Peer == peerID || r.PeerID == peerID {
isRelevant = true
}
for _, groupID := range r.PeerGroups {
if group := a.Groups[groupID]; group != nil {
for _, pid := range group.Peers {
if pid == peerID {
isRelevant = true
break
}
}
}
}
if isRelevant {
for _, groupID := range r.Groups {
relevantGroupIDs[groupID] = struct{}{}
}
for _, groupID := range r.PeerGroups {
relevantGroupIDs[groupID] = struct{}{}
}
for _, groupID := range r.AccessControlGroups {
relevantGroupIDs[groupID] = struct{}{}
}
if r.Peer != "" {
relevantPeerIDs[r.Peer] = struct{}{}
}
if r.PeerID != "" {
relevantPeerIDs[r.PeerID] = struct{}{}
}
}
}
peerIDsList := make([]string, 0, len(relevantPeerIDs))
for pid := range relevantPeerIDs {
peerIDsList = append(peerIDsList, pid)
}
return peerIDsList, relevantGroupIDs
}
func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]string, bool) {
peerInGroups := false
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
filteredPeerIDs := make([]string, 0, len(uniquePeerIDs))
for _, p := range uniquePeerIDs {
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeerIDs = append(filteredPeerIDs, peer.ID)
}
return filteredPeerIDs, peerInGroups
}
func (a *Account) isPolicyRelevantForPeer(ctx context.Context, policy *Policy, peerID string, relevantGroupIDs map[string]struct{}) bool {
for _, rule := range policy.Rules {
for _, srcGroupID := range rule.Sources {
if _, found := relevantGroupIDs[srcGroupID]; found {
return true
}
}
for _, dstGroupID := range rule.Destinations {
if _, found := relevantGroupIDs[dstGroupID]; found {
return true
}
}
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID == peerID {
return true
}
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID == peerID {
return true
}
}
return false
}
func (a *Account) isRouteRelevantForPeer(ctx context.Context, r *route.Route, peerID string, relevantGroupIDs map[string]struct{}) bool {
if r.Peer == peerID || r.PeerID == peerID {
return true
}
for _, groupID := range r.Groups {
if _, found := relevantGroupIDs[groupID]; found {
return true
}
}
for _, groupID := range r.PeerGroups {
if group := a.Groups[groupID]; group != nil {
for _, pid := range group.Peers {
if pid == peerID {
return true
}
}
}
}
for _, groupID := range r.AccessControlGroups {
if _, found := relevantGroupIDs[groupID]; found {
return true
}
}
return false
}

View File

@@ -0,0 +1,769 @@
package types
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
)
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
validatedPeersMap,
resourcePolicies,
routers,
nil,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
validatedPeersMap,
resourcePolicies,
routers,
)
if components == nil {
t.Fatal("GetPeerNetworkMapComponents returned nil")
}
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
if newNetworkMap == nil {
t.Fatal("CalculateNetworkMapFromComponents returned nil")
}
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
}
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
validatedPeersMap,
resourcePolicies,
routers,
nil,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
validatedPeersMap,
resourcePolicies,
routers,
)
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
normalizeAndSortNetworkMap(legacyNetworkMap)
normalizeAndSortNetworkMap(newNetworkMap)
componentsJSON, err := json.MarshalIndent(components, "", " ")
require.NoError(t, err, "error marshaling components to JSON")
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
goldenDir := filepath.Join("testdata", "comparison")
err = os.MkdirAll(goldenDir, 0755)
require.NoError(t, err)
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
require.NoError(t, err, "error writing legacy golden file")
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
err = os.WriteFile(newGoldenPath, newJSON, 0644)
require.NoError(t, err, "error writing components golden file")
componentsPath := filepath.Join(goldenDir, "components.json")
err = os.WriteFile(componentsPath, componentsJSON, 0644)
require.NoError(t, err, "error writing components golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON),
"NetworkMaps from legacy and components approaches do not match.\n"+
"Legacy JSON saved to: %s\n"+
"Components JSON saved to: %s",
legacyGoldenPath, newGoldenPath)
t.Logf("✅ NetworkMaps are identical")
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
t.Logf(" Components NetworkMap: %s", newGoldenPath)
}
func normalizeAndSortNetworkMap(nm *NetworkMap) {
if nm == nil {
return
}
sort.Slice(nm.Peers, func(i, j int) bool {
return nm.Peers[i].ID < nm.Peers[j].ID
})
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
})
sort.Slice(nm.Routes, func(i, j int) bool {
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
})
sort.Slice(nm.FirewallRules, func(i, j int) bool {
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
}
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
}
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
})
for i := range nm.RoutesFirewallRules {
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
}
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
}
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
}
for k := 0; k < minLen; k++ {
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
}
}
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
}
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
})
if nm.DNSConfig.CustomZones != nil {
for i := range nm.DNSConfig.CustomZones {
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
})
}
}
}
func compareNetworkMaps(t *testing.T, legacy, new *NetworkMap) {
t.Helper()
if legacy.Network.Serial != new.Network.Serial {
t.Errorf("Network Serial mismatch: legacy=%d, new=%d", legacy.Network.Serial, new.Network.Serial)
}
if len(legacy.Peers) != len(new.Peers) {
t.Errorf("Peers count mismatch: legacy=%d, new=%d", len(legacy.Peers), len(new.Peers))
}
legacyPeerIDs := make(map[string]bool)
for _, p := range legacy.Peers {
legacyPeerIDs[p.ID] = true
}
for _, p := range new.Peers {
if !legacyPeerIDs[p.ID] {
t.Errorf("New NetworkMap contains peer %s not in legacy", p.ID)
}
}
if len(legacy.OfflinePeers) != len(new.OfflinePeers) {
t.Errorf("OfflinePeers count mismatch: legacy=%d, new=%d", len(legacy.OfflinePeers), len(new.OfflinePeers))
}
if len(legacy.FirewallRules) != len(new.FirewallRules) {
t.Logf("FirewallRules count mismatch: legacy=%d, new=%d", len(legacy.FirewallRules), len(new.FirewallRules))
}
if len(legacy.Routes) != len(new.Routes) {
t.Logf("Routes count mismatch: legacy=%d, new=%d", len(legacy.Routes), len(new.Routes))
}
if len(legacy.RoutesFirewallRules) != len(new.RoutesFirewallRules) {
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, new=%d", len(legacy.RoutesFirewallRules), len(new.RoutesFirewallRules))
}
if legacy.DNSConfig.ServiceEnable != new.DNSConfig.ServiceEnable {
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, new=%v", legacy.DNSConfig.ServiceEnable, new.DNSConfig.ServiceEnable)
}
}
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
policyIDDevOps = "policy-dev-ops"
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
networkRouterID = "router-database"
nameserverGroupID = "ns-group-main"
testingPeerID = "peer-60"
expiredPeerID = "peer-98"
offlinePeerID = "peer-99"
routingPeerID = "peer-95"
testAccountID = "account-comparison-test"
)
func createTestAccount() *Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
ip := net.IP{100, 64, 0, byte(i + 1)}
wtVersion := "0.25.0"
if i%2 == 0 {
wtVersion = "0.40.0"
}
p := &nbpeer.Peer{
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
}
if peerID == expiredPeerID {
p.LoginExpirationEnabled = true
pastTimestamp := time.Now().Add(-2 * time.Hour)
p.LastLogin = &pastTimestamp
}
peers[peerID] = p
allGroupPeers = append(allGroupPeers, peerID)
if i < numPeers/2 {
devGroupPeers = append(devGroupPeers, peerID)
} else {
opsGroupPeers = append(opsGroupPeers, peerID)
}
}
groups := map[string]*Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
}
policies := []*Policy{
{
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
}},
},
{
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
SourcePostureChecks: []string{postureCheckID},
Rules: []*PolicyRule{{
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
}},
},
}
routes := map[route.ID]*route.Route{
routeID: {
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-75"].Key,
PeerID: "peer-75",
Description: "Route to internal resource", Enabled: true,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
},
routeHA1ID: {
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-80"].Key,
PeerID: "peer-80",
Description: "HA Route 1", Enabled: true, Metric: 1000,
PeerGroups: []string{allGroupID},
Groups: []string{allGroupID},
AccessControlGroups: []string{allGroupID},
},
routeHA2ID: {
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-90"].Key,
PeerID: "peer-90",
Description: "HA Route 2", Enabled: true, Metric: 900,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{allGroupID},
},
}
account := &Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Network: &Network{
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
NameServerGroups: map[string]*nbdns.NameServerGroup{
nameserverGroupID: {
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
},
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
},
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}
func BenchmarkLegacyNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
validatedPeersMap,
resourcePolicies,
routers,
nil,
)
}
}
func BenchmarkComponentsNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
b.ResetTimer()
for i := 0; i < b.N; i++ {
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
validatedPeersMap,
resourcePolicies,
routers,
)
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}
func BenchmarkComponentsCreation(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
validatedPeersMap,
resourcePolicies,
routers,
)
}
}
func BenchmarkCalculationFromComponents(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
validatedPeersMap,
resourcePolicies,
routers,
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}
func TestGetPeerNetworkMap_ProdAccount_CompareImplementations(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
ctx := context.Background()
testAccount := loadProdAccountFromJSON(t)
testingPeerID := "cq3526bl0ubs73bbtpbg"
require.Contains(t, testAccount.Peers, testingPeerID, "Testing peer should exist in account")
validatedPeersMap := make(map[string]struct{})
for peerID := range testAccount.Peers {
validatedPeersMap[peerID] = struct{}{}
}
resourcePolicies := testAccount.GetResourcePoliciesMap()
routers := testAccount.GetResourceRoutersMap()
legacyNetworkMap := testAccount.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
require.NotNil(t, legacyNetworkMap, "GetPeerNetworkMap returned nil")
components := testAccount.GetPeerNetworkMapComponents(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers)
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
normalizeAndSortNetworkMap(legacyNetworkMap)
normalizeAndSortNetworkMap(newNetworkMap)
componentsJSON, err := json.MarshalIndent(components, "", " ")
require.NoError(t, err, "error marshaling components to JSON")
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
outputDir := filepath.Join("testdata", fmt.Sprintf("compare_peer_%s", testingPeerID))
err = os.MkdirAll(outputDir, 0755)
require.NoError(t, err)
legacyFilePath := filepath.Join(outputDir, "legacy_networkmap.json")
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
componentsPath := filepath.Join(outputDir, "components.json")
err = os.WriteFile(componentsPath, componentsJSON, 0644)
require.NoError(t, err)
newFilePath := filepath.Join(outputDir, "components_networkmap.json")
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Files saved to:\n Legacy NetworkMap: %s\n Components: %s\n Components NetworkMap: %s",
legacyFilePath, componentsPath, newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON),
"NetworkMaps from legacy and components approaches do not match for peer %s.\n"+
"Legacy JSON saved to: %s\n"+
"Components JSON saved to: %s\n"+
"Components NetworkMap saved to: %s",
testingPeerID, legacyFilePath, componentsPath, newFilePath)
t.Logf("✅ NetworkMaps are identical for peer %s", testingPeerID)
}
func loadProdAccountFromJSON(t testing.TB) *Account {
t.Helper()
testDataPath := filepath.Join("testdata", "account_cnlf3j3l0ubs738o5d4g.json")
data, err := os.ReadFile(testDataPath)
require.NoError(t, err, "Failed to read prod account JSON file")
var account Account
err = json.Unmarshal(data, &account)
require.NoError(t, err, "Failed to unmarshal prod account")
if account.Groups == nil {
account.Groups = make(map[string]*Group)
}
if account.Peers == nil {
account.Peers = make(map[string]*nbpeer.Peer)
}
if account.Policies == nil {
account.Policies = []*Policy{}
}
return &account
}
func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) {
account := loadProdAccountFromJSON(b)
ctx := context.Background()
validatedPeersMap := make(map[string]struct{}, len(account.Peers))
for _, peer := range account.Peers {
validatedPeersMap[peer.ID] = struct{}{}
}
dnsDomain := account.Settings.DNSDomain
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
builder := NewNetworkMapBuilder(account, validatedPeersMap)
testingPeerID := "d3knp53l0ubs738a3n6g"
regularNm := builder.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, nil)
compactNm := builder.GetPeerNetworkMapCompact(ctx, testingPeerID, customZone, validatedPeersMap, nil)
compactCachedNm := builder.GetPeerNetworkMapCompactCached(ctx, testingPeerID, customZone, validatedPeersMap, nil)
regularJSON, err := json.Marshal(regularNm)
require.NoError(b, err)
compactJSON, err := json.Marshal(compactNm)
require.NoError(b, err)
compactCachedJSON, err := json.Marshal(compactCachedNm)
require.NoError(b, err)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
components := account.GetPeerNetworkMapComponents(ctx, testingPeerID, customZone, validatedPeersMap, resourcePolicies, routers)
componentsJSON, err := json.Marshal(components)
require.NoError(b, err)
regularSize := len(regularJSON)
compactSize := len(compactJSON)
compactCachedSize := len(compactCachedJSON)
componentsSize := len(componentsJSON)
compactSavingsPercent := 100 - int(float64(compactCachedSize)/float64(regularSize)*100)
componentsSavingsPercent := 100 - int(float64(componentsSize)/float64(regularSize)*100)
b.ReportMetric(float64(regularSize), "regular_bytes")
b.ReportMetric(float64(compactCachedSize), "compact_cached_bytes")
b.ReportMetric(float64(componentsSize), "components_bytes")
b.ReportMetric(float64(compactSavingsPercent), "compact_savings_%")
b.ReportMetric(float64(componentsSavingsPercent), "components_savings_%")
b.Logf("========== Network Map Size Comparison ==========")
b.Logf("Regular network map: %d bytes", regularSize)
b.Logf("Compact network map: %d bytes (-%d%%)", compactSize, 100-int(float64(compactSize)/float64(regularSize)*100))
b.Logf("Compact cached network map: %d bytes (-%d%%)", compactCachedSize, compactSavingsPercent)
b.Logf("Components: %d bytes (-%d%%)", componentsSize, componentsSavingsPercent)
b.Logf("")
b.Logf("Bandwidth savings (Compact cached): %d bytes saved (%d%%)", regularSize-compactCachedSize, compactSavingsPercent)
b.Logf("Bandwidth savings (Components): %d bytes saved (%d%%)", regularSize-componentsSize, componentsSavingsPercent)
b.Logf("=================================================")
b.Run("Regular", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, nil)
}
})
b.Run("CompactOnDemand", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = builder.GetPeerNetworkMapCompact(ctx, testingPeerID, customZone, validatedPeersMap, nil)
}
})
b.Run("CompactCached", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = builder.GetPeerNetworkMapCompactCached(ctx, testingPeerID, customZone, validatedPeersMap, nil)
}
})
b.Run("Legacy", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
})
b.Run("LegacyCompacted", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapCompacted(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
})
b.Run("ComponentsNetworkMap", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
components := account.GetPeerNetworkMapComponents(
ctx,
testingPeerID,
customZone,
validatedPeersMap,
resourcePolicies,
routers,
)
_ = CalculateNetworkMapFromComponents(ctx, components)
}
})
b.Run("ComponentsCreation", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapComponents(
ctx,
testingPeerID,
customZone,
validatedPeersMap,
resourcePolicies,
routers,
)
}
})
b.Run("CalculationFromComponents", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = CalculateNetworkMapFromComponents(ctx, components)
}
})
}

View File

@@ -0,0 +1,951 @@
package types
import (
"context"
"net"
"net/netip"
"strings"
"time"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
type NetworkMapComponents struct {
PeerID string
Serial uint64
Network *Network
AccountSettings *AccountSettingsInfo
DNSSettings *DNSSettings
CustomZoneDomain string
Peers map[string]*nbpeer.Peer
Groups map[string]*Group
Policies []*Policy
Routes []*route.Route
NameServerGroups []*nbdns.NameServerGroup
AllDNSRecords []nbdns.SimpleRecord
ResourcePoliciesMap map[string][]*Policy
RoutersMap map[string]map[string]*routerTypes.NetworkRouter
NetworkResources []*resourceTypes.NetworkResource
}
type AccountSettingsInfo struct {
PeerLoginExpirationEnabled bool
PeerLoginExpiration time.Duration
PeerInactivityExpirationEnabled bool
PeerInactivityExpiration time.Duration
}
func (c *NetworkMapComponents) GetPeerInfo(peerID string) *nbpeer.Peer {
return c.Peers[peerID]
}
func (c *NetworkMapComponents) GetGroupInfo(groupID string) *Group {
return c.Groups[groupID]
}
func (c *NetworkMapComponents) IsPeerInGroup(peerID, groupID string) bool {
group := c.GetGroupInfo(groupID)
if group == nil {
return false
}
for _, pid := range group.Peers {
if pid == peerID {
return true
}
}
return false
}
func (c *NetworkMapComponents) GetPeerGroups(peerID string) map[string]struct{} {
groups := make(map[string]struct{})
for groupID, group := range c.Groups {
for _, pid := range group.Peers {
if pid == peerID {
groups[groupID] = struct{}{}
break
}
}
}
return groups
}
func (c *NetworkMapComponents) ValidatePostureChecksOnPeer(peerID string, postureCheckIDs []string) bool {
if len(postureCheckIDs) == 0 {
return true
}
_, exists := c.Peers[peerID]
return exists
}
type NetworkMapCalculator struct {
components *NetworkMapComponents
}
func NewNetworkMapCalculator(components *NetworkMapComponents) *NetworkMapCalculator {
return &NetworkMapCalculator{
components: components,
}
}
func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap {
calculator := NewNetworkMapCalculator(components)
return calculator.Calculate(ctx)
}
func (calc *NetworkMapCalculator) Calculate(ctx context.Context) *NetworkMap {
targetPeerID := calc.components.PeerID
aclPeers, firewallRules := calc.getPeerConnectionResources(ctx, targetPeerID)
peersToConnect, expiredPeers := calc.filterPeersByLoginExpiration(aclPeers)
routesUpdate := calc.getRoutesToSync(ctx, targetPeerID, peersToConnect)
routesFirewallRules := calc.getPeerRoutesFirewallRules(ctx, targetPeerID)
isRouter, networkResourcesRoutes, sourcePeers := calc.getNetworkResourcesRoutesToSync(ctx, targetPeerID)
var networkResourcesFirewallRules []*RouteFirewallRule
if isRouter {
networkResourcesFirewallRules = calc.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes)
}
peersToConnectIncludingRouters := calc.addNetworksRoutingPeers(
networkResourcesRoutes,
targetPeerID,
peersToConnect,
expiredPeers,
isRouter,
sourcePeers,
)
dnsManagementStatus := calc.getPeerDNSManagementStatus(targetPeerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
if calc.components.CustomZoneDomain != "" {
records := calc.filterZoneRecordsForPeers(targetPeerID, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: calc.components.CustomZoneDomain,
Records: records,
})
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = calc.getPeerNSGroups(targetPeerID)
}
return &NetworkMap{
Peers: peersToConnectIncludingRouters,
Network: calc.components.Network.Copy(),
Routes: append(networkResourcesRoutes, routesUpdate...),
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...),
}
}
func (calc *NetworkMapCalculator) getPeerConnectionResources(ctx context.Context, targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule) {
targetPeer := calc.components.GetPeerInfo(targetPeerID)
if targetPeer == nil {
return nil, nil
}
generateResources, getAccumulatedResources := calc.connResourcesGenerator(ctx, targetPeer)
for _, policy := range calc.components.Policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
var sourcePeers, destinationPeers []*nbpeer.Peer
var peerInSources, peerInDestinations bool
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers, peerInSources = calc.getPeerFromResource(rule.SourceResource, targetPeerID)
} else {
sourcePeers, peerInSources = calc.getAllPeersFromGroups(ctx, rule.Sources, targetPeerID, policy.SourcePostureChecks)
}
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
destinationPeers, peerInDestinations = calc.getPeerFromResource(rule.DestinationResource, targetPeerID)
} else {
destinationPeers, peerInDestinations = calc.getAllPeersFromGroups(ctx, rule.Destinations, targetPeerID, nil)
}
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
}
}
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
}
}
return getAccumulatedResources()
}
func (calc *NetworkMapCalculator) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
for _, peer := range groupPeers {
if peer == nil {
continue
}
if _, ok := peersExists[peer.ID]; !ok {
peers = append(peers, peer)
peersExists[peer.ID] = struct{}{}
}
fr := FirewallRule{
PolicyID: rule.ID,
PeerIP: net.IP(peer.IP).String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
}
ruleID := rule.ID + fr.PeerIP + string(rune(direction)) +
fr.Protocol + fr.Action
for _, port := range rule.Ports {
ruleID += port
}
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
rules = append(rules, &fr)
continue
}
rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{
ID: rule.ID,
Ports: rule.Ports,
PortRanges: rule.PortRanges,
Protocol: rule.Protocol,
Action: rule.Action,
}, targetPeer)...)
}
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules
}
}
func (calc *NetworkMapCalculator) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) {
peerInGroups := false
uniquePeerIDs := calc.getUniquePeerIDsFromGroupsIDs(ctx, groups)
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
for _, p := range uniquePeerIDs {
peerInfo := calc.components.GetPeerInfo(p)
if peerInfo == nil {
continue
}
if _, ok := calc.components.Peers[p]; !ok {
continue
}
if !calc.components.ValidatePostureChecksOnPeer(p, sourcePostureChecksIDs) {
continue
}
if p == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peerInfo)
}
return filteredPeers, peerInGroups
}
func (calc *NetworkMapCalculator) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups))
for _, groupID := range groups {
group := calc.components.GetGroupInfo(groupID)
if group == nil {
continue
}
if len(groups) == 1 {
return group.Peers
}
for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{}
}
}
ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs {
ids = append(ids, peerID)
}
return ids
}
func (calc *NetworkMapCalculator) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) {
peerInfo := calc.components.GetPeerInfo(resource.ID)
if peerInfo == nil {
return []*nbpeer.Peer{}, false
}
return []*nbpeer.Peer{peerInfo}, resource.ID == peerID
}
func (calc *NetworkMapCalculator) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) {
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(calc.components.AccountSettings.PeerLoginExpiration)
if calc.components.AccountSettings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
return peersToConnect, expiredPeers
}
func (calc *NetworkMapCalculator) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := calc.components.GetPeerGroups(peerID)
enabled := true
for _, groupID := range calc.components.DNSSettings.DisabledManagementGroups {
if _, found := peerGroups[groupID]; found {
enabled = false
break
}
}
return enabled
}
func (calc *NetworkMapCalculator) filterZoneRecordsForPeers(peerID string, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
filteredRecords := make([]nbdns.SimpleRecord, 0, len(calc.components.AllDNSRecords))
peerIPs := make(map[string]struct{})
targetPeerInfo := calc.components.GetPeerInfo(peerID)
if targetPeerInfo != nil {
peerIPs[string(targetPeerInfo.IP)] = struct{}{}
}
for _, peer := range peersToConnect {
peerIPs[string(peer.IP)] = struct{}{}
}
for _, peer := range expiredPeers {
peerIPs[string(peer.IP)] = struct{}{}
}
for _, record := range calc.components.AllDNSRecords {
if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record)
}
}
return filteredRecords
}
func (calc *NetworkMapCalculator) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup {
groupList := calc.components.GetPeerGroups(peerID)
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range calc.components.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
targetPeerInfo := calc.components.GetPeerInfo(peerID)
if targetPeerInfo != nil && !calc.peerIsNameserver(targetPeerInfo, nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
}
return peerNSGroups
}
func (calc *NetworkMapCalculator) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peerInfo.IP.String() == ns.IP.String() {
return true
}
}
return false
}
func (calc *NetworkMapCalculator) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
routes, peerDisabledRoutes := calc.getRoutingPeerRoutes(ctx, peerID)
peerRoutesMembership := make(LookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
groupListMap := calc.components.GetPeerGroups(peerID)
for _, peer := range aclPeers {
activeRoutes, _ := calc.getRoutingPeerRoutes(ctx, peer.ID)
groupFilteredRoutes := calc.filterRoutesByGroups(activeRoutes, groupListMap)
filteredRoutes := calc.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
func (calc *NetworkMapCalculator) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
peerInfo := calc.components.GetPeerInfo(peerID)
if peerInfo == nil {
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok {
return
}
seenRoute[r.ID] = struct{}{}
routeObj := calc.copyRoute(r)
routeObj.Peer = peerInfo.Key
if r.Enabled {
enabledRoutes = append(enabledRoutes, routeObj)
return
}
disabledRoutes = append(disabledRoutes, routeObj)
}
for _, r := range calc.components.Routes {
for _, groupID := range r.PeerGroups {
group := calc.components.GetGroupInfo(groupID)
if group == nil {
continue
}
for _, id := range group.Peers {
if id != peerID {
continue
}
newPeerRoute := calc.copyRoute(r)
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id)
takeRoute(newPeerRoute, id)
break
}
}
if r.Peer == peerID || r.PeerID == peerID {
takeRoute(calc.copyRoute(r), peerID)
}
}
return enabledRoutes, disabledRoutes
}
func (calc *NetworkMapCalculator) copyRoute(r *route.Route) *route.Route {
var groups, accessControlGroups, peerGroups []string
var domains domain.List
if r.Groups != nil {
groups = append([]string{}, r.Groups...)
}
if r.AccessControlGroups != nil {
accessControlGroups = append([]string{}, r.AccessControlGroups...)
}
if r.PeerGroups != nil {
peerGroups = append([]string{}, r.PeerGroups...)
}
if r.Domains != nil {
domains = append(domain.List{}, r.Domains...)
}
return &route.Route{
ID: r.ID,
AccountID: r.AccountID,
Network: r.Network,
NetworkType: r.NetworkType,
Description: r.Description,
Peer: r.Peer,
PeerID: r.PeerID,
Metric: r.Metric,
Masquerade: r.Masquerade,
NetID: r.NetID,
Enabled: r.Enabled,
Groups: groups,
AccessControlGroups: accessControlGroups,
PeerGroups: peerGroups,
Domains: domains,
KeepRoute: r.KeepRoute,
}
}
func (calc *NetworkMapCalculator) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
for _, groupID := range r.Groups {
_, found := groupListMap[groupID]
if found {
filteredRoutes = append(filteredRoutes, r)
break
}
}
}
return filteredRoutes
}
func (calc *NetworkMapCalculator) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[string(r.GetHAUniqueID())]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
}
return filteredRoutes
}
func (calc *NetworkMapCalculator) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0)
enabledRoutes, _ := calc.getRoutingPeerRoutes(ctx, peerID)
for _, r := range enabledRoutes {
if len(r.AccessControlGroups) == 0 {
defaultPermit := calc.getDefaultPermit(r)
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
continue
}
distributionPeers := calc.getDistributionGroupsPeers(r)
for _, accessGroup := range r.AccessControlGroups {
policies := calc.getAllRoutePoliciesFromGroups([]string{accessGroup})
rules := calc.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...)
}
}
return routesFirewallRules
}
func (calc *NetworkMapCalculator) findRoute(routeID route.ID) *route.Route {
for _, r := range calc.components.Routes {
if r.ID == routeID {
return r
}
}
parts := strings.Split(string(routeID), ":")
if len(parts) > 1 {
baseRouteID := route.ID(parts[0])
for _, r := range calc.components.Routes {
if r.ID == baseRouteID {
return r
}
}
}
return nil
}
func (calc *NetworkMapCalculator) getDefaultPermit(r *route.Route) []*RouteFirewallRule {
var rules []*RouteFirewallRule
sources := []string{"0.0.0.0/0"}
rule := RouteFirewallRule{
SourceRanges: sources,
Action: string(PolicyTrafficActionAccept),
Destination: r.Network.String(),
Protocol: string(PolicyRuleProtocolALL),
Domains: r.Domains,
IsDynamic: len(r.Domains) > 0,
RouteID: r.ID,
}
rules = append(rules, &rule)
if len(r.Domains) > 0 {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
rules = append(rules, &ruleV6)
}
return rules
}
func (calc *NetworkMapCalculator) getDistributionGroupsPeers(r *route.Route) map[string]struct{} {
distPeers := make(map[string]struct{})
for _, id := range r.Groups {
group := calc.components.GetGroupInfo(id)
if group == nil {
continue
}
for _, pID := range group.Peers {
distPeers[pID] = struct{}{}
}
}
return distPeers
}
func (calc *NetworkMapCalculator) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy {
routePolicies := make([]*Policy, 0)
for _, groupID := range accessControlGroups {
group := calc.components.GetGroupInfo(groupID)
if group == nil {
continue
}
for _, policy := range calc.components.Policies {
for _, rule := range policy.Rules {
for _, destGroupID := range rule.Destinations {
if destGroupID == group.ID {
routePolicies = append(routePolicies, policy)
break
}
}
}
}
}
return routePolicies
}
func (calc *NetworkMapCalculator) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule {
var fwRules []*RouteFirewallRule
for _, policy := range policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
rulePeers := calc.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers)
rules := calc.generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
fwRules = append(fwRules, rules...)
}
}
return fwRules
}
func (calc *NetworkMapCalculator) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}) []*nbpeer.Peer {
distPeersWithPolicy := make(map[string]struct{})
for _, id := range rule.Sources {
group := calc.components.GetGroupInfo(id)
if group == nil {
continue
}
for _, pID := range group.Peers {
if pID == peerID {
continue
}
_, distPeer := distributionPeers[pID]
_, valid := calc.components.Peers[pID]
if distPeer && valid && calc.components.ValidatePostureChecksOnPeer(pID, postureChecks) {
distPeersWithPolicy[pID] = struct{}{}
}
}
}
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
_, distPeer := distributionPeers[rule.SourceResource.ID]
_, valid := calc.components.Peers[rule.SourceResource.ID]
if distPeer && valid && calc.components.ValidatePostureChecksOnPeer(rule.SourceResource.ID, postureChecks) {
distPeersWithPolicy[rule.SourceResource.ID] = struct{}{}
}
}
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
for pID := range distPeersWithPolicy {
peerInfo := calc.components.GetPeerInfo(pID)
if peerInfo == nil {
continue
}
distributionGroupPeers = append(distributionGroupPeers, peerInfo)
}
return distributionGroupPeers
}
func (calc *NetworkMapCalculator) generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, rulePeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
sourceRanges := make([]string, 0, len(rulePeers))
for _, peer := range rulePeers {
if peer == nil {
continue
}
sourceRanges = append(sourceRanges, peer.IP.String()+"/32")
}
if len(sourceRanges) == 0 {
return nil
}
baseRule := &RouteFirewallRule{
RouteID: route.ID,
SourceRanges: sourceRanges,
Action: string(rule.Action),
Destination: route.Network.String(),
Protocol: string(rule.Protocol),
Domains: route.Domains,
IsDynamic: len(route.Domains) > 0,
}
return []*RouteFirewallRule{baseRule}
}
func (calc *NetworkMapCalculator) getNetworkResourcesRoutesToSync(ctx context.Context, peerID string) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool
var routes []*route.Route
allSourcePeers := make(map[string]struct{})
for _, resource := range calc.components.NetworkResources {
if !resource.Enabled {
continue
}
var addSourcePeers bool
networkRoutingPeers, exists := calc.components.RoutersMap[resource.NetworkID]
if exists {
if router, ok := networkRoutingPeers[peerID]; ok {
isRoutingPeer, addSourcePeers = true, true
routes = append(routes, calc.getNetworkResourcesRoutes(resource, peerID, router)...)
}
}
addedResourceRoute := false
for _, policy := range calc.components.ResourcePoliciesMap[resource.ID] {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = calc.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
}
if addSourcePeers {
for _, pID := range calc.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}
}
} else if calc.peerInSlice(peerID, peers) && calc.components.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) {
for peerId, router := range networkRoutingPeers {
routes = append(routes, calc.getNetworkResourcesRoutes(resource, peerId, router)...)
}
addedResourceRoute = true
}
if addedResourceRoute {
break
}
}
}
return isRoutingPeer, routes, allSourcePeers
}
func (calc *NetworkMapCalculator) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route {
resourceAppliedPolicies := calc.components.ResourcePoliciesMap[resource.ID]
var routes []*route.Route
if len(resourceAppliedPolicies) > 0 {
peerInfo := calc.components.GetPeerInfo(peerID)
if peerInfo != nil {
routes = append(routes, calc.networkResourceToRoute(resource, peerInfo, router))
}
}
return routes
}
func (calc *NetworkMapCalculator) networkResourceToRoute(resource *resourceTypes.NetworkResource, peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route {
r := &route.Route{
ID: route.ID(resource.ID + ":" + peer.ID),
AccountID: resource.AccountID,
Peer: peer.Key,
PeerID: peer.ID,
Metric: router.Metric,
Masquerade: router.Masquerade,
Enabled: resource.Enabled,
KeepRoute: true,
NetID: route.NetID(resource.Name),
Description: resource.Description,
}
if resource.Type == resourceTypes.Host || resource.Type == resourceTypes.Subnet {
r.Network = resource.Prefix
r.NetworkType = route.IPv4Network
if resource.Prefix.Addr().Is6() {
r.NetworkType = route.IPv6Network
}
}
if resource.Type == resourceTypes.Domain {
domainList, err := domain.FromStringList([]string{resource.Domain})
if err == nil {
r.Domains = domainList
r.NetworkType = route.DomainNetwork
r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
}
return r
}
func (calc *NetworkMapCalculator) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if calc.components.ValidatePostureChecksOnPeer(peerID, postureChecksIDs) {
dest = append(dest, peerID)
}
}
return dest
}
func (calc *NetworkMapCalculator) peerInSlice(peerID string, peers []string) bool {
for _, p := range peers {
if p == peerID {
return true
}
}
return false
}
func (calc *NetworkMapCalculator) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0)
peerInfo := calc.components.GetPeerInfo(peerID)
if peerInfo == nil {
return routesFirewallRules
}
for _, r := range routes {
if r.Peer != peerInfo.Key {
continue
}
resourceID := string(r.GetResourceID())
resourcePolicies := calc.components.ResourcePoliciesMap[resourceID]
distributionPeers := calc.getPoliciesSourcePeers(resourcePolicies)
rules := calc.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers)
for _, rule := range rules {
if len(rule.SourceRanges) > 0 {
routesFirewallRules = append(routesFirewallRules, rule)
}
}
}
return routesFirewallRules
}
func (calc *NetworkMapCalculator) getPoliciesSourcePeers(policies []*Policy) map[string]struct{} {
sourcePeers := make(map[string]struct{})
for _, policy := range policies {
for _, rule := range policy.Rules {
for _, sourceGroup := range rule.Sources {
group := calc.components.GetGroupInfo(sourceGroup)
if group == nil {
continue
}
for _, peer := range group.Peers {
sourcePeers[peer] = struct{}{}
}
}
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers[rule.SourceResource.ID] = struct{}{}
}
}
}
return sourcePeers
}
func (calc *NetworkMapCalculator) addNetworksRoutingPeers(
networkResourcesRoutes []*route.Route,
peerID string,
peersToConnect []*nbpeer.Peer,
expiredPeers []*nbpeer.Peer,
isRouter bool,
sourcePeers map[string]struct{},
) []*nbpeer.Peer {
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
for _, r := range networkResourcesRoutes {
networkRoutesPeers[r.PeerID] = struct{}{}
}
delete(sourcePeers, peerID)
delete(networkRoutesPeers, peerID)
for _, existingPeer := range peersToConnect {
delete(sourcePeers, existingPeer.ID)
delete(networkRoutesPeers, existingPeer.ID)
}
for _, expPeer := range expiredPeers {
delete(sourcePeers, expPeer.ID)
delete(networkRoutesPeers, expPeer.ID)
}
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
if isRouter {
for p := range sourcePeers {
missingPeers[p] = struct{}{}
}
}
for p := range networkRoutesPeers {
missingPeers[p] = struct{}{}
}
for p := range missingPeers {
peerInfo := calc.components.GetPeerInfo(p)
if peerInfo != nil {
peersToConnect = append(peersToConnect, peerInfo)
}
}
return peersToConnect
}

View File

@@ -1290,19 +1290,35 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) {
compactCachedJSON, err := json.Marshal(compactCachedNm)
require.NoError(b, err)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
components := account.GetPeerNetworkMapComponents(ctx, testingPeerID, customZone, validatedPeersMap, resourcePolicies, routers)
componentsJSON, err := json.Marshal(components)
require.NoError(b, err)
regularSize := len(regularJSON)
compactSize := len(compactJSON)
compactCachedSize := len(compactCachedJSON)
savingsPercent := 100 - int(float64(compactCachedSize)/float64(regularSize)*100)
componentsSize := len(componentsJSON)
compactSavingsPercent := 100 - int(float64(compactCachedSize)/float64(regularSize)*100)
componentsSavingsPercent := 100 - int(float64(componentsSize)/float64(regularSize)*100)
b.ReportMetric(float64(regularSize), "regular_bytes")
b.ReportMetric(float64(compactCachedSize), "compact_cached_bytes")
b.ReportMetric(float64(savingsPercent), "savings_%")
b.ReportMetric(float64(componentsSize), "components_bytes")
b.ReportMetric(float64(compactSavingsPercent), "compact_savings_%")
b.ReportMetric(float64(componentsSavingsPercent), "components_savings_%")
b.Logf("Regular network map: %d bytes", regularSize)
b.Logf("Compact network map: %d bytes", compactSize)
b.Logf("Compact cached network map: %d bytes", compactCachedSize)
b.Logf("Data savings: %d%% (%d bytes saved)", savingsPercent, regularSize-compactCachedSize)
b.Logf("========== Network Map Size Comparison ==========")
b.Logf("Regular network map: %d bytes", regularSize)
b.Logf("Compact network map: %d bytes (-%d%%)", compactSize, 100-int(float64(compactSize)/float64(regularSize)*100))
b.Logf("Compact cached network map: %d bytes (-%d%%)", compactCachedSize, compactSavingsPercent)
b.Logf("Components: %d bytes (-%d%%)", componentsSize, componentsSavingsPercent)
b.Logf("")
b.Logf("Bandwidth savings (Compact cached): %d bytes saved (%d%%)", regularSize-compactCachedSize, compactSavingsPercent)
b.Logf("Bandwidth savings (Components): %d bytes saved (%d%%)", regularSize-componentsSize, componentsSavingsPercent)
b.Logf("=================================================")
b.Run("Regular", func(b *testing.B) {
b.ResetTimer()
@@ -1324,4 +1340,52 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) {
_ = builder.GetPeerNetworkMapCompactCached(ctx, testingPeerID, customZone, validatedPeersMap, nil)
}
})
b.Run("Legacy", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
})
b.Run("LegacyCompacted", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapCompacted(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
})
b.Run("ComponentsNetworkMap", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
components := account.GetPeerNetworkMapComponents(
ctx,
testingPeerID,
customZone,
validatedPeersMap,
resourcePolicies,
routers,
)
_ = types.CalculateNetworkMapFromComponents(ctx, components)
}
})
b.Run("ComponentsCreation", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapComponents(
ctx,
testingPeerID,
customZone,
validatedPeersMap,
resourcePolicies,
routers,
)
}
})
b.Run("CalculationFromComponents", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = types.CalculateNetworkMapFromComponents(ctx, components)
}
})
}

View File

@@ -0,0 +1,31 @@
package peerid
import (
"crypto/sha256"
v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
"github.com/netbirdio/netbird/shared/relay/messages"
)
var (
// HealthCheckPeerID is the hashed peer ID for health check connections
HealthCheckPeerID = messages.HashID("healthcheck-agent")
// DummyAuthToken is a structurally valid auth token for health check.
// The signature is not valid but the format is correct (1 byte algo + 32 bytes signature + payload).
DummyAuthToken = createDummyToken()
)
func createDummyToken() []byte {
token := v2.Token{
AuthAlgo: v2.AuthAlgoHMACSHA256,
Signature: make([]byte, sha256.Size),
Payload: []byte("healthcheck"),
}
return token.Marshal()
}
// IsHealthCheck checks if the given peer ID is the health check agent
func IsHealthCheck(peerID *messages.PeerID) bool {
return peerID != nil && *peerID == HealthCheckPeerID
}

View File

@@ -7,8 +7,10 @@ import (
"github.com/coder/websocket"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay"
"github.com/netbirdio/netbird/shared/relay/messages"
)
func dialWS(ctx context.Context, address url.URL) error {
@@ -30,7 +32,18 @@ func dialWS(ctx context.Context, address url.URL) error {
if err != nil {
return fmt.Errorf("failed to connect to websocket: %w", err)
}
defer func() {
_ = conn.CloseNow()
}()
authMsg, err := messages.MarshalAuthMsg(peerid.HealthCheckPeerID, peerid.DummyAuthToken)
if err != nil {
return fmt.Errorf("failed to marshal auth message: %w", err)
}
if err := conn.Write(ctx, websocket.MessageBinary, authMsg); err != nil {
return fmt.Errorf("failed to write auth message: %w", err)
}
_ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
return nil
}

View File

@@ -97,7 +97,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
}
if err != nil {
return nil, err
return peerID, err
}
h.peerID = peerID
return peerID, nil
@@ -147,7 +147,7 @@ func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
}
if err := h.validator.Validate(authPayload); err != nil {
return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
return rawPeerID, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
}
return rawPeerID, nil

View File

@@ -12,6 +12,7 @@ import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
@@ -123,7 +124,11 @@ func (r *Relay) Accept(conn net.Conn) {
}
peerID, err := h.handshakeReceive()
if err != nil {
log.Errorf("failed to handshake: %s", err)
if peerid.IsHealthCheck(peerID) {
log.Debugf("health check connection from %s", conn.RemoteAddr())
} else {
log.Errorf("failed to handshake: %s", err)
}
if cErr := conn.Close(); cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}