diff --git a/README.md b/README.md index 12375ff..915fbaf 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,17 @@ used with caution, still in development ```go +//default bearer token extractor and parser +//extract token from "Authorization": Bearer , and parse token into jwt.MapClaim +authMiddleware, err := middleware.NewAuthMiddleware(middleware.WithDefaultBearerExtractorAndParser([]byte("secret"))) +if err != nil { + panic(err) +} +http.ListenAndServe( + "localhost:8080", + // Use h2c so we can serve HTTP/2 without TLS. + h2c.NewHandler(authMiddleware.Wrap(mux), &http2.Server{}), +) ``` ## TODO diff --git a/authInterceptor.go b/authInterceptor.go index 9017354..06152fa 100644 --- a/authInterceptor.go +++ b/authInterceptor.go @@ -29,14 +29,13 @@ However, the added generic has no benefit when extract the value from context, u */ type authInterceptor struct { ServiceHandlerType - parser Parser clientHandler ClientTokenGetter serviceHandler *AuthHandler } -type opt func(*authInterceptor) +type authInterceptorOpt func(*authInterceptor) -func NewAuthInterceptor(opts ...opt) (*authInterceptor, error) { +func NewAuthInterceptor(opts ...authInterceptorOpt) (*authInterceptor, error) { i := authInterceptor{ ServiceHandlerType: UnaryHandler, serviceHandler: &AuthHandler{ @@ -70,24 +69,24 @@ func (i *authInterceptor) preventNilServiceHandler() { } } -func WithDefaultBearerExtractor() opt { +func WithInterceptorDefaultBearerExtractor() authInterceptorOpt { return func(i *authInterceptor) { i.preventNilServiceHandler() - i.serviceHandler.Extractor = DefaultBearerTokenExtractor().ToExtractor() + i.serviceHandler.Extractor = DefaultBasicExtractor().ToExtractor() } } -func WithDefaultBearerExtractorAndParser(signningKey any) opt { +func WithInterceptorDefaultBearerExtractorAndParser(signningKey any) authInterceptorOpt { return func(i *authInterceptor) { i.preventNilServiceHandler() - i.parser = DefaultJWTMapClaimsParser(signningKey) - i.serviceHandler.Extractor = DefaultBearerTokenExtractor().ToExtractor() + i.serviceHandler.Parser = DefaultJWTMapClaimsParser(signningKey) + i.serviceHandler.Extractor = DefaultBasicExtractor().ToExtractor() } } -func WithDefaultJWTMapClaimsParser(signningKey any) opt { +func WithInterceptorDefaultJWTMapClaimsParser(signningKey any) authInterceptorOpt { return func(i *authInterceptor) { - i.parser = DefaultJWTMapClaimsParser(signningKey) + i.serviceHandler.Parser = DefaultJWTMapClaimsParser(signningKey) } } @@ -97,14 +96,14 @@ func WithDefaultJWTMapClaimsParser(signningKey any) opt { // func(ctx context.Context) jwt.Claims{ // return &jwt.MapClaims{} // } -func WithCustomJWTClaimsParser(signningKey any, claimsFunc func(context.Context) jwt.Claims) opt { +func WithInterceptorCustomJWTClaimsParser(signningKey any, claimsFunc func(context.Context) jwt.Claims) authInterceptorOpt { return func(i *authInterceptor) { p, _ := NewJWTParser(WithSigningKey(signningKey), WithNewClaimsFunc(claimsFunc)) - i.parser = p.ToParser() + i.serviceHandler.Parser = p.ToParser() } } -func WithIgnoreError() opt { +func WithInterceptorIgnoreError() authInterceptorOpt { return func(i *authInterceptor) { i.preventNilServiceHandler() i.serviceHandler.ErrorHandler = func(context.Context, *Request, error) error { @@ -114,55 +113,55 @@ func WithIgnoreError() opt { } // WithClientTokenGetter sets client token getter when the interceptor in client side -func WithClientTokenGetter(getter ClientTokenGetter) opt { +func WithInterceptorClientTokenGetter(getter ClientTokenGetter) authInterceptorOpt { return func(i *authInterceptor) { i.clientHandler = getter } } // WithUnarySkipper skip the interceptor for unary handler -func WithSkipper(s Skipper) opt { +func WithInterceptorSkipper(s Skipper) authInterceptorOpt { return func(i *authInterceptor) { i.preventNilServiceHandler() i.serviceHandler.Skipper = s } } -func WithBeforeFunc(fn BeforeOrSuccessFunc) opt { +func WithInterceptorBeforeFunc(fn BeforeOrSuccessFunc) authInterceptorOpt { return func(i *authInterceptor) { i.preventNilServiceHandler() i.serviceHandler.BeforeFunc = fn } } -func WithSuccessFunc(fn BeforeOrSuccessFunc) opt { +func WithInterceptorSuccessFunc(fn BeforeOrSuccessFunc) authInterceptorOpt { return func(i *authInterceptor) { i.preventNilServiceHandler() i.serviceHandler.SuccessFunc = fn } } -func WithErrorHandler(fn ErrorHandle) opt { +func WithInterceptorErrorHandler(fn ErrorHandle) authInterceptorOpt { return func(i *authInterceptor) { i.preventNilServiceHandler() i.serviceHandler.ErrorHandler = fn } } -func WithExtractor(fn Extractor) opt { +func WithInterceptorExtractor(fn Extractor) authInterceptorOpt { return func(i *authInterceptor) { i.preventNilServiceHandler() i.serviceHandler.Extractor = fn } } -func WithParser(p Parser) opt { +func WithInterceptorParser(p Parser) authInterceptorOpt { return func(i *authInterceptor) { - i.parser = p + i.serviceHandler.Parser = p } } -func WithServiceHandlerType(s ServiceHandlerType) opt { +func WithServiceHandlerType(s ServiceHandlerType) authInterceptorOpt { return func(i *authInterceptor) { i.ServiceHandlerType = s } diff --git a/authMiddleware.go b/authMiddleware.go index e94e692..01ad618 100644 --- a/authMiddleware.go +++ b/authMiddleware.go @@ -1,10 +1,13 @@ package middleware import ( + "context" "net/http" "strings" "connectrpc.com/connect" + "github.com/cockroachdb/errors" + "github.com/golang-jwt/jwt/v5" ) type authMiddleware struct { @@ -12,11 +15,117 @@ type authMiddleware struct { errW *connect.ErrorWriter } -func NewAuthMiddleware(handler *AuthHandler) *authMiddleware { - return &authMiddleware{ - handler: handler, - //TODO opts - errW: connect.NewErrorWriter(), +func NewAuthMiddleware(opts ...authMiddlewareOpt) (*authMiddleware, error) { + m := authMiddleware{} + for _, o := range opts { + o(&m) + } + if m.handler == nil { + return nil, errors.New("no handler set") + } + if m.errW == nil { + m.errW = connect.NewErrorWriter() + } + return &m, nil +} + +type authMiddlewareOpt func(*authMiddleware) + +func WithErrorWriterOpts(opts ...connect.HandlerOption) authMiddlewareOpt { + return func(m *authMiddleware) { + m.errW = connect.NewErrorWriter(opts...) + } +} + +func (m *authMiddleware) preventNilHandler() { + if m.handler == nil { + m.handler = &AuthHandler{ + Skipper: DefaultSkipper, + } + } +} + +func WithDefaultBearerExtractor() authMiddlewareOpt { + return func(m *authMiddleware) { + m.preventNilHandler() + m.handler.Extractor = DefaultBasicExtractor().ToExtractor() + } +} + +func WithDefaultBearerExtractorAndParser(signningKey any) authMiddlewareOpt { + return func(m *authMiddleware) { + m.preventNilHandler() + m.handler.Extractor = DefaultBearerTokenExtractor().ToExtractor() + m.handler.Parser = DefaultJWTMapClaimsParser(signningKey) + } +} + +func WithDefaultJWTMapClaimsParser(signningKey any) authMiddlewareOpt { + return func(m *authMiddleware) { + m.handler.Parser = DefaultJWTMapClaimsParser(signningKey) + } +} + +// WithCustomJWTClaimsParser sets Parser with signning key and a claimsFunc, the claimsFunc must return a reference +// for example: +// +// func(ctx context.Context) jwt.Claims{ +// return &jwt.MapClaims{} +// } +func WithCustomJWTClaimsParser(signningKey any, claimsFunc func(context.Context) jwt.Claims) authMiddlewareOpt { + return func(m *authMiddleware) { + p, _ := NewJWTParser(WithSigningKey(signningKey), WithNewClaimsFunc(claimsFunc)) + m.handler.Parser = p.ToParser() + } +} + +func WithIgnoreError() authMiddlewareOpt { + return func(m *authMiddleware) { + m.preventNilHandler() + m.handler.ErrorHandler = func(context.Context, *Request, error) error { + return nil + } + } +} + +// WithUnarySkipper skip the interceptor for unary handler +func WithSkipper(s Skipper) authMiddlewareOpt { + return func(m *authMiddleware) { + m.preventNilHandler() + m.handler.Skipper = s + } +} + +func WithBeforeFunc(fn BeforeOrSuccessFunc) authMiddlewareOpt { + return func(m *authMiddleware) { + m.preventNilHandler() + m.handler.BeforeFunc = fn + } +} + +func WithSuccessFunc(fn BeforeOrSuccessFunc) authMiddlewareOpt { + return func(m *authMiddleware) { + m.handler.SuccessFunc = fn + } +} + +func WithErrorHandler(fn ErrorHandle) authMiddlewareOpt { + return func(m *authMiddleware) { + m.preventNilHandler() + m.handler.ErrorHandler = fn + } +} + +func WithExtractor(fn Extractor) authMiddlewareOpt { + return func(m *authMiddleware) { + m.preventNilHandler() + m.handler.Extractor = fn + } +} + +func WithParser(p Parser) authMiddlewareOpt { + return func(m *authMiddleware) { + m.handler.Parser = p } } diff --git a/auth_test.go b/auth_test.go index ee3b8fb..e228256 100644 --- a/auth_test.go +++ b/auth_test.go @@ -32,7 +32,7 @@ var unaryAuthTests = []struct { { Case: "skip", Interceptor: func(t *testing.T) connect.Interceptor { - interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret"))) + interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret"))) assert.Nil(t, err) return interceptor }, @@ -53,7 +53,7 @@ var unaryAuthTests = []struct { { Case: "ignore error", Interceptor: func(t *testing.T) connect.Interceptor { - interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret")), WithIgnoreError()) + interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret")), WithInterceptorIgnoreError()) assert.Nil(t, err) return interceptor }, @@ -74,7 +74,7 @@ var unaryAuthTests = []struct { { Case: "invalid bearer token", Interceptor: func(t *testing.T) connect.Interceptor { - interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret"))) + interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret"))) assert.Nil(t, err) return interceptor }, @@ -96,7 +96,7 @@ var unaryAuthTests = []struct { { Case: "invalid auth header", Interceptor: func(t *testing.T) connect.Interceptor { - interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret"))) + interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret"))) assert.Nil(t, err) return interceptor }, @@ -118,7 +118,7 @@ var unaryAuthTests = []struct { { Case: "invalid signing key", Interceptor: func(t *testing.T) connect.Interceptor { - interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secet"))) + interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secet"))) assert.Nil(t, err) return interceptor }, @@ -140,7 +140,7 @@ var unaryAuthTests = []struct { { Case: "default", Interceptor: func(t *testing.T) connect.Interceptor { - interceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret"))) + interceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret"))) assert.Nil(t, err) return interceptor }, @@ -164,8 +164,8 @@ var unaryAuthTests = []struct { Case: "custom claim", Interceptor: func(t *testing.T) connect.Interceptor { interceptor, err := NewAuthInterceptor( - WithDefaultBearerExtractor(), - WithCustomJWTClaimsParser([]byte("secret"), func(ctx context.Context) jwt.Claims { + WithInterceptorDefaultBearerExtractor(), + WithInterceptorCustomJWTClaimsParser([]byte("secret"), func(ctx context.Context) jwt.Claims { return &jwtCustomClaims{} }), ) @@ -226,8 +226,8 @@ var unaryAuthTests = []struct { return payload, nil } interceptor, err := NewAuthInterceptor( - WithExtractor(extractor.ToExtractor()), - WithParser(parser), + WithInterceptorExtractor(extractor.ToExtractor()), + WithInterceptorParser(parser), ) assert.Nil(t, err) return interceptor diff --git a/e2e_test.go b/e2e_test.go index 42e19a8..99d4b94 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -31,7 +31,7 @@ func newServer(t *testing.T, interceptor connect.Interceptor, validator func(con return memhttptest.New(t, mux) } func TestE2E(t *testing.T) { - authInterceptor, err := NewAuthInterceptor(WithDefaultBearerExtractorAndParser([]byte("secret"))) + authInterceptor, err := NewAuthInterceptor(WithInterceptorDefaultBearerExtractorAndParser([]byte("secret"))) assert.Nil(t, err) s := newServer(t, authInterceptor, func(ctx context.Context) { claims, ok := FromContext[jwt.MapClaims](ctx) diff --git a/example/main.go b/example/main.go index 30c9c39..cf2724b 100644 --- a/example/main.go +++ b/example/main.go @@ -64,17 +64,21 @@ func (s *PingServer) CumSum( } func main() { - auth, err := middleware.NewAuthInterceptor(middleware.WithDefaultBearerExtractorAndParser([]byte("secret"))) + // auth, err := middleware.NewAuthInterceptor(middleware.WithInterceptorDefaultBearerExtractorAndParser([]byte("secret"))) + // if err != nil { + // panic(err) + // } + // interceptors := connect.WithInterceptors(auth) + greeter := &PingServer{pingv1connect.UnimplementedPingServiceHandler{}} + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(greeter)) + authMiddleware, err := middleware.NewAuthMiddleware(middleware.WithDefaultBearerExtractorAndParser([]byte("secret"))) if err != nil { panic(err) } - interceptors := connect.WithInterceptors(auth) - greeter := &PingServer{pingv1connect.UnimplementedPingServiceHandler{}} - mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(greeter, interceptors)) http.ListenAndServe( "localhost:8080", // Use h2c so we can serve HTTP/2 without TLS. - h2c.NewHandler(mux, &http2.Server{}), + h2c.NewHandler(authMiddleware.Wrap(mux), &http2.Server{}), ) } diff --git a/jwtParser.go b/jwtParser.go index 4042deb..180987f 100644 --- a/jwtParser.go +++ b/jwtParser.go @@ -92,7 +92,7 @@ func WithJWTMapClaims(signingKey any) jwtParserOpt { p.NewClaimsFunc = func(context.Context) jwt.Claims { return jwt.MapClaims{} } - p.KeyFunc = p.defaultKeyFuncForSigningKeys + p.KeyFunc = p.defaultKeyFuncForSigningKey } }