Skip to content

Commit

Permalink
auth/clientcredentials: add HTTP middleware (#74)
Browse files Browse the repository at this point in the history
For now, this is a super-direct port of
https://sourcegraph.sourcegraph.com/github.com/sourcegraph/sourcegraph/-/blob/internal/monolithsams/auth.go
- the only major change is using our modern SDK types, and not
error-logging stuff that's user-input error. I'm going to use this for
sourcegraph/sourcegraph#1941

There were no unit tests to copy over sadly, so I'm omitting them for
now 😁

## Test plan

CI
  • Loading branch information
bobheadxi authored Nov 27, 2024
1 parent ab452c9 commit 7742c23
Show file tree
Hide file tree
Showing 6 changed files with 474 additions and 289 deletions.
193 changes: 0 additions & 193 deletions auth/clientcredentials/clientcredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,9 @@ import (
"net/http"
"strings"

"connectrpc.com/connect"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
otelcodes "go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoimpl"

"github.com/sourcegraph/log"
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go"
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

Expand All @@ -31,172 +23,6 @@ type TokenIntrospector interface {
IntrospectToken(ctx context.Context, token string) (*sams.IntrospectTokenResponse, error)
}

// See clientcredentials.NewInterceptor.
type Interceptor struct {
logger log.Logger
introspector TokenIntrospector
extension *protoimpl.ExtensionInfo
}

// NewInterceptor creates a serverside handler interceptor that ensures every
// incoming request has a valid client credential token with the required scopes
// indicated in the RPC method options. When used, required scopes CANNOT be
// empty - if no scopes are required, declare a separate service that does not
// use this interceptor.
//
// To declare required SAMS scopes in your RPC, add the following to your proto
// schema:
//
// extend google.protobuf.MethodOptions {
// // The SAMS scopes required to use this RPC.
// //
// // The range 50000-99999 is reserved for internal use within individual organizations
// // so you can use numbers in this range freely for in-house applications.
// repeated string sams_required_scopes = 50001;
// }
//
// In your RPCs, add the `(sams_required_scopes)` option as a comma-delimited
// list:
//
// rpc GetUserRoles(GetUserRolesRequest) returns (GetUserRolesResponse) {
// option (sams_required_scopes) = "sams::user.roles::read";
// };
//
// This will generate a variable called `E_SamsRequiredScopes` in your generated
// proto bindings. This variable should be provided to NewInterceptor to allow
// it to identify where to source the required scopes from.
//
// The provided logger is used to record internal-server errors.
func NewInterceptor(
logger log.Logger,
introspector TokenIntrospector,
methodOptionsRequiredScopesExtension *protoimpl.ExtensionInfo,
) *Interceptor {
return &Interceptor{
logger: logger.Scoped("clientcredentials"),
introspector: introspector,
extension: methodOptionsRequiredScopesExtension,
}
}

var _ connect.Interceptor = (*Interceptor)(nil)

func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if req.Spec().IsClient {
return next(ctx, req) // no-op for clients
}
requiredScopes, err := extractSchemaRequiredScopes(req.Spec(), i.extension)
if err != nil {
return nil, internalError(ctx, i.logger, err, "internal schema error") // invalid schema is internal error
}
info, err := i.requireScope(ctx, req.Header(), requiredScopes)
if err != nil {
return nil, err
}
return next(WithClientInfo(ctx, info), req)
}
}

func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
return next(ctx, spec) // no-op for clients
}
}

func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
if conn.Spec().IsClient {
return next(ctx, conn) // no-op for clients
}
requiredScopes, err := extractSchemaRequiredScopes(conn.Spec(), i.extension)
if err != nil {
return internalError(ctx, i.logger, err, "internal schema error") // invalid schema is internal error
}
info, err := i.requireScope(ctx, conn.RequestHeader(), requiredScopes)
if err != nil {
return err
}
return next(WithClientInfo(ctx, info), conn)
}
}

// RequireScope ensures the request context has a valid SAMS M2M token
// with requiredScope. It returns a ConnectRPC status error suitable to be
// returned directly from a ConnectRPC implementation.
func (i *Interceptor) requireScope(ctx context.Context, headers http.Header, requiredScopes scopes.Scopes) (_ *ClientInfo, err error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "clientcredentials.requireScope")
defer func() {
if err != nil {
span.RecordError(err)
span.SetStatus(otelcodes.Error, "check failed")
}
span.End()
}()

token, err := extractBearerContents(headers)
if err != nil {
return nil, connect.NewError(connect.CodeUnauthenticated,
errors.Wrap(err, "invalid authorization header"))
}

result, err := i.introspector.IntrospectToken(ctx, token)
if err != nil {
return nil, internalError(ctx, i.logger, err, "unable to validate token")
}
span.SetAttributes(
attribute.String("client_id", result.ClientID),
attribute.String("token_expires_at", result.ExpiresAt.String()),
attribute.StringSlice("token_scopes", scopes.ToStrings(result.Scopes)))
info := &ClientInfo{
ClientID: result.ClientID,
TokenExpiresAt: result.ExpiresAt,
TokenScopes: result.Scopes,
}

// Active encapsulates whether the token is active, including expiration.
if !result.Active {
// Record detailed error in span, and return an opaque one
span.SetAttributes(attribute.String("full_error", "inactive token"))
return info, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}

