diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServicePlugin.java b/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServicePlugin.java index 88ee2331f33..2ed9a722516 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServicePlugin.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServicePlugin.java @@ -295,8 +295,10 @@ private static ServiceInfo buildHttpServiceInfo(String serviceName, builder = FieldInfo.builder(paramName, toTypeSignature(parameter)); } - fieldInfosBuilder.add(builder.requirement(FieldRequirement.REQUIRED) - .location(fieldLocation) + builder.requirement(parameter.isRequired() ? FieldRequirement.REQUIRED + : FieldRequirement.OPTIONAL); + + fieldInfosBuilder.add(builder.location(fieldLocation) .build()); } }); diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/HttpEndpointSpecification.java b/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/HttpEndpointSpecification.java index 06ffb286f53..5c2031f18e0 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/HttpEndpointSpecification.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/server/grpc/HttpEndpointSpecification.java @@ -180,10 +180,12 @@ public String toString() { public static class Parameter { private final JavaType type; private final boolean isRepeated; + private final boolean isRequired; - public Parameter(JavaType type, boolean isRepeated) { + public Parameter(JavaType type, boolean isRepeated, boolean isRequired) { this.type = requireNonNull(type, "type"); this.isRepeated = isRepeated; + this.isRequired = isRequired; } public JavaType type() { @@ -194,6 +196,10 @@ public boolean isRepeated() { return isRepeated; } + public boolean isRequired() { + return isRequired; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java index 107f5aa0b44..209b4a5cd19 100644 --- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java +++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java @@ -45,6 +45,8 @@ import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.api.AnnotationsProto; +import com.google.api.FieldBehavior; +import com.google.api.FieldBehaviorProto; import com.google.api.HttpBody; import com.google.api.HttpRule; import com.google.common.annotations.VisibleForTesting; @@ -57,6 +59,7 @@ import com.google.protobuf.Any; import com.google.protobuf.BoolValue; import com.google.protobuf.BytesValue; +import com.google.protobuf.DescriptorProtos; import com.google.protobuf.DescriptorProtos.MethodOptions; import com.google.protobuf.Descriptors; import com.google.protobuf.Descriptors.Descriptor; @@ -328,7 +331,7 @@ private static Map buildFields(Descriptor desc, final ImmutableMap.Builder builder = ImmutableMap.builder(); for (FieldDescriptor field : desc.getFields()) { final JavaType type = field.getJavaType(); - + final boolean isRequired = hasRequiredFieldBehavior(field); final String fieldName; switch (currentMatchRule) { case ORIGINAL_FIELD: @@ -368,13 +371,14 @@ private static Map buildFields(Descriptor desc, case BYTE_STRING: case ENUM: // Use field name which is specified in proto file. - builder.put(key, new Field(field, parentNames, field.getJavaType())); + builder.put(key, new Field(field, parentNames, field.getJavaType(), isRequired)); break; case MESSAGE: @Nullable final JavaType wellKnownFieldType = getJavaTypeForWellKnownTypes(field); + if (wellKnownFieldType != null) { - builder.put(key, new Field(field, parentNames, wellKnownFieldType)); + builder.put(key, new Field(field, parentNames, wellKnownFieldType, isRequired)); break; } @@ -408,7 +412,7 @@ private static Map buildFields(Descriptor desc, throw e; } - builder.put(key, new Field(field, parentNames, JavaType.MESSAGE)); + builder.put(key, new Field(field, parentNames, JavaType.MESSAGE, isRequired)); } break; } @@ -418,6 +422,18 @@ private static Map buildFields(Descriptor desc, return builder.buildKeepingLast(); } + private static boolean hasRequiredFieldBehavior(FieldDescriptor field) { + if (field.isRepeated()) { + return false; + } + + final List fieldBehaviors = field + .getOptions() + .getExtension((ExtensionLite>) + FieldBehaviorProto.fieldBehavior); + return fieldBehaviors.contains(FieldBehavior.REQUIRED); + } + @Nullable private static JavaType getJavaTypeForWellKnownTypes(FieldDescriptor fd) { // MapField can be sent only via HTTP body. @@ -434,15 +450,7 @@ private static JavaType getJavaTypeForWellKnownTypes(FieldDescriptor fd) { return JavaType.STRING; } - if (DoubleValue.getDescriptor().getFullName().equals(fullName) || - FloatValue.getDescriptor().getFullName().equals(fullName) || - Int64Value.getDescriptor().getFullName().equals(fullName) || - UInt64Value.getDescriptor().getFullName().equals(fullName) || - Int32Value.getDescriptor().getFullName().equals(fullName) || - UInt32Value.getDescriptor().getFullName().equals(fullName) || - BoolValue.getDescriptor().getFullName().equals(fullName) || - StringValue.getDescriptor().getFullName().equals(fullName) || - BytesValue.getDescriptor().getFullName().equals(fullName)) { + if (isScalarValueWrapperMessage(fullName)) { // "value" field. Wrappers must have one field. assert messageType.getFields().size() == 1 : "Wrappers must have one 'value' field."; return messageType.getFields().get(0).getJavaType(); @@ -469,6 +477,18 @@ private static JavaType getJavaTypeForWellKnownTypes(FieldDescriptor fd) { return null; } + private static boolean isScalarValueWrapperMessage(String fullName) { + return DoubleValue.getDescriptor().getFullName().equals(fullName) || + FloatValue.getDescriptor().getFullName().equals(fullName) || + Int64Value.getDescriptor().getFullName().equals(fullName) || + UInt64Value.getDescriptor().getFullName().equals(fullName) || + Int32Value.getDescriptor().getFullName().equals(fullName) || + UInt32Value.getDescriptor().getFullName().equals(fullName) || + BoolValue.getDescriptor().getFullName().equals(fullName) || + StringValue.getDescriptor().getFullName().equals(fullName) || + BytesValue.getDescriptor().getFullName().equals(fullName); + } + // to make it more efficient, we calculate whether extract response body one time // if there is no matching toplevel field, we set it to null @Nullable @@ -606,7 +626,8 @@ public HttpEndpointSpecification httpEndpointSpecification(Route route) { spec.originalFields.entrySet().stream().collect( toImmutableMap(Entry::getKey, fieldEntry -> new Parameter(fieldEntry.getValue().type(), - fieldEntry.getValue().isRepeated()))); + fieldEntry.getValue().isRepeated(), + fieldEntry.getValue().isRequired()))); return new HttpEndpointSpecification(spec.order, route, paramNames, @@ -981,11 +1002,17 @@ static final class Field { private final FieldDescriptor descriptor; private final List parentNames; private final JavaType javaType; + private final boolean isRequired; - private Field(FieldDescriptor descriptor, List parentNames, JavaType javaType) { + private Field(FieldDescriptor descriptor, + List parentNames, + JavaType javaType, + boolean isRequired + ) { this.descriptor = descriptor; this.parentNames = parentNames; this.javaType = javaType; + this.isRequired = isRequired; } JavaType type() { @@ -999,6 +1026,10 @@ String name() { boolean isRepeated() { return descriptor.isRepeated(); } + + boolean isRequired() { + return isRequired; + } } /** diff --git a/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServicePluginTest.java b/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServicePluginTest.java index 061c7b7e34a..fd9408109b6 100644 --- a/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServicePluginTest.java +++ b/grpc/src/test/java/com/linecorp/armeria/internal/server/grpc/GrpcDocServicePluginTest.java @@ -403,11 +403,11 @@ void httpEndpoint() { FieldInfo.builder("message_id", TypeSignature.ofBase(JavaType.STRING.name())) .location(FieldLocation.PATH).requirement(FieldRequirement.REQUIRED).build(), FieldInfo.builder("revision", TypeSignature.ofBase(JavaType.LONG.name())) - .location(FieldLocation.QUERY).requirement(FieldRequirement.REQUIRED).build(), + .location(FieldLocation.QUERY).requirement(FieldRequirement.OPTIONAL).build(), FieldInfo.builder("sub.subfield", TypeSignature.ofBase(JavaType.STRING.name())) - .location(FieldLocation.QUERY).requirement(FieldRequirement.REQUIRED).build(), + .location(FieldLocation.QUERY).requirement(FieldRequirement.OPTIONAL).build(), FieldInfo.builder("type", TypeSignature.ofBase(JavaType.ENUM.name())) - .location(FieldLocation.QUERY).requirement(FieldRequirement.REQUIRED).build())); + .location(FieldLocation.QUERY).requirement(FieldRequirement.OPTIONAL).build())); assertThat(getMessageV2.useParameterAsRoot()).isFalse(); final MethodInfo getMessageV3 = serviceInfo.methods().stream() @@ -422,7 +422,7 @@ void httpEndpoint() { .location(FieldLocation.PATH).requirement(FieldRequirement.REQUIRED).build(), FieldInfo.builder("revision", TypeSignature.ofList(TypeSignature.ofBase(JavaType.LONG.name()))) - .location(FieldLocation.QUERY).requirement(FieldRequirement.REQUIRED).build())); + .location(FieldLocation.QUERY).requirement(FieldRequirement.OPTIONAL).build())); assertThat(getMessageV3.useParameterAsRoot()).isFalse(); // Check HTTP PATCH method. @@ -437,7 +437,7 @@ void httpEndpoint() { FieldInfo.builder("message_id", TypeSignature.ofBase(JavaType.STRING.name())) .location(FieldLocation.PATH).requirement(FieldRequirement.REQUIRED).build(), FieldInfo.builder("text", TypeSignature.ofBase(JavaType.STRING.name())) - .location(FieldLocation.BODY).requirement(FieldRequirement.REQUIRED).build())); + .location(FieldLocation.BODY).requirement(FieldRequirement.OPTIONAL).build())); assertThat(updateMessageV1.useParameterAsRoot()).isFalse(); final MethodInfo updateMessageV2 = serviceInfo.methods().stream() @@ -451,7 +451,11 @@ void httpEndpoint() { FieldInfo.builder("message_id", TypeSignature.ofBase(JavaType.STRING.name())) .location(FieldLocation.PATH).requirement(FieldRequirement.REQUIRED).build(), FieldInfo.builder("text", TypeSignature.ofBase(JavaType.STRING.name())) - .location(FieldLocation.BODY).requirement(FieldRequirement.REQUIRED).build())); + .location(FieldLocation.BODY).requirement(FieldRequirement.OPTIONAL).build(), + FieldInfo.builder("required_text", TypeSignature.ofBase(JavaType.STRING.name())) + .location(FieldLocation.BODY).requirement(FieldRequirement.REQUIRED).build(), + FieldInfo.builder("optional_text", TypeSignature.ofBase(JavaType.STRING.name())) + .location(FieldLocation.BODY).requirement(FieldRequirement.OPTIONAL).build())); assertThat(updateMessageV2.useParameterAsRoot()).isFalse(); } diff --git a/grpc/src/test/proto/testing/grpc/transcoding.proto b/grpc/src/test/proto/testing/grpc/transcoding.proto index a86572dafed..554c3ca9898 100644 --- a/grpc/src/test/proto/testing/grpc/transcoding.proto +++ b/grpc/src/test/proto/testing/grpc/transcoding.proto @@ -21,6 +21,7 @@ package armeria.grpc.testing; option java_package = "testing.grpc"; import "google/api/annotations.proto"; +import "google/api/field_behavior.proto"; import "google/protobuf/timestamp.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; @@ -292,6 +293,8 @@ message UpdateMessageRequestV1 { message Message { string text = 1; // The resource content. + google.protobuf.StringValue required_text = 2 [(google.api.field_behavior) = REQUIRED]; + google.protobuf.StringValue optional_text = 3; } enum MessageType {