Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output streaming support for the whole pipeline in GMC router #278

Merged
merged 4 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 118 additions & 60 deletions microservices-connector/cmd/router/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
package main

import (
"bufio"
// "bufio"
"bytes"
"context"
"encoding/json"
Expand Down Expand Up @@ -46,11 +46,12 @@ var (
)

const (
ChunkSize = 1024
BufferSize = 16
ServiceURL = "serviceUrl"
ServiceNode = "node"
DataPrep = "DataPrep"
Parameters = "parameters"
Llm = "Llm"
)

type EnsembleStepOutput struct {
Expand All @@ -63,6 +64,19 @@ type GMCGraphRoutingError struct {
Cause string `json:"cause"`
}

type ReadCloser struct {
*bytes.Reader
}

func (ReadCloser) Close() error {
// Typically, you would release resources here, but for bytes.Reader, there's nothing to do.
return nil
}

func NewReadCloser(b []byte) io.ReadCloser {
return ReadCloser{bytes.NewReader(b)}
}

func (e *GMCGraphRoutingError) Error() string {
return fmt.Sprintf("%s. %s", e.ErrorMessage, e.Cause)
}
Expand Down Expand Up @@ -127,7 +141,12 @@ func prepareErrorResponse(err error, errorMessage string) []byte {
return errorResponseBytes
}

func callService(step *mcv1alpha3.Step, serviceUrl string, input []byte, headers http.Header) ([]byte, int, error) {
func callService(
step *mcv1alpha3.Step,
serviceUrl string,
input []byte,
headers http.Header,
) (io.ReadCloser, int, error) {
defer timeTrack(time.Now(), "step", serviceUrl)
log.Info("Entering callService", "url", serviceUrl)

Expand Down Expand Up @@ -157,21 +176,7 @@ func callService(step *mcv1alpha3.Step, serviceUrl string, input []byte, headers
return nil, 500, err
}

defer func() {
if resp.Body != nil {
err := resp.Body.Close()
if err != nil {
log.Error(err, "An error has occurred while closing the response body")
}
}
}()

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Error(err, "Error while reading the response")
}

return body, resp.StatusCode, err
return resp.Body, resp.StatusCode, nil
}

// Use step service name to create a K8s service if serviceURL is empty
Expand All @@ -190,7 +195,7 @@ func executeStep(
initInput []byte,
input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
if step.NodeName != "" {
// when nodeName is specified make a recursive call for routing to next step
return routeStep(step.NodeName, graph, initInput, input, headers)
Expand Down Expand Up @@ -231,16 +236,16 @@ func handleSwitchNode(
initInput []byte,
request []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
var statusCode int
var responseBytes []byte
var responseBody io.ReadCloser
var err error
stepType := ServiceURL
if route.NodeName != "" {
stepType = ServiceNode
}
log.Info("Starting execution of step", "Node Name", route.NodeName, "type", stepType, "stepName", route.StepName)
if responseBytes, statusCode, err = executeStep(route, graph, initInput, request, headers); err != nil {
if responseBody, statusCode, err = executeStep(route, graph, initInput, request, headers); err != nil {
return nil, 500, err
}

Expand All @@ -253,18 +258,19 @@ func handleSwitchNode(
statusCode,
)
}
return responseBytes, statusCode, nil
return responseBody, statusCode, nil
}

func handleSwitchPipeline(nodeName string,
graph mcv1alpha3.GMConnector,
initInput []byte,
input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
currentNode := graph.Spec.Nodes[nodeName]
var statusCode int
var responseBytes []byte
var responseBody io.ReadCloser
var err error

initReqData := make(map[string]interface{})
Expand All @@ -286,35 +292,48 @@ func handleSwitchPipeline(nodeName string,
}
log.Info("Current Step Information", "Node Name", nodeName, "Step Index", index)
request := input
if responseBody != nil {
responseBytes, err = io.ReadAll(responseBody)
if err != nil {
log.Error(err, "Error while reading the response body")
return nil, 500, err
}
log.Info("Print Previous Response Bytes", "Previous Response Bytes",
responseBytes, "Previous Status Code", statusCode)
err = responseBody.Close()
if err != nil {
log.Error(err, "Error while trying to close the responseBody in handleSwitchPipeline")
}
}

log.Info("Print Original Request Bytes", "Request Bytes", request)
if route.Data == "$response" && index > 0 {
request = mergeRequests(responseBytes, initReqData)
}
log.Info("Print New Request Bytes", "Request Bytes", request)
if route.Condition == "" {
responseBytes, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers)
responseBody, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers)
if err != nil {
return responseBytes, statusCode, err
return nil, statusCode, err
}
} else {
if pickupRouteByCondition(initInput, route.Condition) {
responseBytes, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers)
responseBody, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers)
if err != nil {
return responseBytes, statusCode, err
return nil, statusCode, err
}
}
}
log.Info("Print Response Bytes", "Response Bytes", responseBytes, "Status Code", statusCode)
}
return responseBytes, statusCode, err
return responseBody, statusCode, err
}

