Skip to content

Commit

Permalink
🩹 use memoized response on chat client [spring-projects#2097]
Browse files Browse the repository at this point in the history
  • Loading branch information
Grogdunn committed Jan 22, 2025
1 parent f5761de commit 5029d00
Showing 1 changed file with 34 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,9 @@

package org.springframework.ai.chat.client;

import java.io.IOException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.springframework.ai.tool.ToolCallbacks;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
Expand Down Expand Up @@ -65,6 +47,7 @@
import org.springframework.ai.model.Media;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.core.Ordered;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.Resource;
Expand All @@ -73,6 +56,22 @@
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

import java.io.IOException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

/**
* The default implementation of {@link ChatClient} as created by the
Expand Down Expand Up @@ -393,6 +392,8 @@ public static class DefaultCallResponseSpec implements CallResponseSpec {

private final DefaultChatClientRequestSpec request;

private final ThreadLocal<Optional<ChatResponse>> memoizedResponse = ThreadLocal.withInitial(Optional::empty);

public DefaultCallResponseSpec(DefaultChatClientRequestSpec request) {
Assert.notNull(request, "request cannot be null");
this.request = request;
Expand Down Expand Up @@ -506,13 +507,16 @@ private static String getContentFromChatResponse(@Nullable ChatResponse chatResp
@Override
@Nullable
public ChatResponse chatResponse() {
return doGetChatResponse();
final var chatResponse = memoizedResponse.get().orElseGet(this::doGetChatResponse);
memoizedResponse.set(Optional.ofNullable(chatResponse));
return chatResponse;
}

@Override
@Nullable
public String content() {
ChatResponse chatResponse = doGetChatResponse();
final var chatResponse = memoizedResponse.get().orElseGet(this::doGetChatResponse);
memoizedResponse.set(Optional.ofNullable(chatResponse));
return getContentFromChatResponse(chatResponse);
}

Expand All @@ -522,6 +526,8 @@ public static class DefaultStreamResponseSpec implements StreamResponseSpec {

private final DefaultChatClientRequestSpec request;

private final ThreadLocal<Optional<Flux<ChatResponse>>> memoizedFlux = ThreadLocal.withInitial(Optional::empty);

public DefaultStreamResponseSpec(DefaultChatClientRequestSpec request) {
Assert.notNull(request, "request cannot be null");
this.request = request;
Expand Down Expand Up @@ -559,12 +565,18 @@ private Flux<ChatResponse> doGetObservableFluxChatResponse(DefaultChatClientRequ

@Override
public Flux<ChatResponse> chatResponse() {
return doGetObservableFluxChatResponse(this.request);
final var chatResponseFlux = memoizedFlux.get()
.orElseGet(() -> doGetObservableFluxChatResponse(this.request));
memoizedFlux.set(Optional.of(chatResponseFlux));
return chatResponseFlux;
}

@Override
public Flux<String> content() {
return doGetObservableFluxChatResponse(this.request).map(r -> {
final var chatResponseFlux = memoizedFlux.get()
.orElseGet(() -> doGetObservableFluxChatResponse(this.request));
memoizedFlux.set(Optional.of(chatResponseFlux));
return chatResponseFlux.map(r -> {
if (r.getResult() == null || r.getResult().getOutput() == null
|| r.getResult().getOutput().getText() == null) {
return "";
Expand Down

0 comments on commit 5029d00

Please sign in to comment.