mirror of
https://github.com/netbirdio/netbird
synced 2026-04-22 17:44:57 +02:00
Compare commits
20 Commits
test/proxy
...
test/proxy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
373f014aea | ||
|
|
2df3fb959b | ||
|
|
14b0e9462b | ||
|
|
8db71b545e | ||
|
|
7c3532d8e5 | ||
|
|
1aa1eef2c5 | ||
|
|
d9418ddc1e | ||
|
|
1b4c831976 | ||
|
|
a19611d8e0 | ||
|
|
9ab6138040 | ||
|
|
30c02ab78c | ||
|
|
3acd86e346 | ||
|
|
c2fec57c0f | ||
|
|
5c20f13c48 | ||
|
|
e6587b071d | ||
|
|
85451ab4cd | ||
|
|
a7f3ba03eb | ||
|
|
4f0a3a77ad | ||
|
|
44655ca9b5 | ||
|
|
e601278117 |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
|
||||||
skip: go.mod,go.sum,**/proxy/web/**
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
51
.github/workflows/pr-title-check.yml
vendored
Normal file
51
.github/workflows/pr-title-check.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: PR Title Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, edited, synchronize, reopened]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-title:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Validate PR title prefix
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const title = context.payload.pull_request.title;
|
||||||
|
const allowedTags = [
|
||||||
|
'management',
|
||||||
|
'client',
|
||||||
|
'signal',
|
||||||
|
'proxy',
|
||||||
|
'relay',
|
||||||
|
'misc',
|
||||||
|
'infrastructure',
|
||||||
|
'self-hosted',
|
||||||
|
'doc',
|
||||||
|
];
|
||||||
|
|
||||||
|
const pattern = /^\[([^\]]+)\]\s+.+/;
|
||||||
|
const match = title.match(pattern);
|
||||||
|
|
||||||
|
if (!match) {
|
||||||
|
core.setFailed(
|
||||||
|
`PR title must start with a tag in brackets.\n` +
|
||||||
|
`Example: [client] fix something\n` +
|
||||||
|
`Allowed tags: ${allowedTags.join(', ')}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
|
||||||
|
|
||||||
|
const invalid = tags.filter(t => !allowedTags.includes(t));
|
||||||
|
if (invalid.length > 0) {
|
||||||
|
core.setFailed(
|
||||||
|
`Invalid tag(s): ${invalid.join(', ')}\n` +
|
||||||
|
`Allowed tags: ${allowedTags.join(', ')}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`Valid PR title tags: [${tags.join(', ')}]`);
|
||||||
@@ -849,14 +849,26 @@ func (s *Server) cleanupConnection() error {
|
|||||||
if s.actCancel == nil {
|
if s.actCancel == nil {
|
||||||
return ErrServiceNotUp
|
return ErrServiceNotUp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capture the engine reference before cancelling the context.
|
||||||
|
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||||
|
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||||
|
// to skip the engine shutdown entirely.
|
||||||
|
var engine *internal.Engine
|
||||||
|
if s.connectClient != nil {
|
||||||
|
engine = s.connectClient.Engine()
|
||||||
|
}
|
||||||
|
|
||||||
s.actCancel()
|
s.actCancel()
|
||||||
|
|
||||||
if s.connectClient == nil {
|
if s.connectClient == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.connectClient.Stop(); err != nil {
|
if engine != nil {
|
||||||
return err
|
if err := engine.Stop(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.connectClient = nil
|
s.connectClient = nil
|
||||||
|
|||||||
@@ -493,9 +493,6 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
|
|||||||
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
|
||||||
mgmt := cfg.Management
|
mgmt := cfg.Management
|
||||||
|
|
||||||
dnsDomain := mgmt.DnsDomain
|
|
||||||
singleAccModeDomain := dnsDomain
|
|
||||||
|
|
||||||
// Extract port from listen address
|
// Extract port from listen address
|
||||||
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
|
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -507,8 +504,9 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
|||||||
mgmtSrv := mgmtServer.NewServer(
|
mgmtSrv := mgmtServer.NewServer(
|
||||||
&mgmtServer.Config{
|
&mgmtServer.Config{
|
||||||
NbConfig: mgmtConfig,
|
NbConfig: mgmtConfig,
|
||||||
DNSDomain: dnsDomain,
|
DNSDomain: "",
|
||||||
MgmtSingleAccModeDomain: singleAccModeDomain,
|
MgmtSingleAccModeDomain: "",
|
||||||
|
AutoResolveDomains: true,
|
||||||
MgmtPort: mgmtPort,
|
MgmtPort: mgmtPort,
|
||||||
MgmtMetricsPort: cfg.Server.MetricsPort,
|
MgmtMetricsPort: cfg.Server.MetricsPort,
|
||||||
DisableMetrics: mgmt.DisableAnonymousMetrics,
|
DisableMetrics: mgmt.DisableAnonymousMetrics,
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ type AccessLogEntry struct {
|
|||||||
Reason string
|
Reason string
|
||||||
UserId string `gorm:"index"`
|
UserId string `gorm:"index"`
|
||||||
AuthMethodUsed string `gorm:"index"`
|
AuthMethodUsed string `gorm:"index"`
|
||||||
|
BytesUpload int64 `gorm:"index"`
|
||||||
|
BytesDownload int64 `gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
// FromProto creates an AccessLogEntry from a proto.AccessLog
|
||||||
@@ -39,6 +41,8 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
|
|||||||
a.UserId = serviceLog.GetUserId()
|
a.UserId = serviceLog.GetUserId()
|
||||||
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
|
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
|
||||||
a.AccountID = serviceLog.GetAccountId()
|
a.AccountID = serviceLog.GetAccountId()
|
||||||
|
a.BytesUpload = serviceLog.GetBytesUpload()
|
||||||
|
a.BytesDownload = serviceLog.GetBytesDownload()
|
||||||
|
|
||||||
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
|
||||||
if ip, err := netip.ParseAddr(sourceIP); err == nil {
|
if ip, err := netip.ParseAddr(sourceIP); err == nil {
|
||||||
@@ -101,5 +105,7 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
|
|||||||
AuthMethodUsed: authMethod,
|
AuthMethodUsed: authMethod,
|
||||||
CountryCode: countryCode,
|
CountryCode: countryCode,
|
||||||
CityName: cityName,
|
CityName: cityName,
|
||||||
|
BytesUpload: a.BytesUpload,
|
||||||
|
BytesDownload: a.BytesDownload,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,3 +15,12 @@ type Domain struct {
|
|||||||
Type Type `gorm:"-"`
|
Type Type `gorm:"-"`
|
||||||
Validated bool
|
Validated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EventMeta returns activity event metadata for a domain
|
||||||
|
func (d *Domain) EventMeta() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"domain": d.Domain,
|
||||||
|
"target_cluster": d.TargetCluster,
|
||||||
|
"validated": d.Validated,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
@@ -36,16 +38,16 @@ type Manager struct {
|
|||||||
validator domain.Validator
|
validator domain.Validator
|
||||||
proxyManager proxyManager
|
proxyManager proxyManager
|
||||||
permissionsManager permissions.Manager
|
permissionsManager permissions.Manager
|
||||||
|
accountManager account.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
|
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
|
||||||
return Manager{
|
return Manager{
|
||||||
store: store,
|
store: store,
|
||||||
proxyManager: proxyMgr,
|
proxyManager: proxyMgr,
|
||||||
validator: domain.Validator{
|
validator: domain.Validator{Resolver: net.DefaultResolver},
|
||||||
Resolver: net.DefaultResolver,
|
|
||||||
},
|
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
|
accountManager: accountManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +138,9 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return d, fmt.Errorf("create domain in store: %w", err)
|
return d, fmt.Errorf("create domain in store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, d.ID, accountID, activity.DomainAdded, d.EventMeta())
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,10 +153,18 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get domain from store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
|
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
|
||||||
// TODO: check for "no records" type error. Because that is a success condition.
|
// TODO: check for "no records" type error. Because that is a success condition.
|
||||||
return fmt.Errorf("delete domain from store: %w", err)
|
return fmt.Errorf("delete domain from store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(ctx, userID, domainID, accountID, activity.DomainDeleted, d.EventMeta())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,6 +231,8 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
|
|||||||
}).WithError(err).Error("update custom domain in store")
|
}).WithError(err).Error("update custom domain in store")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.accountManager.StoreEvent(context.Background(), userID, domainID, accountID, activity.DomainValidated, d.EventMeta())
|
||||||
} else {
|
} else {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"accountID": accountID,
|
"accountID": accountID,
|
||||||
|
|||||||
@@ -73,7 +73,10 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
service := new(rpservice.Service)
|
service := new(rpservice.Service)
|
||||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err = service.Validate(); err != nil {
|
if err = service.Validate(); err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
@@ -132,7 +135,10 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
service := new(rpservice.Service)
|
service := new(rpservice.Service)
|
||||||
service.ID = serviceID
|
service.ID = serviceID
|
||||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err = service.Validate(); err != nil {
|
if err = service.Validate(); err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||||
|
|||||||
@@ -36,11 +36,11 @@ func TestReapExpiredExposes(t *testing.T) {
|
|||||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||||
|
|
||||||
// Expired service should be deleted
|
// Expired service should be deleted
|
||||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||||
require.Error(t, err, "expired service should be deleted")
|
require.Error(t, err, "expired service should be deleted")
|
||||||
|
|
||||||
// Non-expired service should remain
|
// Non-expired service should remain
|
||||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain)
|
_, err = testStore.GetServiceByDomain(ctx, resp2.Domain)
|
||||||
require.NoError(t, err, "active service should remain")
|
require.NoError(t, err, "active service should remain")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,14 +191,14 @@ func TestReapSkipsRenewedService(t *testing.T) {
|
|||||||
// Reaper should skip it because the re-check sees a fresh timestamp
|
// Reaper should skip it because the re-check sees a fresh timestamp
|
||||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||||
|
|
||||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||||
require.NoError(t, err, "renewed service should survive reaping")
|
require.NoError(t, err, "renewed service should survive reaping")
|
||||||
}
|
}
|
||||||
|
|
||||||
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
|
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
|
||||||
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
|
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
svc, err := s.GetServiceByDomain(context.Background(), accountID, domain)
|
svc, err := s.GetServiceByDomain(context.Background(), domain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expired := time.Now().Add(-2 * exposeTTL)
|
expired := time.Now().Add(-2 * exposeTTL)
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
|||||||
|
|
||||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,7 +245,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
|||||||
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil {
|
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,8 +261,8 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
|
||||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
existingService, err := transaction.GetServiceByDomain(ctx, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||||
return fmt.Errorf("failed to check existing service: %w", err)
|
return fmt.Errorf("failed to check existing service: %w", err)
|
||||||
@@ -271,7 +271,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
|
|||||||
}
|
}
|
||||||
|
|
||||||
if existingService != nil && existingService.ID != excludeServiceID {
|
if existingService != nil && existingService.ID != excludeServiceID {
|
||||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
return status.Errorf(status.AlreadyExists, "domain already taken")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -352,7 +352,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -805,7 +805,7 @@ func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID,
|
|||||||
|
|
||||||
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
|
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
|
||||||
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
|
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
|
||||||
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
|
svc, err := m.store.GetServiceByDomain(ctx, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ func TestInitializeServiceForCreate(t *testing.T) {
|
|||||||
|
|
||||||
func TestCheckDomainAvailable(t *testing.T) {
|
func TestCheckDomainAvailable(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
accountID := "test-account"
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -88,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
|||||||
excludeServiceID: "",
|
excludeServiceID: "",
|
||||||
setupMock: func(ms *store.MockStore) {
|
setupMock: func(ms *store.MockStore) {
|
||||||
ms.EXPECT().
|
ms.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "available.com").
|
GetServiceByDomain(ctx, "available.com").
|
||||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||||
},
|
},
|
||||||
expectedError: false,
|
expectedError: false,
|
||||||
@@ -99,7 +98,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
|||||||
excludeServiceID: "",
|
excludeServiceID: "",
|
||||||
setupMock: func(ms *store.MockStore) {
|
setupMock: func(ms *store.MockStore) {
|
||||||
ms.EXPECT().
|
ms.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
GetServiceByDomain(ctx, "exists.com").
|
||||||
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||||
},
|
},
|
||||||
expectedError: true,
|
expectedError: true,
|
||||||
@@ -111,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
|||||||
excludeServiceID: "service-123",
|
excludeServiceID: "service-123",
|
||||||
setupMock: func(ms *store.MockStore) {
|
setupMock: func(ms *store.MockStore) {
|
||||||
ms.EXPECT().
|
ms.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
GetServiceByDomain(ctx, "exists.com").
|
||||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||||
},
|
},
|
||||||
expectedError: false,
|
expectedError: false,
|
||||||
@@ -122,7 +121,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
|||||||
excludeServiceID: "service-456",
|
excludeServiceID: "service-456",
|
||||||
setupMock: func(ms *store.MockStore) {
|
setupMock: func(ms *store.MockStore) {
|
||||||
ms.EXPECT().
|
ms.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
GetServiceByDomain(ctx, "exists.com").
|
||||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||||
},
|
},
|
||||||
expectedError: true,
|
expectedError: true,
|
||||||
@@ -134,7 +133,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
|||||||
excludeServiceID: "",
|
excludeServiceID: "",
|
||||||
setupMock: func(ms *store.MockStore) {
|
setupMock: func(ms *store.MockStore) {
|
||||||
ms.EXPECT().
|
ms.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "error.com").
|
GetServiceByDomain(ctx, "error.com").
|
||||||
Return(nil, errors.New("database error"))
|
Return(nil, errors.New("database error"))
|
||||||
},
|
},
|
||||||
expectedError: true,
|
expectedError: true,
|
||||||
@@ -150,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
|||||||
tt.setupMock(mockStore)
|
tt.setupMock(mockStore)
|
||||||
|
|
||||||
mgr := &Manager{}
|
mgr := &Manager{}
|
||||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
err := mgr.checkDomainAvailable(ctx, mockStore, tt.domain, tt.excludeServiceID)
|
||||||
|
|
||||||
if tt.expectedError {
|
if tt.expectedError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -168,7 +167,6 @@ func TestCheckDomainAvailable(t *testing.T) {
|
|||||||
|
|
||||||
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
accountID := "test-account"
|
|
||||||
|
|
||||||
t.Run("empty domain", func(t *testing.T) {
|
t.Run("empty domain", func(t *testing.T) {
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
@@ -176,11 +174,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
|||||||
|
|
||||||
mockStore := store.NewMockStore(ctrl)
|
mockStore := store.NewMockStore(ctrl)
|
||||||
mockStore.EXPECT().
|
mockStore.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "").
|
GetServiceByDomain(ctx, "").
|
||||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||||
|
|
||||||
mgr := &Manager{}
|
mgr := &Manager{}
|
||||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
err := mgr.checkDomainAvailable(ctx, mockStore, "", "")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
@@ -191,11 +189,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
|||||||
|
|
||||||
mockStore := store.NewMockStore(ctrl)
|
mockStore := store.NewMockStore(ctrl)
|
||||||
mockStore.EXPECT().
|
mockStore.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "test.com").
|
GetServiceByDomain(ctx, "test.com").
|
||||||
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
|
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||||
|
|
||||||
mgr := &Manager{}
|
mgr := &Manager{}
|
||||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
err := mgr.checkDomainAvailable(ctx, mockStore, "test.com", "")
|
||||||
|
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
sErr, ok := status.FromError(err)
|
sErr, ok := status.FromError(err)
|
||||||
@@ -209,11 +207,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
|||||||
|
|
||||||
mockStore := store.NewMockStore(ctrl)
|
mockStore := store.NewMockStore(ctrl)
|
||||||
mockStore.EXPECT().
|
mockStore.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
GetServiceByDomain(ctx, "nil.com").
|
||||||
Return(nil, nil)
|
Return(nil, nil)
|
||||||
|
|
||||||
mgr := &Manager{}
|
mgr := &Manager{}
|
||||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
err := mgr.checkDomainAvailable(ctx, mockStore, "nil.com", "")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
@@ -241,7 +239,7 @@ func TestPersistNewService(t *testing.T) {
|
|||||||
// Create another mock for the transaction
|
// Create another mock for the transaction
|
||||||
txMock := store.NewMockStore(ctrl)
|
txMock := store.NewMockStore(ctrl)
|
||||||
txMock.EXPECT().
|
txMock.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "new.com").
|
GetServiceByDomain(ctx, "new.com").
|
||||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||||
txMock.EXPECT().
|
txMock.EXPECT().
|
||||||
CreateService(ctx, service).
|
CreateService(ctx, service).
|
||||||
@@ -272,7 +270,7 @@ func TestPersistNewService(t *testing.T) {
|
|||||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||||
txMock := store.NewMockStore(ctrl)
|
txMock := store.NewMockStore(ctrl)
|
||||||
txMock.EXPECT().
|
txMock.EXPECT().
|
||||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
GetServiceByDomain(ctx, "existing.com").
|
||||||
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||||
|
|
||||||
return fn(txMock)
|
return fn(txMock)
|
||||||
@@ -425,8 +423,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
|
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||||
t.Cleanup(srv.Close)
|
require.NoError(t, err)
|
||||||
|
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -705,8 +704,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
|||||||
|
|
||||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
t.Cleanup(proxySrv.Close)
|
require.NoError(t, err)
|
||||||
|
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||||
|
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -814,7 +814,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
|||||||
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
|
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
|
||||||
|
|
||||||
// Verify service is persisted in store
|
// Verify service is persisted in store
|
||||||
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
persisted, err := testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, resp.Domain, persisted.Domain)
|
assert.Equal(t, resp.Domain, persisted.Domain)
|
||||||
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
|
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
|
||||||
@@ -977,7 +977,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify service is deleted
|
// Verify service is deleted
|
||||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||||
require.Error(t, err, "service should be deleted")
|
require.Error(t, err, "service should be deleted")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1012,7 +1012,7 @@ func TestStopServiceFromPeer(t *testing.T) {
|
|||||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||||
require.Error(t, err, "service should be deleted")
|
require.Error(t, err, "service should be deleted")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1031,7 +1031,7 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, int64(1), count, "one ephemeral service should exist after create")
|
assert.Equal(t, int64(1), count, "one ephemeral service should exist after create")
|
||||||
|
|
||||||
svc, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
|
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
|
||||||
@@ -1136,8 +1136,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
|||||||
|
|
||||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
t.Cleanup(proxySrv.Close)
|
require.NoError(t, err)
|
||||||
|
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||||
|
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -6,13 +6,16 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||||
@@ -49,17 +52,25 @@ const (
|
|||||||
SourceEphemeral = "ephemeral"
|
SourceEphemeral = "ephemeral"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type TargetOptions struct {
|
||||||
|
SkipTLSVerify bool `json:"skip_tls_verify"`
|
||||||
|
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
|
||||||
|
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||||
|
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type Target struct {
|
type Target struct {
|
||||||
ID uint `gorm:"primaryKey" json:"-"`
|
ID uint `gorm:"primaryKey" json:"-"`
|
||||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||||
Path *string `json:"path,omitempty"`
|
Path *string `json:"path,omitempty"`
|
||||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||||
Port int `gorm:"index:idx_target_port" json:"port"`
|
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||||
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||||
|
Options TargetOptions `gorm:"embedded" json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PasswordAuthConfig struct {
|
type PasswordAuthConfig struct {
|
||||||
@@ -123,7 +134,7 @@ type Service struct {
|
|||||||
ID string `gorm:"primaryKey"`
|
ID string `gorm:"primaryKey"`
|
||||||
AccountID string `gorm:"index"`
|
AccountID string `gorm:"index"`
|
||||||
Name string
|
Name string
|
||||||
Domain string `gorm:"index"`
|
Domain string `gorm:"type:varchar(255);uniqueIndex"`
|
||||||
ProxyCluster string `gorm:"index"`
|
ProxyCluster string `gorm:"index"`
|
||||||
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||||
Enabled bool
|
Enabled bool
|
||||||
@@ -194,7 +205,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
|||||||
// Convert internal targets to API targets
|
// Convert internal targets to API targets
|
||||||
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||||
for _, target := range s.Targets {
|
for _, target := range s.Targets {
|
||||||
apiTargets = append(apiTargets, api.ServiceTarget{
|
st := api.ServiceTarget{
|
||||||
Path: target.Path,
|
Path: target.Path,
|
||||||
Host: &target.Host,
|
Host: &target.Host,
|
||||||
Port: target.Port,
|
Port: target.Port,
|
||||||
@@ -202,7 +213,9 @@ func (s *Service) ToAPIResponse() *api.Service {
|
|||||||
TargetId: target.TargetId,
|
TargetId: target.TargetId,
|
||||||
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||||
Enabled: target.Enabled,
|
Enabled: target.Enabled,
|
||||||
})
|
}
|
||||||
|
st.Options = targetOptionsToAPI(target.Options)
|
||||||
|
apiTargets = append(apiTargets, st)
|
||||||
}
|
}
|
||||||
|
|
||||||
meta := api.ServiceMeta{
|
meta := api.ServiceMeta{
|
||||||
@@ -256,10 +269,14 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
|||||||
if target.Path != nil {
|
if target.Path != nil {
|
||||||
path = *target.Path
|
path = *target.Path
|
||||||
}
|
}
|
||||||
pathMappings = append(pathMappings, &proto.PathMapping{
|
|
||||||
|
pm := &proto.PathMapping{
|
||||||
Path: path,
|
Path: path,
|
||||||
Target: targetURL.String(),
|
Target: targetURL.String(),
|
||||||
})
|
}
|
||||||
|
|
||||||
|
pm.Options = targetOptionsToProto(target.Options)
|
||||||
|
pathMappings = append(pathMappings, pm)
|
||||||
}
|
}
|
||||||
|
|
||||||
auth := &proto.Authentication{
|
auth := &proto.Authentication{
|
||||||
@@ -312,13 +329,87 @@ func isDefaultPort(scheme string, port int) bool {
|
|||||||
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
// PathRewriteMode controls how the request path is rewritten before forwarding.
|
||||||
|
type PathRewriteMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PathRewritePreserve PathRewriteMode = "preserve"
|
||||||
|
)
|
||||||
|
|
||||||
|
func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
||||||
|
switch mode {
|
||||||
|
case PathRewritePreserve:
|
||||||
|
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
|
||||||
|
default:
|
||||||
|
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||||
|
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiOpts := &api.ServiceTargetOptions{}
|
||||||
|
if opts.SkipTLSVerify {
|
||||||
|
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
|
||||||
|
}
|
||||||
|
if opts.RequestTimeout != 0 {
|
||||||
|
s := opts.RequestTimeout.String()
|
||||||
|
apiOpts.RequestTimeout = &s
|
||||||
|
}
|
||||||
|
if opts.PathRewrite != "" {
|
||||||
|
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
|
||||||
|
apiOpts.PathRewrite = &pr
|
||||||
|
}
|
||||||
|
if len(opts.CustomHeaders) > 0 {
|
||||||
|
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||||
|
}
|
||||||
|
return apiOpts
|
||||||
|
}
|
||||||
|
|
||||||
|
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||||
|
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
popts := &proto.PathTargetOptions{
|
||||||
|
SkipTlsVerify: opts.SkipTLSVerify,
|
||||||
|
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||||
|
CustomHeaders: opts.CustomHeaders,
|
||||||
|
}
|
||||||
|
if opts.RequestTimeout != 0 {
|
||||||
|
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||||
|
}
|
||||||
|
return popts
|
||||||
|
}
|
||||||
|
|
||||||
|
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
|
||||||
|
var opts TargetOptions
|
||||||
|
if o.SkipTlsVerify != nil {
|
||||||
|
opts.SkipTLSVerify = *o.SkipTlsVerify
|
||||||
|
}
|
||||||
|
if o.RequestTimeout != nil {
|
||||||
|
d, err := time.ParseDuration(*o.RequestTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
|
||||||
|
}
|
||||||
|
opts.RequestTimeout = d
|
||||||
|
}
|
||||||
|
if o.PathRewrite != nil {
|
||||||
|
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
|
||||||
|
}
|
||||||
|
if o.CustomHeaders != nil {
|
||||||
|
opts.CustomHeaders = *o.CustomHeaders
|
||||||
|
}
|
||||||
|
return opts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
|
||||||
s.Name = req.Name
|
s.Name = req.Name
|
||||||
s.Domain = req.Domain
|
s.Domain = req.Domain
|
||||||
s.AccountID = accountID
|
s.AccountID = accountID
|
||||||
|
|
||||||
targets := make([]*Target, 0, len(req.Targets))
|
targets := make([]*Target, 0, len(req.Targets))
|
||||||
for _, apiTarget := range req.Targets {
|
for i, apiTarget := range req.Targets {
|
||||||
target := &Target{
|
target := &Target{
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Path: apiTarget.Path,
|
Path: apiTarget.Path,
|
||||||
@@ -331,6 +422,13 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
|||||||
if apiTarget.Host != nil {
|
if apiTarget.Host != nil {
|
||||||
target.Host = *apiTarget.Host
|
target.Host = *apiTarget.Host
|
||||||
}
|
}
|
||||||
|
if apiTarget.Options != nil {
|
||||||
|
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
target.Options = opts
|
||||||
|
}
|
||||||
targets = append(targets, target)
|
targets = append(targets, target)
|
||||||
}
|
}
|
||||||
s.Targets = targets
|
s.Targets = targets
|
||||||
@@ -368,6 +466,8 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
|||||||
}
|
}
|
||||||
s.Auth.BearerAuth = bearerAuth
|
s.Auth.BearerAuth = bearerAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) Validate() error {
|
func (s *Service) Validate() error {
|
||||||
@@ -400,11 +500,113 @@ func (s *Service) Validate() error {
|
|||||||
if target.TargetId == "" {
|
if target.TargetId == "" {
|
||||||
return fmt.Errorf("target %d has empty target_id", i)
|
return fmt.Errorf("target %d has empty target_id", i)
|
||||||
}
|
}
|
||||||
|
if err := validateTargetOptions(i, &target.Options); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxRequestTimeout = 5 * time.Minute
|
||||||
|
maxCustomHeaders = 16
|
||||||
|
maxHeaderKeyLen = 128
|
||||||
|
maxHeaderValueLen = 4096
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
|
||||||
|
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)
|
||||||
|
|
||||||
|
// hopByHopHeaders are headers that must not be set as custom headers
|
||||||
|
// because they are connection-level and stripped by the proxy.
|
||||||
|
var hopByHopHeaders = map[string]struct{}{
|
||||||
|
"Connection": {},
|
||||||
|
"Keep-Alive": {},
|
||||||
|
"Proxy-Authenticate": {},
|
||||||
|
"Proxy-Authorization": {},
|
||||||
|
"Proxy-Connection": {},
|
||||||
|
"Te": {},
|
||||||
|
"Trailer": {},
|
||||||
|
"Transfer-Encoding": {},
|
||||||
|
"Upgrade": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
|
||||||
|
// and cannot be overridden.
|
||||||
|
var reservedHeaders = map[string]struct{}{
|
||||||
|
"Content-Length": {},
|
||||||
|
"Content-Type": {},
|
||||||
|
"Cookie": {},
|
||||||
|
"Forwarded": {},
|
||||||
|
"X-Forwarded-For": {},
|
||||||
|
"X-Forwarded-Host": {},
|
||||||
|
"X-Forwarded-Port": {},
|
||||||
|
"X-Forwarded-Proto": {},
|
||||||
|
"X-Real-Ip": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateTargetOptions(idx int, opts *TargetOptions) error {
|
||||||
|
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
|
||||||
|
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.RequestTimeout != 0 {
|
||||||
|
if opts.RequestTimeout <= 0 {
|
||||||
|
return fmt.Errorf("target %d: request_timeout must be positive", idx)
|
||||||
|
}
|
||||||
|
if opts.RequestTimeout > maxRequestTimeout {
|
||||||
|
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCustomHeaders(idx int, headers map[string]string) error {
|
||||||
|
if len(headers) > maxCustomHeaders {
|
||||||
|
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
|
||||||
|
}
|
||||||
|
seen := make(map[string]string, len(headers))
|
||||||
|
for key, value := range headers {
|
||||||
|
if !httpHeaderNameRe.MatchString(key) {
|
||||||
|
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
|
||||||
|
}
|
||||||
|
if len(key) > maxHeaderKeyLen {
|
||||||
|
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
|
||||||
|
}
|
||||||
|
if len(value) > maxHeaderValueLen {
|
||||||
|
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
|
||||||
|
}
|
||||||
|
if containsCRLF(key) || containsCRLF(value) {
|
||||||
|
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
|
||||||
|
}
|
||||||
|
canonical := http.CanonicalHeaderKey(key)
|
||||||
|
if prev, ok := seen[canonical]; ok {
|
||||||
|
return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical)
|
||||||
|
}
|
||||||
|
seen[canonical] = key
|
||||||
|
if _, ok := hopByHopHeaders[canonical]; ok {
|
||||||
|
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
|
||||||
|
}
|
||||||
|
if _, ok := reservedHeaders[canonical]; ok {
|
||||||
|
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
|
||||||
|
}
|
||||||
|
if canonical == "Host" {
|
||||||
|
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsCRLF(s string) bool {
|
||||||
|
return strings.ContainsAny(s, "\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) EventMeta() map[string]any {
|
func (s *Service) EventMeta() map[string]any {
|
||||||
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
|
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
|
||||||
}
|
}
|
||||||
@@ -417,6 +619,12 @@ func (s *Service) Copy() *Service {
|
|||||||
targets := make([]*Target, len(s.Targets))
|
targets := make([]*Target, len(s.Targets))
|
||||||
for i, target := range s.Targets {
|
for i, target := range s.Targets {
|
||||||
targetCopy := *target
|
targetCopy := *target
|
||||||
|
if len(target.Options.CustomHeaders) > 0 {
|
||||||
|
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
|
||||||
|
for k, v := range target.Options.CustomHeaders {
|
||||||
|
targetCopy.Options.CustomHeaders[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
targets[i] = &targetCopy
|
targets[i] = &targetCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -87,6 +88,188 @@ func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
|||||||
assert.Contains(t, err.Error(), "empty target_id")
|
assert.Contains(t, err.Error(), "empty target_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mode PathRewriteMode
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{"empty is default", "", ""},
|
||||||
|
{"preserve is valid", PathRewritePreserve, ""},
|
||||||
|
{"unknown rejected", "regex", "unknown path_rewrite mode"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.PathRewrite = tt.mode
|
||||||
|
err := rp.Validate()
|
||||||
|
if tt.wantErr == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.ErrorContains(t, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
timeout time.Duration
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{"valid 30s", 30 * time.Second, ""},
|
||||||
|
{"valid 2m", 2 * time.Minute, ""},
|
||||||
|
{"zero is fine", 0, ""},
|
||||||
|
{"negative", -1 * time.Second, "must be positive"},
|
||||||
|
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.RequestTimeout = tt.timeout
|
||||||
|
err := rp.Validate()
|
||||||
|
if tt.wantErr == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.ErrorContains(t, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
|
||||||
|
t.Run("valid headers", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||||
|
"X-Custom": "value",
|
||||||
|
"X-Trace": "abc123",
|
||||||
|
}
|
||||||
|
assert.NoError(t, rp.Validate())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CRLF in key", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CRLF in value", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "invalid characters")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("hop-by-hop header rejected", func(t *testing.T) {
|
||||||
|
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reserved header rejected", func(t *testing.T) {
|
||||||
|
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Host header rejected", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("too many headers", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
headers := make(map[string]string, 17)
|
||||||
|
for i := range 17 {
|
||||||
|
headers[fmt.Sprintf("X-H%d", i)] = "v"
|
||||||
|
}
|
||||||
|
rp.Targets[0].Options.CustomHeaders = headers
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("key too long", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "key")
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("value too long", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
|
||||||
|
rp := validProxy()
|
||||||
|
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||||
|
"x-custom": "a",
|
||||||
|
"X-Custom": "b",
|
||||||
|
}
|
||||||
|
assert.ErrorContains(t, rp.Validate(), "collide")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_TargetOptions(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "svc-1",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{
|
||||||
|
TargetId: "peer-1",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: "10.0.0.1",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
Enabled: true,
|
||||||
|
Options: TargetOptions{
|
||||||
|
SkipTLSVerify: true,
|
||||||
|
RequestTimeout: 30 * time.Second,
|
||||||
|
PathRewrite: PathRewritePreserve,
|
||||||
|
CustomHeaders: map[string]string{"X-Custom": "val"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||||
|
require.Len(t, pm.Path, 1)
|
||||||
|
|
||||||
|
opts := pm.Path[0].Options
|
||||||
|
require.NotNil(t, opts, "options should be populated")
|
||||||
|
assert.True(t, opts.SkipTlsVerify)
|
||||||
|
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
|
||||||
|
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
|
||||||
|
require.NotNil(t, opts.RequestTimeout)
|
||||||
|
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
|
||||||
|
rp := &Service{
|
||||||
|
ID: "svc-1",
|
||||||
|
AccountID: "acc-1",
|
||||||
|
Domain: "example.com",
|
||||||
|
Targets: []*Target{
|
||||||
|
{
|
||||||
|
TargetId: "peer-1",
|
||||||
|
TargetType: TargetTypePeer,
|
||||||
|
Host: "10.0.0.1",
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||||
|
require.Len(t, pm.Path, 1)
|
||||||
|
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
|
||||||
|
}
|
||||||
|
|
||||||
func TestIsDefaultPort(t *testing.T) {
|
func TestIsDefaultPort(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
scheme string
|
scheme string
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
|
|
||||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||||
s.AfterInit(func(s *BaseServer) {
|
s.AfterInit(func(s *BaseServer) {
|
||||||
proxyService.SetServiceManager(s.ServiceManager())
|
proxyService.SetServiceManager(s.ServiceManager())
|
||||||
proxyService.SetProxyController(s.ServiceProxyController())
|
proxyService.SetProxyController(s.ServiceProxyController())
|
||||||
@@ -203,6 +203,16 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
|
||||||
|
return Create(s, func() *nbgrpc.PKCEVerifierStore {
|
||||||
|
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create PKCE verifier store: %v", err)
|
||||||
|
}
|
||||||
|
return pkceStore
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||||
return Create(s, func() accesslogs.Manager {
|
return Create(s, func() accesslogs.Manager {
|
||||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||||
return Create(s, func() *manager.Manager {
|
return Create(s, func() *manager.Manager {
|
||||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager())
|
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
|
||||||
return &m
|
return &m
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,9 +28,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
const (
|
||||||
// It is used for backward compatibility now.
|
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||||
const ManagementLegacyPort = 33073
|
// It is used for backward compatibility now.
|
||||||
|
ManagementLegacyPort = 33073
|
||||||
|
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
|
||||||
|
DefaultSelfHostedDomain = "netbird.selfhosted"
|
||||||
|
)
|
||||||
|
|
||||||
type Server interface {
|
type Server interface {
|
||||||
Start(ctx context.Context) error
|
Start(ctx context.Context) error
|
||||||
@@ -58,6 +62,7 @@ type BaseServer struct {
|
|||||||
mgmtMetricsPort int
|
mgmtMetricsPort int
|
||||||
mgmtPort int
|
mgmtPort int
|
||||||
disableLegacyManagementPort bool
|
disableLegacyManagementPort bool
|
||||||
|
autoResolveDomains bool
|
||||||
|
|
||||||
proxyAuthClose func()
|
proxyAuthClose func()
|
||||||
|
|
||||||
@@ -81,6 +86,7 @@ type Config struct {
|
|||||||
DisableMetrics bool
|
DisableMetrics bool
|
||||||
DisableGeoliteUpdate bool
|
DisableGeoliteUpdate bool
|
||||||
UserDeleteFromIDPEnabled bool
|
UserDeleteFromIDPEnabled bool
|
||||||
|
AutoResolveDomains bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer initializes and configures a new Server instance
|
// NewServer initializes and configures a new Server instance
|
||||||
@@ -96,6 +102,7 @@ func NewServer(cfg *Config) *BaseServer {
|
|||||||
mgmtPort: cfg.MgmtPort,
|
mgmtPort: cfg.MgmtPort,
|
||||||
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
|
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
|
||||||
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
mgmtMetricsPort: cfg.MgmtMetricsPort,
|
||||||
|
autoResolveDomains: cfg.AutoResolveDomains,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,6 +116,10 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
s.cancel = cancel
|
s.cancel = cancel
|
||||||
s.errCh = make(chan error, 4)
|
s.errCh = make(chan error, 4)
|
||||||
|
|
||||||
|
if s.autoResolveDomains {
|
||||||
|
s.resolveDomains(srvCtx)
|
||||||
|
}
|
||||||
|
|
||||||
s.PeersManager()
|
s.PeersManager()
|
||||||
s.GeoLocationManager()
|
s.GeoLocationManager()
|
||||||
|
|
||||||
@@ -237,7 +248,6 @@ func (s *BaseServer) Stop() error {
|
|||||||
_ = s.certManager.Listener().Close()
|
_ = s.certManager.Listener().Close()
|
||||||
}
|
}
|
||||||
s.GRPCServer().Stop()
|
s.GRPCServer().Stop()
|
||||||
s.ReverseProxyGRPCServer().Close()
|
|
||||||
if s.proxyAuthClose != nil {
|
if s.proxyAuthClose != nil {
|
||||||
s.proxyAuthClose()
|
s.proxyAuthClose()
|
||||||
s.proxyAuthClose = nil
|
s.proxyAuthClose = nil
|
||||||
@@ -381,6 +391,60 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
|
||||||
|
// Fresh installs use the default self-hosted domain, while existing installs reuse the
|
||||||
|
// persisted account domain to keep addressing stable across config changes.
|
||||||
|
func (s *BaseServer) resolveDomains(ctx context.Context) {
|
||||||
|
st := s.Store()
|
||||||
|
|
||||||
|
setDefault := func(logMsg string, args ...any) {
|
||||||
|
if logMsg != "" {
|
||||||
|
log.WithContext(ctx).Warnf(logMsg, args...)
|
||||||
|
}
|
||||||
|
s.dnsDomain = DefaultSelfHostedDomain
|
||||||
|
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
accountsCount, err := st.GetAccountsCounter(ctx)
|
||||||
|
if err != nil {
|
||||||
|
setDefault("resolve domains: failed to read accounts counter: %v; using default domain %q", err, DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountsCount == 0 {
|
||||||
|
s.dnsDomain = DefaultSelfHostedDomain
|
||||||
|
s.mgmtSingleAccModeDomain = DefaultSelfHostedDomain
|
||||||
|
log.WithContext(ctx).Infof("resolve domains: fresh install detected, using default domain %q", DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, err := st.GetAnyAccountID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
setDefault("resolve domains: failed to get existing account ID: %v; using default domain %q", err, DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountID == "" {
|
||||||
|
setDefault("resolve domains: empty account ID returned for existing accounts; using default domain %q", DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, _, err := st.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
setDefault("resolve domains: failed to get account domain for account %q: %v; using default domain %q", accountID, err, DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if domain == "" {
|
||||||
|
setDefault("resolve domains: account %q has empty domain; using default domain %q", accountID, DefaultSelfHostedDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.dnsDomain = domain
|
||||||
|
s.mgmtSingleAccModeDomain = domain
|
||||||
|
log.WithContext(ctx).Infof("resolve domains: using persisted account domain %q", domain)
|
||||||
|
}
|
||||||
|
|
||||||
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
|
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
|
||||||
installationID := store.GetInstallationID()
|
installationID := store.GetInstallationID()
|
||||||
if installationID != "" {
|
if installationID != "" {
|
||||||
|
|||||||
63
management/internals/server/server_resolve_domains_test.go
Normal file
63
management/internals/server/server_resolve_domains_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), nil)
|
||||||
|
|
||||||
|
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||||
|
Inject[store.Store](srv, mockStore)
|
||||||
|
|
||||||
|
srv.resolveDomains(context.Background())
|
||||||
|
|
||||||
|
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||||
|
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(1), nil)
|
||||||
|
mockStore.EXPECT().GetAnyAccountID(gomock.Any()).Return("acc-1", nil)
|
||||||
|
mockStore.EXPECT().GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "acc-1").Return("vpn.mycompany.com", "", nil)
|
||||||
|
|
||||||
|
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||||
|
Inject[store.Store](srv, mockStore)
|
||||||
|
|
||||||
|
srv.resolveDomains(context.Background())
|
||||||
|
|
||||||
|
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
|
||||||
|
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetAccountsCounter(gomock.Any()).Return(int64(0), errors.New("db failed"))
|
||||||
|
|
||||||
|
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||||
|
Inject[store.Store](srv, mockStore)
|
||||||
|
|
||||||
|
srv.resolveDomains(context.Background())
|
||||||
|
|
||||||
|
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||||
|
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||||
|
}
|
||||||
61
management/internals/shared/grpc/pkce_verifier.go
Normal file
61
management/internals/shared/grpc/pkce_verifier.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/eko/gocache/lib/v4/cache"
|
||||||
|
"github.com/eko/gocache/lib/v4/store"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
|
||||||
|
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||||
|
type PKCEVerifierStore struct {
|
||||||
|
cache *cache.Cache[string]
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
|
||||||
|
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
|
||||||
|
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PKCEVerifierStore{
|
||||||
|
cache: cache.New[string](cacheStore),
|
||||||
|
ctx: ctx,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store saves a PKCE verifier associated with an OAuth state parameter.
|
||||||
|
// The verifier is stored with the specified TTL and will be automatically deleted after expiration.
|
||||||
|
func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error {
|
||||||
|
if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil {
|
||||||
|
return fmt.Errorf("failed to store PKCE verifier: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAndDelete retrieves and removes a PKCE verifier for the given state.
|
||||||
|
// Returns the verifier and true if found, or empty string and false if not found.
|
||||||
|
// This enforces single-use semantics for PKCE verifiers.
|
||||||
|
func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) {
|
||||||
|
verifier, err := s.cache.Get(s.ctx, state)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("PKCE verifier not found for state")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.cache.Delete(s.ctx, state); err != nil {
|
||||||
|
log.Warnf("Failed to delete PKCE verifier for state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return verifier, true
|
||||||
|
}
|
||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/peer"
|
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
@@ -83,20 +82,12 @@ type ProxyServiceServer struct {
|
|||||||
// OIDC configuration for proxy authentication
|
// OIDC configuration for proxy authentication
|
||||||
oidcConfig ProxyOIDCConfig
|
oidcConfig ProxyOIDCConfig
|
||||||
|
|
||||||
// TODO: use database to store these instead?
|
// Store for PKCE verifiers
|
||||||
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state.
|
pkceVerifierStore *PKCEVerifierStore
|
||||||
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
|
|
||||||
pkceVerifiers sync.Map
|
|
||||||
pkceCleanupCancel context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const pkceVerifierTTL = 10 * time.Minute
|
const pkceVerifierTTL = 10 * time.Minute
|
||||||
|
|
||||||
type pkceEntry struct {
|
|
||||||
verifier string
|
|
||||||
createdAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxyConnection represents a connected proxy
|
// proxyConnection represents a connected proxy
|
||||||
type proxyConnection struct {
|
type proxyConnection struct {
|
||||||
proxyID string
|
proxyID string
|
||||||
@@ -108,42 +99,21 @@ type proxyConnection struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewProxyServiceServer creates a new proxy service server.
|
// NewProxyServiceServer creates a new proxy service server.
|
||||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx := context.Background()
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
accessLogManager: accessLogMgr,
|
accessLogManager: accessLogMgr,
|
||||||
oidcConfig: oidcConfig,
|
oidcConfig: oidcConfig,
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
peersManager: peersManager,
|
peersManager: peersManager,
|
||||||
usersManager: usersManager,
|
usersManager: usersManager,
|
||||||
proxyManager: proxyMgr,
|
proxyManager: proxyMgr,
|
||||||
pkceCleanupCancel: cancel,
|
|
||||||
}
|
}
|
||||||
go s.cleanupPKCEVerifiers(ctx)
|
|
||||||
go s.cleanupStaleProxies(ctx)
|
go s.cleanupStaleProxies(ctx)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
|
|
||||||
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(pkceVerifierTTL)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
now := time.Now()
|
|
||||||
s.pkceVerifiers.Range(func(key, value any) bool {
|
|
||||||
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
|
|
||||||
s.pkceVerifiers.Delete(key)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
|
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
|
||||||
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
@@ -160,11 +130,6 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close stops background goroutines.
|
|
||||||
func (s *ProxyServiceServer) Close() {
|
|
||||||
s.pkceCleanupCancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||||
s.serviceManager = manager
|
s.serviceManager = manager
|
||||||
}
|
}
|
||||||
@@ -177,11 +142,7 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
|
|||||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||||
ctx := stream.Context()
|
ctx := stream.Context()
|
||||||
|
|
||||||
peerInfo := ""
|
peerInfo := PeerIPFromContext(ctx)
|
||||||
if p, ok := peer.FromContext(ctx); ok {
|
|
||||||
peerInfo = p.Addr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("New proxy connection from %s", peerInfo)
|
log.Infof("New proxy connection from %s", peerInfo)
|
||||||
|
|
||||||
proxyID := req.GetProxyId()
|
proxyID := req.GetProxyId()
|
||||||
@@ -795,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
|||||||
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
|
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
|
||||||
|
|
||||||
codeVerifier := oauth2.GenerateVerifier()
|
codeVerifier := oauth2.GenerateVerifier()
|
||||||
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
|
if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
|
||||||
|
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.GetOIDCURLResponse{
|
return &proto.GetOIDCURLResponse{
|
||||||
Url: (&oauth2.Config{
|
Url: (&oauth2.Config{
|
||||||
@@ -832,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
|
|||||||
// ValidateState validates the state parameter from an OAuth callback.
|
// ValidateState validates the state parameter from an OAuth callback.
|
||||||
// Returns the original redirect URL if valid, or an error if invalid.
|
// Returns the original redirect URL if valid, or an error if invalid.
|
||||||
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
|
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
|
||||||
v, ok := s.pkceVerifiers.LoadAndDelete(state)
|
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", "", errors.New("no verifier for state")
|
return "", "", errors.New("no verifier for state")
|
||||||
}
|
}
|
||||||
entry, ok := v.(pkceEntry)
|
|
||||||
if !ok {
|
|
||||||
return "", "", errors.New("invalid verifier for state")
|
|
||||||
}
|
|
||||||
if time.Since(entry.createdAt) > pkceVerifierTTL {
|
|
||||||
return "", "", errors.New("PKCE verifier expired")
|
|
||||||
}
|
|
||||||
verifier = entry.verifier
|
|
||||||
|
|
||||||
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
|
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
|
||||||
parts := strings.Split(state, "|")
|
parts := strings.Split(state, "|")
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInter
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
|
||||||
clientIP := peerIPFromContext(ctx)
|
clientIP := PeerIPFromContext(ctx)
|
||||||
|
|
||||||
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
|
||||||
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
|
||||||
|
|||||||
@@ -115,9 +115,9 @@ func (l *authFailureLimiter) stop() {
|
|||||||
l.cancel()
|
l.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// peerIPFromContext extracts the client IP from the gRPC context.
|
// PeerIPFromContext extracts the client IP from the gRPC context.
|
||||||
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
|
||||||
func peerIPFromContext(ctx context.Context) clientIP {
|
func PeerIPFromContext(ctx context.Context) string {
|
||||||
if addr, ok := realip.FromContext(ctx); ok {
|
if addr, ok := realip.FromContext(ctx); ok {
|
||||||
return addr.String()
|
return addr.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,11 +5,10 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -94,11 +93,16 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
ctx := context.Background()
|
||||||
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
s.SetProxyController(newTestProxyController())
|
s.SetProxyController(newTestProxyController())
|
||||||
|
|
||||||
@@ -151,11 +155,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
ctx := context.Background()
|
||||||
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
s.SetProxyController(newTestProxyController())
|
s.SetProxyController(newTestProxyController())
|
||||||
|
|
||||||
@@ -185,11 +194,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
ctx := context.Background()
|
||||||
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
s.SetProxyController(newTestProxyController())
|
s.SetProxyController(newTestProxyController())
|
||||||
|
|
||||||
@@ -241,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOAuthState_NeverTheSame(t *testing.T) {
|
func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
oidcConfig: ProxyOIDCConfig{
|
oidcConfig: ProxyOIDCConfig{
|
||||||
HMACKey: []byte("test-hmac-key"),
|
HMACKey: []byte("test-hmac-key"),
|
||||||
},
|
},
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := "https://app.example.com/callback"
|
redirectURL := "https://app.example.com/callback"
|
||||||
@@ -265,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
oidcConfig: ProxyOIDCConfig{
|
oidcConfig: ProxyOIDCConfig{
|
||||||
HMACKey: []byte("test-hmac-key"),
|
HMACKey: []byte("test-hmac-key"),
|
||||||
},
|
},
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Old format had only 2 parts: base64(url)|hmac
|
// Old format had only 2 parts: base64(url)|hmac
|
||||||
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, _, err := s.ValidateState("base64url|hmac")
|
_, _, err = s.ValidateState("base64url|hmac")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "invalid state format")
|
assert.Contains(t, err.Error(), "invalid state format")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
oidcConfig: ProxyOIDCConfig{
|
oidcConfig: ProxyOIDCConfig{
|
||||||
HMACKey: []byte("test-hmac-key"),
|
HMACKey: []byte("test-hmac-key"),
|
||||||
},
|
},
|
||||||
|
pkceVerifierStore: pkceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store with tampered HMAC
|
// Store with tampered HMAC
|
||||||
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
_, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "invalid state signature")
|
assert.Contains(t, err.Error(), "invalid state signature")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,10 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
|||||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||||
proxyService.SetServiceManager(serviceManager)
|
proxyService.SetServiceManager(serviceManager)
|
||||||
|
|
||||||
createTestProxies(t, ctx, testStore)
|
createTestProxies(t, ctx, testStore)
|
||||||
|
|||||||
@@ -1379,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
|||||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||||
// This section is mostly related to self-hosted installations.
|
// This section is mostly related to self-hosted installations.
|
||||||
// We override incoming domain claims to group users under a single account.
|
// We override incoming domain claims to group users under a single account.
|
||||||
userAuth.Domain = am.singleAccountModeDomain
|
err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
|
||||||
userAuth.DomainCategory = types.PrivateCategory
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
return "", "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
||||||
@@ -1414,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
|||||||
return accountID, user.Id, nil
|
return accountID, user.Id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
|
||||||
|
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
|
||||||
|
userAuth.DomainCategory = types.PrivateCategory
|
||||||
|
userAuth.Domain = am.singleAccountModeDomain
|
||||||
|
|
||||||
|
accountID, err := am.Store.GetAnyAccountID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountID == "" {
|
||||||
|
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userAuth.Domain = domain
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
// and propagates changes to peers if group propagation is enabled.
|
// and propagates changes to peers if group propagation is enabled.
|
||||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
"github.com/prometheus/client_golang/prometheus/push"
|
"github.com/prometheus/client_golang/prometheus/push"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -3132,7 +3133,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -3966,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
|
|||||||
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
|
||||||
|
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("account-1", nil)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||||
|
Return("real-domain.com", "private", nil)
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "real-domain.com", userAuth.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("", nil)
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("", status.Errorf(status.NotFound, "no accounts"))
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("", status.Errorf(status.Internal, "db down"))
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "db down")
|
||||||
|
// Defaults should still be set before error path
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("account-1", nil)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||||
|
Return("", "", status.Errorf(status.Internal, "query failed"))
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "query failed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -220,6 +220,13 @@ const (
|
|||||||
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
|
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
|
||||||
AccountPeerExposeDisabled Activity = 115
|
AccountPeerExposeDisabled Activity = 115
|
||||||
|
|
||||||
|
// DomainAdded indicates that a user added a custom domain
|
||||||
|
DomainAdded Activity = 116
|
||||||
|
// DomainDeleted indicates that a user deleted a custom domain
|
||||||
|
DomainDeleted Activity = 117
|
||||||
|
// DomainValidated indicates that a custom domain was validated
|
||||||
|
DomainValidated Activity = 118
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -364,6 +371,10 @@ var activityMap = map[Activity]Code{
|
|||||||
|
|
||||||
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
|
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
|
||||||
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
|
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
|
||||||
|
|
||||||
|
DomainAdded: {"Domain added", "domain.add"},
|
||||||
|
DomainDeleted: {"Domain deleted", "domain.delete"},
|
||||||
|
DomainValidated: {"Domain validated", "domain.validate"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
@@ -193,6 +193,9 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
|||||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
usersManager := users.NewManager(testStore)
|
usersManager := users.NewManager(testStore)
|
||||||
|
|
||||||
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||||
@@ -206,6 +209,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
|||||||
proxyService := nbgrpc.NewProxyServiceServer(
|
proxyService := nbgrpc.NewProxyServiceServer(
|
||||||
&testAccessLogManager{},
|
&testAccessLogManager{},
|
||||||
tokenStore,
|
tokenStore,
|
||||||
|
pkceStore,
|
||||||
oidcConfig,
|
oidcConfig,
|
||||||
nil,
|
nil,
|
||||||
usersManager,
|
usersManager,
|
||||||
|
|||||||
@@ -98,13 +98,17 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create proxy token store: %v", err)
|
t.Fatalf("Failed to create proxy token store: %v", err)
|
||||||
}
|
}
|
||||||
|
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create PKCE verifier store: %v", err)
|
||||||
|
}
|
||||||
noopMeter := noop.NewMeterProvider().Meter("")
|
noopMeter := noop.NewMeterProvider().Meter("")
|
||||||
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
|
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||||
}
|
}
|
||||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
|
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create proxy controller: %v", err)
|
t.Fatalf("Failed to create proxy controller: %v", err)
|
||||||
|
|||||||
@@ -4977,9 +4977,9 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
|
|||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
func (s *SqlStore) GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||||
var service *rpservice.Service
|
var service *rpservice.Service
|
||||||
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
|
result := s.db.Preload("Targets").Where("domain = ?", domain).First(&service)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
|
return nil, status.Errorf(status.NotFound, "service with domain %s not found", domain)
|
||||||
|
|||||||
@@ -257,7 +257,7 @@ type Store interface {
|
|||||||
UpdateService(ctx context.Context, service *rpservice.Service) error
|
UpdateService(ctx context.Context, service *rpservice.Service) error
|
||||||
DeleteService(ctx context.Context, accountID, serviceID string) error
|
DeleteService(ctx context.Context, accountID, serviceID string) error
|
||||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
|
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
|
||||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
|
GetServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error)
|
||||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
|
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
|
||||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
|
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
|
||||||
|
|
||||||
|
|||||||
@@ -1932,18 +1932,18 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetServiceByDomain mocks base method.
|
// GetServiceByDomain mocks base method.
|
||||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
|
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
|
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||||
ret0, _ := ret[0].(*service.Service)
|
ret0, _ := ret[0].(*service.Service)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||||
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain interface{}) *gomock.Call {
|
func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, accountID, domain)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockStore)(nil).GetServiceByDomain), ctx, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServiceByID mocks base method.
|
// GetServiceByID mocks base method.
|
||||||
|
|||||||
185
management/server/telemetry/account_aggregator.go
Normal file
185
management/server/telemetry/account_aggregator.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package telemetry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||||
|
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccountDurationAggregator uses OpenTelemetry histograms per account to calculate P95
|
||||||
|
// without publishing individual account labels
|
||||||
|
type AccountDurationAggregator struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
accounts map[string]*accountHistogram
|
||||||
|
meterProvider *sdkmetric.MeterProvider
|
||||||
|
manualReader *sdkmetric.ManualReader
|
||||||
|
|
||||||
|
FlushInterval time.Duration
|
||||||
|
MaxAge time.Duration
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
type accountHistogram struct {
|
||||||
|
histogram metric.Int64Histogram
|
||||||
|
lastUpdate time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAccountDurationAggregator creates aggregator using OTel histograms
|
||||||
|
func NewAccountDurationAggregator(ctx context.Context, flushInterval, maxAge time.Duration) *AccountDurationAggregator {
|
||||||
|
manualReader := sdkmetric.NewManualReader(
|
||||||
|
sdkmetric.WithTemporalitySelector(func(kind sdkmetric.InstrumentKind) metricdata.Temporality {
|
||||||
|
return metricdata.DeltaTemporality
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
meterProvider := sdkmetric.NewMeterProvider(
|
||||||
|
sdkmetric.WithReader(manualReader),
|
||||||
|
)
|
||||||
|
|
||||||
|
return &AccountDurationAggregator{
|
||||||
|
accounts: make(map[string]*accountHistogram),
|
||||||
|
meterProvider: meterProvider,
|
||||||
|
manualReader: manualReader,
|
||||||
|
FlushInterval: flushInterval,
|
||||||
|
MaxAge: maxAge,
|
||||||
|
ctx: ctx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record adds a duration for an account using OTel histogram
|
||||||
|
func (a *AccountDurationAggregator) Record(accountID string, duration time.Duration) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
accHist, exists := a.accounts[accountID]
|
||||||
|
if !exists {
|
||||||
|
meter := a.meterProvider.Meter("account-aggregator")
|
||||||
|
histogram, err := meter.Int64Histogram(
|
||||||
|
"sync_duration_per_account",
|
||||||
|
metric.WithUnit("milliseconds"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accHist = &accountHistogram{
|
||||||
|
histogram: histogram,
|
||||||
|
}
|
||||||
|
a.accounts[accountID] = accHist
|
||||||
|
}
|
||||||
|
|
||||||
|
accHist.histogram.Record(a.ctx, duration.Milliseconds(),
|
||||||
|
metric.WithAttributes(attribute.String("account_id", accountID)))
|
||||||
|
accHist.lastUpdate = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlushAndGetP95s extracts P95 from each account's histogram
|
||||||
|
func (a *AccountDurationAggregator) FlushAndGetP95s() []int64 {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
var rm metricdata.ResourceMetrics
|
||||||
|
err := a.manualReader.Collect(a.ctx, &rm)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
p95s := make([]int64, 0, len(a.accounts))
|
||||||
|
|
||||||
|
for _, scopeMetrics := range rm.ScopeMetrics {
|
||||||
|
for _, metric := range scopeMetrics.Metrics {
|
||||||
|
histogramData, ok := metric.Data.(metricdata.Histogram[int64])
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dataPoint := range histogramData.DataPoints {
|
||||||
|
a.processDataPoint(dataPoint, now, &p95s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
a.cleanupStaleAccounts(now)
|
||||||
|
|
||||||
|
return p95s
|
||||||
|
}
|
||||||
|
|
||||||
|
// processDataPoint extracts P95 from a single histogram data point
|
||||||
|
func (a *AccountDurationAggregator) processDataPoint(dataPoint metricdata.HistogramDataPoint[int64], now time.Time, p95s *[]int64) {
|
||||||
|
accountID := extractAccountID(dataPoint)
|
||||||
|
if accountID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if p95 := calculateP95FromHistogram(dataPoint); p95 > 0 {
|
||||||
|
*p95s = append(*p95s, p95)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupStaleAccounts removes accounts that haven't been updated recently
|
||||||
|
func (a *AccountDurationAggregator) cleanupStaleAccounts(now time.Time) {
|
||||||
|
for accountID := range a.accounts {
|
||||||
|
if a.isStaleAccount(accountID, now) {
|
||||||
|
delete(a.accounts, accountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractAccountID retrieves the account_id from histogram data point attributes
|
||||||
|
func extractAccountID(dp metricdata.HistogramDataPoint[int64]) string {
|
||||||
|
for _, attr := range dp.Attributes.ToSlice() {
|
||||||
|
if attr.Key == "account_id" {
|
||||||
|
return attr.Value.AsString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// isStaleAccount checks if an account hasn't been updated recently
|
||||||
|
func (a *AccountDurationAggregator) isStaleAccount(accountID string, now time.Time) bool {
|
||||||
|
accHist, exists := a.accounts[accountID]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return now.Sub(accHist.lastUpdate) > a.MaxAge
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateP95FromHistogram computes P95 from OTel histogram data
|
||||||
|
func calculateP95FromHistogram(dp metricdata.HistogramDataPoint[int64]) int64 {
|
||||||
|
if dp.Count == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
targetCount := uint64(math.Ceil(float64(dp.Count) * 0.95))
|
||||||
|
if targetCount == 0 {
|
||||||
|
targetCount = 1
|
||||||
|
}
|
||||||
|
var cumulativeCount uint64
|
||||||
|
|
||||||
|
for i, bucketCount := range dp.BucketCounts {
|
||||||
|
cumulativeCount += bucketCount
|
||||||
|
if cumulativeCount >= targetCount {
|
||||||
|
if i < len(dp.Bounds) {
|
||||||
|
return int64(dp.Bounds[i])
|
||||||
|
}
|
||||||
|
if maxVal, defined := dp.Max.Value(); defined {
|
||||||
|
return maxVal
|
||||||
|
}
|
||||||
|
return dp.Sum / int64(dp.Count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dp.Sum / int64(dp.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown cleans up resources
|
||||||
|
func (a *AccountDurationAggregator) Shutdown() error {
|
||||||
|
return a.meterProvider.Shutdown(a.ctx)
|
||||||
|
}
|
||||||
219
management/server/telemetry/account_aggregator_test.go
Normal file
219
management/server/telemetry/account_aggregator_test.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package telemetry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeltaTemporality_P95ReflectsCurrentWindow(t *testing.T) {
|
||||||
|
// Verify that with delta temporality, each flush window only reflects
|
||||||
|
// recordings since the last flush — not all-time data.
|
||||||
|
ctx := context.Background()
|
||||||
|
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||||
|
defer func(agg *AccountDurationAggregator) {
|
||||||
|
err := agg.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||||
|
}
|
||||||
|
}(agg)
|
||||||
|
|
||||||
|
// Window 1: Record 100 slow requests (500ms each)
|
||||||
|
for range 100 {
|
||||||
|
agg.Record("account-A", 500*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
p95sWindow1 := agg.FlushAndGetP95s()
|
||||||
|
require.Len(t, p95sWindow1, 1, "should have P95 for one account")
|
||||||
|
firstP95 := p95sWindow1[0]
|
||||||
|
assert.GreaterOrEqual(t, firstP95, int64(200),
|
||||||
|
"first window P95 should reflect the 500ms recordings")
|
||||||
|
|
||||||
|
// Window 2: Record 100 FAST requests (10ms each)
|
||||||
|
for range 100 {
|
||||||
|
agg.Record("account-A", 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
p95sWindow2 := agg.FlushAndGetP95s()
|
||||||
|
require.Len(t, p95sWindow2, 1, "should have P95 for one account")
|
||||||
|
secondP95 := p95sWindow2[0]
|
||||||
|
|
||||||
|
// With delta temporality the P95 should drop significantly because
|
||||||
|
// the first window's slow recordings are no longer included.
|
||||||
|
assert.Less(t, secondP95, firstP95,
|
||||||
|
"second window P95 should be lower than first — delta temporality "+
|
||||||
|
"ensures each window only reflects recent recordings")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEqualWeightPerAccount(t *testing.T) {
|
||||||
|
// Verify that each account contributes exactly one P95 value,
|
||||||
|
// regardless of how many requests it made.
|
||||||
|
ctx := context.Background()
|
||||||
|
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||||
|
defer func(agg *AccountDurationAggregator) {
|
||||||
|
err := agg.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||||
|
}
|
||||||
|
}(agg)
|
||||||
|
|
||||||
|
// Account A: 10,000 requests at 500ms (noisy customer)
|
||||||
|
for range 10000 {
|
||||||
|
agg.Record("account-A", 500*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accounts B, C, D: 10 requests each at 50ms (normal customers)
|
||||||
|
for _, id := range []string{"account-B", "account-C", "account-D"} {
|
||||||
|
for range 10 {
|
||||||
|
agg.Record(id, 50*time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p95s := agg.FlushAndGetP95s()
|
||||||
|
|
||||||
|
// Should get exactly 4 P95 values — one per account
|
||||||
|
assert.Len(t, p95s, 4, "each account should contribute exactly one P95")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaleAccountEviction(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
// Use a very short MaxAge so we can test staleness
|
||||||
|
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
|
||||||
|
defer func(agg *AccountDurationAggregator) {
|
||||||
|
err := agg.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||||
|
}
|
||||||
|
}(agg)
|
||||||
|
|
||||||
|
agg.Record("account-A", 100*time.Millisecond)
|
||||||
|
agg.Record("account-B", 200*time.Millisecond)
|
||||||
|
|
||||||
|
// Both accounts should appear
|
||||||
|
p95s := agg.FlushAndGetP95s()
|
||||||
|
assert.Len(t, p95s, 2, "both accounts should have P95 values")
|
||||||
|
|
||||||
|
// Wait for account-A to become stale, then only update account-B
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
agg.Record("account-B", 200*time.Millisecond)
|
||||||
|
|
||||||
|
p95s = agg.FlushAndGetP95s()
|
||||||
|
assert.Len(t, p95s, 1, "both accounts should have P95 values")
|
||||||
|
|
||||||
|
// account-A should have been evicted from the accounts map
|
||||||
|
agg.mu.RLock()
|
||||||
|
_, accountAExists := agg.accounts["account-A"]
|
||||||
|
_, accountBExists := agg.accounts["account-B"]
|
||||||
|
agg.mu.RUnlock()
|
||||||
|
|
||||||
|
assert.False(t, accountAExists, "stale account-A should be evicted from map")
|
||||||
|
assert.True(t, accountBExists, "active account-B should remain in map")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaleAccountEviction_DoesNotReappear(t *testing.T) {
|
||||||
|
// Verify that with delta temporality, an evicted stale account does not
|
||||||
|
// reappear in subsequent flushes.
|
||||||
|
ctx := context.Background()
|
||||||
|
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
|
||||||
|
defer func(agg *AccountDurationAggregator) {
|
||||||
|
err := agg.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||||
|
}
|
||||||
|
}(agg)
|
||||||
|
|
||||||
|
agg.Record("account-stale", 100*time.Millisecond)
|
||||||
|
|
||||||
|
// Wait for it to become stale
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
// First flush: should detect staleness and evict
|
||||||
|
_ = agg.FlushAndGetP95s()
|
||||||
|
|
||||||
|
agg.mu.RLock()
|
||||||
|
_, exists := agg.accounts["account-stale"]
|
||||||
|
agg.mu.RUnlock()
|
||||||
|
assert.False(t, exists, "account should be evicted after first flush")
|
||||||
|
|
||||||
|
// Second flush: with delta temporality, the stale account should NOT reappear
|
||||||
|
p95sSecond := agg.FlushAndGetP95s()
|
||||||
|
assert.Empty(t, p95sSecond,
|
||||||
|
"evicted account should not reappear in subsequent flushes with delta temporality")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestP95Calculation_SingleSample(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||||
|
defer func(agg *AccountDurationAggregator) {
|
||||||
|
err := agg.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||||
|
}
|
||||||
|
}(agg)
|
||||||
|
|
||||||
|
agg.Record("account-A", 150*time.Millisecond)
|
||||||
|
|
||||||
|
p95s := agg.FlushAndGetP95s()
|
||||||
|
require.Len(t, p95s, 1)
|
||||||
|
// With a single sample, P95 should be the bucket bound containing 150ms
|
||||||
|
assert.Greater(t, p95s[0], int64(0), "P95 of a single sample should be positive")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestP95Calculation_AllSameValue(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||||
|
defer func(agg *AccountDurationAggregator) {
|
||||||
|
err := agg.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||||
|
}
|
||||||
|
}(agg)
|
||||||
|
|
||||||
|
// All samples are 100ms — P95 should be the bucket bound containing 100ms
|
||||||
|
for range 100 {
|
||||||
|
agg.Record("account-A", 100*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
p95s := agg.FlushAndGetP95s()
|
||||||
|
require.Len(t, p95s, 1)
|
||||||
|
assert.Greater(t, p95s[0], int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleAccounts_IndependentP95s(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||||
|
defer func(agg *AccountDurationAggregator) {
|
||||||
|
err := agg.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||||
|
}
|
||||||
|
}(agg)
|
||||||
|
|
||||||
|
// Account A: all fast (10ms)
|
||||||
|
for range 100 {
|
||||||
|
agg.Record("account-fast", 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Account B: all slow (5000ms)
|
||||||
|
for range 100 {
|
||||||
|
agg.Record("account-slow", 5000*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
p95s := agg.FlushAndGetP95s()
|
||||||
|
require.Len(t, p95s, 2, "should have two P95 values")
|
||||||
|
|
||||||
|
// Find min and max — they should differ significantly
|
||||||
|
minP95 := p95s[0]
|
||||||
|
maxP95 := p95s[1]
|
||||||
|
if minP95 > maxP95 {
|
||||||
|
minP95, maxP95 = maxP95, minP95
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Less(t, minP95, int64(1000),
|
||||||
|
"fast account P95 should be well under 1000ms")
|
||||||
|
assert.Greater(t, maxP95, int64(1000),
|
||||||
|
"slow account P95 should be well over 1000ms")
|
||||||
|
}
|
||||||
@@ -13,18 +13,24 @@ const HighLatencyThreshold = time.Second * 7
|
|||||||
|
|
||||||
// GRPCMetrics are gRPC server metrics
|
// GRPCMetrics are gRPC server metrics
|
||||||
type GRPCMetrics struct {
|
type GRPCMetrics struct {
|
||||||
meter metric.Meter
|
meter metric.Meter
|
||||||
syncRequestsCounter metric.Int64Counter
|
syncRequestsCounter metric.Int64Counter
|
||||||
syncRequestsBlockedCounter metric.Int64Counter
|
syncRequestsBlockedCounter metric.Int64Counter
|
||||||
loginRequestsCounter metric.Int64Counter
|
loginRequestsCounter metric.Int64Counter
|
||||||
loginRequestsBlockedCounter metric.Int64Counter
|
loginRequestsBlockedCounter metric.Int64Counter
|
||||||
loginRequestHighLatencyCounter metric.Int64Counter
|
loginRequestHighLatencyCounter metric.Int64Counter
|
||||||
getKeyRequestsCounter metric.Int64Counter
|
getKeyRequestsCounter metric.Int64Counter
|
||||||
activeStreamsGauge metric.Int64ObservableGauge
|
activeStreamsGauge metric.Int64ObservableGauge
|
||||||
syncRequestDuration metric.Int64Histogram
|
syncRequestDuration metric.Int64Histogram
|
||||||
loginRequestDuration metric.Int64Histogram
|
syncRequestDurationP95ByAccount metric.Int64Histogram
|
||||||
channelQueueLength metric.Int64Histogram
|
loginRequestDuration metric.Int64Histogram
|
||||||
ctx context.Context
|
loginRequestDurationP95ByAccount metric.Int64Histogram
|
||||||
|
channelQueueLength metric.Int64Histogram
|
||||||
|
ctx context.Context
|
||||||
|
|
||||||
|
// Per-account aggregation
|
||||||
|
syncDurationAggregator *AccountDurationAggregator
|
||||||
|
loginDurationAggregator *AccountDurationAggregator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
|
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
|
||||||
@@ -93,6 +99,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
syncRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.sync.request.duration.p95.by.account.ms",
|
||||||
|
metric.WithUnit("milliseconds"),
|
||||||
|
metric.WithDescription("P95 duration of sync requests aggregated per account - each data point represents one account's P95"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms",
|
loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms",
|
||||||
metric.WithUnit("milliseconds"),
|
metric.WithUnit("milliseconds"),
|
||||||
metric.WithDescription("Duration of the login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
|
metric.WithDescription("Duration of the login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
|
||||||
@@ -101,6 +115,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loginRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.login.request.duration.p95.by.account.ms",
|
||||||
|
metric.WithUnit("milliseconds"),
|
||||||
|
metric.WithDescription("P95 duration of login requests aggregated per account - each data point represents one account's P95"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
|
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
|
||||||
// Then we should be able to extract min, manx, mean and the percentiles.
|
// Then we should be able to extract min, manx, mean and the percentiles.
|
||||||
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
|
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
|
||||||
@@ -113,20 +135,32 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &GRPCMetrics{
|
syncDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
|
||||||
meter: meter,
|
loginDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
|
||||||
syncRequestsCounter: syncRequestsCounter,
|
|
||||||
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
|
grpcMetrics := &GRPCMetrics{
|
||||||
loginRequestsCounter: loginRequestsCounter,
|
meter: meter,
|
||||||
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
|
syncRequestsCounter: syncRequestsCounter,
|
||||||
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
|
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
|
||||||
getKeyRequestsCounter: getKeyRequestsCounter,
|
loginRequestsCounter: loginRequestsCounter,
|
||||||
activeStreamsGauge: activeStreamsGauge,
|
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
|
||||||
syncRequestDuration: syncRequestDuration,
|
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
|
||||||
loginRequestDuration: loginRequestDuration,
|
getKeyRequestsCounter: getKeyRequestsCounter,
|
||||||
channelQueueLength: channelQueue,
|
activeStreamsGauge: activeStreamsGauge,
|
||||||
ctx: ctx,
|
syncRequestDuration: syncRequestDuration,
|
||||||
}, err
|
syncRequestDurationP95ByAccount: syncRequestDurationP95ByAccount,
|
||||||
|
loginRequestDuration: loginRequestDuration,
|
||||||
|
loginRequestDurationP95ByAccount: loginRequestDurationP95ByAccount,
|
||||||
|
channelQueueLength: channelQueue,
|
||||||
|
ctx: ctx,
|
||||||
|
syncDurationAggregator: syncDurationAggregator,
|
||||||
|
loginDurationAggregator: loginDurationAggregator,
|
||||||
|
}
|
||||||
|
|
||||||
|
go grpcMetrics.startSyncP95Flusher()
|
||||||
|
go grpcMetrics.startLoginP95Flusher()
|
||||||
|
|
||||||
|
return grpcMetrics, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API
|
// CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API
|
||||||
@@ -157,6 +191,9 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() {
|
|||||||
// CountLoginRequestDuration counts the duration of the login gRPC requests
|
// CountLoginRequestDuration counts the duration of the login gRPC requests
|
||||||
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
|
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
|
||||||
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
||||||
|
|
||||||
|
grpcMetrics.loginDurationAggregator.Record(accountID, duration)
|
||||||
|
|
||||||
if duration > HighLatencyThreshold {
|
if duration > HighLatencyThreshold {
|
||||||
grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
|
grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
|
||||||
}
|
}
|
||||||
@@ -165,6 +202,44 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
|
|||||||
// CountSyncRequestDuration counts the duration of the sync gRPC requests
|
// CountSyncRequestDuration counts the duration of the sync gRPC requests
|
||||||
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
|
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
|
||||||
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
||||||
|
|
||||||
|
grpcMetrics.syncDurationAggregator.Record(accountID, duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// startSyncP95Flusher periodically flushes per-account sync P95 values to the histogram
|
||||||
|
func (grpcMetrics *GRPCMetrics) startSyncP95Flusher() {
|
||||||
|
ticker := time.NewTicker(grpcMetrics.syncDurationAggregator.FlushInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-grpcMetrics.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
p95s := grpcMetrics.syncDurationAggregator.FlushAndGetP95s()
|
||||||
|
for _, p95 := range p95s {
|
||||||
|
grpcMetrics.syncRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startLoginP95Flusher periodically flushes per-account login P95 values to the histogram
|
||||||
|
func (grpcMetrics *GRPCMetrics) startLoginP95Flusher() {
|
||||||
|
ticker := time.NewTicker(grpcMetrics.loginDurationAggregator.FlushInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-grpcMetrics.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
p95s := grpcMetrics.loginDurationAggregator.FlushAndGetP95s()
|
||||||
|
for _, p95 := range p95s {
|
||||||
|
grpcMetrics.loginRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
|
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package accesslog
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -13,6 +14,23 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
requestThreshold = 10000 // Log every 10k requests
|
||||||
|
bytesThreshold = 1024 * 1024 * 1024 // Log every 1GB
|
||||||
|
usageCleanupPeriod = 1 * time.Hour // Clean up stale counters every hour
|
||||||
|
usageInactiveWindow = 24 * time.Hour // Consider domain inactive if no traffic for 24 hours
|
||||||
|
)
|
||||||
|
|
||||||
|
type domainUsage struct {
|
||||||
|
requestCount int64
|
||||||
|
requestStartTime time.Time
|
||||||
|
|
||||||
|
bytesTransferred int64
|
||||||
|
bytesStartTime time.Time
|
||||||
|
|
||||||
|
lastActivity time.Time // Track last activity for cleanup
|
||||||
|
}
|
||||||
|
|
||||||
type gRPCClient interface {
|
type gRPCClient interface {
|
||||||
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
|
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
|
||||||
}
|
}
|
||||||
@@ -22,6 +40,11 @@ type Logger struct {
|
|||||||
client gRPCClient
|
client gRPCClient
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
trustedProxies []netip.Prefix
|
trustedProxies []netip.Prefix
|
||||||
|
|
||||||
|
usageMux sync.Mutex
|
||||||
|
domainUsage map[string]*domainUsage
|
||||||
|
|
||||||
|
cleanupCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLogger creates a new access log Logger. The trustedProxies parameter
|
// NewLogger creates a new access log Logger. The trustedProxies parameter
|
||||||
@@ -31,10 +54,26 @@ func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Pre
|
|||||||
if logger == nil {
|
if logger == nil {
|
||||||
logger = log.StandardLogger()
|
logger = log.StandardLogger()
|
||||||
}
|
}
|
||||||
return &Logger{
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
l := &Logger{
|
||||||
client: client,
|
client: client,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
trustedProxies: trustedProxies,
|
trustedProxies: trustedProxies,
|
||||||
|
domainUsage: make(map[string]*domainUsage),
|
||||||
|
cleanupCancel: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start background cleanup routine
|
||||||
|
go l.cleanupStaleUsage(ctx)
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine. Should be called during graceful shutdown.
|
||||||
|
func (l *Logger) Close() {
|
||||||
|
if l.cleanupCancel != nil {
|
||||||
|
l.cleanupCancel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,6 +90,8 @@ type logEntry struct {
|
|||||||
AuthMechanism string
|
AuthMechanism string
|
||||||
UserId string
|
UserId string
|
||||||
AuthSuccess bool
|
AuthSuccess bool
|
||||||
|
BytesUpload int64
|
||||||
|
BytesDownload int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) log(ctx context.Context, entry logEntry) {
|
func (l *Logger) log(ctx context.Context, entry logEntry) {
|
||||||
@@ -84,6 +125,8 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
|
|||||||
AuthMechanism: entry.AuthMechanism,
|
AuthMechanism: entry.AuthMechanism,
|
||||||
UserId: entry.UserId,
|
UserId: entry.UserId,
|
||||||
AuthSuccess: entry.AuthSuccess,
|
AuthSuccess: entry.AuthSuccess,
|
||||||
|
BytesUpload: entry.BytesUpload,
|
||||||
|
BytesDownload: entry.BytesDownload,
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
// If it fails to send on the gRPC connection, then at least log it to the error log.
|
// If it fails to send on the gRPC connection, then at least log it to the error log.
|
||||||
@@ -103,3 +146,82 @@ func (l *Logger) log(ctx context.Context, entry logEntry) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// trackUsage records request and byte counts per domain, logging when thresholds are hit.
|
||||||
|
func (l *Logger) trackUsage(domain string, bytesTransferred int64) {
|
||||||
|
if domain == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.usageMux.Lock()
|
||||||
|
defer l.usageMux.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
usage, exists := l.domainUsage[domain]
|
||||||
|
if !exists {
|
||||||
|
usage = &domainUsage{
|
||||||
|
requestStartTime: now,
|
||||||
|
bytesStartTime: now,
|
||||||
|
lastActivity: now,
|
||||||
|
}
|
||||||
|
l.domainUsage[domain] = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
usage.lastActivity = now
|
||||||
|
|
||||||
|
usage.requestCount++
|
||||||
|
if usage.requestCount >= requestThreshold {
|
||||||
|
elapsed := time.Since(usage.requestStartTime)
|
||||||
|
l.logger.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"requests": usage.requestCount,
|
||||||
|
"duration": elapsed.String(),
|
||||||
|
}).Infof("domain %s had %d requests over %s", domain, usage.requestCount, elapsed)
|
||||||
|
|
||||||
|
usage.requestCount = 0
|
||||||
|
usage.requestStartTime = now
|
||||||
|
}
|
||||||
|
|
||||||
|
usage.bytesTransferred += bytesTransferred
|
||||||
|
if usage.bytesTransferred >= bytesThreshold {
|
||||||
|
elapsed := time.Since(usage.bytesStartTime)
|
||||||
|
bytesInGB := float64(usage.bytesTransferred) / (1024 * 1024 * 1024)
|
||||||
|
l.logger.WithFields(log.Fields{
|
||||||
|
"domain": domain,
|
||||||
|
"bytes": usage.bytesTransferred,
|
||||||
|
"bytes_gb": bytesInGB,
|
||||||
|
"duration": elapsed.String(),
|
||||||
|
}).Infof("domain %s transferred %.2f GB over %s", domain, bytesInGB, elapsed)
|
||||||
|
|
||||||
|
usage.bytesTransferred = 0
|
||||||
|
usage.bytesStartTime = now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupStaleUsage removes usage entries for domains that have been inactive.
|
||||||
|
func (l *Logger) cleanupStaleUsage(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(usageCleanupPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
l.usageMux.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
removed := 0
|
||||||
|
for domain, usage := range l.domainUsage {
|
||||||
|
if now.Sub(usage.lastActivity) > usageInactiveWindow {
|
||||||
|
delete(l.domainUsage, domain)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.usageMux.Unlock()
|
||||||
|
|
||||||
|
if removed > 0 {
|
||||||
|
l.logger.Debugf("cleaned up %d stale domain usage entries", removed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
|||||||
status: http.StatusOK,
|
status: http.StatusOK,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var bytesRead int64
|
||||||
|
if r.Body != nil {
|
||||||
|
r.Body = &bodyCounter{
|
||||||
|
ReadCloser: r.Body,
|
||||||
|
bytesRead: &bytesRead,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Resolve the source IP using trusted proxy configuration before passing
|
// Resolve the source IP using trusted proxy configuration before passing
|
||||||
// the request on, as the proxy will modify forwarding headers.
|
// the request on, as the proxy will modify forwarding headers.
|
||||||
sourceIp := extractSourceIP(r, l.trustedProxies)
|
sourceIp := extractSourceIP(r, l.trustedProxies)
|
||||||
@@ -53,6 +61,9 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
|||||||
host = r.Host
|
host = r.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bytesUpload := bytesRead
|
||||||
|
bytesDownload := sw.bytesWritten
|
||||||
|
|
||||||
entry := logEntry{
|
entry := logEntry{
|
||||||
ID: requestID,
|
ID: requestID,
|
||||||
ServiceId: capturedData.GetServiceId(),
|
ServiceId: capturedData.GetServiceId(),
|
||||||
@@ -66,10 +77,15 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
|||||||
AuthMechanism: capturedData.GetAuthMethod(),
|
AuthMechanism: capturedData.GetAuthMethod(),
|
||||||
UserId: capturedData.GetUserID(),
|
UserId: capturedData.GetUserID(),
|
||||||
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
|
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
|
||||||
|
BytesUpload: bytesUpload,
|
||||||
|
BytesDownload: bytesDownload,
|
||||||
}
|
}
|
||||||
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
|
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
|
||||||
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
|
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
|
||||||
|
|
||||||
l.log(r.Context(), entry)
|
l.log(r.Context(), entry)
|
||||||
|
|
||||||
|
// Track usage for cost monitoring (upload + download) by domain
|
||||||
|
l.trackUsage(host, bytesUpload+bytesDownload)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,18 +1,39 @@
|
|||||||
package accesslog
|
package accesslog
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
// statusWriter captures the HTTP status code from WriteHeader calls.
|
// statusWriter captures the HTTP status code and bytes written from responses.
|
||||||
// It embeds responsewriter.PassthroughWriter which handles all the optional
|
// It embeds responsewriter.PassthroughWriter which handles all the optional
|
||||||
// interfaces (Hijacker, Flusher, Pusher) automatically.
|
// interfaces (Hijacker, Flusher, Pusher) automatically.
|
||||||
type statusWriter struct {
|
type statusWriter struct {
|
||||||
*responsewriter.PassthroughWriter
|
*responsewriter.PassthroughWriter
|
||||||
status int
|
status int
|
||||||
|
bytesWritten int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *statusWriter) WriteHeader(status int) {
|
func (w *statusWriter) WriteHeader(status int) {
|
||||||
w.status = status
|
w.status = status
|
||||||
w.PassthroughWriter.WriteHeader(status)
|
w.PassthroughWriter.WriteHeader(status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *statusWriter) Write(b []byte) (int, error) {
|
||||||
|
n, err := w.PassthroughWriter.Write(b)
|
||||||
|
w.bytesWritten += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// bodyCounter wraps an io.ReadCloser and counts bytes read from the request body.
|
||||||
|
type bodyCounter struct {
|
||||||
|
io.ReadCloser
|
||||||
|
bytesRead *int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bc *bodyCounter) Read(p []byte) (int, error) {
|
||||||
|
n, err := bc.ReadCloser.Read(p)
|
||||||
|
*bc.bytesRead += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ type domainInfo struct {
|
|||||||
err string
|
err string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type metricsRecorder interface {
|
||||||
|
RecordCertificateIssuance(duration time.Duration)
|
||||||
|
}
|
||||||
|
|
||||||
// Manager wraps autocert.Manager with domain tracking and cross-replica
|
// Manager wraps autocert.Manager with domain tracking and cross-replica
|
||||||
// coordination via a pluggable locking strategy. The locker prevents
|
// coordination via a pluggable locking strategy. The locker prevents
|
||||||
// duplicate ACME requests when multiple replicas share a certificate cache.
|
// duplicate ACME requests when multiple replicas share a certificate cache.
|
||||||
@@ -55,6 +59,7 @@ type Manager struct {
|
|||||||
|
|
||||||
certNotifier certificateNotifier
|
certNotifier certificateNotifier
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
metrics metricsRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new ACME certificate manager. The certDir is used
|
// NewManager creates a new ACME certificate manager. The certDir is used
|
||||||
@@ -63,7 +68,7 @@ type Manager struct {
|
|||||||
// eabKID and eabHMACKey are optional External Account Binding credentials
|
// eabKID and eabHMACKey are optional External Account Binding credentials
|
||||||
// required for some CAs like ZeroSSL. The eabHMACKey should be the base64
|
// required for some CAs like ZeroSSL. The eabHMACKey should be the base64
|
||||||
// URL-encoded string provided by the CA.
|
// URL-encoded string provided by the CA.
|
||||||
func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager {
|
func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod, metrics metricsRecorder) *Manager {
|
||||||
if logger == nil {
|
if logger == nil {
|
||||||
logger = log.StandardLogger()
|
logger = log.StandardLogger()
|
||||||
}
|
}
|
||||||
@@ -73,6 +78,7 @@ func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificat
|
|||||||
domains: make(map[domain.Domain]*domainInfo),
|
domains: make(map[domain.Domain]*domainInfo),
|
||||||
certNotifier: notifier,
|
certNotifier: notifier,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
metrics: metrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
var eab *acme.ExternalAccountBinding
|
var eab *acme.ExternalAccountBinding
|
||||||
@@ -129,24 +135,45 @@ func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) {
|
|||||||
|
|
||||||
// prefetchCertificate proactively triggers certificate generation for a domain.
|
// prefetchCertificate proactively triggers certificate generation for a domain.
|
||||||
// It acquires a distributed lock to prevent multiple replicas from issuing
|
// It acquires a distributed lock to prevent multiple replicas from issuing
|
||||||
// duplicate ACME requests. The second replica will block until the first
|
// duplicate ACME requests. If the certificate appears in the cache while waiting
|
||||||
// finishes, then find the certificate in the cache.
|
// for the lock, it cancels the wait and uses the cached certificate.
|
||||||
func (mgr *Manager) prefetchCertificate(d domain.Domain) {
|
func (mgr *Manager) prefetchCertificate(d domain.Domain) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
name := d.PunycodeString()
|
name := d.PunycodeString()
|
||||||
|
|
||||||
|
if mgr.certExistsInCache(ctx, name) {
|
||||||
|
mgr.logger.Infof("certificate for domain %q already exists in cache before lock attempt", name)
|
||||||
|
mgr.loadAndFinalizeCachedCert(ctx, d, name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go mgr.pollCacheAndCancel(ctx, name, cancel)
|
||||||
|
|
||||||
|
// Acquire lock
|
||||||
mgr.logger.Infof("acquiring cert lock for domain %q", name)
|
mgr.logger.Infof("acquiring cert lock for domain %q", name)
|
||||||
lockStart := time.Now()
|
lockStart := time.Now()
|
||||||
unlock, err := mgr.locker.Lock(ctx, name)
|
unlock, err := mgr.locker.Lock(ctx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mgr.logger.Warnf("acquire cert lock for domain %q, proceeding without lock: %v", name, err)
|
if mgr.certExistsInCache(context.Background(), name) {
|
||||||
|
mgr.logger.Infof("certificate for domain %q appeared in cache while waiting for lock", name)
|
||||||
|
mgr.loadAndFinalizeCachedCert(context.Background(), d, name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mgr.logger.Warnf("acquire cert lock for domain %q: %v", name, err)
|
||||||
|
// Continue without lock
|
||||||
} else {
|
} else {
|
||||||
mgr.logger.Infof("acquired cert lock for domain %q in %s", name, time.Since(lockStart))
|
mgr.logger.Infof("acquired cert lock for domain %q in %s", name, time.Since(lockStart))
|
||||||
defer unlock()
|
defer unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if mgr.certExistsInCache(ctx, name) {
|
||||||
|
mgr.logger.Infof("certificate for domain %q already exists in cache after lock acquisition, skipping ACME request", name)
|
||||||
|
mgr.loadAndFinalizeCachedCert(ctx, d, name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
hello := &tls.ClientHelloInfo{
|
hello := &tls.ClientHelloInfo{
|
||||||
ServerName: name,
|
ServerName: name,
|
||||||
Conn: &dummyConn{ctx: ctx},
|
Conn: &dummyConn{ctx: ctx},
|
||||||
@@ -161,6 +188,10 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if mgr.metrics != nil {
|
||||||
|
mgr.metrics.RecordCertificateIssuance(elapsed)
|
||||||
|
}
|
||||||
|
|
||||||
mgr.setDomainState(d, domainReady, "")
|
mgr.setDomainState(d, domainReady, "")
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -199,6 +230,78 @@ func (mgr *Manager) setDomainState(d domain.Domain, state domainState, errMsg st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// certExistsInCache checks if a certificate exists on the shared disk for the given domain.
|
||||||
|
func (mgr *Manager) certExistsInCache(ctx context.Context, domain string) bool {
|
||||||
|
if mgr.Cache == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, err := mgr.Cache.Get(ctx, domain)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// pollCacheAndCancel periodically checks if a certificate appears in the cache.
|
||||||
|
// If found, it cancels the context to abort the lock wait.
|
||||||
|
func (mgr *Manager) pollCacheAndCancel(ctx context.Context, domain string, cancel context.CancelFunc) {
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if mgr.certExistsInCache(context.Background(), domain) {
|
||||||
|
mgr.logger.Debugf("cert detected in cache for domain %q, cancelling lock wait", domain)
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAndFinalizeCachedCert loads a certificate from cache and updates domain state.
|
||||||
|
func (mgr *Manager) loadAndFinalizeCachedCert(ctx context.Context, d domain.Domain, name string) {
|
||||||
|
hello := &tls.ClientHelloInfo{
|
||||||
|
ServerName: name,
|
||||||
|
Conn: &dummyConn{ctx: ctx},
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := mgr.GetCertificate(hello)
|
||||||
|
if err != nil {
|
||||||
|
mgr.logger.Warnf("load cached certificate for domain %q: %v", name, err)
|
||||||
|
mgr.setDomainState(d, domainFailed, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.setDomainState(d, domainReady, "")
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
if cert != nil && cert.Leaf != nil {
|
||||||
|
leaf := cert.Leaf
|
||||||
|
mgr.logger.Infof("loaded cached certificate for domain %q: serial=%s SANs=%v notBefore=%s, notAfter=%s, now=%s",
|
||||||
|
name,
|
||||||
|
leaf.SerialNumber.Text(16),
|
||||||
|
leaf.DNSNames,
|
||||||
|
leaf.NotBefore.UTC().Format(time.RFC3339),
|
||||||
|
leaf.NotAfter.UTC().Format(time.RFC3339),
|
||||||
|
now.UTC().Format(time.RFC3339),
|
||||||
|
)
|
||||||
|
mgr.logCertificateDetails(name, leaf, now)
|
||||||
|
} else {
|
||||||
|
mgr.logger.Infof("loaded cached certificate for domain %q", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.mu.RLock()
|
||||||
|
info := mgr.domains[d]
|
||||||
|
mgr.mu.RUnlock()
|
||||||
|
|
||||||
|
if info != nil && mgr.certNotifier != nil {
|
||||||
|
if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil {
|
||||||
|
mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// logCertificateDetails logs certificate validity and SCT timestamps.
|
// logCertificateDetails logs certificate validity and SCT timestamps.
|
||||||
func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) {
|
func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) {
|
||||||
if cert.NotBefore.After(now) {
|
if cert.NotBefore.After(now) {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestHostPolicy(t *testing.T) {
|
func TestHostPolicy(t *testing.T) {
|
||||||
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "")
|
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil)
|
||||||
mgr.AddDomain("example.com", "acc1", "rp1")
|
mgr.AddDomain("example.com", "acc1", "rp1")
|
||||||
|
|
||||||
// Wait for the background prefetch goroutine to finish so the temp dir
|
// Wait for the background prefetch goroutine to finish so the temp dir
|
||||||
@@ -70,7 +70,7 @@ func TestHostPolicy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDomainStates(t *testing.T) {
|
func TestDomainStates(t *testing.T) {
|
||||||
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "")
|
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "", nil)
|
||||||
|
|
||||||
assert.Equal(t, 0, mgr.PendingCerts(), "initially zero")
|
assert.Equal(t, 0, mgr.PendingCerts(), "initially zero")
|
||||||
assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains")
|
assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains")
|
||||||
|
|||||||
@@ -1,64 +1,106 @@
|
|||||||
package metrics
|
package metrics
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"go.opentelemetry.io/otel/metric"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
requestsTotal prometheus.Counter
|
ctx context.Context
|
||||||
activeRequests prometheus.Gauge
|
requestsTotal metric.Int64Counter
|
||||||
configuredDomains prometheus.Gauge
|
activeRequests metric.Int64UpDownCounter
|
||||||
pathsPerDomain *prometheus.GaugeVec
|
configuredDomains metric.Int64UpDownCounter
|
||||||
requestDuration *prometheus.HistogramVec
|
totalPaths metric.Int64UpDownCounter
|
||||||
backendDuration *prometheus.HistogramVec
|
requestDuration metric.Int64Histogram
|
||||||
|
backendDuration metric.Int64Histogram
|
||||||
|
certificateIssueDuration metric.Int64Histogram
|
||||||
|
|
||||||
|
mappingsMux sync.Mutex
|
||||||
|
mappingPaths map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(reg prometheus.Registerer) *Metrics {
|
func New(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||||
promFactory := promauto.With(reg)
|
requestsTotal, err := meter.Int64Counter(
|
||||||
return &Metrics{
|
"proxy.http.request.counter",
|
||||||
requestsTotal: promFactory.NewCounter(prometheus.CounterOpts{
|
metric.WithUnit("1"),
|
||||||
Name: "netbird_proxy_requests_total",
|
metric.WithDescription("Total number of requests made to the netbird proxy"),
|
||||||
Help: "Total number of requests made to the netbird proxy",
|
)
|
||||||
}),
|
if err != nil {
|
||||||
activeRequests: promFactory.NewGauge(prometheus.GaugeOpts{
|
return nil, err
|
||||||
Name: "netbird_proxy_active_requests_count",
|
|
||||||
Help: "Current in-flight requests handled by the netbird proxy",
|
|
||||||
}),
|
|
||||||
configuredDomains: promFactory.NewGauge(prometheus.GaugeOpts{
|
|
||||||
Name: "netbird_proxy_domains_count",
|
|
||||||
Help: "Current number of domains configured on the netbird proxy",
|
|
||||||
}),
|
|
||||||
pathsPerDomain: promFactory.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "netbird_proxy_paths_count",
|
|
||||||
Help: "Current number of paths configured on the netbird proxy labelled by domain",
|
|
||||||
},
|
|
||||||
[]string{"domain"},
|
|
||||||
),
|
|
||||||
requestDuration: promFactory.NewHistogramVec(
|
|
||||||
prometheus.HistogramOpts{
|
|
||||||
Name: "netbird_proxy_request_duration_seconds",
|
|
||||||
Help: "Duration of requests made to the netbird proxy",
|
|
||||||
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
|
||||||
},
|
|
||||||
[]string{"status", "size", "method", "host", "path"},
|
|
||||||
),
|
|
||||||
backendDuration: promFactory.NewHistogramVec(prometheus.HistogramOpts{
|
|
||||||
Name: "netbird_proxy_backend_duration_seconds",
|
|
||||||
Help: "Duration of peer round trip time from the netbird proxy",
|
|
||||||
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
|
||||||
},
|
|
||||||
[]string{"status", "size", "method", "host", "path"},
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
activeRequests, err := meter.Int64UpDownCounter(
|
||||||
|
"proxy.http.active_requests",
|
||||||
|
metric.WithUnit("1"),
|
||||||
|
metric.WithDescription("Current in-flight requests handled by the netbird proxy"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
configuredDomains, err := meter.Int64UpDownCounter(
|
||||||
|
"proxy.domains.count",
|
||||||
|
metric.WithUnit("1"),
|
||||||
|
metric.WithDescription("Current number of domains configured on the netbird proxy"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
totalPaths, err := meter.Int64UpDownCounter(
|
||||||
|
"proxy.paths.count",
|
||||||
|
metric.WithUnit("1"),
|
||||||
|
metric.WithDescription("Total number of paths configured on the netbird proxy"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
requestDuration, err := meter.Int64Histogram(
|
||||||
|
"proxy.http.request.duration.ms",
|
||||||
|
metric.WithUnit("milliseconds"),
|
||||||
|
metric.WithDescription("Duration of requests made to the netbird proxy"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
backendDuration, err := meter.Int64Histogram(
|
||||||
|
"proxy.backend.duration.ms",
|
||||||
|
metric.WithUnit("milliseconds"),
|
||||||
|
metric.WithDescription("Duration of peer round trip time from the netbird proxy"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certificateIssueDuration, err := meter.Int64Histogram(
|
||||||
|
"proxy.certificate.issue.duration.ms",
|
||||||
|
metric.WithUnit("milliseconds"),
|
||||||
|
metric.WithDescription("Duration of ACME certificate issuance"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Metrics{
|
||||||
|
ctx: ctx,
|
||||||
|
requestsTotal: requestsTotal,
|
||||||
|
activeRequests: activeRequests,
|
||||||
|
configuredDomains: configuredDomains,
|
||||||
|
totalPaths: totalPaths,
|
||||||
|
requestDuration: requestDuration,
|
||||||
|
backendDuration: backendDuration,
|
||||||
|
certificateIssueDuration: certificateIssueDuration,
|
||||||
|
mappingPaths: make(map[string]int),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type responseInterceptor struct {
|
type responseInterceptor struct {
|
||||||
@@ -80,23 +122,19 @@ func (w *responseInterceptor) Write(b []byte) (int, error) {
|
|||||||
|
|
||||||
func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
m.requestsTotal.Inc()
|
m.requestsTotal.Add(m.ctx, 1)
|
||||||
m.activeRequests.Inc()
|
m.activeRequests.Add(m.ctx, 1)
|
||||||
|
|
||||||
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
|
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
next.ServeHTTP(interceptor, r)
|
defer func() {
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
|
m.activeRequests.Add(m.ctx, -1)
|
||||||
|
m.requestDuration.Record(m.ctx, duration.Milliseconds())
|
||||||
|
}()
|
||||||
|
|
||||||
m.activeRequests.Desc()
|
next.ServeHTTP(interceptor, r)
|
||||||
m.requestDuration.With(prometheus.Labels{
|
|
||||||
"status": strconv.Itoa(interceptor.status),
|
|
||||||
"size": strconv.Itoa(interceptor.size),
|
|
||||||
"method": r.Method,
|
|
||||||
"host": r.Host,
|
|
||||||
"path": r.URL.Path,
|
|
||||||
}).Observe(duration.Seconds())
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,44 +146,52 @@ func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
|||||||
|
|
||||||
func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
|
func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
|
||||||
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
labels := prometheus.Labels{
|
|
||||||
"method": req.Method,
|
|
||||||
"host": req.Host,
|
|
||||||
// Fill potentially empty labels with default values to avoid cardinality issues.
|
|
||||||
"path": "/",
|
|
||||||
"status": "0",
|
|
||||||
"size": "0",
|
|
||||||
}
|
|
||||||
if req.URL != nil {
|
|
||||||
labels["path"] = req.URL.Path
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
res, err := next.RoundTrip(req)
|
res, err := next.RoundTrip(req)
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
|
|
||||||
// Not all labels will be available if there was an error.
|
m.backendDuration.Record(m.ctx, duration.Milliseconds())
|
||||||
if res != nil {
|
|
||||||
labels["status"] = strconv.Itoa(res.StatusCode)
|
|
||||||
labels["size"] = strconv.Itoa(int(res.ContentLength))
|
|
||||||
}
|
|
||||||
|
|
||||||
m.backendDuration.With(labels).Observe(duration.Seconds())
|
|
||||||
|
|
||||||
return res, err
|
return res, err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Metrics) AddMapping(mapping proxy.Mapping) {
|
func (m *Metrics) AddMapping(mapping proxy.Mapping) {
|
||||||
m.configuredDomains.Inc()
|
m.mappingsMux.Lock()
|
||||||
m.pathsPerDomain.With(prometheus.Labels{
|
defer m.mappingsMux.Unlock()
|
||||||
"domain": mapping.Host,
|
|
||||||
}).Set(float64(len(mapping.Paths)))
|
newPathCount := len(mapping.Paths)
|
||||||
|
oldPathCount, exists := m.mappingPaths[mapping.Host]
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
m.configuredDomains.Add(m.ctx, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pathDelta := newPathCount - oldPathCount
|
||||||
|
if pathDelta != 0 {
|
||||||
|
m.totalPaths.Add(m.ctx, int64(pathDelta))
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mappingPaths[mapping.Host] = newPathCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
|
func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
|
||||||
m.configuredDomains.Dec()
|
m.mappingsMux.Lock()
|
||||||
m.pathsPerDomain.With(prometheus.Labels{
|
defer m.mappingsMux.Unlock()
|
||||||
"domain": mapping.Host,
|
|
||||||
}).Set(0)
|
oldPathCount, exists := m.mappingPaths[mapping.Host]
|
||||||
|
if !exists {
|
||||||
|
// Nothing to remove
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.configuredDomains.Add(m.ctx, -1)
|
||||||
|
m.totalPaths.Add(m.ctx, -int64(oldPathCount))
|
||||||
|
|
||||||
|
delete(m.mappingPaths, mapping.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordCertificateIssuance records the duration of a certificate issuance.
|
||||||
|
func (m *Metrics) RecordCertificateIssuance(duration time.Duration) {
|
||||||
|
m.certificateIssueDuration.Record(m.ctx, duration.Milliseconds())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
package metrics_test
|
package metrics_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"go.opentelemetry.io/otel/exporters/prometheus"
|
||||||
|
"go.opentelemetry.io/otel/sdk/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/metrics"
|
"github.com/netbirdio/netbird/proxy/internal/metrics"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type testRoundTripper struct {
|
type testRoundTripper struct {
|
||||||
@@ -47,7 +51,19 @@ func TestMetrics_RoundTripper(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m := metrics.New(prometheus.NewRegistry())
|
exporter, err := prometheus.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create prometheus exporter: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||||
|
pkg := reflect.TypeOf(metrics.Metrics{}).PkgPath()
|
||||||
|
meter := provider.Meter(pkg)
|
||||||
|
|
||||||
|
m, err := metrics.New(context.Background(), meter)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create metrics: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
for name, test := range tests {
|
for name, test := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
|
|||||||
@@ -28,10 +28,12 @@ func BenchmarkServeHTTP(b *testing.B) {
|
|||||||
ID: rand.Text(),
|
ID: rand.Text(),
|
||||||
AccountID: types.AccountID(rand.Text()),
|
AccountID: types.AccountID(rand.Text()),
|
||||||
Host: "app.example.com",
|
Host: "app.example.com",
|
||||||
Paths: map[string]*url.URL{
|
Paths: map[string]*proxy.PathTarget{
|
||||||
"/": {
|
"/": {
|
||||||
Scheme: "http",
|
URL: &url.URL{
|
||||||
Host: "10.0.0.1:8080",
|
Scheme: "http",
|
||||||
|
Host: "10.0.0.1:8080",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -67,10 +69,12 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
|
|||||||
ID: id,
|
ID: id,
|
||||||
AccountID: types.AccountID(rand.Text()),
|
AccountID: types.AccountID(rand.Text()),
|
||||||
Host: host,
|
Host: host,
|
||||||
Paths: map[string]*url.URL{
|
Paths: map[string]*proxy.PathTarget{
|
||||||
"/": {
|
"/": {
|
||||||
Scheme: "http",
|
URL: &url.URL{
|
||||||
Host: "10.0.0.1:8080",
|
Scheme: "http",
|
||||||
|
Host: "10.0.0.1:8080",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -100,15 +104,17 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
|
|||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paths := make(map[string]*url.URL, pathCount)
|
paths := make(map[string]*proxy.PathTarget, pathCount)
|
||||||
for i := range pathCount {
|
for i := range pathCount {
|
||||||
path := "/" + rand.Text()
|
path := "/" + rand.Text()
|
||||||
if int64(i) == targetIndex.Int64() {
|
if int64(i) == targetIndex.Int64() {
|
||||||
target = path
|
target = path
|
||||||
}
|
}
|
||||||
paths[path] = &url.URL{
|
paths[path] = &proxy.PathTarget{
|
||||||
Scheme: "http",
|
URL: &url.URL{
|
||||||
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
|
Scheme: "http",
|
||||||
|
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rp.AddMapping(proxy.Mapping{
|
rp.AddMapping(proxy.Mapping{
|
||||||
|
|||||||
@@ -80,14 +80,30 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
capturedData.SetAccountId(result.accountID)
|
capturedData.SetAccountId(result.accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pt := result.target
|
||||||
|
|
||||||
|
if pt.SkipTLSVerify {
|
||||||
|
ctx = roundtrip.WithSkipTLSVerify(ctx)
|
||||||
|
}
|
||||||
|
if pt.RequestTimeout > 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriteMatchedPath := result.matchedPath
|
||||||
|
if pt.PathRewrite == PathRewritePreserve {
|
||||||
|
rewriteMatchedPath = ""
|
||||||
|
}
|
||||||
|
|
||||||
rp := &httputil.ReverseProxy{
|
rp := &httputil.ReverseProxy{
|
||||||
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
|
Rewrite: p.rewriteFunc(pt.URL, rewriteMatchedPath, result.passHostHeader, pt.PathRewrite, pt.CustomHeaders),
|
||||||
Transport: p.transport,
|
Transport: p.transport,
|
||||||
FlushInterval: -1,
|
FlushInterval: -1,
|
||||||
ErrorHandler: proxyErrorHandler,
|
ErrorHandler: proxyErrorHandler,
|
||||||
}
|
}
|
||||||
if result.rewriteRedirects {
|
if result.rewriteRedirects {
|
||||||
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
|
rp.ModifyResponse = p.rewriteLocationFunc(pt.URL, rewriteMatchedPath, r) //nolint:bodyclose
|
||||||
}
|
}
|
||||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
@@ -97,16 +113,22 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// forwarding headers and stripping proxy authentication credentials.
|
// forwarding headers and stripping proxy authentication credentials.
|
||||||
// When passHostHeader is true, the original client Host header is preserved
|
// When passHostHeader is true, the original client Host header is preserved
|
||||||
// instead of being rewritten to the backend's address.
|
// instead of being rewritten to the backend's address.
|
||||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) {
|
// The pathRewrite parameter controls how the request path is transformed.
|
||||||
|
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool, pathRewrite PathRewriteMode, customHeaders map[string]string) func(r *httputil.ProxyRequest) {
|
||||||
return func(r *httputil.ProxyRequest) {
|
return func(r *httputil.ProxyRequest) {
|
||||||
// Strip the matched path prefix from the incoming request path before
|
switch pathRewrite {
|
||||||
// SetURL joins it with the target's base path, avoiding path duplication.
|
case PathRewritePreserve:
|
||||||
if matchedPath != "" && matchedPath != "/" {
|
// Keep the full original request path as-is.
|
||||||
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
|
default:
|
||||||
if r.Out.URL.Path == "" {
|
if matchedPath != "" && matchedPath != "/" {
|
||||||
r.Out.URL.Path = "/"
|
// Strip the matched path prefix from the incoming request path before
|
||||||
|
// SetURL joins it with the target's base path, avoiding path duplication.
|
||||||
|
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
|
||||||
|
if r.Out.URL.Path == "" {
|
||||||
|
r.Out.URL.Path = "/"
|
||||||
|
}
|
||||||
|
r.Out.URL.RawPath = ""
|
||||||
}
|
}
|
||||||
r.Out.URL.RawPath = ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.SetURL(target)
|
r.SetURL(target)
|
||||||
@@ -116,6 +138,10 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
|
|||||||
r.Out.Host = target.Host
|
r.Out.Host = target.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for k, v := range customHeaders {
|
||||||
|
r.Out.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
clientIP := extractClientIP(r.In.RemoteAddr)
|
clientIP := extractClientIP(r.In.RemoteAddr)
|
||||||
|
|
||||||
if IsTrustedProxy(clientIP, p.trustedProxies) {
|
if IsTrustedProxy(clientIP, p.trustedProxies) {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
|||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
|
|
||||||
t.Run("rewrites host to backend by default", func(t *testing.T) {
|
t.Run("rewrites host to backend by default", func(t *testing.T) {
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -37,7 +37,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
|
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
|
||||||
rewrite := p.rewriteFunc(target, "", true)
|
rewrite := p.rewriteFunc(target, "", true, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -52,7 +52,7 @@ func TestRewriteFunc_HostRewriting(t *testing.T) {
|
|||||||
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
|
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
|
||||||
target, _ := url.Parse("http://backend.internal:8080")
|
target, _ := url.Parse("http://backend.internal:8080")
|
||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
|
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
|
||||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||||
@@ -89,7 +89,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
|
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -99,7 +99,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
|
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -109,7 +109,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
|
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||||
pr.In.TLS = &tls.ConnectionState{}
|
pr.In.TLS = &tls.ConnectionState{}
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
|
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -130,7 +130,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("auto detects https from TLS", func(t *testing.T) {
|
t.Run("auto detects https from TLS", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||||
pr.In.TLS = &tls.ConnectionState{}
|
pr.In.TLS = &tls.ConnectionState{}
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("auto detects http without TLS", func(t *testing.T) {
|
t.Run("auto detects http without TLS", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -151,7 +151,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
|
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "https"}
|
p := &ReverseProxy{forwardedProto: "https"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||||
// No TLS, but forced to https
|
// No TLS, but forced to https
|
||||||
|
|
||||||
@@ -162,7 +162,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("forced http proto", func(t *testing.T) {
|
t.Run("forced http proto", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "http"}
|
p := &ReverseProxy{forwardedProto: "http"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||||
pr.In.TLS = &tls.ConnectionState{}
|
pr.In.TLS = &tls.ConnectionState{}
|
||||||
|
|
||||||
@@ -175,7 +175,7 @@ func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
|||||||
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||||
target, _ := url.Parse("http://backend.internal:8080")
|
target, _ := url.Parse("http://backend.internal:8080")
|
||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
t.Run("strips nb_session cookie", func(t *testing.T) {
|
t.Run("strips nb_session cookie", func(t *testing.T) {
|
||||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||||
@@ -220,7 +220,7 @@ func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
|||||||
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
|
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
|
||||||
target, _ := url.Parse("http://backend.internal:8080")
|
target, _ := url.Parse("http://backend.internal:8080")
|
||||||
p := &ReverseProxy{forwardedProto: "auto"}
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
t.Run("strips session_token query parameter", func(t *testing.T) {
|
t.Run("strips session_token query parameter", func(t *testing.T) {
|
||||||
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
|
||||||
@@ -248,7 +248,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
|
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
|
||||||
target, _ := url.Parse("http://backend.internal:8080/app")
|
target, _ := url.Parse("http://backend.internal:8080/app")
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -261,7 +261,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
|
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
|
||||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||||
rewrite := p.rewriteFunc(target, "/app", false)
|
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -274,7 +274,7 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
|
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
|
||||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||||
rewrite := p.rewriteFunc(target, "/app", false)
|
rewrite := p.rewriteFunc(target, "/app", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -332,7 +332,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
|
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||||
@@ -344,7 +344,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
|
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||||
@@ -357,7 +357,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
|
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
|
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
|
||||||
@@ -370,7 +370,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
|
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
|
||||||
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
|
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
|
||||||
@@ -382,7 +382,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
|
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||||
pr.In.Header.Set("X-Forwarded-Proto", "https")
|
pr.In.Header.Set("X-Forwarded-Proto", "https")
|
||||||
@@ -394,7 +394,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
|
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||||
pr.In.Header.Set("X-Forwarded-Port", "8443")
|
pr.In.Header.Set("X-Forwarded-Port", "8443")
|
||||||
@@ -406,7 +406,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
|
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||||
|
|
||||||
@@ -418,7 +418,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
|
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||||
|
|
||||||
@@ -429,7 +429,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
|
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||||
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
||||||
@@ -454,7 +454,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
|
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||||
@@ -467,7 +467,7 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
|
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
|
||||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||||
rewrite := p.rewriteFunc(target, "", false)
|
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil)
|
||||||
|
|
||||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||||
|
|
||||||
@@ -490,7 +490,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
|||||||
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
|
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
|
||||||
// Management builds: path="/heise", target="https://heise.de:443/heise"
|
// Management builds: path="/heise", target="https://heise.de:443/heise"
|
||||||
target, _ := url.Parse("https://heise.de:443/heise")
|
target, _ := url.Parse("https://heise.de:443/heise")
|
||||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -501,7 +501,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("subpath under prefix also preserved", func(t *testing.T) {
|
t.Run("subpath under prefix also preserved", func(t *testing.T) {
|
||||||
target, _ := url.Parse("https://heise.de:443/heise")
|
target, _ := url.Parse("https://heise.de:443/heise")
|
||||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -513,7 +513,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
|||||||
// What the behavior WOULD be if target URL had no path (true stripping)
|
// What the behavior WOULD be if target URL had no path (true stripping)
|
||||||
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
|
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
|
||||||
target, _ := url.Parse("https://heise.de:443")
|
target, _ := url.Parse("https://heise.de:443")
|
||||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -524,7 +524,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
|
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
|
||||||
target, _ := url.Parse("https://heise.de:443")
|
target, _ := url.Parse("https://heise.de:443")
|
||||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
rewrite := p.rewriteFunc(target, "/heise", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -536,7 +536,7 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
|||||||
// Root path "/" — no stripping expected
|
// Root path "/" — no stripping expected
|
||||||
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
|
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
|
||||||
target, _ := url.Parse("https://backend.example.com:443/")
|
target, _ := url.Parse("https://backend.example.com:443/")
|
||||||
rewrite := p.rewriteFunc(target, "/", false)
|
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
|
||||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||||
|
|
||||||
rewrite(pr)
|
rewrite(pr)
|
||||||
@@ -546,6 +546,82 @@ func TestRewriteFunc_PathForwarding(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRewriteFunc_PreservePath(t *testing.T) {
|
||||||
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
|
target, _ := url.Parse("http://backend.internal:8080")
|
||||||
|
|
||||||
|
t.Run("preserve keeps full request path", func(t *testing.T) {
|
||||||
|
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, nil)
|
||||||
|
pr := newProxyRequest(t, "http://example.com/api/users/123", "1.2.3.4:5000")
|
||||||
|
|
||||||
|
rewrite(pr)
|
||||||
|
|
||||||
|
assert.Equal(t, "/api/users/123", pr.Out.URL.Path,
|
||||||
|
"preserve should keep the full original request path")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserve with root matchedPath", func(t *testing.T) {
|
||||||
|
rewrite := p.rewriteFunc(target, "/", false, PathRewritePreserve, nil)
|
||||||
|
pr := newProxyRequest(t, "http://example.com/anything", "1.2.3.4:5000")
|
||||||
|
|
||||||
|
rewrite(pr)
|
||||||
|
|
||||||
|
assert.Equal(t, "/anything", pr.Out.URL.Path)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteFunc_CustomHeaders(t *testing.T) {
|
||||||
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
|
target, _ := url.Parse("http://backend.internal:8080")
|
||||||
|
|
||||||
|
t.Run("injects custom headers", func(t *testing.T) {
|
||||||
|
headers := map[string]string{
|
||||||
|
"X-Custom-Auth": "token-abc",
|
||||||
|
"X-Env": "production",
|
||||||
|
}
|
||||||
|
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
|
||||||
|
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||||
|
|
||||||
|
rewrite(pr)
|
||||||
|
|
||||||
|
assert.Equal(t, "token-abc", pr.Out.Header.Get("X-Custom-Auth"))
|
||||||
|
assert.Equal(t, "production", pr.Out.Header.Get("X-Env"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil customHeaders is fine", func(t *testing.T) {
|
||||||
|
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, nil)
|
||||||
|
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||||
|
|
||||||
|
rewrite(pr)
|
||||||
|
|
||||||
|
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom headers override existing request headers", func(t *testing.T) {
|
||||||
|
headers := map[string]string{"X-Override": "new-value"}
|
||||||
|
rewrite := p.rewriteFunc(target, "/", false, PathRewriteDefault, headers)
|
||||||
|
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||||
|
pr.In.Header.Set("X-Override", "old-value")
|
||||||
|
|
||||||
|
rewrite(pr)
|
||||||
|
|
||||||
|
assert.Equal(t, "new-value", pr.Out.Header.Get("X-Override"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteFunc_PreservePathWithCustomHeaders(t *testing.T) {
|
||||||
|
p := &ReverseProxy{forwardedProto: "auto"}
|
||||||
|
target, _ := url.Parse("http://backend.internal:8080")
|
||||||
|
|
||||||
|
rewrite := p.rewriteFunc(target, "/api", false, PathRewritePreserve, map[string]string{"X-Via": "proxy"})
|
||||||
|
pr := newProxyRequest(t, "http://example.com/api/deep/path", "1.2.3.4:5000")
|
||||||
|
|
||||||
|
rewrite(pr)
|
||||||
|
|
||||||
|
assert.Equal(t, "/api/deep/path", pr.Out.URL.Path, "preserve should keep the full original path")
|
||||||
|
assert.Equal(t, "proxy", pr.Out.Header.Get("X-Via"), "custom header should be set")
|
||||||
|
}
|
||||||
|
|
||||||
func TestRewriteLocationFunc(t *testing.T) {
|
func TestRewriteLocationFunc(t *testing.T) {
|
||||||
target, _ := url.Parse("http://backend.internal:8080")
|
target, _ := url.Parse("http://backend.internal:8080")
|
||||||
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }
|
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }
|
||||||
|
|||||||
@@ -6,21 +6,41 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PathRewriteMode controls how the request path is rewritten before forwarding.
|
||||||
|
type PathRewriteMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PathRewriteDefault strips the matched prefix and joins with the target path.
|
||||||
|
PathRewriteDefault PathRewriteMode = iota
|
||||||
|
// PathRewritePreserve keeps the full original request path as-is.
|
||||||
|
PathRewritePreserve
|
||||||
|
)
|
||||||
|
|
||||||
|
// PathTarget holds a backend URL and per-target behavioral options.
|
||||||
|
type PathTarget struct {
|
||||||
|
URL *url.URL
|
||||||
|
SkipTLSVerify bool
|
||||||
|
RequestTimeout time.Duration
|
||||||
|
PathRewrite PathRewriteMode
|
||||||
|
CustomHeaders map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
type Mapping struct {
|
type Mapping struct {
|
||||||
ID string
|
ID string
|
||||||
AccountID types.AccountID
|
AccountID types.AccountID
|
||||||
Host string
|
Host string
|
||||||
Paths map[string]*url.URL
|
Paths map[string]*PathTarget
|
||||||
PassHostHeader bool
|
PassHostHeader bool
|
||||||
RewriteRedirects bool
|
RewriteRedirects bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type targetResult struct {
|
type targetResult struct {
|
||||||
url *url.URL
|
target *PathTarget
|
||||||
matchedPath string
|
matchedPath string
|
||||||
serviceID string
|
serviceID string
|
||||||
accountID types.AccountID
|
accountID types.AccountID
|
||||||
@@ -55,10 +75,14 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
|
|||||||
|
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
if strings.HasPrefix(req.URL.Path, path) {
|
if strings.HasPrefix(req.URL.Path, path) {
|
||||||
target := m.Paths[path]
|
pt := m.Paths[path]
|
||||||
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
|
if pt == nil || pt.URL == nil {
|
||||||
|
p.logger.Warnf("invalid mapping for host: %s, path: %s (nil target)", host, path)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, pt.URL)
|
||||||
return targetResult{
|
return targetResult{
|
||||||
url: target,
|
target: pt,
|
||||||
matchedPath: path,
|
matchedPath: path,
|
||||||
serviceID: m.ID,
|
serviceID: m.ID,
|
||||||
accountID: m.AccountID,
|
accountID: m.AccountID,
|
||||||
|
|||||||
32
proxy/internal/roundtrip/context_test.go
Normal file
32
proxy/internal/roundtrip/context_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package roundtrip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountIDContext(t *testing.T) {
|
||||||
|
t.Run("returns empty when missing", func(t *testing.T) {
|
||||||
|
assert.Equal(t, types.AccountID(""), AccountIDFromContext(context.Background()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("round-trips value", func(t *testing.T) {
|
||||||
|
ctx := WithAccountID(context.Background(), "acc-123")
|
||||||
|
assert.Equal(t, types.AccountID("acc-123"), AccountIDFromContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSkipTLSVerifyContext(t *testing.T) {
|
||||||
|
t.Run("false by default", func(t *testing.T) {
|
||||||
|
assert.False(t, skipTLSVerifyFromContext(context.Background()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("true when set", func(t *testing.T) {
|
||||||
|
ctx := WithSkipTLSVerify(context.Background())
|
||||||
|
assert.True(t, skipTLSVerifyFromContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package roundtrip
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -52,9 +53,12 @@ type domainNotification struct {
|
|||||||
type clientEntry struct {
|
type clientEntry struct {
|
||||||
client *embed.Client
|
client *embed.Client
|
||||||
transport *http.Transport
|
transport *http.Transport
|
||||||
domains map[domain.Domain]domainInfo
|
// insecureTransport is a clone of transport with TLS verification disabled,
|
||||||
createdAt time.Time
|
// used when per-target skip_tls_verify is set.
|
||||||
started bool
|
insecureTransport *http.Transport
|
||||||
|
domains map[domain.Domain]domainInfo
|
||||||
|
createdAt time.Time
|
||||||
|
started bool
|
||||||
// Per-backend in-flight limiting keyed by target host:port.
|
// Per-backend in-flight limiting keyed by target host:port.
|
||||||
// TODO: clean up stale entries when backend targets change.
|
// TODO: clean up stale entries when backend targets change.
|
||||||
inflightMu sync.Mutex
|
inflightMu sync.Mutex
|
||||||
@@ -130,6 +134,9 @@ type ClientDebugInfo struct {
|
|||||||
// accountIDContextKey is the context key for storing the account ID.
|
// accountIDContextKey is the context key for storing the account ID.
|
||||||
type accountIDContextKey struct{}
|
type accountIDContextKey struct{}
|
||||||
|
|
||||||
|
// skipTLSVerifyContextKey is the context key for requesting insecure TLS.
|
||||||
|
type skipTLSVerifyContextKey struct{}
|
||||||
|
|
||||||
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||||
// one is created by authenticating with the management server using the provided token.
|
// one is created by authenticating with the management server using the provided token.
|
||||||
// Multiple domains can share the same client.
|
// Multiple domains can share the same client.
|
||||||
@@ -249,27 +256,33 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
|||||||
// Create a transport using the client dialer. We do this instead of using
|
// Create a transport using the client dialer. We do this instead of using
|
||||||
// the client's HTTPClient to avoid issues with request validation that do
|
// the client's HTTPClient to avoid issues with request validation that do
|
||||||
// not work with reverse proxied requests.
|
// not work with reverse proxied requests.
|
||||||
|
transport := &http.Transport{
|
||||||
|
DialContext: client.DialContext,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||||
|
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||||
|
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
||||||
|
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
||||||
|
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
||||||
|
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
||||||
|
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
||||||
|
WriteBufferSize: n.transportCfg.writeBufferSize,
|
||||||
|
ReadBufferSize: n.transportCfg.readBufferSize,
|
||||||
|
DisableCompression: n.transportCfg.disableCompression,
|
||||||
|
}
|
||||||
|
|
||||||
|
insecureTransport := transport.Clone()
|
||||||
|
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec
|
||||||
|
|
||||||
return &clientEntry{
|
return &clientEntry{
|
||||||
client: client,
|
client: client,
|
||||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||||
transport: &http.Transport{
|
transport: transport,
|
||||||
DialContext: client.DialContext,
|
insecureTransport: insecureTransport,
|
||||||
ForceAttemptHTTP2: true,
|
createdAt: time.Now(),
|
||||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
started: false,
|
||||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
inflightMap: make(map[backendKey]chan struct{}),
|
||||||
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
maxInflight: n.transportCfg.maxInflight,
|
||||||
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
|
||||||
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
|
||||||
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
|
||||||
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
|
||||||
WriteBufferSize: n.transportCfg.writeBufferSize,
|
|
||||||
ReadBufferSize: n.transportCfg.readBufferSize,
|
|
||||||
DisableCompression: n.transportCfg.disableCompression,
|
|
||||||
},
|
|
||||||
createdAt: time.Now(),
|
|
||||||
started: false,
|
|
||||||
inflightMap: make(map[backendKey]chan struct{}),
|
|
||||||
maxInflight: n.transportCfg.maxInflight,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -373,6 +386,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
|
|||||||
|
|
||||||
client := entry.client
|
client := entry.client
|
||||||
transport := entry.transport
|
transport := entry.transport
|
||||||
|
insecureTransport := entry.insecureTransport
|
||||||
delete(n.clients, accountID)
|
delete(n.clients, accountID)
|
||||||
n.clientsMux.Unlock()
|
n.clientsMux.Unlock()
|
||||||
|
|
||||||
@@ -387,6 +401,7 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d d
|
|||||||
}
|
}
|
||||||
|
|
||||||
transport.CloseIdleConnections()
|
transport.CloseIdleConnections()
|
||||||
|
insecureTransport.CloseIdleConnections()
|
||||||
|
|
||||||
if err := client.Stop(ctx); err != nil {
|
if err := client.Stop(ctx); err != nil {
|
||||||
n.logger.WithFields(log.Fields{
|
n.logger.WithFields(log.Fields{
|
||||||
@@ -415,6 +430,9 @@ func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
}
|
}
|
||||||
client := entry.client
|
client := entry.client
|
||||||
transport := entry.transport
|
transport := entry.transport
|
||||||
|
if skipTLSVerifyFromContext(req.Context()) {
|
||||||
|
transport = entry.insecureTransport
|
||||||
|
}
|
||||||
n.clientsMux.RUnlock()
|
n.clientsMux.RUnlock()
|
||||||
|
|
||||||
release, ok := entry.acquireInflight(req.URL.Host)
|
release, ok := entry.acquireInflight(req.URL.Host)
|
||||||
@@ -457,6 +475,7 @@ func (n *NetBird) StopAll(ctx context.Context) error {
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for accountID, entry := range n.clients {
|
for accountID, entry := range n.clients {
|
||||||
entry.transport.CloseIdleConnections()
|
entry.transport.CloseIdleConnections()
|
||||||
|
entry.insecureTransport.CloseIdleConnections()
|
||||||
if err := entry.client.Stop(ctx); err != nil {
|
if err := entry.client.Stop(ctx); err != nil {
|
||||||
n.logger.WithFields(log.Fields{
|
n.logger.WithFields(log.Fields{
|
||||||
"account_id": accountID,
|
"account_id": accountID,
|
||||||
@@ -579,3 +598,14 @@ func AccountIDFromContext(ctx context.Context) types.AccountID {
|
|||||||
}
|
}
|
||||||
return accountID
|
return accountID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithSkipTLSVerify marks the context to use an insecure transport that skips
|
||||||
|
// TLS certificate verification for the backend connection.
|
||||||
|
func WithSkipTLSVerify(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, skipTLSVerifyContextKey{}, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipTLSVerifyFromContext(ctx context.Context) bool {
|
||||||
|
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|||||||
@@ -116,6 +116,9 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
|||||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create real users manager
|
// Create real users manager
|
||||||
usersManager := users.NewManager(testStore)
|
usersManager := users.NewManager(testStore)
|
||||||
|
|
||||||
@@ -131,6 +134,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
|||||||
proxyService := nbgrpc.NewProxyServiceServer(
|
proxyService := nbgrpc.NewProxyServiceServer(
|
||||||
&testAccessLogManager{},
|
&testAccessLogManager{},
|
||||||
tokenStore,
|
tokenStore,
|
||||||
|
pkceStore,
|
||||||
oidcConfig,
|
oidcConfig,
|
||||||
nil,
|
nil,
|
||||||
usersManager,
|
usersManager,
|
||||||
|
|||||||
@@ -19,14 +19,17 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
"github.com/pires/go-proxyproto"
|
"github.com/pires/go-proxyproto"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
prometheus2 "github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"go.opentelemetry.io/otel/exporters/prometheus"
|
||||||
|
"go.opentelemetry.io/otel/sdk/metric"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
@@ -42,7 +45,7 @@ import (
|
|||||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/metrics"
|
proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||||
@@ -63,7 +66,7 @@ type Server struct {
|
|||||||
debug *http.Server
|
debug *http.Server
|
||||||
healthServer *health.Server
|
healthServer *health.Server
|
||||||
healthChecker *health.Checker
|
healthChecker *health.Checker
|
||||||
meter *metrics.Metrics
|
meter *proxymetrics.Metrics
|
||||||
|
|
||||||
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
||||||
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||||
@@ -152,8 +155,19 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
|
|||||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||||
s.initDefaults()
|
s.initDefaults()
|
||||||
|
|
||||||
reg := prometheus.NewRegistry()
|
exporter, err := prometheus.New()
|
||||||
s.meter = metrics.New(reg)
|
if err != nil {
|
||||||
|
return fmt.Errorf("create prometheus exporter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||||
|
pkg := reflect.TypeOf(Server{}).PkgPath()
|
||||||
|
meter := provider.Meter(pkg)
|
||||||
|
|
||||||
|
s.meter, err = proxymetrics.New(ctx, meter)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create metrics: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
mgmtConn, err := s.dialManagement()
|
mgmtConn, err := s.dialManagement()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -193,7 +207,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
|||||||
|
|
||||||
s.startDebugEndpoint()
|
s.startDebugEndpoint()
|
||||||
|
|
||||||
if err := s.startHealthServer(reg); err != nil {
|
if err := s.startHealthServer(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,12 +298,12 @@ func (s *Server) startDebugEndpoint() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// startHealthServer launches the health probe and metrics server.
|
// startHealthServer launches the health probe and metrics server.
|
||||||
func (s *Server) startHealthServer(reg *prometheus.Registry) error {
|
func (s *Server) startHealthServer() error {
|
||||||
healthAddr := s.HealthAddress
|
healthAddr := s.HealthAddress
|
||||||
if healthAddr == "" {
|
if healthAddr == "" {
|
||||||
healthAddr = defaultHealthAddr
|
healthAddr = defaultHealthAddr
|
||||||
}
|
}
|
||||||
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
|
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
|
||||||
healthListener, err := net.Listen("tcp", healthAddr)
|
healthListener, err := net.Listen("tcp", healthAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
|
||||||
@@ -423,7 +437,7 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
|
|||||||
"acme_server": s.ACMEDirectory,
|
"acme_server": s.ACMEDirectory,
|
||||||
"challenge_type": s.ACMEChallengeType,
|
"challenge_type": s.ACMEChallengeType,
|
||||||
}).Debug("ACME certificates enabled, configuring certificate manager")
|
}).Debug("ACME certificates enabled, configuring certificate manager")
|
||||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod)
|
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod, s.meter)
|
||||||
|
|
||||||
if s.ACMEChallengeType == "http-01" {
|
if s.ACMEChallengeType == "http-01" {
|
||||||
s.http = &http.Server{
|
s.http = &http.Server{
|
||||||
@@ -720,7 +734,7 @@ func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
||||||
paths := make(map[string]*url.URL)
|
paths := make(map[string]*proxy.PathTarget)
|
||||||
for _, pathMapping := range mapping.GetPath() {
|
for _, pathMapping := range mapping.GetPath() {
|
||||||
targetURL, err := url.Parse(pathMapping.GetTarget())
|
targetURL, err := url.Parse(pathMapping.GetTarget())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -734,7 +748,17 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
|||||||
}).WithError(err).Error("failed to parse target URL for path, skipping")
|
}).WithError(err).Error("failed to parse target URL for path, skipping")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
paths[pathMapping.GetPath()] = targetURL
|
|
||||||
|
pt := &proxy.PathTarget{URL: targetURL}
|
||||||
|
if opts := pathMapping.GetOptions(); opts != nil {
|
||||||
|
pt.SkipTLSVerify = opts.GetSkipTlsVerify()
|
||||||
|
pt.PathRewrite = protoToPathRewrite(opts.GetPathRewrite())
|
||||||
|
pt.CustomHeaders = opts.GetCustomHeaders()
|
||||||
|
if d := opts.GetRequestTimeout(); d != nil {
|
||||||
|
pt.RequestTimeout = d.AsDuration()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
paths[pathMapping.GetPath()] = pt
|
||||||
}
|
}
|
||||||
return proxy.Mapping{
|
return proxy.Mapping{
|
||||||
ID: mapping.GetId(),
|
ID: mapping.GetId(),
|
||||||
@@ -746,6 +770,15 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func protoToPathRewrite(mode proto.PathRewriteMode) proxy.PathRewriteMode {
|
||||||
|
switch mode {
|
||||||
|
case proto.PathRewriteMode_PATH_REWRITE_PRESERVE:
|
||||||
|
return proxy.PathRewritePreserve
|
||||||
|
default:
|
||||||
|
return proxy.PathRewriteDefault
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// debugEndpointAddr returns the address for the debug endpoint.
|
// debugEndpointAddr returns the address for the debug endpoint.
|
||||||
// If addr is empty, it defaults to localhost:8444 for security.
|
// If addr is empty, it defaults to localhost:8444 for security.
|
||||||
func debugEndpointAddr(addr string) string {
|
func debugEndpointAddr(addr string) string {
|
||||||
|
|||||||
271
shared/management/client/rest/reverse_proxy_services_test.go
Normal file
271
shared/management/client/rest/reverse_proxy_services_test.go
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package rest_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testServiceTarget = api.ServiceTarget{
|
||||||
|
TargetId: "peer-123",
|
||||||
|
TargetType: "peer",
|
||||||
|
Protocol: "https",
|
||||||
|
Port: 8443,
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var testService = api.Service{
|
||||||
|
Id: "svc-1",
|
||||||
|
Name: "test-service",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
Auth: api.ServiceAuthConfig{},
|
||||||
|
Meta: api.ServiceMeta{
|
||||||
|
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
Status: "active",
|
||||||
|
},
|
||||||
|
Targets: []api.ServiceTarget{testServiceTarget},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_List_200(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
retBytes, _ := json.Marshal([]api.Service{testService})
|
||||||
|
_, err := w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
ret, err := c.ReverseProxyServices.List(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, ret, 1)
|
||||||
|
assert.Equal(t, testService.Id, ret[0].Id)
|
||||||
|
assert.Equal(t, testService.Name, ret[0].Name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_List_Err(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
|
w.WriteHeader(400)
|
||||||
|
_, err := w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
ret, err := c.ReverseProxyServices.List(context.Background())
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, "No", err.Error())
|
||||||
|
assert.Empty(t, ret)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_Get_200(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
retBytes, _ := json.Marshal(testService)
|
||||||
|
_, err := w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testService.Id, ret.Id)
|
||||||
|
assert.Equal(t, testService.Domain, ret.Domain)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_Get_Err(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 404})
|
||||||
|
w.WriteHeader(404)
|
||||||
|
_, err := w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
ret, err := c.ReverseProxyServices.Get(context.Background(), "svc-1")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, "No", err.Error())
|
||||||
|
assert.Nil(t, ret)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_Create_200(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "POST", r.Method)
|
||||||
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
var req api.ServiceRequest
|
||||||
|
require.NoError(t, json.Unmarshal(reqBytes, &req))
|
||||||
|
assert.Equal(t, "test-service", req.Name)
|
||||||
|
assert.Equal(t, "test.example.com", req.Domain)
|
||||||
|
retBytes, _ := json.Marshal(testService)
|
||||||
|
_, err = w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
|
||||||
|
Name: "test-service",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
Auth: api.ServiceAuthConfig{},
|
||||||
|
Targets: []api.ServiceTarget{testServiceTarget},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testService.Id, ret.Id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_Create_Err(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
|
w.WriteHeader(400)
|
||||||
|
_, err := w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
|
||||||
|
Name: "test-service",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
Auth: api.ServiceAuthConfig{},
|
||||||
|
Targets: []api.ServiceTarget{testServiceTarget},
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, "No", err.Error())
|
||||||
|
assert.Nil(t, ret)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_Create_WithPerTargetOptions(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "POST", r.Method)
|
||||||
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
var req api.ServiceRequest
|
||||||
|
require.NoError(t, json.Unmarshal(reqBytes, &req))
|
||||||
|
|
||||||
|
require.Len(t, req.Targets, 1)
|
||||||
|
target := req.Targets[0]
|
||||||
|
require.NotNil(t, target.Options, "options should be present")
|
||||||
|
opts := target.Options
|
||||||
|
require.NotNil(t, opts.SkipTlsVerify, "skip_tls_verify should be present")
|
||||||
|
assert.True(t, *opts.SkipTlsVerify)
|
||||||
|
require.NotNil(t, opts.RequestTimeout, "request_timeout should be present")
|
||||||
|
assert.Equal(t, "30s", *opts.RequestTimeout)
|
||||||
|
require.NotNil(t, opts.PathRewrite, "path_rewrite should be present")
|
||||||
|
assert.Equal(t, api.ServiceTargetOptionsPathRewrite("preserve"), *opts.PathRewrite)
|
||||||
|
require.NotNil(t, opts.CustomHeaders, "custom_headers should be present")
|
||||||
|
assert.Equal(t, "bar", (*opts.CustomHeaders)["X-Foo"])
|
||||||
|
|
||||||
|
retBytes, _ := json.Marshal(testService)
|
||||||
|
_, err = w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
pathRewrite := api.ServiceTargetOptionsPathRewrite("preserve")
|
||||||
|
ret, err := c.ReverseProxyServices.Create(context.Background(), api.PostApiReverseProxiesServicesJSONRequestBody{
|
||||||
|
Name: "test-service",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
Auth: api.ServiceAuthConfig{},
|
||||||
|
Targets: []api.ServiceTarget{
|
||||||
|
{
|
||||||
|
TargetId: "peer-123",
|
||||||
|
TargetType: "peer",
|
||||||
|
Protocol: "https",
|
||||||
|
Port: 8443,
|
||||||
|
Enabled: true,
|
||||||
|
Options: &api.ServiceTargetOptions{
|
||||||
|
SkipTlsVerify: ptr(true),
|
||||||
|
RequestTimeout: ptr("30s"),
|
||||||
|
PathRewrite: &pathRewrite,
|
||||||
|
CustomHeaders: &map[string]string{"X-Foo": "bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testService.Id, ret.Id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_Update_200(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "PUT", r.Method)
|
||||||
|
reqBytes, err := io.ReadAll(r.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
var req api.ServiceRequest
|
||||||
|
require.NoError(t, json.Unmarshal(reqBytes, &req))
|
||||||
|
assert.Equal(t, "updated-service", req.Name)
|
||||||
|
retBytes, _ := json.Marshal(testService)
|
||||||
|
_, err = w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{
|
||||||
|
Name: "updated-service",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
Auth: api.ServiceAuthConfig{},
|
||||||
|
Targets: []api.ServiceTarget{testServiceTarget},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testService.Id, ret.Id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_Update_Err(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||||
|
w.WriteHeader(400)
|
||||||
|
_, err := w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
ret, err := c.ReverseProxyServices.Update(context.Background(), "svc-1", api.PutApiReverseProxiesServicesServiceIdJSONRequestBody{
|
||||||
|
Name: "updated-service",
|
||||||
|
Domain: "test.example.com",
|
||||||
|
Enabled: true,
|
||||||
|
Auth: api.ServiceAuthConfig{},
|
||||||
|
Targets: []api.ServiceTarget{testServiceTarget},
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, "No", err.Error())
|
||||||
|
assert.Nil(t, ret)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_Delete_200(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "DELETE", r.Method)
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
err := c.ReverseProxyServices.Delete(context.Background(), "svc-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyServices_Delete_Err(t *testing.T) {
|
||||||
|
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/api/reverse-proxies/services/svc-1", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||||
|
w.WriteHeader(404)
|
||||||
|
_, err := w.Write(retBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
err := c.ReverseProxyServices.Delete(context.Background(), "svc-1")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, "Not found", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -2822,6 +2822,16 @@ components:
|
|||||||
type: string
|
type: string
|
||||||
description: "City name from geolocation"
|
description: "City name from geolocation"
|
||||||
example: "San Francisco"
|
example: "San Francisco"
|
||||||
|
bytes_upload:
|
||||||
|
type: integer
|
||||||
|
format: int64
|
||||||
|
description: "Bytes uploaded (request body size)"
|
||||||
|
example: 1024
|
||||||
|
bytes_download:
|
||||||
|
type: integer
|
||||||
|
format: int64
|
||||||
|
description: "Bytes downloaded (response body size)"
|
||||||
|
example: 8192
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- service_id
|
- service_id
|
||||||
@@ -2831,6 +2841,8 @@ components:
|
|||||||
- path
|
- path
|
||||||
- duration_ms
|
- duration_ms
|
||||||
- status_code
|
- status_code
|
||||||
|
- bytes_upload
|
||||||
|
- bytes_download
|
||||||
ProxyAccessLogsResponse:
|
ProxyAccessLogsResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -3027,6 +3039,28 @@ components:
|
|||||||
- targets
|
- targets
|
||||||
- auth
|
- auth
|
||||||
- enabled
|
- enabled
|
||||||
|
ServiceTargetOptions:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
skip_tls_verify:
|
||||||
|
type: boolean
|
||||||
|
description: Skip TLS certificate verification for this backend
|
||||||
|
request_timeout:
|
||||||
|
type: string
|
||||||
|
description: Per-target response timeout as a Go duration string (e.g. "30s", "2m")
|
||||||
|
path_rewrite:
|
||||||
|
type: string
|
||||||
|
description: Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path.
|
||||||
|
enum: [preserve]
|
||||||
|
custom_headers:
|
||||||
|
type: object
|
||||||
|
description: Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected.
|
||||||
|
propertyNames:
|
||||||
|
type: string
|
||||||
|
pattern: '^[!#$%&''*+.^_`|~0-9A-Za-z-]+$'
|
||||||
|
additionalProperties:
|
||||||
|
type: string
|
||||||
|
pattern: '^[^\r\n]*$'
|
||||||
ServiceTarget:
|
ServiceTarget:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -3053,6 +3087,8 @@ components:
|
|||||||
enabled:
|
enabled:
|
||||||
type: boolean
|
type: boolean
|
||||||
description: Whether this target is enabled
|
description: Whether this target is enabled
|
||||||
|
options:
|
||||||
|
$ref: '#/components/schemas/ServiceTargetOptions'
|
||||||
required:
|
required:
|
||||||
- target_id
|
- target_id
|
||||||
- target_type
|
- target_type
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.26.0
|
// protoc-gen-go v1.26.0
|
||||||
// protoc v6.33.3
|
// protoc v6.33.0
|
||||||
// source: management.proto
|
// source: management.proto
|
||||||
|
|
||||||
package proto
|
package proto
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ package management;
|
|||||||
|
|
||||||
option go_package = "/proto";
|
option go_package = "/proto";
|
||||||
|
|
||||||
|
import "google/protobuf/duration.proto";
|
||||||
import "google/protobuf/timestamp.proto";
|
import "google/protobuf/timestamp.proto";
|
||||||
|
|
||||||
// ProxyService - Management is the SERVER, Proxy is the CLIENT
|
// ProxyService - Management is the SERVER, Proxy is the CLIENT
|
||||||
@@ -50,9 +51,22 @@ enum ProxyMappingUpdateType {
|
|||||||
UPDATE_TYPE_REMOVED = 2;
|
UPDATE_TYPE_REMOVED = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum PathRewriteMode {
|
||||||
|
PATH_REWRITE_DEFAULT = 0;
|
||||||
|
PATH_REWRITE_PRESERVE = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PathTargetOptions {
|
||||||
|
bool skip_tls_verify = 1;
|
||||||
|
google.protobuf.Duration request_timeout = 2;
|
||||||
|
PathRewriteMode path_rewrite = 3;
|
||||||
|
map<string, string> custom_headers = 4;
|
||||||
|
}
|
||||||
|
|
||||||
message PathMapping {
|
message PathMapping {
|
||||||
string path = 1;
|
string path = 1;
|
||||||
string target = 2;
|
string target = 2;
|
||||||
|
PathTargetOptions options = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Authentication {
|
message Authentication {
|
||||||
@@ -101,6 +115,8 @@ message AccessLog {
|
|||||||
string auth_mechanism = 11;
|
string auth_mechanism = 11;
|
||||||
string user_id = 12;
|
string user_id = 12;
|
||||||
bool auth_success = 13;
|
bool auth_success = 13;
|
||||||
|
int64 bytes_upload = 14;
|
||||||
|
int64 bytes_download = 15;
|
||||||
}
|
}
|
||||||
|
|
||||||
message AuthenticateRequest {
|
message AuthenticateRequest {
|
||||||
|
|||||||
Reference in New Issue
Block a user