func handleEnsemblePipeline(nodeName string,
graph mcv1alpha3.GMConnector,
initInput []byte,
input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
currentNode := graph.Spec.Nodes[nodeName]
ensembleRes := make([]chan EnsembleStepOutput, len(currentNode.Steps))
errChan := make(chan error)
Expand All @@ -328,8 +347,12 @@ func handleEnsemblePipeline(nodeName string,
resultChan := make(chan EnsembleStepOutput)
ensembleRes[i] = resultChan
go func() {
output, statusCode, err := executeStep(step, graph, initInput, input, headers)
responseBody, statusCode, err := executeStep(step, graph, initInput, input, headers)
if err == nil {
output, rerr := io.ReadAll(responseBody)
if rerr != nil {
log.Error(rerr, "Error while reading the response body")
}
var res map[string]interface{}
if err = json.Unmarshal(output, &res); err == nil {
resultChan <- EnsembleStepOutput{
Expand All @@ -339,6 +362,10 @@ func handleEnsemblePipeline(nodeName string,
return
}
}
rerr := responseBody.Close()
if rerr != nil {
log.Error(rerr, "Error while trying to close the responseBody in handleEnsemblePipeline")
}
errChan <- err
}()
}
Expand All @@ -361,7 +388,8 @@ func handleEnsemblePipeline(nodeName string,
ensembleStepOutput.StepStatusCode,
)
stepResponse, _ := json.Marshal(ensembleStepOutput.StepResponse)
return stepResponse, ensembleStepOutput.StepStatusCode, nil
stepIOReader := NewReadCloser(stepResponse)
return stepIOReader, ensembleStepOutput.StepStatusCode, nil
} else {
response[key] = ensembleStepOutput.StepResponse
}
Expand All @@ -371,17 +399,19 @@ func handleEnsemblePipeline(nodeName string,
}
// return json.Marshal(response)
combinedResponse, _ := json.Marshal(response) // TODO check if you need err handling for Marshalling
return combinedResponse, 200, nil
combinedIOReader := NewReadCloser(combinedResponse)
return combinedIOReader, 200, nil
}

func handleSequencePipeline(nodeName string,
graph mcv1alpha3.GMConnector,
initInput []byte,
input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
currentNode := graph.Spec.Nodes[nodeName]
var statusCode int
var responseBody io.ReadCloser
var responseBytes []byte
var err error

Expand Down Expand Up @@ -409,6 +439,20 @@ func handleSequencePipeline(nodeName string,
log.Info("Starting execution of step", "type", stepType, "stepName", step.StepName)
request := input
log.Info("Print Original Request Bytes", "Request Bytes", request)
if responseBody != nil {
responseBytes, err = io.ReadAll(responseBody)
if err != nil {
log.Error(err, "Error while reading the response body")
return nil, 500, err
}
log.Info("Print Previous Response Bytes", "Previous Response Bytes",
responseBytes, "Previous Status Code", statusCode)
err := responseBody.Close()
if err != nil {
log.Error(err, "Error while trying to close the responseBody in handleSequencePipeline")
}
}

if step.Data == "$response" && i > 0 {
request = mergeRequests(responseBytes, initReqData)
}
Expand All @@ -419,13 +463,12 @@ func handleSequencePipeline(nodeName string,
}
// if the condition does not match for the step in the sequence we stop and return the response
if !gjson.GetBytes(responseBytes, step.Condition).Exists() {
return responseBytes, 500, nil
return responseBody, 500, nil
}
}
if responseBytes, statusCode, err = executeStep(step, graph, initInput, request, headers); err != nil {
if responseBody, statusCode, err = executeStep(step, graph, initInput, request, headers); err != nil {
return nil, 500, err
}
log.Info("Print Response Bytes", "Response Bytes", responseBytes, "Status Code", statusCode)
/*
Only if a step is a hard dependency, we will check for its success.
*/
Expand All @@ -439,18 +482,18 @@ func handleSequencePipeline(nodeName string,
statusCode,
)
// Stop the execution of sequence right away if step is a hard dependency and is unsuccessful
return responseBytes, statusCode, nil
return responseBody, statusCode, nil
}
}
}
return responseBytes, statusCode, nil
return responseBody, statusCode, nil
}