// Check for our required scope.
for _, required := range requiredScopes {
if !result.Scopes.Match(required) {
err = errors.Newf("got scopes %+v, required: %+v", result.Scopes, requiredScopes)
span.SetAttributes(attribute.String("full_error", err.Error()))
return info, connect.NewError(connect.CodePermissionDenied,
errors.Wrap(err, "insufficient scopes"))
}
}

return info, nil
}

func extractSchemaRequiredScopes(spec connect.Spec, extension *protoimpl.ExtensionInfo) (scopes.Scopes, error) {
method, ok := spec.Schema.(protoreflect.MethodDescriptor)
if !ok {
return nil, errors.Newf("expected protoreflect.MethodDescriptor, got %T", spec.Schema)
}

value := method.Options().ProtoReflect().Get(extension.TypeDescriptor())
if !value.IsValid() {
return nil, errors.Newf("extension field %s not valid", extension.TypeDescriptor().FullName())
}
list := value.List()
if list.Len() == 0 {
return nil, errors.Newf("extension field %s cannot be empty", extension.TypeDescriptor().FullName())
}

requiredScopes := make(scopes.Scopes, list.Len())
for i := 0; i < list.Len(); i++ {
requiredScopes[i] = scopes.Scope(list.Get(i).String())
}
return requiredScopes, nil
}

func extractBearerContents(h http.Header) (string, error) {
authHeader := h.Get("Authorization")
if authHeader == "" {
Expand All @@ -211,22 +37,3 @@ func extractBearerContents(h http.Header) (string, error) {
}
return typ[1], nil
}

// internalError logs an error, adds it to the trace, and returns a connect
// error with a safe message.
func internalError(ctx context.Context, logger log.Logger, err error, safeMsg string) error {
trace.SpanFromContext(ctx).
SetAttributes(
attribute.String("full_error", err.Error()),
)
logger.WithTrace(log.TraceContext{
TraceID: trace.SpanContextFromContext(ctx).TraceID().String(),
SpanID: trace.SpanContextFromContext(ctx).SpanID().String(),
}).
AddCallerSkip(1).
Error(safeMsg,
log.String("code", connect.CodeInternal.String()),
log.Error(err),
)
return connect.NewError(connect.CodeInternal, errors.New(safeMsg))
}
96 changes: 0 additions & 96 deletions auth/clientcredentials/clientcredentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,8 @@ package clientcredentials

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"connectrpc.com/connect"
"github.com/hexops/autogold/v2"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"

"github.com/sourcegraph/log/logtest"
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go"
clientsv1 "github.com/sourcegraph/sourcegraph-accounts-sdk-go/clients/v1"
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/clients/v1/clientsv1connect"
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

type mockTokenIntrospector struct {
Expand All @@ -26,86 +13,3 @@ type mockTokenIntrospector struct {
func (m *mockTokenIntrospector) IntrospectToken(ctx context.Context, token string) (*sams.IntrospectTokenResponse, error) {
return m.response, nil
}

func TestInterceptor(t *testing.T) {
// All tests based on UsersService.GetUser()
for _, tc := range []struct {
name string
token *sams.IntrospectTokenResponse
wantError autogold.Value
wantLogs autogold.Value
}{{
name: "inactive token",
token: &sams.IntrospectTokenResponse{
Active: false,
},
wantError: autogold.Expect("permission_denied: permission denied"),
wantLogs: autogold.Expect([]string{}),
}, {
name: "insufficient scopes",
token: &sams.IntrospectTokenResponse{
Active: true,
},
wantError: autogold.Expect("permission_denied: insufficient scopes: got scopes [], required: [profile]"),
wantLogs: autogold.Expect([]string{}),
}, {
name: "matches required scope",
token: &sams.IntrospectTokenResponse{
Active: true,
Scopes: scopes.Scopes{"profile"},
},
wantError: autogold.Expect(nil), // should not error!
wantLogs: autogold.Expect([]string{}),
}, {
name: "wrong scope",
token: &sams.IntrospectTokenResponse{
Active: true,
Scopes: scopes.Scopes{"not-a-scope"},
},
wantError: autogold.Expect("permission_denied: insufficient scopes: got scopes [not-a-scope], required: [profile]"),
wantLogs: autogold.Expect([]string{}),
}} {
t.Run(tc.name, func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
interceptor := NewInterceptor(
logger,
&mockTokenIntrospector{
response: tc.token,
},
clientsv1.E_SamsRequiredScopes,
)
mux := http.NewServeMux()
mux.Handle(
clientsv1connect.NewUsersServiceHandler(clientsv1connect.UnimplementedUsersServiceHandler{},
connect.WithInterceptors(interceptor)),
)
srv := httptest.NewServer(mux)
c := clientsv1connect.NewUsersServiceClient(
oauth2.NewClient(
context.Background(),
oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: "foobar",
TokenType: "bearer",
}),
),
srv.URL)
_, err := c.GetUser(context.Background(), connect.NewRequest(&clientsv1.GetUserRequest{}))

// Success cases are connect.CodeUnimplemented
require.Error(t, err)

var connectErr *connect.Error
if errors.As(err, &connectErr) {
if connectErr.Code() == connect.CodeUnimplemented {
tc.wantError.Equal(t, nil) // should not expect an error
} else {
tc.wantError.Equal(t, err.Error())
}
} else {
t.Errorf("error %q is not *connect.Error", err.Error())
}

tc.wantLogs.Equal(t, exportLogs().Messages())
})
}
}
Loading

0 comments on commit 7742c23

Please sign in to comment.