Compare commits

...

20 Commits

Author SHA1 Message Date
pascal
373f014aea Merge branch 'feature/check-cert-locker-before-acme' into test/proxy-fixes 2026-03-09 17:24:31 +01:00
pascal
2df3fb959b Merge branch 'feature/domain-activity-events' into test/proxy-fixes 2026-03-09 17:23:54 +01:00
pascal
14b0e9462b Merge branch 'feature/move-proxy-to-opentelemetry' into test/proxy-fixes 2026-03-09 17:23:42 +01:00
pascal
8db71b545e add certificate issue duration metrics 2026-03-09 16:59:33 +01:00
pascal
7c3532d8e5 fix test 2026-03-09 16:18:49 +01:00
pascal
1aa1eef2c5 account for streaming 2026-03-09 16:08:30 +01:00
pascal
d9418ddc1e do log throughput and requests. Also add throughput to the log entries 2026-03-09 15:39:29 +01:00
pascal
1b4c831976 add activity events for domains 2026-03-09 14:18:05 +01:00
pascal
a19611d8e0 check the cert locker before starting the acme challenge 2026-03-09 13:40:41 +01:00
pascal
9ab6138040 fix mapping counting and metrics registry 2026-03-09 13:17:50 +01:00
Pascal Fischer
30c02ab78c [management] use the cache for the pkce state (#5516) 2026-03-09 12:23:06 +01:00
Zoltan Papp
3acd86e346 [client] "reset connection" error on wake from sleep (#5522)
Capture engine reference before actCancel() in cleanupConnection().

After actCancel(), the connectWithRetryRuns goroutine sets engine to nil,
causing connectClient.Stop() to skip shutdown. This allows the goroutine
to set ErrResetConnection on the shared state after Down() clears it,
causing the next Up() to fail.
2026-03-09 10:25:51 +01:00
pascal
c2fec57c0f switch proxy to use opentelemetry 2026-03-07 11:16:40 +01:00
Pascal Fischer
5c20f13c48 [management] fix domain uniqueness (#5529) 2026-03-07 10:46:37 +01:00
Pascal Fischer
e6587b071d [management] use realip for proxy registration (#5525) 2026-03-06 16:11:44 +01:00
Maycon Santos
85451ab4cd [management] Add stable domain resolution for combined server (#5515)
The combined server was using the hostname from exposedAddress for both
singleAccountModeDomain and dnsDomain, causing fresh installs to get
the wrong domain and existing installs to break if the config changed.
 Add resolveDomains() to BaseServer that reads domain from the store:
  - Fresh install (0 accounts): uses "netbird.selfhosted" default
  - Existing install: reads persisted domain from the account in DB
  - Store errors: falls back to default safely

The combined server opts in via AutoResolveDomains flag, while the
 standalone management server is unaffected.
2026-03-06 08:43:46 +01:00
Pascal Fischer
a7f3ba03eb [management] aggregate grpc metrics by accountID (#5486) 2026-03-05 22:10:45 +01:00
Maycon Santos
4f0a3a77ad [management] Avoid breaking single acc mode when switching domains (#5511)
* **Bug Fixes**
  * Fixed domain configuration handling in single account mode to properly retrieve and apply domain settings from account data.
  * Improved error handling when account data is unavailable with fallback to configured default domain.

* **Tests**
  * Added comprehensive test coverage for single account mode domain configuration scenarios, including edge cases for missing or unavailable account data.
2026-03-05 14:30:31 +01:00
Maycon Santos
44655ca9b5 [misc] add PR title validation workflow (#5503) 2026-03-05 11:43:18 +01:00
Viktor Liu
e601278117 [management,proxy] Add per-target options to reverse proxy (#5501) 2026-03-05 10:03:26 +01:00
55 changed files with 3899 additions and 640 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
skip: go.mod,go.sum,**/proxy/web/** skip: go.mod,go.sum,**/proxy/web/**
golangci: golangci:
strategy: strategy:

51
.github/workflows/pr-title-check.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -849,15 +849,27 @@ func (s *Server) cleanupConnection() error {
if s.actCancel == nil { if s.actCancel == nil {
return ErrServiceNotUp return ErrServiceNotUp
} }
// Capture the engine reference before cancelling the context.
// After actCancel(), the connectWithRetryRuns goroutine wakes up
// and sets connectClient.engine = nil, causing connectClient.Stop()
// to skip the engine shutdown entirely.
var engine *internal.Engine
if s.connectClient != nil {
engine = s.connectClient.Engine()
}
s.actCancel() s.actCancel()
if s.connectClient == nil { if s.connectClient == nil {
return nil return nil
} }
if err := s.connectClient.Stop(); err != nil { if engine != nil {
if err := engine.Stop(); err != nil {
return err return err
} }
}
s.connectClient = nil s.connectClient = nil
s.isSessionActive.Store(false) s.isSessionActive.Store(false)

View File

@@ -493,9 +493,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) { func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address // Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress) _, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil { if err != nil {
@@ -507,8 +504,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
mgmtSrv := mgmtServer.NewServer( mgmtSrv := mgmtServer.NewServer(
&mgmtServer.Config{ &mgmtServer.Config{
NbConfig: mgmtConfig, NbConfig: mgmtConfig,
DNSDomain: dnsDomain, DNSDomain: "",
MgmtSingleAccModeDomain: singleAccModeDomain, MgmtSingleAccModeDomain: "",
AutoResolveDomains: true,
MgmtPort: mgmtPort, MgmtPort: mgmtPort,
MgmtMetricsPort: cfg.Server.MetricsPort, MgmtMetricsPort: cfg.Server.MetricsPort,
DisableMetrics: mgmt.DisableAnonymousMetrics, DisableMetrics: mgmt.DisableAnonymousMetrics,

View File

@@ -24,6 +24,8 @@ type AccessLogEntry struct {
Reason string Reason string
UserId string `gorm:"index"` UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"` AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
} }
// FromProto creates an AccessLogEntry from a proto.AccessLog // FromProto creates an AccessLogEntry from a proto.AccessLog
@@ -39,6 +41,8 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
a.UserId = serviceLog.GetUserId() a.UserId = serviceLog.GetUserId()
a.AuthMethodUsed = serviceLog.GetAuthMechanism() a.AuthMethodUsed = serviceLog.GetAuthMechanism()
a.AccountID = serviceLog.GetAccountId() a.AccountID = serviceLog.GetAccountId()
a.BytesUpload = serviceLog.GetBytesUpload()
a.BytesDownload = serviceLog.GetBytesDownload()
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" { if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
if ip, err := netip.ParseAddr(sourceIP); err == nil { if ip, err := netip.ParseAddr(sourceIP); err == nil {
@@ -101,5 +105,7 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
AuthMethodUsed: authMethod, AuthMethodUsed: authMethod,
CountryCode: countryCode, CountryCode: countryCode,
CityName: cityName, CityName: cityName,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
} }
} }

View File

@@ -15,3 +15,12 @@ type Domain struct {
Type Type `gorm:"-"` Type Type `gorm:"-"`
Validated bool Validated bool
} }
// EventMeta returns activity event metadata for a domain
func (d *Domain) EventMeta() map[string]any {
return map[string]any{
"domain": d.Domain,
"target_cluster": d.TargetCluster,
"validated": d.Validated,
}
}

View File

@@ -9,6 +9,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -36,16 +38,16 @@ type Manager struct {
validator domain.Validator validator domain.Validator
proxyManager proxyManager proxyManager proxyManager
permissionsManager permissions.Manager permissionsManager permissions.Manager
accountManager account.Manager
} }
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager { func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
return Manager{ return Manager{
store: store, store: store,
proxyManager: proxyMgr, proxyManager: proxyMgr,
validator: domain.Validator{ validator: domain.Validator{Resolver: net.DefaultResolver},
Resolver: net.DefaultResolver,
},
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
accountManager: accountManager,
} }
} }
@@ -136,6 +138,9 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
if err != nil { if err != nil {
return d, fmt.Errorf("create domain in store: %w", err) return d, fmt.Errorf("create domain in store: %w", err)
} }
m.accountManager.StoreEvent(ctx, userID, d.ID, accountID, activity.DomainAdded, d.EventMeta())
return d, nil return d, nil
} }
@@ -148,10 +153,18 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
if err != nil {
return fmt.Errorf("get domain from store: %w", err)
}
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil { if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
// TODO: check for "no records" type error. Because that is a success condition. // TODO: check for "no records" type error. Because that is a success condition.
return fmt.Errorf("delete domain from store: %w", err) return fmt.Errorf("delete domain from store: %w", err)
} }
m.accountManager.StoreEvent(ctx, userID, domainID, accountID, activity.DomainDeleted, d.EventMeta())
return nil return nil
} }
@@ -218,6 +231,8 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
}).WithError(err).Error("update custom domain in store") }).WithError(err).Error("update custom domain in store")
return return
} }
m.accountManager.StoreEvent(context.Background(), userID, domainID, accountID, activity.DomainValidated, d.EventMeta())
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"accountID": accountID, "accountID": accountID,

View File

@@ -73,7 +73,10 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
} }
service := new(rpservice.Service) service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId) if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil { if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
@@ -132,7 +135,10 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
service := new(rpservice.Service) service := new(rpservice.Service)
service.ID = serviceID service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId) if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil { if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)

View File

@@ -36,11 +36,11 @@ func TestReapExpiredExposes(t *testing.T) {
mgr.exposeReaper.reapExpiredExposes(ctx) mgr.exposeReaper.reapExpiredExposes(ctx)
// Expired service should be deleted // Expired service should be deleted
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) _, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "expired service should be deleted") require.Error(t, err, "expired service should be deleted")
// Non-expired service should remain // Non-expired service should remain
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain) _, err = testStore.GetServiceByDomain(ctx, resp2.Domain)
require.NoError(t, err, "active service should remain") require.NoError(t, err, "active service should remain")
} }
@@ -191,14 +191,14 @@ func TestReapSkipsRenewedService(t *testing.T) {
// Reaper should skip it because the re-check sees a fresh timestamp // Reaper should skip it because the re-check sees a fresh timestamp
mgr.exposeReaper.reapExpiredExposes(ctx) mgr.exposeReaper.reapExpiredExposes(ctx)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) _, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err, "renewed service should survive reaping") require.NoError(t, err, "renewed service should survive reaping")
} }
// expireEphemeralService backdates meta_last_renewed_at to force expiration. // expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) { func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper() t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), accountID, domain) svc, err := s.GetServiceByDomain(context.Background(), domain)
require.NoError(t, err) require.NoError(t, err)
expired := time.Now().Add(-2 * exposeTTL) expired := time.Now().Add(-2 * exposeTTL)

View File

@@ -199,7 +199,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error { func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil { if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
return err return err
} }
@@ -245,7 +245,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
} }
if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err return err
} }
@@ -261,8 +261,8 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
}) })
} }
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain) existingService, err := transaction.GetServiceByDomain(ctx, domain)
if err != nil { if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err) return fmt.Errorf("failed to check existing service: %w", err)
@@ -271,7 +271,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
} }
if existingService != nil && existingService.ID != excludeServiceID { if existingService != nil && existingService.ID != excludeServiceID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain) return status.Errorf(status.AlreadyExists, "domain already taken")
} }
return nil return nil
@@ -352,7 +352,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
} }
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error { func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil { if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
return err return err
} }
@@ -805,7 +805,7 @@ func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID,
// lookupPeerService finds a peer-initiated service by domain and validates ownership. // lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) { func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain) svc, err := m.store.GetServiceByDomain(ctx, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -72,7 +72,6 @@ func TestInitializeServiceForCreate(t *testing.T) {
func TestCheckDomainAvailable(t *testing.T) { func TestCheckDomainAvailable(t *testing.T) {
ctx := context.Background() ctx := context.Background()
accountID := "test-account"
tests := []struct { tests := []struct {
name string name string
@@ -88,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "", excludeServiceID: "",
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "available.com"). GetServiceByDomain(ctx, "available.com").
Return(nil, status.Errorf(status.NotFound, "not found")) Return(nil, status.Errorf(status.NotFound, "not found"))
}, },
expectedError: false, expectedError: false,
@@ -99,7 +98,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "", excludeServiceID: "",
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com"). GetServiceByDomain(ctx, "exists.com").
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil) Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
}, },
expectedError: true, expectedError: true,
@@ -111,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "service-123", excludeServiceID: "service-123",
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com"). GetServiceByDomain(ctx, "exists.com").
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
}, },
expectedError: false, expectedError: false,
@@ -122,7 +121,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "service-456", excludeServiceID: "service-456",
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com"). GetServiceByDomain(ctx, "exists.com").
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
}, },
expectedError: true, expectedError: true,
@@ -134,7 +133,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "", excludeServiceID: "",
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "error.com"). GetServiceByDomain(ctx, "error.com").
Return(nil, errors.New("database error")) Return(nil, errors.New("database error"))
}, },
expectedError: true, expectedError: true,
@@ -150,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) {
tt.setupMock(mockStore) tt.setupMock(mockStore)
mgr := &Manager{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID) err := mgr.checkDomainAvailable(ctx, mockStore, tt.domain, tt.excludeServiceID)
if tt.expectedError { if tt.expectedError {
require.Error(t, err) require.Error(t, err)
@@ -168,7 +167,6 @@ func TestCheckDomainAvailable(t *testing.T) {
func TestCheckDomainAvailable_EdgeCases(t *testing.T) { func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
ctx := context.Background() ctx := context.Background()
accountID := "test-account"
t.Run("empty domain", func(t *testing.T) { t.Run("empty domain", func(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
@@ -176,11 +174,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT(). mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, ""). GetServiceByDomain(ctx, "").
Return(nil, status.Errorf(status.NotFound, "not found")) Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &Manager{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "") err := mgr.checkDomainAvailable(ctx, mockStore, "", "")
assert.NoError(t, err) assert.NoError(t, err)
}) })
@@ -191,11 +189,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT(). mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com"). GetServiceByDomain(ctx, "test.com").
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil) Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &Manager{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "") err := mgr.checkDomainAvailable(ctx, mockStore, "test.com", "")
assert.Error(t, err) assert.Error(t, err)
sErr, ok := status.FromError(err) sErr, ok := status.FromError(err)
@@ -209,11 +207,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT(). mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "nil.com"). GetServiceByDomain(ctx, "nil.com").
Return(nil, nil) Return(nil, nil)
mgr := &Manager{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "") err := mgr.checkDomainAvailable(ctx, mockStore, "nil.com", "")
assert.NoError(t, err) assert.NoError(t, err)
}) })
@@ -241,7 +239,7 @@ func TestPersistNewService(t *testing.T) {
// Create another mock for the transaction // Create another mock for the transaction
txMock := store.NewMockStore(ctrl) txMock := store.NewMockStore(ctrl)
txMock.EXPECT(). txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "new.com"). GetServiceByDomain(ctx, "new.com").
Return(nil, status.Errorf(status.NotFound, "not found")) Return(nil, status.Errorf(status.NotFound, "not found"))
txMock.EXPECT(). txMock.EXPECT().
CreateService(ctx, service). CreateService(ctx, service).
@@ -272,7 +270,7 @@ func TestPersistNewService(t *testing.T) {
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error { DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl) txMock := store.NewMockStore(ctrl)
txMock.EXPECT(). txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com"). GetServiceByDomain(ctx, "existing.com").
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil) Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock) return fn(txMock)
@@ -425,8 +423,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper() t.Helper()
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100) tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
t.Cleanup(srv.Close) require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
return srv return srv
} }
@@ -705,8 +704,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
t.Cleanup(proxySrv.Close) require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err) require.NoError(t, err)
@@ -814,7 +814,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set") assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
// Verify service is persisted in store // Verify service is persisted in store
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) persisted, err := testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, resp.Domain, persisted.Domain) assert.Equal(t, resp.Domain, persisted.Domain)
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral") assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
@@ -977,7 +977,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify service is deleted // Verify service is deleted
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) _, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "service should be deleted") require.Error(t, err, "service should be deleted")
}) })
@@ -1012,7 +1012,7 @@ func TestStopServiceFromPeer(t *testing.T) {
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err) require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) _, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "service should be deleted") require.Error(t, err, "service should be deleted")
}) })
} }
@@ -1031,7 +1031,7 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, int64(1), count, "one ephemeral service should exist after create") assert.Equal(t, int64(1), count, "one ephemeral service should exist after create")
svc, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err) require.NoError(t, err)
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID) err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
@@ -1136,8 +1136,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
t.Cleanup(proxySrv.Close) require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err) require.NoError(t, err)

View File

@@ -6,13 +6,16 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"net" "net"
"net/http"
"net/url" "net/url"
"regexp" "regexp"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/shared/hash/argon2id"
@@ -49,6 +52,13 @@ const (
SourceEphemeral = "ephemeral" SourceEphemeral = "ephemeral"
) )
type TargetOptions struct {
SkipTLSVerify bool `json:"skip_tls_verify"`
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
}
type Target struct { type Target struct {
ID uint `gorm:"primaryKey" json:"-"` ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"` AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
@@ -60,6 +70,7 @@ type Target struct {
TargetId string `gorm:"index:idx_target_id" json:"target_id"` TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"` TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
Options TargetOptions `gorm:"embedded" json:"options"`
} }
type PasswordAuthConfig struct { type PasswordAuthConfig struct {
@@ -123,7 +134,7 @@ type Service struct {
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
Name string Name string
Domain string `gorm:"index"` Domain string `gorm:"type:varchar(255);uniqueIndex"`
ProxyCluster string `gorm:"index"` ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"` Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool Enabled bool
@@ -194,7 +205,7 @@ func (s *Service) ToAPIResponse() *api.Service {
// Convert internal targets to API targets // Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets)) apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets { for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{ st := api.ServiceTarget{
Path: target.Path, Path: target.Path,
Host: &target.Host, Host: &target.Host,
Port: target.Port, Port: target.Port,
@@ -202,7 +213,9 @@ func (s *Service) ToAPIResponse() *api.Service {
TargetId: target.TargetId, TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType), TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled, Enabled: target.Enabled,
}) }
st.Options = targetOptionsToAPI(target.Options)
apiTargets = append(apiTargets, st)
} }
meta := api.ServiceMeta{ meta := api.ServiceMeta{
@@ -256,10 +269,14 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
if target.Path != nil { if target.Path != nil {
path = *target.Path path = *target.Path
} }
pathMappings = append(pathMappings, &proto.PathMapping{
pm := &proto.PathMapping{
Path: path, Path: path,
Target: targetURL.String(), Target: targetURL.String(),
}) }
pm.Options = targetOptionsToProto(target.Options)
pathMappings = append(pathMappings, pm)
} }
auth := &proto.Authentication{ auth := &proto.Authentication{
@@ -312,13 +329,87 @@ func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80) return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
} }
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) { // PathRewriteMode controls how the request path is rewritten before forwarding.
type PathRewriteMode string
const (
PathRewritePreserve PathRewriteMode = "preserve"
)
func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
switch mode {
case PathRewritePreserve:
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
default:
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
}
}
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
if opts.SkipTLSVerify {
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
}
if opts.RequestTimeout != 0 {
s := opts.RequestTimeout.String()
apiOpts.RequestTimeout = &s
}
if opts.PathRewrite != "" {
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
apiOpts.PathRewrite = &pr
}
if len(opts.CustomHeaders) > 0 {
apiOpts.CustomHeaders = &opts.CustomHeaders
}
return apiOpts
}
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
return nil
}
popts := &proto.PathTargetOptions{
SkipTlsVerify: opts.SkipTLSVerify,
PathRewrite: pathRewriteToProto(opts.PathRewrite),
CustomHeaders: opts.CustomHeaders,
}
if opts.RequestTimeout != 0 {
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
}
return popts
}
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
var opts TargetOptions
if o.SkipTlsVerify != nil {
opts.SkipTLSVerify = *o.SkipTlsVerify
}
if o.RequestTimeout != nil {
d, err := time.ParseDuration(*o.RequestTimeout)
if err != nil {
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
}
opts.RequestTimeout = d
}
if o.PathRewrite != nil {
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
}
if o.CustomHeaders != nil {
opts.CustomHeaders = *o.CustomHeaders
}
return opts, nil
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
s.Name = req.Name s.Name = req.Name
s.Domain = req.Domain s.Domain = req.Domain
s.AccountID = accountID s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets)) targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets { for i, apiTarget := range req.Targets {
target := &Target{ target := &Target{
AccountID: accountID, AccountID: accountID,
Path: apiTarget.Path, Path: apiTarget.Path,
@@ -331,6 +422,13 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
if apiTarget.Host != nil { if apiTarget.Host != nil {
target.Host = *apiTarget.Host target.Host = *apiTarget.Host
} }
if apiTarget.Options != nil {
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
if err != nil {
return err
}
target.Options = opts
}
targets = append(targets, target) targets = append(targets, target)
} }
s.Targets = targets s.Targets = targets
@@ -368,6 +466,8 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
} }
s.Auth.BearerAuth = bearerAuth s.Auth.BearerAuth = bearerAuth
} }
return nil
} }
func (s *Service) Validate() error { func (s *Service) Validate() error {
@@ -400,11 +500,113 @@ func (s *Service) Validate() error {
if target.TargetId == "" { if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i) return fmt.Errorf("target %d has empty target_id", i)
} }
if err := validateTargetOptions(i, &target.Options); err != nil {
return err
}
} }
return nil return nil
} }
const (
maxRequestTimeout = 5 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)
// hopByHopHeaders are headers that must not be set as custom headers
// because they are connection-level and stripped by the proxy.
var hopByHopHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Proxy-Connection": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
}
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
// and cannot be overridden.
var reservedHeaders = map[string]struct{}{
"Content-Length": {},
"Content-Type": {},
"Cookie": {},
"Forwarded": {},
"X-Forwarded-For": {},
"X-Forwarded-Host": {},
"X-Forwarded-Port": {},
"X-Forwarded-Proto": {},
"X-Real-Ip": {},
}
func validateTargetOptions(idx int, opts *TargetOptions) error {
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
}
if opts.RequestTimeout != 0 {
if opts.RequestTimeout <= 0 {
return fmt.Errorf("target %d: request_timeout must be positive", idx)
}
if opts.RequestTimeout > maxRequestTimeout {
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
}
}
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
return err
}
return nil
}
func validateCustomHeaders(idx int, headers map[string]string) error {
if len(headers) > maxCustomHeaders {
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
}
seen := make(map[string]string, len(headers))
for key, value := range headers {
if !httpHeaderNameRe.MatchString(key) {
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
}
if len(key) > maxHeaderKeyLen {
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
}
if len(value) > maxHeaderValueLen {
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
}
if containsCRLF(key) || containsCRLF(value) {
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
}
canonical := http.CanonicalHeaderKey(key)
if prev, ok := seen[canonical]; ok {
return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical)
}
seen[canonical] = key
if _, ok := hopByHopHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
}
if _, ok := reservedHeaders[canonical]; ok {
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
}
if canonical == "Host" {
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
}
}
return nil
}
func containsCRLF(s string) bool {
return strings.ContainsAny(s, "\r\n")
}
func (s *Service) EventMeta() map[string]any { func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()} return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
} }
@@ -417,6 +619,12 @@ func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets)) targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets { for i, target := range s.Targets {
targetCopy := *target targetCopy := *target
if len(target.Options.CustomHeaders) > 0 {
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
for k, v := range target.Options.CustomHeaders {
targetCopy.Options.CustomHeaders[k] = v
}
}
targets[i] = &targetCopy targets[i] = &targetCopy
} }

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -87,6 +88,188 @@ func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
assert.Contains(t, err.Error(), "empty target_id") assert.Contains(t, err.Error(), "empty target_id")
} }
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
tests := []struct {
name string
mode PathRewriteMode
wantErr string
}{
{"empty is default", "", ""},
{"preserve is valid", PathRewritePreserve, ""},
{"unknown rejected", "regex", "unknown path_rewrite mode"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.PathRewrite = tt.mode
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
wantErr string
}{
{"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""},
{"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.RequestTimeout = tt.timeout
err := rp.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tt.wantErr)
}
})
}
}
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
t.Run("valid headers", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"X-Custom": "value",
"X-Trace": "abc123",
}
assert.NoError(t, rp.Validate())
})
t.Run("CRLF in key", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
})
t.Run("CRLF in value", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
assert.ErrorContains(t, rp.Validate(), "invalid characters")
})
t.Run("hop-by-hop header rejected", func(t *testing.T) {
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
}
})
t.Run("reserved header rejected", func(t *testing.T) {
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
}
})
t.Run("Host header rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
})
t.Run("too many headers", func(t *testing.T) {
rp := validProxy()
headers := make(map[string]string, 17)
for i := range 17 {
headers[fmt.Sprintf("X-H%d", i)] = "v"
}
rp.Targets[0].Options.CustomHeaders = headers
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
})
t.Run("key too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
assert.ErrorContains(t, rp.Validate(), "key")
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
})
t.Run("value too long", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
})
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
rp := validProxy()
rp.Targets[0].Options.CustomHeaders = map[string]string{
"x-custom": "a",
"X-Custom": "b",
}
assert.ErrorContains(t, rp.Validate(), "collide")
})
}
func TestToProtoMapping_TargetOptions(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
Options: TargetOptions{
SkipTLSVerify: true,
RequestTimeout: 30 * time.Second,
PathRewrite: PathRewritePreserve,
CustomHeaders: map[string]string{"X-Custom": "val"},
},
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
opts := pm.Path[0].Options
require.NotNil(t, opts, "options should be populated")
assert.True(t, opts.SkipTlsVerify)
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
require.NotNil(t, opts.RequestTimeout)
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
}
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
rp := &Service{
ID: "svc-1",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
}
func TestIsDefaultPort(t *testing.T) { func TestIsDefaultPort(t *testing.T) {
tests := []struct { tests := []struct {
scheme string scheme string

View File

@@ -168,7 +168,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController()) proxyService.SetProxyController(s.ServiceProxyController())
@@ -203,6 +203,16 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
}) })
} }
func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
return Create(s, func() *nbgrpc.PKCEVerifierStore {
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create PKCE verifier store: %v", err)
}
return pkceStore
})
}
func (s *BaseServer) AccessLogsManager() accesslogs.Manager { func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager { return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager()) accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())

View File

@@ -210,7 +210,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager { return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager()) m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
return &m return &m
}) })
} }

View File

@@ -28,9 +28,13 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// ManagementLegacyPort is the port that was used before by the Management gRPC server. const (
// It is used for backward compatibility now. // ManagementLegacyPort is the port that was used before by the Management gRPC server.
const ManagementLegacyPort = 33073 // It is used for backward compatibility now.
ManagementLegacyPort = 33073
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
DefaultSelfHostedDomain = "netbird.selfhosted"
)
type Server interface { type Server interface {
Start(ctx context.Context) error Start(ctx context.Context) error
@@ -58,6 +62,7 @@ type BaseServer struct {
mgmtMetricsPort int mgmtMetricsPort int
mgmtPort int mgmtPort int
disableLegacyManagementPort bool disableLegacyManagementPort bool
autoResolveDomains bool
proxyAuthClose func() proxyAuthClose func()
@@ -81,6 +86,7 @@ type Config struct {
DisableMetrics bool DisableMetrics bool
DisableGeoliteUpdate bool DisableGeoliteUpdate bool
UserDeleteFromIDPEnabled bool UserDeleteFromIDPEnabled bool
AutoResolveDomains bool
} }
// NewServer initializes and configures a new Server instance // NewServer initializes and configures a new Server instance
@@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer {
mgmtPort: cfg.MgmtPort, mgmtPort: cfg.MgmtPort,
disableLegacyManagementPort: cfg.DisableLegacyManagementPort, disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
mgmtMetricsPort: cfg.MgmtMetricsPort, mgmtMetricsPort: cfg.MgmtMetricsPort,
autoResolveDomains: cfg.AutoResolveDomains,
} }
} }
@@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.cancel = cancel s.cancel = cancel
s.errCh = make(chan error, 4) s.errCh = make(chan error, 4)
if s.autoResolveDomains {
s.resolveDomains(srvCtx)
}
s.PeersManager() s.PeersManager()
s.GeoLocationManager() s.GeoLocationManager()
@@ -237,7 +248,6 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close() _ = s.certManager.Listener().Close()
} }
s.GRPCServer().Stop() s.GRPCServer().Stop()
s.ReverseProxyGRPCServer().Close()
if s.proxyAuthClose != nil { if s.proxyAuthClose != nil {
s.proxyAuthClose() s.proxyAuthClose()
s.proxyAuthClose = nil s.proxyAuthClose = nil
@@ -381,6 +391,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
}() }()
} }
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
// Fresh installs use the default self-hosted domain, while existing installs reuse the
// persisted account domain to keep addressing stable across config changes.
func (s *BaseServer) resolveDomains(ctx context.Context) {
st := s.Store()
setDefault := func(logMsg string, args ...any) {
if logMsg != "" {
log.WithContext(ctx).Warnf(logMsg, args...)
}
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
}
accountsCount, err := st.GetAccountsCounter(ctx)
if err != nil {
setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountsCount == 0 {
s.dnsDomain = DefaultSelfHostedDomain
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain)
return
}
accountID, err := st.GetAnyAccountID(ctx)
if err != nil {
setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain)
return
}
if accountID == "" {
setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain)
return
}
domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain)
return
}
if domain == "" {
setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain)
return
}
s.dnsDomain = domain
s.mgmtSingleAccModeDomain = domain
log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain)
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) { func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID() installationID := store.GetInstallationID()
if installationID != "" { if installationID != "" {

View File

@@ -0,0 +1,63 @@
package server
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
)
func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil)
mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil)
mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil)
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
}
func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed"))
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
}

View File

@@ -0,0 +1,61 @@
package grpc
import (
"context"
"fmt"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
type PKCEVerifierStore struct {
cache *cache.Cache[string]
ctx context.Context
}
// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}
return &PKCEVerifierStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}
// Store saves a PKCE verifier associated with an OAuth state parameter.
// The verifier is stored with the specified TTL and will be automatically deleted after expiration.
func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error {
if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil {
return fmt.Errorf("failed to store PKCE verifier: %w", err)
}
log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl)
return nil
}
// LoadAndDelete retrieves and removes a PKCE verifier for the given state.
// Returns the verifier and true if found, or empty string and false if not found.
// This enforces single-use semantics for PKCE verifiers.
func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) {
verifier, err := s.cache.Get(s.ctx, state)
if err != nil {
log.Debugf("PKCE verifier not found for state")
return "", false
}
if err := s.cache.Delete(s.ctx, state); err != nil {
log.Warnf("Failed to delete PKCE verifier for state: %v", err)
}
return verifier, true
}

View File

@@ -18,7 +18,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
@@ -83,20 +82,12 @@ type ProxyServiceServer struct {
// OIDC configuration for proxy authentication // OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig oidcConfig ProxyOIDCConfig
// TODO: use database to store these instead? // Store for PKCE verifiers
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state. pkceVerifierStore *PKCEVerifierStore
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
pkceVerifiers sync.Map
pkceCleanupCancel context.CancelFunc
} }
const pkceVerifierTTL = 10 * time.Minute const pkceVerifierTTL = 10 * time.Minute
type pkceEntry struct {
verifier string
createdAt time.Time
}
// proxyConnection represents a connected proxy // proxyConnection represents a connected proxy
type proxyConnection struct { type proxyConnection struct {
proxyID string proxyID string
@@ -108,42 +99,21 @@ type proxyConnection struct {
} }
// NewProxyServiceServer creates a new proxy service server. // NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background()) ctx := context.Background()
s := &ProxyServiceServer{ s := &ProxyServiceServer{
accessLogManager: accessLogMgr, accessLogManager: accessLogMgr,
oidcConfig: oidcConfig, oidcConfig: oidcConfig,
tokenStore: tokenStore, tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
peersManager: peersManager, peersManager: peersManager,
usersManager: usersManager, usersManager: usersManager,
proxyManager: proxyMgr, proxyManager: proxyMgr,
pkceCleanupCancel: cancel,
} }
go s.cleanupPKCEVerifiers(ctx)
go s.cleanupStaleProxies(ctx) go s.cleanupStaleProxies(ctx)
return s return s
} }
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
ticker := time.NewTicker(pkceVerifierTTL)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
s.pkceVerifiers.Range(func(key, value any) bool {
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
s.pkceVerifiers.Delete(key)
}
return true
})
}
}
}
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes // cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
@@ -160,11 +130,6 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
} }
} }
// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
}
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) { func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.serviceManager = manager s.serviceManager = manager
} }
@@ -177,11 +142,7 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error { func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
ctx := stream.Context() ctx := stream.Context()
peerInfo := "" peerInfo := PeerIPFromContext(ctx)
if p, ok := peer.FromContext(ctx); ok {
peerInfo = p.Addr.String()
}
log.Infof("New proxy connection from %s", peerInfo) log.Infof("New proxy connection from %s", peerInfo)
proxyID := req.GetProxyId() proxyID := req.GetProxyId()
@@ -795,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum) state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
codeVerifier := oauth2.GenerateVerifier() codeVerifier := oauth2.GenerateVerifier()
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()}) if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
}
return &proto.GetOIDCURLResponse{ return &proto.GetOIDCURLResponse{
Url: (&oauth2.Config{ Url: (&oauth2.Config{
@@ -832,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
// ValidateState validates the state parameter from an OAuth callback. // ValidateState validates the state parameter from an OAuth callback.
// Returns the original redirect URL if valid, or an error if invalid. // Returns the original redirect URL if valid, or an error if invalid.
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) { func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
v, ok := s.pkceVerifiers.LoadAndDelete(state) verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok { if !ok {
return "", "", errors.New("no verifier for state") return "", "", errors.New("no verifier for state")
} }
entry, ok := v.(pkceEntry)
if !ok {
return "", "", errors.New("invalid verifier for state")
}
if time.Since(entry.createdAt) > pkceVerifierTTL {
return "", "", errors.New("PKCE verifier expired")
}
verifier = entry.verifier
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce) // State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|") parts := strings.Split(state, "|")

View File

@@ -107,7 +107,7 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
} }
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) { func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
clientIP := peerIPFromContext(ctx) clientIP := PeerIPFromContext(ctx)
if clientIP != "" && i.failureLimiter.isLimited(clientIP) { if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts") return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")

View File

@@ -115,9 +115,9 @@ func (l *authFailureLimiter) stop() {
l.cancel() l.cancel()
} }
// peerIPFromContext extracts the client IP from the gRPC context. // PeerIPFromContext extracts the client IP from the gRPC context.
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address. // Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
func peerIPFromContext(ctx context.Context) clientIP { func PeerIPFromContext(ctx context.Context) string {
if addr, ok := realip.FromContext(ctx); ok { if addr, ok := realip.FromContext(ctx); ok {
return addr.String() return addr.String()
} }

View File

@@ -5,11 +5,10 @@ import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"sync"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -94,11 +93,16 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
} }
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
tokenStore: tokenStore, tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
} }
s.SetProxyController(newTestProxyController()) s.SetProxyController(newTestProxyController())
@@ -151,11 +155,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
} }
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
tokenStore: tokenStore, tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
} }
s.SetProxyController(newTestProxyController()) s.SetProxyController(newTestProxyController())
@@ -185,11 +194,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
} }
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
tokenStore: tokenStore, tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
} }
s.SetProxyController(newTestProxyController()) s.SetProxyController(newTestProxyController())
@@ -241,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
} }
func TestOAuthState_NeverTheSame(t *testing.T) { func TestOAuthState_NeverTheSame(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{ oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"), HMACKey: []byte("test-hmac-key"),
}, },
pkceVerifierStore: pkceStore,
} }
redirectURL := "https://app.example.com/callback" redirectURL := "https://app.example.com/callback"
@@ -265,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
} }
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{ oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"), HMACKey: []byte("test-hmac-key"),
}, },
pkceVerifierStore: pkceStore,
} }
// Old format had only 2 parts: base64(url)|hmac // Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
require.NoError(t, err)
_, _, err := s.ValidateState("base64url|hmac") _, _, err = s.ValidateState("base64url|hmac")
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format") assert.Contains(t, err.Error(), "invalid state format")
} }
func TestValidateState_RejectsInvalidHMAC(t *testing.T) { func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{ s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{ oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"), HMACKey: []byte("test-hmac-key"),
}, },
pkceVerifierStore: pkceStore,
} }
// Store with tampered HMAC // Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
require.NoError(t, err)
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac") _, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature") assert.Contains(t, err.Error(), "invalid state signature")
} }

View File

@@ -41,7 +41,10 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager) proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore) createTestProxies(t, ctx, testStore)

View File

@@ -1379,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
if am.singleAccountMode && am.singleAccountModeDomain != "" { if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations. // This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account. // We override incoming domain claims to group users under a single account.
userAuth.Domain = am.singleAccountModeDomain err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
userAuth.DomainCategory = types.PrivateCategory if err != nil {
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") return "", "", err
}
} }
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth) accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
@@ -1414,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil return accountID, user.Id, nil
} }
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
userAuth.DomainCategory = types.PrivateCategory
userAuth.Domain = am.singleAccountModeDomain
accountID, err := am.Store.GetAnyAccountID(ctx)
if err != nil {
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
return err
}
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
if accountID == "" {
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
userAuth.Domain = domain
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
return nil
}
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled. // and propagates changes to peers if group propagation is enabled.
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager // requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager

View File

@@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/prometheus/client_golang/prometheus/push" "github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -3132,7 +3133,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err return nil, nil, err
} }
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -3966,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange") t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
} }
} }
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("real-domain.com", "private", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "real-domain.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.NotFound, "no accounts"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.Internal, "db down"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "db down")
// Defaults should still be set before error path
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("", "", status.Errorf(status.Internal, "query failed"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "query failed")
})
}

View File

@@ -220,6 +220,13 @@ const (
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account // AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
AccountPeerExposeDisabled Activity = 115 AccountPeerExposeDisabled Activity = 115
// DomainAdded indicates that a user added a custom domain
DomainAdded Activity = 116
// DomainDeleted indicates that a user deleted a custom domain
DomainDeleted Activity = 117
// DomainValidated indicates that a custom domain was validated
DomainValidated Activity = 118
AccountDeleted Activity = 99999 AccountDeleted Activity = 99999
) )
@@ -364,6 +371,10 @@ var activityMap = map[Activity]Code{
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"}, AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"}, AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@@ -193,6 +193,9 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore) usersManager := users.NewManager(testStore)
oidcConfig := nbgrpc.ProxyOIDCConfig{ oidcConfig := nbgrpc.ProxyOIDCConfig{
@@ -206,6 +209,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
proxyService := nbgrpc.NewProxyServiceServer( proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{}, &testAccessLogManager{},
tokenStore, tokenStore,
pkceStore,
oidcConfig, oidcConfig,
nil, nil,
usersManager, usersManager,

View File

@@ -98,13 +98,17 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil { if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err) t.Fatalf("Failed to create proxy token store: %v", err)
} }
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
noopMeter := noop.NewMeterProvider().Meter("") noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter) proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil { if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err) t.Fatalf("Failed to create proxy manager: %v", err)
} }
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil { if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err) t.Fatalf("Failed to create proxy controller: %v", err)

View File

@@ -4977,9 +4977,9 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
return service, nil return service, nil
} }
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) { func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
var service *rpservice.Service var service *rpservice.Service
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service) result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain) return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)

View File

@@ -257,7 +257,7 @@ type Store interface {
UpdateService(ctx context.Context, service *rpservice.Service) error UpdateService(ctx context.Context, service *rpservice.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)

View File

@@ -1932,18 +1932,18 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
} }
// GetServiceByDomain mocks base method. // GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) { func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain) ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*service.Service) ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetServiceByDomain indicates an expected call of GetServiceByDomain. // GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain interface{}) *gomock.Call { func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, accountID, domain) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, domain)
} }
// GetServiceByID mocks base method. // GetServiceByID mocks base method.

View File

@@ -0,0 +1,185 @@
package telemetry
import (
"context"
"math"
"sync"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
)
// AccountDurationAggregator uses OpenTelemetry histograms per account to calculate P95
// without publishing individual account labels
type AccountDurationAggregator struct {
mu sync.RWMutex
accounts map[string]*accountHistogram
meterProvider *sdkmetric.MeterProvider
manualReader *sdkmetric.ManualReader
FlushInterval time.Duration
MaxAge time.Duration
ctx context.Context
}
type accountHistogram struct {
histogram metric.Int64Histogram
lastUpdate time.Time
}
// NewAccountDurationAggregator creates aggregator using OTel histograms
func NewAccountDurationAggregator(ctx context.Context, flushInterval, maxAge time.Duration) *AccountDurationAggregator {
manualReader := sdkmetric.NewManualReader(
sdkmetric.WithTemporalitySelector(func(kind sdkmetric.InstrumentKind) metricdata.Temporality {
return metricdata.DeltaTemporality
}),
)
meterProvider := sdkmetric.NewMeterProvider(
sdkmetric.WithReader(manualReader),
)
return &AccountDurationAggregator{
accounts: make(map[string]*accountHistogram),
meterProvider: meterProvider,
manualReader: manualReader,
FlushInterval: flushInterval,
MaxAge: maxAge,
ctx: ctx,
}
}
// Record adds a duration for an account using OTel histogram
func (a *AccountDurationAggregator) Record(accountID string, duration time.Duration) {
a.mu.Lock()
defer a.mu.Unlock()
accHist, exists := a.accounts[accountID]
if !exists {
meter := a.meterProvider.Meter("account-aggregator")
histogram, err := meter.Int64Histogram(
"sync_duration_per_account",
metric.WithUnit("milliseconds"),
)
if err != nil {
return
}
accHist = &accountHistogram{
histogram: histogram,
}
a.accounts[accountID] = accHist
}
accHist.histogram.Record(a.ctx, duration.Milliseconds(),
metric.WithAttributes(attribute.String("account_id", accountID)))
accHist.lastUpdate = time.Now()
}
// FlushAndGetP95s extracts P95 from each account's histogram
func (a *AccountDurationAggregator) FlushAndGetP95s() []int64 {
a.mu.Lock()
defer a.mu.Unlock()
var rm metricdata.ResourceMetrics
err := a.manualReader.Collect(a.ctx, &rm)
if err != nil {
return nil
}
now := time.Now()
p95s := make([]int64, 0, len(a.accounts))
for _, scopeMetrics := range rm.ScopeMetrics {
for _, metric := range scopeMetrics.Metrics {
histogramData, ok := metric.Data.(metricdata.Histogram[int64])
if !ok {
continue
}
for _, dataPoint := range histogramData.DataPoints {
a.processDataPoint(dataPoint, now, &p95s)
}
}
}
a.cleanupStaleAccounts(now)
return p95s
}
// processDataPoint extracts P95 from a single histogram data point
func (a *AccountDurationAggregator) processDataPoint(dataPoint metricdata.HistogramDataPoint[int64], now time.Time, p95s *[]int64) {
accountID := extractAccountID(dataPoint)
if accountID == "" {
return
}
if p95 := calculateP95FromHistogram(dataPoint); p95 > 0 {
*p95s = append(*p95s, p95)
}
}
// cleanupStaleAccounts removes accounts that haven't been updated recently
func (a *AccountDurationAggregator) cleanupStaleAccounts(now time.Time) {
for accountID := range a.accounts {
if a.isStaleAccount(accountID, now) {
delete(a.accounts, accountID)
}
}
}
// extractAccountID retrieves the account_id from histogram data point attributes
func extractAccountID(dp metricdata.HistogramDataPoint[int64]) string {
for _, attr := range dp.Attributes.ToSlice() {
if attr.Key == "account_id" {
return attr.Value.AsString()
}
}
return ""
}
// isStaleAccount checks if an account hasn't been updated recently
func (a *AccountDurationAggregator) isStaleAccount(accountID string, now time.Time) bool {
accHist, exists := a.accounts[accountID]
if !exists {
return false
}
return now.Sub(accHist.lastUpdate) > a.MaxAge
}
// calculateP95FromHistogram computes P95 from OTel histogram data
func calculateP95FromHistogram(dp metricdata.HistogramDataPoint[int64]) int64 {
if dp.Count == 0 {
return 0
}
targetCount := uint64(math.Ceil(float64(dp.Count) * 0.95))
if targetCount == 0 {
targetCount = 1
}
var cumulativeCount uint64
for i, bucketCount := range dp.BucketCounts {
cumulativeCount += bucketCount
if cumulativeCount >= targetCount {
if i < len(dp.Bounds) {
return int64(dp.Bounds[i])
}
if maxVal, defined := dp.Max.Value(); defined {
return maxVal
}
return dp.Sum / int64(dp.Count)
}
}
return dp.Sum / int64(dp.Count)
}
// Shutdown cleans up resources
func (a *AccountDurationAggregator) Shutdown() error {
return a.meterProvider.Shutdown(a.ctx)
}

View File

@@ -0,0 +1,219 @@
package telemetry
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDeltaTemporality_P95ReflectsCurrentWindow(t *testing.T) {
// Verify that with delta temporality, each flush window only reflects
// recordings since the last flush — not all-time data.
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// Window 1: Record 100 slow requests (500ms each)
for range 100 {
agg.Record("account-A", 500*time.Millisecond)
}
p95sWindow1 := agg.FlushAndGetP95s()
require.Len(t, p95sWindow1, 1, "should have P95 for one account")
firstP95 := p95sWindow1[0]
assert.GreaterOrEqual(t, firstP95, int64(200),
"first window P95 should reflect the 500ms recordings")
// Window 2: Record 100 FAST requests (10ms each)
for range 100 {
agg.Record("account-A", 10*time.Millisecond)
}
p95sWindow2 := agg.FlushAndGetP95s()
require.Len(t, p95sWindow2, 1, "should have P95 for one account")
secondP95 := p95sWindow2[0]
// With delta temporality the P95 should drop significantly because
// the first window's slow recordings are no longer included.
assert.Less(t, secondP95, firstP95,
"second window P95 should be lower than first — delta temporality "+
"ensures each window only reflects recent recordings")
}
func TestEqualWeightPerAccount(t *testing.T) {
// Verify that each account contributes exactly one P95 value,
// regardless of how many requests it made.
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// Account A: 10,000 requests at 500ms (noisy customer)
for range 10000 {
agg.Record("account-A", 500*time.Millisecond)
}
// Accounts B, C, D: 10 requests each at 50ms (normal customers)
for _, id := range []string{"account-B", "account-C", "account-D"} {
for range 10 {
agg.Record(id, 50*time.Millisecond)
}
}
p95s := agg.FlushAndGetP95s()
// Should get exactly 4 P95 values — one per account
assert.Len(t, p95s, 4, "each account should contribute exactly one P95")
}
func TestStaleAccountEviction(t *testing.T) {
ctx := context.Background()
// Use a very short MaxAge so we can test staleness
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
agg.Record("account-A", 100*time.Millisecond)
agg.Record("account-B", 200*time.Millisecond)
// Both accounts should appear
p95s := agg.FlushAndGetP95s()
assert.Len(t, p95s, 2, "both accounts should have P95 values")
// Wait for account-A to become stale, then only update account-B
time.Sleep(60 * time.Millisecond)
agg.Record("account-B", 200*time.Millisecond)
p95s = agg.FlushAndGetP95s()
assert.Len(t, p95s, 1, "both accounts should have P95 values")
// account-A should have been evicted from the accounts map
agg.mu.RLock()
_, accountAExists := agg.accounts["account-A"]
_, accountBExists := agg.accounts["account-B"]
agg.mu.RUnlock()
assert.False(t, accountAExists, "stale account-A should be evicted from map")
assert.True(t, accountBExists, "active account-B should remain in map")
}
func TestStaleAccountEviction_DoesNotReappear(t *testing.T) {
// Verify that with delta temporality, an evicted stale account does not
// reappear in subsequent flushes.
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
agg.Record("account-stale", 100*time.Millisecond)
// Wait for it to become stale
time.Sleep(60 * time.Millisecond)
// First flush: should detect staleness and evict
_ = agg.FlushAndGetP95s()
agg.mu.RLock()
_, exists := agg.accounts["account-stale"]
agg.mu.RUnlock()
assert.False(t, exists, "account should be evicted after first flush")
// Second flush: with delta temporality, the stale account should NOT reappear
p95sSecond := agg.FlushAndGetP95s()
assert.Empty(t, p95sSecond,
"evicted account should not reappear in subsequent flushes with delta temporality")
}
func TestP95Calculation_SingleSample(t *testing.T) {
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
agg.Record("account-A", 150*time.Millisecond)
p95s := agg.FlushAndGetP95s()
require.Len(t, p95s, 1)
// With a single sample, P95 should be the bucket bound containing 150ms
assert.Greater(t, p95s[0], int64(0), "P95 of a single sample should be positive")
}
func TestP95Calculation_AllSameValue(t *testing.T) {
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// All samples are 100ms — P95 should be the bucket bound containing 100ms
for range 100 {
agg.Record("account-A", 100*time.Millisecond)
}
p95s := agg.FlushAndGetP95s()
require.Len(t, p95s, 1)
assert.Greater(t, p95s[0], int64(0))
}
func TestMultipleAccounts_IndependentP95s(t *testing.T) {
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// Account A: all fast (10ms)
for range 100 {
agg.Record("account-fast", 10*time.Millisecond)
}
// Account B: all slow (5000ms)
for range 100 {
agg.Record("account-slow", 5000*time.Millisecond)
}
p95s := agg.FlushAndGetP95s()
require.Len(t, p95s, 2, "should have two P95 values")
// Find min and max — they should differ significantly
minP95 := p95s[0]
maxP95 := p95s[1]
if minP95 > maxP95 {
minP95, maxP95 = maxP95, minP95
}
assert.Less(t, minP95, int64(1000),
"fast account P95 should be well under 1000ms")
assert.Greater(t, maxP95, int64(1000),
"slow account P95 should be well over 1000ms")
}

View File

@@ -22,9 +22,15 @@ type GRPCMetrics struct {
getKeyRequestsCounter metric.Int64Counter getKeyRequestsCounter metric.Int64Counter
activeStreamsGauge metric.Int64ObservableGauge activeStreamsGauge metric.Int64ObservableGauge
syncRequestDuration metric.Int64Histogram syncRequestDuration metric.Int64Histogram
syncRequestDurationP95ByAccount metric.Int64Histogram
loginRequestDuration metric.Int64Histogram loginRequestDuration metric.Int64Histogram
loginRequestDurationP95ByAccount metric.Int64Histogram
channelQueueLength metric.Int64Histogram channelQueueLength metric.Int64Histogram
ctx context.Context ctx context.Context
// Per-account aggregation
syncDurationAggregator *AccountDurationAggregator
loginDurationAggregator *AccountDurationAggregator
} }
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server // NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
@@ -93,6 +99,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err return nil, err
} }
syncRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.sync.request.duration.p95.by.account.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("P95 duration of sync requests aggregated per account - each data point represents one account's P95"),
)
if err != nil {
return nil, err
}
loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms", loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms",
metric.WithUnit("milliseconds"), metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of the login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), metric.WithDescription("Duration of the login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -101,6 +115,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err return nil, err
} }
loginRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.login.request.duration.p95.by.account.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("P95 duration of login requests aggregated per account - each data point represents one account's P95"),
)
if err != nil {
return nil, err
}
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time // We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
// Then we should be able to extract min, manx, mean and the percentiles. // Then we should be able to extract min, manx, mean and the percentiles.
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100) // TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
@@ -113,7 +135,10 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err return nil, err
} }
return &GRPCMetrics{ syncDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
loginDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
grpcMetrics := &GRPCMetrics{
meter: meter, meter: meter,
syncRequestsCounter: syncRequestsCounter, syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter, syncRequestsBlockedCounter: syncRequestsBlockedCounter,
@@ -123,10 +148,19 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
getKeyRequestsCounter: getKeyRequestsCounter, getKeyRequestsCounter: getKeyRequestsCounter,
activeStreamsGauge: activeStreamsGauge, activeStreamsGauge: activeStreamsGauge,
syncRequestDuration: syncRequestDuration, syncRequestDuration: syncRequestDuration,
syncRequestDurationP95ByAccount: syncRequestDurationP95ByAccount,
loginRequestDuration: loginRequestDuration, loginRequestDuration: loginRequestDuration,
loginRequestDurationP95ByAccount: loginRequestDurationP95ByAccount,
channelQueueLength: channelQueue, channelQueueLength: channelQueue,
ctx: ctx, ctx: ctx,
}, err syncDurationAggregator: syncDurationAggregator,
loginDurationAggregator: loginDurationAggregator,
}
go grpcMetrics.startSyncP95Flusher()
go grpcMetrics.startLoginP95Flusher()
return grpcMetrics, err
} }
// CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API // CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API
@@ -157,6 +191,9 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() {
// CountLoginRequestDuration counts the duration of the login gRPC requests // CountLoginRequestDuration counts the duration of the login gRPC requests
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) { func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
grpcMetrics.loginDurationAggregator.Record(accountID, duration)
if duration > HighLatencyThreshold { if duration > HighLatencyThreshold {
grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
} }
@@ -165,6 +202,44 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
// CountSyncRequestDuration counts the duration of the sync gRPC requests // CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
grpcMetrics.syncDurationAggregator.Record(accountID, duration)
}
// startSyncP95Flusher periodically flushes per-account sync P95 values to the histogram
func (grpcMetrics *GRPCMetrics) startSyncP95Flusher() {
ticker := time.NewTicker(grpcMetrics.syncDurationAggregator.FlushInterval)
defer ticker.Stop()
for {
select {
case <-grpcMetrics.ctx.Done():
return
case <-ticker.C:
p95s := grpcMetrics.syncDurationAggregator.FlushAndGetP95s()
for _, p95 := range p95s {
grpcMetrics.syncRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
}
}
}
}
// startLoginP95Flusher periodically flushes per-account login P95 values to the histogram
func (grpcMetrics *GRPCMetrics) startLoginP95Flusher() {
ticker := time.NewTicker(grpcMetrics.loginDurationAggregator.FlushInterval)
defer ticker.Stop()
for {
select {
case <-grpcMetrics.ctx.Done():
return
case <-ticker.C:
p95s := grpcMetrics.loginDurationAggregator.FlushAndGetP95s()
for _, p95 := range p95s {
grpcMetrics.loginRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
}
}
}
} }
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.

View File

@@ -3,6 +3,7 @@ package accesslog
import ( import (
"context" "context"
"net/netip" "net/netip"
"sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -13,6 +14,23 @@ import (
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
const (
requestThreshold = 10000 // Log every 10k requests
bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB
usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour
usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours
)
type domainUsage struct {
requestCount int64
requestStartTime time.Time
bytesTransferred int64
bytesStartTime time.Time
lastActivity time.Time // Track last activity for cleanup
}
type gRPCClient interface { type gRPCClient interface {
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error) SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
} }
@@ -22,6 +40,11 @@ type Logger struct {
client gRPCClient client gRPCClient
logger *log.Logger logger *log.Logger
trustedProxies []netip.Prefix trustedProxies []netip.Prefix
usageMux sync.Mutex
domainUsage map[string]*domainUsage
cleanupCancel context.CancelFunc
} }
// NewLogger creates a new access log Logger. The trustedProxies parameter // NewLogger creates a new access log Logger. The trustedProxies parameter
@@ -31,10 +54,26 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre
if logger == nil { if logger == nil {
logger = log.StandardLogger() logger = log.StandardLogger()
} }
return &Logger{
ctx, cancel := context.WithCancel(context.Background())
l := &Logger{
client: client, client: client,
logger: logger, logger: logger,
trustedProxies: trustedProxies, trustedProxies: trustedProxies,
domainUsage: make(map[string]*domainUsage),
cleanupCancel: cancel,
}
// Start background cleanup routine
go l.cleanupStaleUsage(ctx)
return l
}
// Close stops the cleanup routine. Should be called during graceful shutdown.
func (l *Logger) Close() {
if l.cleanupCancel != nil {
l.cleanupCancel()
} }
} }
@@ -51,6 +90,8 @@ type logEntry struct {
AuthMechanism string AuthMechanism string
UserId string UserId string
AuthSuccess bool AuthSuccess bool
BytesUpload int64
BytesDownload int64
} }
func (l *Logger) log(ctx context.Context, entry logEntry) { func (l *Logger) log(ctx context.Context, entry logEntry) {
@@ -84,6 +125,8 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
AuthMechanism: entry.AuthMechanism, AuthMechanism: entry.AuthMechanism,
UserId: entry.UserId, UserId: entry.UserId,
AuthSuccess: entry.AuthSuccess, AuthSuccess: entry.AuthSuccess,
BytesUpload: entry.BytesUpload,
BytesDownload: entry.BytesDownload,
}, },
}); err != nil { }); err != nil {
// If it fails to send on the gRPC connection, then at least log it to the error log. // If it fails to send on the gRPC connection, then at least log it to the error log.
@@ -103,3 +146,82 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
} }
}() }()
} }
// trackUsage records request and byte counts per domain, logging when thresholds are hit.
func (l *Logger) trackUsage(domain string, bytesTransferred int64) {
if domain == "" {
return
}
l.usageMux.Lock()
defer l.usageMux.Unlock()
now := time.Now()
usage, exists := l.domainUsage[domain]
if !exists {
usage = &domainUsage{
requestStartTime: now,
bytesStartTime: now,
lastActivity: now,
}
l.domainUsage[domain] = usage
}
usage.lastActivity = now
usage.requestCount++
if usage.requestCount >= requestThreshold {
elapsed := time.Since(usage.requestStartTime)
l.logger.WithFields(log.Fields{
"domain": domain,
"requests": usage.requestCount,
"duration": elapsed.String(),
}).Infof("domain %s had %d requests over %s", domain, usage.requestCount, elapsed)
usage.requestCount = 0
usage.requestStartTime = now
}
usage.bytesTransferred += bytesTransferred
if usage.bytesTransferred >= bytesThreshold {
elapsed := time.Since(usage.bytesStartTime)
bytesInGB := float64(usage.bytesTransferred) / (1024 * 1024 * 1024)
l.logger.WithFields(log.Fields{
"domain": domain,
"bytes": usage.bytesTransferred,
"bytes_gb": bytesInGB,
"duration": elapsed.String(),
}).Infof("domain %s transferred %.2f GB over %s", domain, bytesInGB, elapsed)
usage.bytesTransferred = 0
usage.bytesStartTime = now
}
}
// cleanupStaleUsage removes usage entries for domains that have been inactive.
func (l *Logger) cleanupStaleUsage(ctx context.Context) {
ticker := time.NewTicker(usageCleanupPeriod)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
l.usageMux.Lock()
now := time.Now()
removed := 0
for domain, usage := range l.domainUsage {
if now.Sub(usage.lastActivity) > usageInactiveWindow {
delete(l.domainUsage, domain)
removed++
}
}
l.usageMux.Unlock()
if removed > 0 {
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
}
}
}
}

View File

@@ -32,6 +32,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
status: http.StatusOK, status: http.StatusOK,
} }
var bytesRead int64
if r.Body != nil {
r.Body = &bodyCounter{
ReadCloser: r.Body,
bytesRead: &bytesRead,
}
}
// Resolve the source IP using trusted proxy configuration before passing // Resolve the source IP using trusted proxy configuration before passing
// the request on, as the proxy will modify forwarding headers. // the request on, as the proxy will modify forwarding headers.
sourceIp := extractSourceIP(r, l.trustedProxies) sourceIp := extractSourceIP(r, l.trustedProxies)
@@ -53,6 +61,9 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
host = r.Host host = r.Host
} }
bytesUpload := bytesRead
bytesDownload := sw.bytesWritten
entry := logEntry{ entry := logEntry{
ID: requestID, ID: requestID,
ServiceId: capturedData.GetServiceId(), ServiceId: capturedData.GetServiceId(),
@@ -66,10 +77,15 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
AuthMechanism: capturedData.GetAuthMethod(), AuthMechanism: capturedData.GetAuthMethod(),
UserId: capturedData.GetUserID(), UserId: capturedData.GetUserID(),
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden, AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
BytesUpload: bytesUpload,
BytesDownload: bytesDownload,
} }
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s", l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId()) requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
l.log(r.Context(), entry) l.log(r.Context(), entry)
// Track usage for cost monitoring (upload + download) by domain
l.trackUsage(host, bytesUpload+bytesDownload)
}) })
} }

View File

@@ -1,18 +1,39 @@
package accesslog package accesslog
import ( import (
"io"
"github.com/netbirdio/netbird/proxy/internal/responsewriter" "github.com/netbirdio/netbird/proxy/internal/responsewriter"
) )
// statusWriter captures the HTTP status code from WriteHeader calls. // statusWriter captures the HTTP status code and bytes written from responses.
// It embeds responsewriter.PassthroughWriter which handles all the optional // It embeds responsewriter.PassthroughWriter which handles all the optional
// interfaces (Hijacker, Flusher, Pusher) automatically. // interfaces (Hijacker, Flusher, Pusher) automatically.
type statusWriter struct { type statusWriter struct {
*responsewriter.PassthroughWriter *responsewriter.PassthroughWriter
status int status int
bytesWritten int64
} }
func (w *statusWriter) WriteHeader(status int) { func (w *statusWriter) WriteHeader(status int) {
w.status = status w.status = status
w.PassthroughWriter.WriteHeader(status) w.PassthroughWriter.WriteHeader(status)
} }
func (w *statusWriter) Write(b []byte) (int, error) {
n, err := w.PassthroughWriter.Write(b)
w.bytesWritten += int64(n)
return n, err
}
// bodyCounter wraps an io.ReadCloser and counts bytes read from the request body.
type bodyCounter struct {
io.ReadCloser
bytesRead *int64
}
func (bc *bodyCounter) Read(p []byte) (int, error) {
n, err := bc.ReadCloser.Read(p)
*bc.bytesRead += int64(n)
return n, err
}

View File

@@ -42,6 +42,10 @@ type domainInfo struct {
err string err string
} }
type metricsRecorder interface {
RecordCertificateIssuance(duration time.Duration)
}
// Manager wraps autocert.Manager with domain tracking and cross-replica // Manager wraps autocert.Manager with domain tracking and cross-replica
// coordination via a pluggable locking strategy. The locker prevents // coordination via a pluggable locking strategy. The locker prevents
// duplicate ACME requests when multiple replicas share a certificate cache. // duplicate ACME requests when multiple replicas share a certificate cache.
@@ -55,6 +59,7 @@ type Manager struct {
certNotifier certificateNotifier certNotifier certificateNotifier
logger *log.Logger logger *log.Logger
metrics metricsRecorder
} }
// NewManager creates a new ACME certificate manager. The certDir is used // NewManager creates a new ACME certificate manager. The certDir is used
@@ -63,7 +68,7 @@ type Manager struct {
// eabKID and eabHMACKey are optional External Account Binding credentials // eabKID and eabHMACKey are optional External Account Binding credentials
// required for some CAs like ZeroSSL. The eabHMACKey should be the base64 // required for some CAs like ZeroSSL. The eabHMACKey should be the base64
// URL-encoded string provided by the CA. // URL-encoded string provided by the CA.
func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager { func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod, metrics metricsRecorder) *Manager {
if logger == nil { if logger == nil {
logger = log.StandardLogger() logger = log.StandardLogger()
} }
@@ -73,6 +78,7 @@ func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificat
domains: make(map[domain.Domain]*domainInfo), domains: make(map[domain.Domain]*domainInfo),
certNotifier: notifier, certNotifier: notifier,
logger: logger, logger: logger,
metrics: metrics,
} }
var eab *acme.ExternalAccountBinding var eab *acme.ExternalAccountBinding
@@ -129,24 +135,45 @@ func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) {
// prefetchCertificate proactively triggers certificate generation for a domain. // prefetchCertificate proactively triggers certificate generation for a domain.
// It acquires a distributed lock to prevent multiple replicas from issuing // It acquires a distributed lock to prevent multiple replicas from issuing
// duplicate ACME requests. The second replica will block until the first // duplicate ACME requests. If the certificate appears in the cache while waiting
// finishes, then find the certificate in the cache. // for the lock, it cancels the wait and uses the cached certificate.
func (mgr *Manager) prefetchCertificate(d domain.Domain) { func (mgr *Manager) prefetchCertificate(d domain.Domain) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
name := d.PunycodeString() name := d.PunycodeString()
if mgr.certExistsInCache(ctx, name) {
mgr.logger.Infof("certificate for domain %q already exists in cache before lock attempt", name)
mgr.loadAndFinalizeCachedCert(ctx, d, name)
return
}
go mgr.pollCacheAndCancel(ctx, name, cancel)
// Acquire lock
mgr.logger.Infof("acquiring cert lock for domain %q", name) mgr.logger.Infof("acquiring cert lock for domain %q", name)
lockStart := time.Now() lockStart := time.Now()
unlock, err := mgr.locker.Lock(ctx, name) unlock, err := mgr.locker.Lock(ctx, name)
if err != nil { if err != nil {
mgr.logger.Warnf("acquire cert lock for domain %q, proceeding without lock: %v", name, err) if mgr.certExistsInCache(context.Background(), name) {
mgr.logger.Infof("certificate for domain %q appeared in cache while waiting for lock", name)
mgr.loadAndFinalizeCachedCert(context.Background(), d, name)
return
}
mgr.logger.Warnf("acquire cert lock for domain %q: %v", name, err)
// Continue without lock
} else { } else {
mgr.logger.Infof("acquired cert lock for domain %q in %s", name, time.Since(lockStart)) mgr.logger.Infof("acquired cert lock for domain %q in %s", name, time.Since(lockStart))
defer unlock() defer unlock()
} }
if mgr.certExistsInCache(ctx, name) {
mgr.logger.Infof("certificate for domain %q already exists in cache after lock acquisition, skipping ACME request", name)
mgr.loadAndFinalizeCachedCert(ctx, d, name)
return
}
hello := &tls.ClientHelloInfo{ hello := &tls.ClientHelloInfo{
ServerName: name, ServerName: name,
Conn: &dummyConn{ctx: ctx}, Conn: &dummyConn{ctx: ctx},
@@ -161,6 +188,10 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) {
return return
} }
if mgr.metrics != nil {
mgr.metrics.RecordCertificateIssuance(elapsed)
}
mgr.setDomainState(d, domainReady, "") mgr.setDomainState(d, domainReady, "")
now := time.Now() now := time.Now()
@@ -199,6 +230,78 @@ func (mgr *Manager) setDomainState(d domain.Domain, state domainState, errMsg st
} }
} }
// certExistsInCache checks if a certificate exists on the shared disk for the given domain.
func (mgr *Manager) certExistsInCache(ctx context.Context, domain string) bool {
if mgr.Cache == nil {
return false
}
_, err := mgr.Cache.Get(ctx, domain)
return err == nil
}
// pollCacheAndCancel periodically checks if a certificate appears in the cache.
// If found, it cancels the context to abort the lock wait.
func (mgr *Manager) pollCacheAndCancel(ctx context.Context, domain string, cancel context.CancelFunc) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if mgr.certExistsInCache(context.Background(), domain) {
mgr.logger.Debugf("cert detected in cache for domain %q, cancelling lock wait", domain)
cancel()
return
}
}
}
}
// loadAndFinalizeCachedCert loads a certificate from cache and updates domain state.
func (mgr *Manager) loadAndFinalizeCachedCert(ctx context.Context, d domain.Domain, name string) {
hello := &tls.ClientHelloInfo{
ServerName: name,
Conn: &dummyConn{ctx: ctx},
}
cert, err := mgr.GetCertificate(hello)
if err != nil {
mgr.logger.Warnf("load cached certificate for domain %q: %v", name, err)
mgr.setDomainState(d, domainFailed, err.Error())
return
}
mgr.setDomainState(d, domainReady, "")
now := time.Now()
if cert != nil && cert.Leaf != nil {
leaf := cert.Leaf
mgr.logger.Infof("loaded cached certificate for domain %q: serial=%s SANs=%v notBefore=%s, notAfter=%s, now=%s",
name,
leaf.SerialNumber.Text(16),
leaf.DNSNames,
leaf.NotBefore.UTC().Format(time.RFC3339),
leaf.NotAfter.UTC().Format(time.RFC3339),
now.UTC().Format(time.RFC3339),
)
mgr.logCertificateDetails(name, leaf, now)
} else {
mgr.logger.Infof("loaded cached certificate for domain %q", name)
}
mgr.mu.RLock()
info := mgr.domains[d]
mgr.mu.RUnlock()
if info != nil && mgr.certNotifier != nil {
if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil {
mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err)
}
}
}
// logCertificateDetails logs certificate validity and SCT timestamps. // logCertificateDetails logs certificate validity and SCT timestamps.
func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) { func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) {
if cert.NotBefore.After(now) { if cert.NotBefore.After(now) {

View File

@@ -10,7 +10,7 @@ import (
) )
func TestHostPolicy(t *testing.T) { func TestHostPolicy(t *testing.T) {
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "") mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil)
mgr.AddDomain("example.com", "acc1", "rp1") mgr.AddDomain("example.com", "acc1", "rp1")
// Wait for the background prefetch goroutine to finish so the temp dir // Wait for the background prefetch goroutine to finish so the temp dir
@@ -70,7 +70,7 @@ func TestHostPolicy(t *testing.T) {
} }
func TestDomainStates(t *testing.T) { func TestDomainStates(t *testing.T) {
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "") mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil)
assert.Equal(t, 0, mgr.PendingCerts(), "initially zero") assert.Equal(t, 0, mgr.PendingCerts(), "initially zero")
assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains") assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains")

View File

@@ -1,64 +1,106 @@
package metrics package metrics
import ( import (
"context"
"net/http" "net/http"
"strconv" "sync"
"time" "time"
"github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/metric"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/responsewriter" "github.com/netbirdio/netbird/proxy/internal/responsewriter"
) )
type Metrics struct { type Metrics struct {
requestsTotal prometheus.Counter ctx context.Context
activeRequests prometheus.Gauge requestsTotal metric.Int64Counter
configuredDomains prometheus.Gauge activeRequests metric.Int64UpDownCounter
pathsPerDomain *prometheus.GaugeVec configuredDomains metric.Int64UpDownCounter
requestDuration *prometheus.HistogramVec totalPaths metric.Int64UpDownCounter
backendDuration *prometheus.HistogramVec requestDuration metric.Int64Histogram
backendDuration metric.Int64Histogram
certificateIssueDuration metric.Int64Histogram
mappingsMux sync.Mutex
mappingPaths map[string]int
} }
func New(reg prometheus.Registerer) *Metrics { func New(ctx context.Context, meter metric.Meter) (*Metrics, error) {
promFactory := promauto.With(reg) requestsTotal, err := meter.Int64Counter(
return &Metrics{ "proxy.http.request.counter",
requestsTotal: promFactory.NewCounter(prometheus.CounterOpts{ metric.WithUnit("1"),
Name: "netbird_proxy_requests_total", metric.WithDescription("Total number of requests made to the netbird proxy"),
Help: "Total number of requests made to the netbird proxy", )
}), if err != nil {
activeRequests: promFactory.NewGauge(prometheus.GaugeOpts{ return nil, err
Name: "netbird_proxy_active_requests_count",
Help: "Current in-flight requests handled by the netbird proxy",
}),
configuredDomains: promFactory.NewGauge(prometheus.GaugeOpts{
Name: "netbird_proxy_domains_count",
Help: "Current number of domains configured on the netbird proxy",
}),
pathsPerDomain: promFactory.NewGaugeVec(
prometheus.GaugeOpts{
Name: "netbird_proxy_paths_count",
Help: "Current number of paths configured on the netbird proxy labelled by domain",
},
[]string{"domain"},
),
requestDuration: promFactory.NewHistogramVec(
prometheus.HistogramOpts{
Name: "netbird_proxy_request_duration_seconds",
Help: "Duration of requests made to the netbird proxy",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
},
[]string{"status", "size", "method", "host", "path"},
),
backendDuration: promFactory.NewHistogramVec(prometheus.HistogramOpts{
Name: "netbird_proxy_backend_duration_seconds",
Help: "Duration of peer round trip time from the netbird proxy",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
},
[]string{"status", "size", "method", "host", "path"},
),
} }
activeRequests, err := meter.Int64UpDownCounter(
"proxy.http.active_requests",
metric.WithUnit("1"),
metric.WithDescription("Current in-flight requests handled by the netbird proxy"),
)
if err != nil {
return nil, err
}
configuredDomains, err := meter.Int64UpDownCounter(
"proxy.domains.count",
metric.WithUnit("1"),
metric.WithDescription("Current number of domains configured on the netbird proxy"),
)
if err != nil {
return nil, err
}
totalPaths, err := meter.Int64UpDownCounter(
"proxy.paths.count",
metric.WithUnit("1"),
metric.WithDescription("Total number of paths configured on the netbird proxy"),
)
if err != nil {
return nil, err
}
requestDuration, err := meter.Int64Histogram(
"proxy.http.request.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of requests made to the netbird proxy"),
)
if err != nil {
return nil, err
}
backendDuration, err := meter.Int64Histogram(
"proxy.backend.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of peer round trip time from the netbird proxy"),
)
if err != nil {
return nil, err
}
certificateIssueDuration, err := meter.Int64Histogram(
"proxy.certificate.issue.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of ACME certificate issuance"),
)
if err != nil {
return nil, err
}
return &Metrics{
ctx: ctx,
requestsTotal: requestsTotal,
activeRequests: activeRequests,
configuredDomains: configuredDomains,
totalPaths: totalPaths,
requestDuration: requestDuration,
backendDuration: backendDuration,
certificateIssueDuration: certificateIssueDuration,
mappingPaths: make(map[string]int),
}, nil
} }
type responseInterceptor struct { type responseInterceptor struct {
@@ -80,23 +122,19 @@ func (w *responseInterceptor) Write(b []byte) (int, error) {
func (m *Metrics) Middleware(next http.Handler) http.Handler { func (m *Metrics) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.requestsTotal.Inc() m.requestsTotal.Add(m.ctx, 1)
m.activeRequests.Inc() m.activeRequests.Add(m.ctx, 1)
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)} interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
start := time.Now() start := time.Now()
next.ServeHTTP(interceptor, r) defer func() {
duration := time.Since(start) duration := time.Since(start)
m.activeRequests.Add(m.ctx, -1)
m.requestDuration.Record(m.ctx, duration.Milliseconds())
}()
m.activeRequests.Desc() next.ServeHTTP(interceptor, r)
m.requestDuration.With(prometheus.Labels{
"status": strconv.Itoa(interceptor.status),
"size": strconv.Itoa(interceptor.size),
"method": r.Method,
"host": r.Host,
"path": r.URL.Path,
}).Observe(duration.Seconds())
}) })
} }
@@ -108,44 +146,52 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper { func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (*http.Response, error) { return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
labels := prometheus.Labels{
"method": req.Method,
"host": req.Host,
// Fill potentially empty labels with default values to avoid cardinality issues.
"path": "/",
"status": "0",
"size": "0",
}
if req.URL != nil {
labels["path"] = req.URL.Path
}
start := time.Now() start := time.Now()
res, err := next.RoundTrip(req) res, err := next.RoundTrip(req)
duration := time.Since(start) duration := time.Since(start)
// Not all labels will be available if there was an error. m.backendDuration.Record(m.ctx, duration.Milliseconds())
if res != nil {
labels["status"] = strconv.Itoa(res.StatusCode)
labels["size"] = strconv.Itoa(int(res.ContentLength))
}
m.backendDuration.With(labels).Observe(duration.Seconds())
return res, err return res, err
}) })
} }
func (m *Metrics) AddMapping(mapping proxy.Mapping) { func (m *Metrics) AddMapping(mapping proxy.Mapping) {
m.configuredDomains.Inc() m.mappingsMux.Lock()
m.pathsPerDomain.With(prometheus.Labels{ defer m.mappingsMux.Unlock()
"domain": mapping.Host,
}).Set(float64(len(mapping.Paths))) newPathCount := len(mapping.Paths)
oldPathCount, exists := m.mappingPaths[mapping.Host]
if !exists {
m.configuredDomains.Add(m.ctx, 1)
}
pathDelta := newPathCount - oldPathCount
if pathDelta != 0 {
m.totalPaths.Add(m.ctx, int64(pathDelta))
}
m.mappingPaths[mapping.Host] = newPathCount
} }
func (m *Metrics) RemoveMapping(mapping proxy.Mapping) { func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
m.configuredDomains.Dec() m.mappingsMux.Lock()
m.pathsPerDomain.With(prometheus.Labels{ defer m.mappingsMux.Unlock()
"domain": mapping.Host,
}).Set(0) oldPathCount, exists := m.mappingPaths[mapping.Host]
if !exists {
// Nothing to remove
return
}
m.configuredDomains.Add(m.ctx, -1)
m.totalPaths.Add(m.ctx, -int64(oldPathCount))
delete(m.mappingPaths, mapping.Host)
}
// RecordCertificateIssuance records the duration of a certificate issuance.
func (m *Metrics) RecordCertificateIssuance(duration time.Duration) {
m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds())
} }

View File

@@ -1,13 +1,17 @@
package metrics_test package metrics_test
import ( import (
"context"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"go.opentelemetry.io/otel/exporters/prometheus"
"go.opentelemetry.io/otel/sdk/metric"
"github.com/netbirdio/netbird/proxy/internal/metrics" "github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/prometheus/client_golang/prometheus"
) )
type testRoundTripper struct { type testRoundTripper struct {
@@ -47,7 +51,19 @@ func TestMetrics_RoundTripper(t *testing.T) {
}, },
} }
m := metrics.New(prometheus.NewRegistry()) exporter, err := prometheus.New()
if err != nil {
t.Fatalf("create prometheus exporter: %v", err)
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath()
meter := provider.Meter(pkg)
m, err := metrics.New(context.Background(), meter)
if err != nil {
t.Fatalf("create metrics: %v", err)
}
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {

View File

@@ -28,12 +28,14 @@ func BenchmarkServeHTTP(b *testing.B) {
ID: rand.Text(), ID: rand.Text(),
AccountID: types.AccountID(rand.Text()), AccountID: types.AccountID(rand.Text()),
Host: "app.example.com", Host: "app.example.com",
Paths: map[string]*url.URL{ Paths: map[string]*proxy.PathTarget{
"/": { "/": {
URL: &url.URL{
Scheme: "http", Scheme: "http",
Host: "10.0.0.1:8080", Host: "10.0.0.1:8080",
}, },
}, },
},
}) })
req := httptest.NewRequest(http.MethodGet, "http://app.example.com", nil) req := httptest.NewRequest(http.MethodGet, "http://app.example.com", nil)
@@ -67,12 +69,14 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
ID: id, ID: id,
AccountID: types.AccountID(rand.Text()), AccountID: types.AccountID(rand.Text()),
Host: host, Host: host,
Paths: map[string]*url.URL{ Paths: map[string]*proxy.PathTarget{
"/": { "/": {
URL: &url.URL{
Scheme: "http", Scheme: "http",
Host: "10.0.0.1:8080", Host: "10.0.0.1:8080",
}, },
}, },
},
}) })
} }
@@ -100,15 +104,17 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
paths := make(map[string]*url.URL, pathCount) paths := make(map[string]*proxy.PathTarget, pathCount)
for i := range pathCount { for i := range pathCount {
path := "/" + rand.Text() path := "/" + rand.Text()
if int64(i) == targetIndex.Int64() { if int64(i) == targetIndex.Int64() {
target = path target = path
} }
paths[path] = &url.URL{ paths[path] = &proxy.PathTarget{
URL: &url.URL{
Scheme: "http", Scheme: "http",
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i), Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
},
} }
} }
rp.AddMapping(proxy.Mapping{ rp.AddMapping(proxy.Mapping{

View File

@@ -80,14 +80,30 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
capturedData.SetAccountId(result.accountID) capturedData.SetAccountId(result.accountID)
} }
pt := result.target
if pt.SkipTLSVerify {
ctx = roundtrip.WithSkipTLSVerify(ctx)
}
if pt.RequestTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
defer cancel()
}
rewriteMatchedPath := result.matchedPath
if pt.PathRewrite == PathRewritePreserve {
rewriteMatchedPath = ""
}
rp := &httputil.ReverseProxy{ rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader), Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders),
Transport: p.transport, Transport: p.transport,
FlushInterval: -1, FlushInterval: -1,
ErrorHandler: proxyErrorHandler, ErrorHandler: proxyErrorHandler,
} }
if result.rewriteRedirects { if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
} }
rp.ServeHTTP(w, r.WithContext(ctx)) rp.ServeHTTP(w, r.WithContext(ctx))
} }
@@ -97,17 +113,23 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// forwarding headers and stripping proxy authentication credentials. // forwarding headers and stripping proxy authentication credentials.
// When passHostHeader is true, the original client Host header is preserved // When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address. // instead of being rewritten to the backend's address.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) { // The pathRewrite parameter controls how the request path is transformed.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) { return func(r *httputil.ProxyRequest) {
switch pathRewrite {
case PathRewritePreserve:
// Keep the full original request path as-is.
default:
if matchedPath != "" && matchedPath != "/" {
// Strip the matched path prefix from the incoming request path before // Strip the matched path prefix from the incoming request path before
// SetURL joins it with the target's base path, avoiding path duplication. // SetURL joins it with the target's base path, avoiding path duplication.
if matchedPath != "" && matchedPath != "/" {
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath) r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
if r.Out.URL.Path == "" { if r.Out.URL.Path == "" {
r.Out.URL.Path = "/" r.Out.URL.Path = "/"
} }
r.Out.URL.RawPath = "" r.Out.URL.RawPath = ""
} }
}
r.SetURL(target) r.SetURL(target)
if passHostHeader { if passHostHeader {
@@ -116,6 +138,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Host = target.Host r.Out.Host = target.Host
} }
for k, v := range customHeaders {
r.Out.Header.Set(k, v)
}
clientIP := extractClientIP(r.In.RemoteAddr) clientIP := extractClientIP(r.In.RemoteAddr)
if IsTrustedProxy(clientIP, p.trustedProxies) { if IsTrustedProxy(clientIP, p.trustedProxies) {

View File

@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites host to backend by default", func(t *testing.T) { t.Run("rewrites host to backend by default", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345") pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr) rewrite(pr)
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
}) })
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) { t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", true) rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345") pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr) rewrite(pr)
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
func TestRewriteFunc_XForwardedForStripping(t *testing.T) { func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080") target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) { t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) { t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000") pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) { t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) { t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{} pr.In.TLS = &tls.ConnectionState{}
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) { t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects https from TLS", func(t *testing.T) { t.Run("auto detects https from TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{} pr.In.TLS = &tls.ConnectionState{}
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("auto detects http without TLS", func(t *testing.T) { t.Run("auto detects http without TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced proto overrides TLS detection", func(t *testing.T) { t.Run("forced proto overrides TLS detection", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https"} p := &ReverseProxy{forwardedProto: "https"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
// No TLS, but forced to https // No TLS, but forced to https
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
t.Run("forced http proto", func(t *testing.T) { t.Run("forced http proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "http"} p := &ReverseProxy{forwardedProto: "http"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{} pr.In.TLS = &tls.ConnectionState{}
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
func TestRewriteFunc_SessionCookieStripping(t *testing.T) { func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080") target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
t.Run("strips nb_session cookie", func(t *testing.T) { t.Run("strips nb_session cookie", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) { func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080") target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"} p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
t.Run("strips session_token query parameter", func(t *testing.T) { t.Run("strips session_token query parameter", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("rewrites URL to target with path prefix", func(t *testing.T) { t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080/app") target, _ := url.Parse("http://backend.internal:8080/app")
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) { t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app") target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false) rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) { t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app") target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false) rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000") pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("appends to X-Forwarded-For", func(t *testing.T) { t.Run("appends to X-Forwarded-For", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Real-IP", func(t *testing.T) { t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) { t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) { t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Host", "original.example.com") pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) { t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Proto", "https") pr.In.Header.Set("X-Forwarded-Proto", "https")
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) { t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Port", "8443") pr.In.Header.Set("X-Forwarded-Port", "8443")
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) { t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) { t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) { t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1") pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) { t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) { t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false) rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) { t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
// Management builds: path="/heise", target="https://heise.de:443/heise" // Management builds: path="/heise", target="https://heise.de:443/heise"
target, _ := url.Parse("https://heise.de:443/heise") target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false) rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("subpath under prefix also preserved", func(t *testing.T) { t.Run("subpath under prefix also preserved", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443/heise") target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false) rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// What the behavior WOULD be if target URL had no path (true stripping) // What the behavior WOULD be if target URL had no path (true stripping)
t.Run("target without path prefix gives true stripping", func(t *testing.T) { t.Run("target without path prefix gives true stripping", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443") target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false) rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) { t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443") target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false) rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
// Root path "/" — no stripping expected // Root path "/" — no stripping expected
t.Run("root path forwards full request path unchanged", func(t *testing.T) { t.Run("root path forwards full request path unchanged", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.com:443/") target, _ := url.Parse("https://backend.example.com:443/")
rewrite := p.rewriteFunc(target, "/", false) rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000") pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr) rewrite(pr)
@@ -546,6 +546,82 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
}) })
} }
func TestRewriteFunc_PreservePath(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
t.Run("preserve keeps full request path", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil)
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/api/users/123", pr.Out.URL.Path,
"preserve should keep the full original request path")
})
t.Run("preserve with root matchedPath", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil)
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/anything", pr.Out.URL.Path)
})
}
func TestRewriteFunc_CustomHeaders(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
t.Run("injects custom headers", func(t *testing.T) {
headers := map[string]string{
"X-Custom-Auth": "token-abc",
"X-Env": "production",
}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "token-abc", pr.Out.Header.Get("X-Custom-Auth"))
assert.Equal(t, "production", pr.Out.Header.Get("X-Env"))
})
t.Run("nil customHeaders is fine", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
})
t.Run("custom headers override existing request headers", func(t *testing.T) {
headers := map[string]string{"X-Override": "new-value"}
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.Header.Set("X-Override", "old-value")
rewrite(pr)
assert.Equal(t, "new-value", pr.Out.Header.Get("X-Override"))
})
}
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
target, _ := url.Parse("http://backend.internal:8080")
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"})
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/api/deep/path", pr.Out.URL.Path, "preserve should keep the full original path")
assert.Equal(t, "proxy", pr.Out.Header.Get("X-Via"), "custom header should be set")
}
func TestRewriteLocationFunc(t *testing.T) { func TestRewriteLocationFunc(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080") target, _ := url.Parse("http://backend.internal:8080")
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} } newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }

View File

@@ -6,21 +6,41 @@ import (
"net/url" "net/url"
"sort" "sort"
"strings" "strings"
"time"
"github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/internal/types"
) )
// PathRewriteMode controls how the request path is rewritten before forwarding.
type PathRewriteMode int
const (
// PathRewriteDefault strips the matched prefix and joins with the target path.
PathRewriteDefault PathRewriteMode = iota
// PathRewritePreserve keeps the full original request path as-is.
PathRewritePreserve
)
// PathTarget holds a backend URL and per-target behavioral options.
type PathTarget struct {
URL *url.URL
SkipTLSVerify bool
RequestTimeout time.Duration
PathRewrite PathRewriteMode
CustomHeaders map[string]string
}
type Mapping struct { type Mapping struct {
ID string ID string
AccountID types.AccountID AccountID types.AccountID
Host string Host string
Paths map[string]*url.URL Paths map[string]*PathTarget
PassHostHeader bool PassHostHeader bool
RewriteRedirects bool RewriteRedirects bool
} }
type targetResult struct { type targetResult struct {
url *url.URL target *PathTarget
matchedPath string matchedPath string
serviceID string serviceID string
accountID types.AccountID accountID types.AccountID
@@ -55,10 +75,14 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
for _, path := range paths { for _, path := range paths {
if strings.HasPrefix(req.URL.Path, path) { if strings.HasPrefix(req.URL.Path, path) {
target := m.Paths[path] pt := m.Paths[path]
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target) if pt == nil || pt.URL == nil {
p.logger.Warnf("invalid mapping for host: %s, path: %s (nil target)", host, path)
continue
}
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, pt.URL)
return targetResult{ return targetResult{
url: target, target: pt,
matchedPath: path, matchedPath: path,
serviceID: m.ID, serviceID: m.ID,
accountID: m.AccountID, accountID: m.AccountID,

View File

@@ -0,0 +1,32 @@
package roundtrip
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/proxy/internal/types"
)
func TestAccountIDContext(t *testing.T) {
t.Run("returns empty when missing", func(t *testing.T) {
assert.Equal(t, types.AccountID(""), AccountIDFromContext(context.Background()))
})
t.Run("round-trips value", func(t *testing.T) {
ctx := WithAccountID(context.Background(), "acc-123")
assert.Equal(t, types.AccountID("acc-123"), AccountIDFromContext(ctx))
})
}
func TestSkipTLSVerifyContext(t *testing.T) {
t.Run("false by default", func(t *testing.T) {
assert.False(t, skipTLSVerifyFromContext(context.Background()))
})
t.Run("true when set", func(t *testing.T) {
ctx := WithSkipTLSVerify(context.Background())
assert.True(t, skipTLSVerifyFromContext(ctx))
})
}

View File

@@ -2,6 +2,7 @@ package roundtrip
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@@ -52,6 +53,9 @@ type domainNotification struct {
type clientEntry struct { type clientEntry struct {
client *embed.Client client *embed.Client
transport *http.Transport transport *http.Transport
// insecureTransport is a clone of transport with TLS verification disabled,
// used when per-target skip_tls_verify is set.
insecureTransport *http.Transport
domains map[domain.Domain]domainInfo domains map[domain.Domain]domainInfo
createdAt time.Time createdAt time.Time
started bool started bool
@@ -130,6 +134,9 @@ type ClientDebugInfo struct {
// accountIDContextKey is the context key for storing the account ID. // accountIDContextKey is the context key for storing the account ID.
type accountIDContextKey struct{} type accountIDContextKey struct{}
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
type skipTLSVerifyContextKey struct{}
// AddPeer registers a domain for an account. If the account doesn't have a client yet, // AddPeer registers a domain for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token. // one is created by authenticating with the management server using the provided token.
// Multiple domains can share the same client. // Multiple domains can share the same client.
@@ -249,10 +256,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create a transport using the client dialer. We do this instead of using // Create a transport using the client dialer. We do this instead of using
// the client's HTTPClient to avoid issues with request validation that do // the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests. // not work with reverse proxied requests.
return &clientEntry{ transport := &http.Transport{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: &http.Transport{
DialContext: client.DialContext, DialContext: client.DialContext,
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns, MaxIdleConns: n.transportCfg.maxIdleConns,
@@ -265,7 +269,16 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
WriteBufferSize: n.transportCfg.writeBufferSize, WriteBufferSize: n.transportCfg.writeBufferSize,
ReadBufferSize: n.transportCfg.readBufferSize, ReadBufferSize: n.transportCfg.readBufferSize,
DisableCompression: n.transportCfg.disableCompression, DisableCompression: n.transportCfg.disableCompression,
}, }
insecureTransport := transport.Clone()
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: transport,
insecureTransport: insecureTransport,
createdAt: time.Now(), createdAt: time.Now(),
started: false, started: false,
inflightMap: make(map[backendKey]chan struct{}), inflightMap: make(map[backendKey]chan struct{}),
@@ -373,6 +386,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
client := entry.client client := entry.client
transport := entry.transport transport := entry.transport
insecureTransport := entry.insecureTransport
delete(n.clients, accountID) delete(n.clients, accountID)
n.clientsMux.Unlock() n.clientsMux.Unlock()
@@ -387,6 +401,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
} }
transport.CloseIdleConnections() transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil { if err := client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{ n.logger.WithFields(log.Fields{
@@ -415,6 +430,9 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
} }
client := entry.client client := entry.client
transport := entry.transport transport := entry.transport
if skipTLSVerifyFromContext(req.Context()) {
transport = entry.insecureTransport
}
n.clientsMux.RUnlock() n.clientsMux.RUnlock()
release, ok := entry.acquireInflight(req.URL.Host) release, ok := entry.acquireInflight(req.URL.Host)
@@ -457,6 +475,7 @@ func (n *NetBird) StopAll(ctx context.Context) error {
var merr *multierror.Error var merr *multierror.Error
for accountID, entry := range n.clients { for accountID, entry := range n.clients {
entry.transport.CloseIdleConnections() entry.transport.CloseIdleConnections()
entry.insecureTransport.CloseIdleConnections()
if err := entry.client.Stop(ctx); err != nil { if err := entry.client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{ n.logger.WithFields(log.Fields{
"account_id": accountID, "account_id": accountID,
@@ -579,3 +598,14 @@ func AccountIDFromContext(ctx context.Context) types.AccountID {
} }
return accountID return accountID
} }
// WithSkipTLSVerify marks the context to use an insecure transport that skips
// TLS certificate verification for the backend connection.
func WithSkipTLSVerify(ctx context.Context) context.Context {
return context.WithValue(ctx, skipTLSVerifyContextKey{}, true)
}
func skipTLSVerifyFromContext(ctx context.Context) bool {
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
return v
}

View File

@@ -116,6 +116,9 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
require.NoError(t, err) require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
// Create real users manager // Create real users manager
usersManager := users.NewManager(testStore) usersManager := users.NewManager(testStore)
@@ -131,6 +134,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
proxyService := nbgrpc.NewProxyServiceServer( proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{}, &testAccessLogManager{},
tokenStore, tokenStore,
pkceStore,
oidcConfig, oidcConfig,
nil, nil,
usersManager, usersManager,

View File

@@ -19,14 +19,17 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"path/filepath" "path/filepath"
"reflect"
"sync" "sync"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/pires/go-proxyproto" "github.com/pires/go-proxyproto"
"github.com/prometheus/client_golang/prometheus" prometheus2 "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/exporters/prometheus"
"go.opentelemetry.io/otel/sdk/metric"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@@ -42,7 +45,7 @@ import (
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/k8s" "github.com/netbirdio/netbird/proxy/internal/k8s"
"github.com/netbirdio/netbird/proxy/internal/metrics" proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/internal/types"
@@ -63,7 +66,7 @@ type Server struct {
debug *http.Server debug *http.Server
healthServer *health.Server healthServer *health.Server
healthChecker *health.Checker healthChecker *health.Checker
meter *metrics.Metrics meter *proxymetrics.Metrics
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades) // hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
// so they can be closed during graceful shutdown, since http.Server.Shutdown // so they can be closed during graceful shutdown, since http.Server.Shutdown
@@ -152,8 +155,19 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.initDefaults() s.initDefaults()
reg := prometheus.NewRegistry() exporter, err := prometheus.New()
s.meter = metrics.New(reg) if err != nil {
return fmt.Errorf("create prometheus exporter: %w", err)
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(Server{}).PkgPath()
meter := provider.Meter(pkg)
s.meter, err = proxymetrics.New(ctx, meter)
if err != nil {
return fmt.Errorf("create metrics: %w", err)
}
mgmtConn, err := s.dialManagement() mgmtConn, err := s.dialManagement()
if err != nil { if err != nil {
@@ -193,7 +207,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.startDebugEndpoint() s.startDebugEndpoint()
if err := s.startHealthServer(reg); err != nil { if err := s.startHealthServer(); err != nil {
return err return err
} }
@@ -284,12 +298,12 @@ func (s *Server) startDebugEndpoint() {
} }
// startHealthServer launches the health probe and metrics server. // startHealthServer launches the health probe and metrics server.
func (s *Server) startHealthServer(reg *prometheus.Registry) error { func (s *Server) startHealthServer() error {
healthAddr := s.HealthAddress healthAddr := s.HealthAddress
if healthAddr == "" { if healthAddr == "" {
healthAddr = defaultHealthAddr healthAddr = defaultHealthAddr
} }
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
healthListener, err := net.Listen("tcp", healthAddr) healthListener, err := net.Listen("tcp", healthAddr)
if err != nil { if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err) return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
@@ -423,7 +437,7 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
"acme_server": s.ACMEDirectory, "acme_server": s.ACMEDirectory,
"challenge_type": s.ACMEChallengeType, "challenge_type": s.ACMEChallengeType,
}).Debug("ACME certificates enabled, configuring certificate manager") }).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod) s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod, s.meter)
if s.ACMEChallengeType == "http-01" { if s.ACMEChallengeType == "http-01" {
s.http = &http.Server{ s.http = &http.Server{
@@ -720,7 +734,7 @@ func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping)
} }
func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping { func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
paths := make(map[string]*url.URL) paths := make(map[string]*proxy.PathTarget)
for _, pathMapping := range mapping.GetPath() { for _, pathMapping := range mapping.GetPath() {
targetURL, err := url.Parse(pathMapping.GetTarget()) targetURL, err := url.Parse(pathMapping.GetTarget())
if err != nil { if err != nil {
@@ -734,7 +748,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
}).WithError(err).Error("failed to parse target URL for path, skipping") }).WithError(err).Error("failed to parse target URL for path, skipping")
continue continue
} }
paths[pathMapping.GetPath()] = targetURL
pt := &proxy.PathTarget{URL: targetURL}
if opts := pathMapping.GetOptions(); opts != nil {
pt.SkipTLSVerify = opts.GetSkipTlsVerify()
pt.PathRewrite = protoToPathRewrite(opts.GetPathRewrite())
pt.CustomHeaders = opts.GetCustomHeaders()
if d := opts.GetRequestTimeout(); d != nil {
pt.RequestTimeout = d.AsDuration()
}
}
paths[pathMapping.GetPath()] = pt
} }
return proxy.Mapping{ return proxy.Mapping{
ID: mapping.GetId(), ID: mapping.GetId(),
@@ -746,6 +770,15 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
} }
} }
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {
switch mode {
case proto.PathRewriteMode_PATH_REWRITE_PRESERVE:
return proxy.PathRewritePreserve
default:
return proxy.PathRewriteDefault
}
}
// debugEndpointAddr returns the address for the debug endpoint. // debugEndpointAddr returns the address for the debug endpoint.
// If addr is empty, it defaults to localhost:8444 for security. // If addr is empty, it defaults to localhost:8444 for security.
func debugEndpointAddr(addr string) string { func debugEndpointAddr(addr string) string {

View File

@@ -0,0 +1,271 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var testServiceTarget = api.ServiceTarget{
TargetId: "peer-123",
TargetType: "peer",
Protocol: "https",
Port: 8443,
Enabled: true,
}
var testService = api.Service{
Id: "svc-1",
Name: "test-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Meta: api.ServiceMeta{
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
Status: "active",
},
Targets: []api.ServiceTarget{testServiceTarget},
}
func TestReverseProxyServices_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Service{testService})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.List(context.Background())
require.NoError(t, err)
require.Len(t, ret, 1)
assert.Equal(t, testService.Id, ret[0].Id)
assert.Equal(t, testService.Name, ret[0].Name)
})
}
func TestReverseProxyServices_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestReverseProxyServices_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testService)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1")
require.NoError(t, err)
assert.Equal(t, testService.Id, ret.Id)
assert.Equal(t, testService.Domain, ret.Domain)
})
}
func TestReverseProxyServices_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestReverseProxyServices_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.ServiceRequest
require.NoError(t, json.Unmarshal(reqBytes, &req))
assert.Equal(t, "test-service", req.Name)
assert.Equal(t, "test.example.com", req.Domain)
retBytes, _ := json.Marshal(testService)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
Name: "test-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{testServiceTarget},
})
require.NoError(t, err)
assert.Equal(t, testService.Id, ret.Id)
})
}
func TestReverseProxyServices_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
Name: "test-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{testServiceTarget},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestReverseProxyServices_Create_WithPerTargetOptions(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.ServiceRequest
require.NoError(t, json.Unmarshal(reqBytes, &req))
require.Len(t, req.Targets, 1)
target := req.Targets[0]
require.NotNil(t, target.Options, "options should be present")
opts := target.Options
require.NotNil(t, opts.SkipTlsVerify, "skip_tls_verify should be present")
assert.True(t, *opts.SkipTlsVerify)
require.NotNil(t, opts.RequestTimeout, "request_timeout should be present")
assert.Equal(t, "30s", *opts.RequestTimeout)
require.NotNil(t, opts.PathRewrite, "path_rewrite should be present")
assert.Equal(t, api.ServiceTargetOptionsPathRewrite("preserve"), *opts.PathRewrite)
require.NotNil(t, opts.CustomHeaders, "custom_headers should be present")
assert.Equal(t, "bar", (*opts.CustomHeaders)["X-Foo"])
retBytes, _ := json.Marshal(testService)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
pathRewrite := api.ServiceTargetOptionsPathRewrite("preserve")
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
Name: "test-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{
{
TargetId: "peer-123",
TargetType: "peer",
Protocol: "https",
Port: 8443,
Enabled: true,
Options: &api.ServiceTargetOptions{
SkipTlsVerify: ptr(true),
RequestTimeout: ptr("30s"),
PathRewrite: &pathRewrite,
CustomHeaders: &map[string]string{"X-Foo": "bar"},
},
},
},
})
require.NoError(t, err)
assert.Equal(t, testService.Id, ret.Id)
})
}
func TestReverseProxyServices_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.ServiceRequest
require.NoError(t, json.Unmarshal(reqBytes, &req))
assert.Equal(t, "updated-service", req.Name)
retBytes, _ := json.Marshal(testService)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{
Name: "updated-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{testServiceTarget},
})
require.NoError(t, err)
assert.Equal(t, testService.Id, ret.Id)
})
}
func TestReverseProxyServices_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{
Name: "updated-service",
Domain: "test.example.com",
Enabled: true,
Auth: api.ServiceAuthConfig{},
Targets: []api.ServiceTarget{testServiceTarget},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestReverseProxyServices_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.ReverseProxyServices.Delete(context.Background(), "svc-1")
require.NoError(t, err)
})
}
func TestReverseProxyServices_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.ReverseProxyServices.Delete(context.Background(), "svc-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}

View File

@@ -2822,6 +2822,16 @@ components:
type: string type: string
description: "City name from geolocation" description: "City name from geolocation"
example: "San Francisco" example: "San Francisco"
bytes_upload:
type: integer
format: int64
description: "Bytes uploaded (request body size)"
example: 1024
bytes_download:
type: integer
format: int64
description: "Bytes downloaded (response body size)"
example: 8192
required: required:
- id - id
- service_id - service_id
@@ -2831,6 +2841,8 @@ components:
- path - path
- duration_ms - duration_ms
- status_code - status_code
- bytes_upload
- bytes_download
ProxyAccessLogsResponse: ProxyAccessLogsResponse:
type: object type: object
properties: properties:
@@ -3027,6 +3039,28 @@ components:
- targets - targets
- auth - auth
- enabled - enabled
ServiceTargetOptions:
type: object
properties:
skip_tls_verify:
type: boolean
description: Skip TLS certificate verification for this backend
request_timeout:
type: string
description: Per-target response timeout as a Go duration string (e.g. "30s", "2m")
path_rewrite:
type: string
description: Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path.
enum: [preserve]
custom_headers:
type: object
description: Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected.
propertyNames:
type: string
pattern: '^[!#$%&''*+.^_`|~0-9A-Za-z-]+$'
additionalProperties:
type: string
pattern: '^[^\r\n]*$'
ServiceTarget: ServiceTarget:
type: object type: object
properties: properties:
@@ -3053,6 +3087,8 @@ components:
enabled: enabled:
type: boolean type: boolean
description: Whether this target is enabled description: Whether this target is enabled
options:
$ref: '#/components/schemas/ServiceTargetOptions'
required: required:
- target_id - target_id
- target_type - target_type

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v6.33.3 // protoc v6.33.0
// source: management.proto // source: management.proto
package proto package proto

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ package management;
option go_package = "/proto"; option go_package = "/proto";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";
// ProxyService - Management is the SERVER, Proxy is the CLIENT // ProxyService - Management is the SERVER, Proxy is the CLIENT
@@ -50,9 +51,22 @@ enum ProxyMappingUpdateType {
UPDATE_TYPE_REMOVED = 2; UPDATE_TYPE_REMOVED = 2;
} }
enum PathRewriteMode {
PATH_REWRITE_DEFAULT = 0;
PATH_REWRITE_PRESERVE = 1;
}
message PathTargetOptions {
bool skip_tls_verify = 1;
google.protobuf.Duration request_timeout = 2;
PathRewriteMode path_rewrite = 3;
map<string, string> custom_headers = 4;
}
message PathMapping { message PathMapping {
string path = 1; string path = 1;
string target = 2; string target = 2;
PathTargetOptions options = 3;
} }
message Authentication { message Authentication {
@@ -101,6 +115,8 @@ message AccessLog {
string auth_mechanism = 11; string auth_mechanism = 11;
string user_id = 12; string user_id = 12;
bool auth_success = 13; bool auth_success = 13;
int64 bytes_upload = 14;
int64 bytes_download = 15;
} }
message AuthenticateRequest { message AuthenticateRequest {