Skip to content

Commit

Permalink
feat: allow setting max concurrent requests
Browse files Browse the repository at this point in the history
This is useful for limiting the number of simultaneous connections the
web runner handles at any given time, preventing overloading the server
  • Loading branch information
farzadghanei committed May 9, 2024
1 parent 6c91e4e commit 584198d
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 54 deletions.
1 change: 1 addition & 0 deletions examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ runners:
# response_write_timeout: 2s
# timeout: 5s
# max_header_bytes: 8192
# max_concurrent_requests: 1 # 0 means no limit


check_suites:
Expand Down
89 changes: 58 additions & 31 deletions internal/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ type Conf struct {

// ConfRunner is config for the check runners
type ConfRunner struct {
Timeout *time.Duration
ShutdownSignalHeader *string `yaml:"shutdown_signal_header"`
MaxHeaderBytes *int `yaml:"max_header_bytes"`
ListenAddress string `yaml:"listen_address"`
RequestReadTimeout *time.Duration `yaml:"request_read_timeout"`
ResponseWriteTimeout *time.Duration `yaml:"response_write_timeout"`
ResponseOK *string `yaml:"response_ok"`
ResponseFailed *string `yaml:"response_failed"`
ResponseTimeout *string `yaml:"response_timeout"`
Timeout *time.Duration
ShutdownSignalHeader *string `yaml:"shutdown_signal_header"`
MaxHeaderBytes *int `yaml:"max_header_bytes"`
MaxConcurrentRequests *int `yaml:"max_concurrent_requests"`
ListenAddress string `yaml:"listen_address"`
RequestReadTimeout *time.Duration `yaml:"request_read_timeout"`
ResponseWriteTimeout *time.Duration `yaml:"response_write_timeout"`
ResponseOK *string `yaml:"response_ok"`
ResponseFailed *string `yaml:"response_failed"`
ResponseTimeout *string `yaml:"response_timeout"`
ResponseUnavailable *string `yaml:"response_unavailable"`
}

// ConfCheckSpec is the spec for each check configuration
Expand Down Expand Up @@ -58,23 +60,33 @@ func ReadConf(path string) (*Conf, error) {
return &conf, err
}

// GetDefaultConfRunner returns a ConfRunner based on the default configuration
func GetDefaultConfRunner(runners *ConfRunners) ConfRunner {
// GetBaseConfRunner returns a base ConfRunner with default literal values
func GetBaseConfRunner() ConfRunner {
var timeout, readTimeout, writeTimout time.Duration = 5 * time.Minute, 30 * time.Second, 30 * time.Second
var maxHeaderBytes int = 8 * 1024
var MaxConcurrentRequests int = 1
var respOK, respFailed, respTimeout string = "OK", "FAILED", "TIMEOUT"
var respUnavailable string = "UNAVAILABLE"

baseConf := ConfRunner{
Timeout: &timeout,
ShutdownSignalHeader: nil,
MaxHeaderBytes: &maxHeaderBytes,
ListenAddress: "127.0.0.1:8880",
RequestReadTimeout: &readTimeout,
ResponseWriteTimeout: &writeTimout,
ResponseOK: &respOK,
ResponseFailed: &respFailed,
ResponseTimeout: &respTimeout,
Timeout: &timeout,
ShutdownSignalHeader: nil,
MaxHeaderBytes: &maxHeaderBytes,
ListenAddress: "127.0.0.1:8880",
RequestReadTimeout: &readTimeout,
ResponseWriteTimeout: &writeTimout,
ResponseOK: &respOK,
ResponseFailed: &respFailed,
ResponseTimeout: &respTimeout,
ResponseUnavailable: &respUnavailable,
MaxConcurrentRequests: &MaxConcurrentRequests,
}
return baseConf
}

// GetDefaultConfRunner returns a ConfRunner based on the default configuration
func GetDefaultConfRunner(runners *ConfRunners) ConfRunner {
baseConf := GetBaseConfRunner()

if defaultConf, defaultExists := (*runners)["default"]; defaultExists {
baseConf = MergedConfRunners(&baseConf, &defaultConf)
Expand All @@ -100,17 +112,7 @@ func GetConfRunner(runners *ConfRunners, name string) (ConfRunner, bool) {

// MergedConfRunners merges the baseConf with the overrideConf and returns the merged ConfRunner
func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner {
mergedConf := ConfRunner{
Timeout: overrideConf.Timeout,
ShutdownSignalHeader: overrideConf.ShutdownSignalHeader,
ListenAddress: overrideConf.ListenAddress,
RequestReadTimeout: overrideConf.RequestReadTimeout,
ResponseWriteTimeout: overrideConf.ResponseWriteTimeout,
ResponseOK: overrideConf.ResponseOK,
ResponseFailed: overrideConf.ResponseFailed,
ResponseTimeout: overrideConf.ResponseTimeout,
MaxHeaderBytes: overrideConf.MaxHeaderBytes,
}
mergedConf := CopyConfRunner(overrideConf)

if mergedConf.Timeout == nil {
mergedConf.Timeout = baseConf.Timeout
Expand Down Expand Up @@ -148,5 +150,30 @@ func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner {
mergedConf.ResponseTimeout = baseConf.ResponseTimeout
}

if mergedConf.ResponseUnavailable == nil {
mergedConf.ResponseUnavailable = baseConf.ResponseUnavailable
}

if mergedConf.MaxConcurrentRequests == nil {
mergedConf.MaxConcurrentRequests = baseConf.MaxConcurrentRequests
}

return mergedConf
}

// CopyConfRunner returns a copy of the ConfRunner with the same values
func CopyConfRunner(conf *ConfRunner) ConfRunner {
return ConfRunner{
Timeout: conf.Timeout,
ShutdownSignalHeader: conf.ShutdownSignalHeader,
ListenAddress: conf.ListenAddress,
RequestReadTimeout: conf.RequestReadTimeout,
ResponseWriteTimeout: conf.ResponseWriteTimeout,
ResponseOK: conf.ResponseOK,
ResponseFailed: conf.ResponseFailed,
ResponseTimeout: conf.ResponseTimeout,
ResponseUnavailable: conf.ResponseUnavailable,
MaxHeaderBytes: conf.MaxHeaderBytes,
MaxConcurrentRequests: conf.MaxConcurrentRequests,
}
}
31 changes: 17 additions & 14 deletions internal/conf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ func TestGetDefaultConfRunner(t *testing.T) {
respNo := "NO"
respMaybe := "MAYBE"

wantMaxHeaderBytes := 8 * 1024

runners := ConfRunners{
"default": ConfRunner{
Timeout: &wantTimeout,
Expand Down Expand Up @@ -115,25 +113,30 @@ func TestGetDefaultConfRunner(t *testing.T) {
}

// Test case where default key does not exist
baseRunner := GetBaseConfRunner()
runners = ConfRunners{}
defaultRunner = GetDefaultConfRunner(&runners)
if *defaultRunner.Timeout != 5*time.Minute {
t.Errorf("Expected Timeout to be 0, got %v", *defaultRunner.Timeout)
if *defaultRunner.Timeout != *baseRunner.Timeout {
t.Errorf("Timeout want %v, got %v", *baseRunner.Timeout, *defaultRunner.Timeout)
}
if defaultRunner.ListenAddress != baseRunner.ListenAddress {
t.Errorf("ListenAddress want %v, got %s", baseRunner.ListenAddress, defaultRunner.ListenAddress)
}
if defaultRunner.ListenAddress != "127.0.0.1:8880" {
t.Errorf("Expected ListenAddress to be 127.0.0.1:8080, got %s", defaultRunner.ListenAddress)
if *defaultRunner.ResponseOK != *baseRunner.ResponseOK {
t.Errorf("ResponseOK want %v, got %s", *baseRunner.ResponseOK, *defaultRunner.ResponseOK)
}
if *defaultRunner.ResponseOK != "OK" {
t.Errorf("Expected ResponseOK to be OK, got %s", *defaultRunner.ResponseOK)
if *defaultRunner.ResponseFailed != *baseRunner.ResponseFailed {
t.Errorf("ResponseFailed want %v, got %s", *baseRunner.ResponseFailed, *defaultRunner.ResponseFailed)
}
if *defaultRunner.ResponseFailed != "FAILED" {
t.Errorf("Expected ResponseFailed to be FAILED, got %s", *defaultRunner.ResponseFailed)
if *defaultRunner.ResponseTimeout != *baseRunner.ResponseTimeout {
t.Errorf("ResponseTimeout want %v, got %s", *baseRunner.ResponseTimeout, *defaultRunner.ResponseTimeout)
}
if *defaultRunner.ResponseTimeout != "TIMEOUT" {
t.Errorf("Expected ResponseTimeout to be TIMEOUT, got %s", *defaultRunner.ResponseTimeout)
if *defaultRunner.MaxHeaderBytes != *baseRunner.MaxHeaderBytes {
t.Errorf("MaxHeaderBytes want %v, got %v", *baseRunner.MaxHeaderBytes, *defaultRunner.MaxHeaderBytes)
}
if *defaultRunner.MaxHeaderBytes != wantMaxHeaderBytes {
t.Errorf("Expected MaxHeaderBytes to be %v, got %v", wantMaxHeaderBytes, *defaultRunner.MaxHeaderBytes)
if *defaultRunner.MaxConcurrentRequests != *baseRunner.MaxConcurrentRequests {
t.Errorf("MaxConcurrentRequests want %v, got %v", *baseRunner.MaxConcurrentRequests,
*defaultRunner.MaxConcurrentRequests)
}
}

Expand Down
23 changes: 14 additions & 9 deletions internal/run_modes.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,24 @@ func RunModeHTTP(checkGroups *CheckSuites, conf *ConfRunner, logger *log.Logger)
if conf.ShutdownSignalHeader != nil {
shutdownSignalHeaderValue = *conf.ShutdownSignalHeader
}
listenAddress := conf.ListenAddress
timeout := *conf.Timeout
maxConcurrentRequests := *conf.MaxConcurrentRequests
responseOK := *conf.ResponseOK
responseFailed := *conf.ResponseFailed
responseTimeout := *conf.ResponseTimeout
requestReadTimeout := *conf.RequestReadTimeout
responseWriteTimeout := *conf.ResponseWriteTimeout
responseUnavailable := *conf.ResponseUnavailable

runner := Runner{Log: logger, Timeout: timeout}
runner := Runner{Log: logger, Timeout: *conf.Timeout}

var runningRequests atomic.Int32
var reqHandlerChan = make(chan *http.Request, 1)

httpHandler := func(w http.ResponseWriter, r *http.Request) {
runningRequests.Add(1)
if maxConcurrentRequests > 0 && runningRequests.Load() > int32(maxConcurrentRequests) {
w.WriteHeader(http.StatusServiceUnavailable) // 503
fmt.Fprint(w, responseUnavailable)
}
defer runningRequests.Add(-1)
logger.Printf("processing http request: %s", httpRequestAsString(r))
_, failed, timedout := runChecks(&runner, checkGroups, logger)
if timedout > 0 {
Expand All @@ -73,10 +78,10 @@ func RunModeHTTP(checkGroups *CheckSuites, conf *ConfRunner, logger *log.Logger)
http.HandleFunc("/", httpHandler)

server := &http.Server{
Addr: listenAddress,
Addr: conf.ListenAddress,
Handler: nil, // use http.DefaultServeMux
ReadTimeout: requestReadTimeout,
WriteTimeout: responseWriteTimeout,
ReadTimeout: *conf.RequestReadTimeout,
WriteTimeout: *conf.ResponseWriteTimeout,
IdleTimeout: 0 * time.Second, // set to 0 so uses read timeout
MaxHeaderBytes: *conf.MaxHeaderBytes,
}
Expand All @@ -101,7 +106,7 @@ func RunModeHTTP(checkGroups *CheckSuites, conf *ConfRunner, logger *log.Logger)
}
}()

logger.Printf("starting http server listening on %s", listenAddress)
logger.Printf("starting http server listening on %s", conf.ListenAddress)
err := server.ListenAndServe()
close(reqHandlerChan)
if err != nil && err != http.ErrServerClosed {
Expand Down

0 comments on commit 584198d

Please sign in to comment.