diff --git a/internal/outpost/radius/request.go b/internal/outpost/radius/request.go index 962c929832..bdf20cbcf6 100644 --- a/internal/outpost/radius/request.go +++ b/internal/outpost/radius/request.go @@ -1,7 +1,6 @@ package radius import ( - "bytes" "crypto/hmac" "crypto/md5" "errors" @@ -40,13 +39,18 @@ func (r *RadiusRequest) ID() string { func (r *RadiusRequest) validateMessageAuthenticator() error { mauth := rfc2869.MessageAuthenticator_Get(r.Packet) - hash := hmac.New(md5.New, r.Secret) + // Per RFC 2869 ยง5.14, the Message-Authenticator field must be treated as + // 16 zero bytes when computing the HMAC-MD5 for verification. + _ = rfc2869.MessageAuthenticator_Set(r.Packet, make([]byte, 16)) + hash := hmac.New(md5.New, r.pi.SharedSecret) encode, err := r.MarshalBinary() + // Restore the original value regardless of whether marshaling succeeded. + _ = rfc2869.MessageAuthenticator_Set(r.Packet, mauth) if err != nil { return err } hash.Write(encode) - if bytes.Equal(mauth, hash.Sum(nil)) { + if !hmac.Equal(mauth, hash.Sum(nil)) { return ErrInvalidMessageAuthenticator } return nil @@ -54,7 +58,7 @@ func (r *RadiusRequest) validateMessageAuthenticator() error { func (r *RadiusRequest) setMessageAuthenticator(rp *radius.Packet) error { _ = rfc2869.MessageAuthenticator_Set(rp, make([]byte, 16)) - hash := hmac.New(md5.New, rp.Secret) + hash := hmac.New(md5.New, r.pi.SharedSecret) encode, err := rp.MarshalBinary() if err != nil { return err diff --git a/internal/outpost/radius/request_test.go b/internal/outpost/radius/request_test.go new file mode 100644 index 0000000000..744da04e9c --- /dev/null +++ b/internal/outpost/radius/request_test.go @@ -0,0 +1,63 @@ +package radius + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "layeh.com/radius" +) + +var ( + radiusPacketAccReq = []byte{0x1, 0x8f, 0x0, 0x4d, 0x4a, 0xd5, 0x47, 0x98, 0xbf, 0x18, 0xe, 0x4b, 0x6a, 0xdd, 0x0, 0xc7, 0x99, 0xb4, 0xa6, 0x57, 0x50, 0x12, 0xa5, 0xf7, 0x16, 0x88, 0xc5, 0xd8, 0xd9, 0xec, 0x19, 0xc8, 0x51, 0x47, 0x9, 0x5f, 0xe5, 0x60, 0x1, 0x9, 0x61, 0x6b, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2, 0x12, 0x37, 0x36, 0x8, 0xa3, 0x72, 0x20, 0xf, 0xf4, 0xc0, 0xc, 0xd2, 0x40, 0xc1, 0xc3, 0x3f, 0xef, 0x4, 0x6, 0xa, 0x78, 0x14, 0x4c, 0x5, 0x6, 0x0, 0x0, 0x0, 0xa} +) + +func Test_Request_validateMessageAuthenticator_valid(t *testing.T) { + p, err := radius.Parse(radiusPacketAccReq, []byte("foo")) + assert.NoError(t, err) + req := RadiusRequest{ + Request: &radius.Request{ + Packet: p, + }, + pi: &ProviderInstance{ + SharedSecret: []byte("foo"), + }, + } + assert.NoError(t, req.validateMessageAuthenticator()) +} + +func Test_Request_validateMessageAuthenticator_invalid(t *testing.T) { + p, err := radius.Parse(radiusPacketAccReq, []byte("bar")) + assert.NoError(t, err) + req := RadiusRequest{ + Request: &radius.Request{ + Packet: p, + }, + pi: &ProviderInstance{ + SharedSecret: []byte("bar"), + }, + } + assert.Error(t, req.validateMessageAuthenticator(), ErrInvalidMessageAuthenticator) +} + +func Test_Request_setMessageAuthenticator(t *testing.T) { + p, err := radius.Parse(radiusPacketAccReq, []byte("foo")) + assert.NoError(t, err) + req := RadiusRequest{ + Request: &radius.Request{ + Packet: p, + }, + pi: &ProviderInstance{ + SharedSecret: []byte("foo"), + }, + } + res := p.Response(radius.CodeAccessAccept) + assert.NoError(t, req.setMessageAuthenticator(res)) + + nr := RadiusRequest{ + Request: &radius.Request{ + Packet: res, + }, + pi: req.pi, + } + assert.NoError(t, nr.validateMessageAuthenticator()) +}