From f58a8015d096599a3d4dd98ce60b67c7a3afc438 Mon Sep 17 00:00:00 2001 From: Denys Smirnov Date: Tue, 4 Feb 2025 20:43:42 +0200 Subject: [PATCH] Fix gRPC error propagation. --- errors.go | 278 ++++++++++++-------- internal/test/my_service/my_service.pb.go | 8 +- internal/test/my_service/my_service.proto | 2 +- internal/test/my_service/my_service_test.go | 32 ++- pkg/server/rpc.go | 5 + 5 files changed, 201 insertions(+), 124 deletions(-) diff --git a/errors.go b/errors.go index 9171366..aa7d1aa 100644 --- a/errors.go +++ b/errors.go @@ -56,7 +56,168 @@ func (e ErrorCode) Error() string { return string(e) } +func (e ErrorCode) ToHTTP() int { + switch e { + case OK: + return http.StatusOK + case Unknown, MalformedResponse, Internal, DataLoss: + return http.StatusInternalServerError + case InvalidArgument, MalformedRequest: + return http.StatusBadRequest + case NotFound: + return http.StatusNotFound + case NotAcceptable: + return http.StatusNotAcceptable + case AlreadyExists, Aborted: + return http.StatusConflict + case PermissionDenied: + return http.StatusForbidden + case ResourceExhausted: + return http.StatusTooManyRequests + case FailedPrecondition: + return http.StatusPreconditionFailed + case OutOfRange: + return http.StatusRequestedRangeNotSatisfiable + case Unimplemented: + return http.StatusNotImplemented + case Canceled, DeadlineExceeded, Unavailable: + return http.StatusServiceUnavailable + case Unauthenticated: + return http.StatusUnauthorized + default: + return http.StatusInternalServerError + } +} + +func ErrorCodeFromGRPC(code codes.Code) ErrorCode { + switch code { + case codes.OK: + return OK + case codes.Canceled: + return Canceled + case codes.Unknown: + return Unknown + case codes.InvalidArgument: + return InvalidArgument + case codes.DeadlineExceeded: + return DeadlineExceeded + case codes.NotFound: + return NotFound + case codes.AlreadyExists: + return AlreadyExists + case codes.PermissionDenied: + return PermissionDenied + case codes.ResourceExhausted: + return ResourceExhausted + case codes.FailedPrecondition: + return FailedPrecondition + case codes.Aborted: + return Aborted + case codes.OutOfRange: + return OutOfRange + case codes.Unimplemented: + return Unimplemented + case codes.Internal: + return Internal + case codes.Unavailable: + return Unavailable + case codes.DataLoss: + return DataLoss + case codes.Unauthenticated: + return Unauthenticated + default: + return Unknown + } +} + +func (e ErrorCode) ToGRPC() codes.Code { + switch e { + case OK: + return codes.OK + case Canceled: + return codes.Canceled + case Unknown: + return codes.Unknown + case InvalidArgument, MalformedRequest: + return codes.InvalidArgument + case DeadlineExceeded: + return codes.DeadlineExceeded + case NotFound: + return codes.NotFound + case AlreadyExists: + return codes.AlreadyExists + case PermissionDenied: + return codes.PermissionDenied + case ResourceExhausted: + return codes.ResourceExhausted + case FailedPrecondition: + return codes.FailedPrecondition + case Aborted: + return codes.Aborted + case OutOfRange: + return codes.OutOfRange + case Unimplemented: + return codes.Unimplemented + case MalformedResponse, Internal: + return codes.Internal + case Unavailable: + return codes.Unavailable + case DataLoss: + return codes.DataLoss + case Unauthenticated: + return codes.Unauthenticated + default: + return codes.Unknown + } +} + +func (e ErrorCode) ToTwirp() twirp.ErrorCode { + switch e { + case OK: + return twirp.NoError + case Canceled: + return twirp.Canceled + case Unknown: + return twirp.Unknown + case InvalidArgument: + return twirp.InvalidArgument + case MalformedRequest, MalformedResponse: + return twirp.Malformed + case DeadlineExceeded: + return twirp.DeadlineExceeded + case NotFound: + return twirp.NotFound + case AlreadyExists: + return twirp.AlreadyExists + case PermissionDenied: + return twirp.PermissionDenied + case ResourceExhausted: + return twirp.ResourceExhausted + case FailedPrecondition: + return twirp.FailedPrecondition + case Aborted: + return twirp.Aborted + case OutOfRange: + return twirp.OutOfRange + case Unimplemented: + return twirp.Unimplemented + case Internal: + return twirp.Internal + case Unavailable: + return twirp.Unavailable + case DataLoss: + return twirp.DataLoss + case Unauthenticated: + return twirp.Unauthenticated + default: + return twirp.Unknown + } +} + func NewError(code ErrorCode, err error, details ...proto.Message) Error { + if err == nil { + panic("error is nil") + } var protoDetails []*anypb.Any for _, e := range details { if p, err := anypb.New(e); err == nil { @@ -144,36 +305,7 @@ func (e psrpcError) Code() ErrorCode { } func (e psrpcError) ToHttp() int { - switch e.code { - case OK: - return http.StatusOK - case Unknown, MalformedResponse, Internal, DataLoss: - return http.StatusInternalServerError - case InvalidArgument, MalformedRequest: - return http.StatusBadRequest - case NotFound: - return http.StatusNotFound - case NotAcceptable: - return http.StatusNotAcceptable - case AlreadyExists, Aborted: - return http.StatusConflict - case PermissionDenied: - return http.StatusForbidden - case ResourceExhausted: - return http.StatusTooManyRequests - case FailedPrecondition: - return http.StatusPreconditionFailed - case OutOfRange: - return http.StatusRequestedRangeNotSatisfiable - case Unimplemented: - return http.StatusNotImplemented - case Canceled, DeadlineExceeded, Unavailable: - return http.StatusServiceUnavailable - case Unauthenticated: - return http.StatusUnauthorized - default: - return http.StatusInternalServerError - } + return e.code.ToHTTP() } func (e psrpcError) DetailsProto() []*anypb.Any { @@ -185,97 +317,15 @@ func (e psrpcError) Details() []any { } func (e psrpcError) GRPCStatus() *status.Status { - var c codes.Code - switch e.code { - case OK: - c = codes.OK - case Canceled: - c = codes.Canceled - case Unknown: - c = codes.Unknown - case InvalidArgument, MalformedRequest: - c = codes.InvalidArgument - case DeadlineExceeded: - c = codes.DeadlineExceeded - case NotFound: - c = codes.NotFound - case AlreadyExists: - c = codes.AlreadyExists - case PermissionDenied: - c = codes.PermissionDenied - case ResourceExhausted: - c = codes.ResourceExhausted - case FailedPrecondition: - c = codes.FailedPrecondition - case Aborted: - c = codes.Aborted - case OutOfRange: - c = codes.OutOfRange - case Unimplemented: - c = codes.Unimplemented - case MalformedResponse, Internal: - c = codes.Internal - case Unavailable: - c = codes.Unavailable - case DataLoss: - c = codes.DataLoss - case Unauthenticated: - c = codes.Unauthenticated - default: - c = codes.Unknown - } - return status.FromProto(&spb.Status{ - Code: int32(c), + Code: int32(e.code.ToGRPC()), Message: e.Error(), Details: e.details, }) } func (e psrpcError) toTwirp() twirp.Error { - var c twirp.ErrorCode - switch e.code { - case OK: - c = twirp.NoError - case Canceled: - c = twirp.Canceled - case Unknown: - c = twirp.Unknown - case InvalidArgument: - c = twirp.InvalidArgument - case MalformedRequest, MalformedResponse: - c = twirp.Malformed - case DeadlineExceeded: - c = twirp.DeadlineExceeded - case NotFound: - c = twirp.NotFound - case AlreadyExists: - c = twirp.AlreadyExists - case PermissionDenied: - c = twirp.PermissionDenied - case ResourceExhausted: - c = twirp.ResourceExhausted - case FailedPrecondition: - c = twirp.FailedPrecondition - case Aborted: - c = twirp.Aborted - case OutOfRange: - c = twirp.OutOfRange - case Unimplemented: - c = twirp.Unimplemented - case Internal: - c = twirp.Internal - case Unavailable: - c = twirp.Unavailable - case DataLoss: - c = twirp.DataLoss - case Unauthenticated: - c = twirp.Unauthenticated - default: - c = twirp.Unknown - } - - return twirp.NewError(c, e.Error()) + return twirp.NewError(e.code.ToTwirp(), e.Error()) } func (e psrpcError) As(target any) bool { diff --git a/internal/test/my_service/my_service.pb.go b/internal/test/my_service/my_service.pb.go index f5dcbc8..90cef4a 100644 --- a/internal/test/my_service/my_service.pb.go +++ b/internal/test/my_service/my_service.pb.go @@ -74,7 +74,7 @@ func (*Ignored) Descriptor() ([]byte, []int) { type MyRequest struct { state protoimpl.MessageState `protogen:"open.v1"` - ReturnError bool `protobuf:"varint,1,opt,name=return_error,json=returnError,proto3" json:"return_error,omitempty"` + ReturnError int32 `protobuf:"varint,1,opt,name=return_error,json=returnError,proto3" json:"return_error,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -109,11 +109,11 @@ func (*MyRequest) Descriptor() ([]byte, []int) { return file_my_service_proto_rawDescGZIP(), []int{1} } -func (x *MyRequest) GetReturnError() bool { +func (x *MyRequest) GetReturnError() int32 { if x != nil { return x.ReturnError } - return false + return 0 } type MyResponse struct { @@ -270,7 +270,7 @@ var file_my_service_proto_rawDesc = string([]byte{ 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x09, 0x0a, 0x07, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x64, 0x22, 0x2e, 0x0a, 0x09, 0x4d, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x74, 0x75, 0x72, 0x6e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0b, 0x72, 0x65, 0x74, 0x75, 0x72, 0x6e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x22, + 0x28, 0x05, 0x52, 0x0b, 0x72, 0x65, 0x74, 0x75, 0x72, 0x6e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x0c, 0x0a, 0x0a, 0x4d, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0a, 0x0a, 0x08, 0x4d, 0x79, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, 0x11, 0x0a, 0x0f, 0x4d, 0x79, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x11, 0x0a, 0x0f, diff --git a/internal/test/my_service/my_service.proto b/internal/test/my_service/my_service.proto index 2abd19b..d03ab50 100644 --- a/internal/test/my_service/my_service.proto +++ b/internal/test/my_service/my_service.proto @@ -64,7 +64,7 @@ service MyService { message Ignored {} message MyRequest { - bool return_error = 1; + int32 return_error = 1; } message MyResponse {} message MyUpdate {} diff --git a/internal/test/my_service/my_service_test.go b/internal/test/my_service/my_service_test.go index ee04b3c..89fb926 100644 --- a/internal/test/my_service/my_service_test.go +++ b/internal/test/my_service/my_service_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "github.com/livekit/psrpc" @@ -167,8 +168,8 @@ func testGeneratedService(t *testing.T, bus func(t testing.TB) bus.MessageBus) { require.NoError(t, subA.Close()) require.NoError(t, subB.Close()) - // rpc NormalRPC(MyRequest) returns (MyResponse); - _, err = cA.NormalRPC(ctx, &MyRequest{ReturnError: true}) + // Test PSRPC error propagation (with error details) + _, err = cA.NormalRPC(ctx, &MyRequest{ReturnError: 1}) require.Error(t, err) e, ok := err.(psrpc.Error) require.True(t, ok) @@ -176,13 +177,24 @@ func testGeneratedService(t *testing.T, bus func(t testing.TB) bus.MessageBus) { require.Equal(t, "requested error", e.Error()) details := e.Details() require.Equal(t, 1, len(details)) - require.True(t, proto.Equal(&MyRequest{ReturnError: true}, details[0].(proto.Message))) + require.True(t, proto.Equal(&MyRequest{ReturnError: 1}, details[0].(proto.Message))) st := e.GRPCStatus() require.Equal(t, codes.FailedPrecondition, st.Code()) require.Equal(t, "requested error", st.Message()) details = st.Details() require.Equal(t, 1, len(details)) - require.True(t, proto.Equal(&MyRequest{ReturnError: true}, details[0].(proto.Message))) + require.True(t, proto.Equal(&MyRequest{ReturnError: 1}, details[0].(proto.Message))) + + // Test gRPC error propagation (with error details) + _, err = cA.NormalRPC(ctx, &MyRequest{ReturnError: 2}) + require.Error(t, err) + st, ok = status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.FailedPrecondition, st.Code()) + require.Equal(t, "requested error", st.Message()) + details = st.Details() + require.Equal(t, 1, len(details)) + require.True(t, proto.Equal(&MyRequest{ReturnError: 2}, details[0].(proto.Message))) } func requireTwo(t *testing.T, subA, subB psrpc.Subscription[*MyUpdate]) { @@ -234,8 +246,18 @@ type MyService struct { } func (s *MyService) NormalRPC(_ context.Context, req *MyRequest) (*MyResponse, error) { - if req.ReturnError { + switch req.ReturnError { + case 1: return nil, psrpc.NewError(psrpc.FailedPrecondition, errors.New("requested error"), req) + case 2: + st := status.New(codes.FailedPrecondition, "requested error") + st, err := st.WithDetails(req) + if err != nil { + panic(err) + } + return nil, st.Err() + case 0: + // none } s.Lock() s.counts["NormalRPC"]++ diff --git a/pkg/server/rpc.go b/pkg/server/rpc.go index b7c6257..901c1db 100644 --- a/pkg/server/rpc.go +++ b/pkg/server/rpc.go @@ -20,6 +20,7 @@ import ( "sync" "time" + "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "github.com/livekit/psrpc" @@ -263,6 +264,10 @@ func (h *rpcHandlerImpl[RequestType, ResponseType]) sendResponse( res.Error = e.Error() res.Code = string(e.Code()) res.ErrorDetails = append(res.ErrorDetails, e.DetailsProto()...) + } else if st, ok := status.FromError(err); ok { + res.Error = st.Message() + res.Code = string(psrpc.ErrorCodeFromGRPC(st.Code())) + res.ErrorDetails = append(res.ErrorDetails, st.Proto().Details...) } else { res.Error = err.Error() res.Code = string(psrpc.Unknown)