33
33
import com .linecorp .armeria .common .HttpHeaderNames ;
34
34
import com .linecorp .armeria .common .HttpHeaders ;
35
35
import com .linecorp .armeria .common .HttpObject ;
36
+ import com .linecorp .armeria .common .HttpStatus ;
36
37
import com .linecorp .armeria .common .RequestHeaders ;
37
38
import com .linecorp .armeria .common .ResponseCompleteException ;
39
+ import com .linecorp .armeria .common .ResponseHeaders ;
38
40
import com .linecorp .armeria .common .SessionProtocol ;
39
41
import com .linecorp .armeria .common .annotation .Nullable ;
40
42
import com .linecorp .armeria .common .logging .RequestLogBuilder ;
50
52
import io .netty .channel .ChannelFuture ;
51
53
import io .netty .channel .ChannelFutureListener ;
52
54
import io .netty .channel .ChannelPromise ;
55
+ import io .netty .handler .codec .http .HttpHeaderValues ;
53
56
import io .netty .handler .codec .http2 .Http2Error ;
54
57
import io .netty .handler .proxy .ProxyConnectException ;
55
58
@@ -61,6 +64,7 @@ enum State {
61
64
NEEDS_TO_WRITE_FIRST_HEADER ,
62
65
NEEDS_DATA ,
63
66
NEEDS_DATA_OR_TRAILERS ,
67
+ NEEDS_100_CONTINUE ,
64
68
DONE
65
69
}
66
70
@@ -143,6 +147,11 @@ public final void operationComplete(ChannelFuture future) throws Exception {
143
147
responseWrapper .initTimeout ();
144
148
}
145
149
150
+ if (state == State .NEEDS_100_CONTINUE ) {
151
+ assert responseWrapper != null ;
152
+ responseWrapper .initTimeout ();
153
+ }
154
+
146
155
onWriteSuccess ();
147
156
return ;
148
157
}
@@ -176,7 +185,7 @@ final boolean tryInitialize() {
176
185
}
177
186
178
187
this .session = session ;
179
- responseWrapper = responseDecoder .addResponse (id , originalRes , ctx , ch .eventLoop ());
188
+ responseWrapper = responseDecoder .addResponse (this , id , originalRes , ctx , ch .eventLoop ());
180
189
181
190
if (timeoutMillis > 0 ) {
182
191
// The timer would be executed if the first message has not been sent out within the timeout.
@@ -187,34 +196,39 @@ final boolean tryInitialize() {
187
196
return true ;
188
197
}
189
198
199
+ RequestHeaders mergedRequestHeaders (RequestHeaders headers ) {
200
+ final HttpHeaders internalHeaders ;
201
+ final ClientRequestContextExtension ctxExtension = ctx .as (ClientRequestContextExtension .class );
202
+ if (ctxExtension == null ) {
203
+ internalHeaders = HttpHeaders .of ();
204
+ } else {
205
+ internalHeaders = ctxExtension .internalRequestHeaders ();
206
+ }
207
+ return mergeRequestHeaders (
208
+ headers , ctx .defaultRequestHeaders (), ctx .additionalRequestHeaders (), internalHeaders );
209
+ }
210
+
190
211
/**
191
212
* Writes the {@link RequestHeaders} to the {@link Channel}.
192
213
* The {@link RequestHeaders} is merged with {@link ClientRequestContext#additionalRequestHeaders()}
193
214
* before being written.
194
215
* Note that the written data is not flushed by this method. The caller should explicitly call
195
216
* {@link Channel#flush()} when each write unit is done.
196
217
*/
197
- final void writeHeaders (RequestHeaders headers ) {
218
+ final void writeHeaders (RequestHeaders headers , boolean needs100Continue ) {
198
219
final SessionProtocol protocol = session .protocol ();
199
220
assert protocol != null ;
200
- if (headersOnly ) {
221
+ if (needs100Continue ) {
222
+ state = State .NEEDS_100_CONTINUE ;
223
+ } else if (headersOnly ) {
201
224
state = State .DONE ;
202
225
} else if (allowTrailers ) {
203
226
state = State .NEEDS_DATA_OR_TRAILERS ;
204
227
} else {
205
228
state = State .NEEDS_DATA ;
206
229
}
207
230
208
- final HttpHeaders internalHeaders ;
209
- final ClientRequestContextExtension ctxExtension = ctx .as (ClientRequestContextExtension .class );
210
- if (ctxExtension == null ) {
211
- internalHeaders = HttpHeaders .of ();
212
- } else {
213
- internalHeaders = ctxExtension .internalRequestHeaders ();
214
- }
215
- final RequestHeaders merged = mergeRequestHeaders (
216
- headers , ctx .defaultRequestHeaders (), ctx .additionalRequestHeaders (), internalHeaders );
217
- logBuilder .requestHeaders (merged );
231
+ logBuilder .requestHeaders (headers );
218
232
219
233
final String connectionOption = headers .get (HttpHeaderNames .CONNECTION );
220
234
if (CLOSE_STRING .equalsIgnoreCase (connectionOption ) || !keepAlive ) {
@@ -230,9 +244,37 @@ final void writeHeaders(RequestHeaders headers) {
230
244
// Attach a listener first to make the listener early handle a cause raised while writing headers
231
245
// before any other callbacks like `onStreamClosed()` are invoked.
232
246
promise .addListener (this );
233
- encoder .writeHeaders (id , streamId (), merged , headersOnly , promise );
247
+ encoder .writeHeaders (id , streamId (), headers , headersOnly , promise );
234
248
}
235
249
250
+ static boolean needs100Continue (RequestHeaders headers ) {
251
+ return headers .contains (HttpHeaderNames .EXPECT , HttpHeaderValues .CONTINUE .toString ());
252
+ }
253
+
254
+ void handle100Continue (ResponseHeaders responseHeaders ) {
255
+ if (state != State .NEEDS_100_CONTINUE ) {
256
+ return ;
257
+ }
258
+
259
+ if (responseHeaders .status () == HttpStatus .CONTINUE ) {
260
+ state = State .NEEDS_DATA_OR_TRAILERS ;
261
+ resume ();
262
+ // TODO(minwoox): reset the timeout
263
+ } else {
264
+ // We do not retry the request when HttpStatus.EXPECTATION_FAILED is received
265
+ // because:
266
+ // - Most servers support 100-continue.
267
+ // - It's much simpler to just fail the request and let the user retry.
268
+ state = State .DONE ;
269
+ logBuilder .endRequest ();
270
+ discardRequestBody ();
271
+ }
272
+ }
273
+
274
+ abstract void resume ();
275
+
276
+ abstract void discardRequestBody ();
277
+
236
278
/**
237
279
* Writes the {@link HttpData} to the {@link Channel}.
238
280
* Note that the written data is not flushed by this method. The caller should explicitly call
0 commit comments