func routeStep(nodeName string,
graph mcv1alpha3.GMConnector,
initInput, input []byte,
headers http.Header,
) ([]byte, int, error) {
) (io.ReadCloser, int, error) {
defer timeTrack(time.Now(), "node", nodeName)
currentNode := graph.Spec.Nodes[nodeName]
log.Info("Current Node", "Node Name", nodeName)
Expand Down Expand Up @@ -478,9 +521,14 @@ func mcGraphHandler(w http.ResponseWriter, req *http.Request) {
go func() {
defer close(done)

inputBytes, _ := io.ReadAll(req.Body)
response, statusCode, err := routeStep(defaultNodeName, *mcGraph, inputBytes, inputBytes, req.Header)
inputBytes, err := io.ReadAll(req.Body)
if err != nil {
log.Error(err, "failed to read request body")
http.Error(w, "failed to read request body", http.StatusBadRequest)
return
}

responseBody, statusCode, err := routeStep(defaultNodeName, *mcGraph, inputBytes, inputBytes, req.Header)
if err != nil {
log.Error(err, "failed to process request")
w.Header().Set("Content-Type", "application/json")
Expand All @@ -490,37 +538,47 @@ func mcGraphHandler(w http.ResponseWriter, req *http.Request) {
}
return
}
if json.Valid(response) {
w.Header().Set("Content-Type", "application/json")
}
w.WriteHeader(statusCode)

writer := bufio.NewWriter(w)
defer func() {
if err := writer.Flush(); err != nil {
log.Error(err, "error flushing writer when processing response")
err := responseBody.Close()
if err != nil {
log.Error(err, "Error while trying to close the responseBody in mcGraphHandler")
}
}()

for start := 0; start < len(response); start += ChunkSize {
end := start + ChunkSize
if end > len(response) {
end = len(response)
w.Header().Set("Content-Type", "application/json")
buffer := make([]byte, BufferSize)
for {
n, err := responseBody.Read(buffer)
if err != nil && err != io.EOF {
log.Error(err, "failed to read from response body")
http.Error(w, "failed to read from response body", http.StatusInternalServerError)
return
}
if _, err := writer.Write(response[start:end]); err != nil {
log.Error(err, "failed to write mcGraphHandler response")
if n == 0 {
break
}

log.Info("[llm - chat_stream] chunk:", "Buffer", string(buffer[:n]))

// Write the chunk to the ResponseWriter
if _, err := w.Write(buffer[:n]); err != nil {
log.Error(err, "failed to write to ResponseWriter")
return
}

if err := writer.Flush(); err != nil {
log.Error(err, "error flushing writer when processing response")
// Flush the data to the client immediately
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
} else {
log.Error(errors.New("unable to flush data"), "ResponseWriter does not support flushing")
return
}
}
}()

select {
case <-ctx.Done():
log.Error(errors.New("failed to process request"), "request timed out")
log.Error(errors.New("request timed out"), "failed to process request")
http.Error(w, "request timed out", http.StatusGatewayTimeout)
case <-done:
log.Info("mcGraphHandler is done")
Expand Down
Loading
Loading