Skip to content

Commit

Permalink
feat(web): Add support for required headers for http runner
Browse files Browse the repository at this point in the history
These changes allow to configure a list of required HTTP headers
to limit clients accessing the http runner.
a list of headers can be specified using a map
  • Loading branch information
farzadghanei committed May 10, 2024
1 parent 584198d commit 786bca1
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 114 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
---

name: tests

on:
Expand Down
5 changes: 3 additions & 2 deletions .golangci.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
---
# https://golangci-lint.run/usage/configuration/

run:
Expand Down Expand Up @@ -66,7 +67,7 @@ linters-settings:
- performance
- style
disabled-checks:
- dupImport # https://github.com/go-critic/go-critic/issues/845
- dupImport # https://github.com/go-critic/go-critic/issues/845
- ifElseChain
- octalLiteral
- whyNoLint
Expand Down Expand Up @@ -98,7 +99,7 @@ linters-settings:
# There are three different modes: `original`, `strict`, and `lax`.
# Default: "original"
list-mode: original
# List of file globs that will match this list of settings to compare against.
# File globs that will match this list of settings to compare against.
# Default: $all
files:
- "$all"
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
---
exclude: "^docs/|/.vscode/"
default_stages: [commit]

Expand Down
2 changes: 2 additions & 0 deletions cmd/chkok_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func TestRunHttp(t *testing.T) {
t.Fatalf("Failed to create HTTP request: %v", err)
}
req.Header.Set("X-Server-Shutdown", "test-shutdown-signal") // shutdown the server after the request
req.Header.Set("X-Required-Header", "required-value")
req.Header.Set("X-Required-Header2", "anything")

// Send the request multiple times, waiting for the server to
// start up and respond
Expand Down
11 changes: 8 additions & 3 deletions examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@ runners:
# response_timeout: "TIMEOUT"
cli: {} # override default runner only for CLI mode
http: # override default runner only for HTTP mode
# shutdown_signal_header is mainly useful for testing http mode, do not set it in production
# if set, better be treated like a secret, and a secure transport layer should be used.
listen_address: "127.0.0.1:51234"
# shutdown_signal_header is mainly useful for testing http mode,
# do not set it in production
# if set, better be treated like a secret, and a secure transport
# layer should be used.
# this is the value set on "X-Shutdown-Signal" header in the http request
# shutdown_signal_header: "test-shutdown-signal"
# listen_address: "127.0.0.1:51234"
# request_read_timeout: 2s
# response_write_timeout: 2s
# timeout: 5s
# max_header_bytes: 8192
# max_concurrent_requests: 1 # 0 means no limit
# request_required_headers:
# "X-Required-Header": "required-value"
# "X-Required-Header2": "" # header existance is required, not value


check_suites:
Expand Down
10 changes: 8 additions & 2 deletions examples/test-http.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
runners:
default:
timeout: 1m
request_required_headers:
"X-Required-Header": "required-value"
http:
listen_address: "127.0.0.1:51234"
request_read_timeout: 2s
response_write_timeout: 2s
# shutdown_signal_header is mainly useful for testing http mode, do not set it in production
# if set, better be treated like a secret, and a secure transport layer should be used.
# shutdown_signal_header is mainly useful for testing http mode,
# do not set it in production
# if set, better be treated like a secret, and a secure transport
# layer should be used.
# this is the value set on "X-Shutdown-Signal" header in the http request
shutdown_signal_header: "test-shutdown-signal"
timeout: 5s
request_required_headers:
"X-Required-Header2": "" # header existance is required


check_suites:
Expand Down
92 changes: 56 additions & 36 deletions internal/conf.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package chkok

import (
"maps"
"os"
"time"

Expand All @@ -21,17 +22,19 @@ type Conf struct {

// ConfRunner is config for the check runners
type ConfRunner struct {
Timeout *time.Duration
ShutdownSignalHeader *string `yaml:"shutdown_signal_header"`
MaxHeaderBytes *int `yaml:"max_header_bytes"`
MaxConcurrentRequests *int `yaml:"max_concurrent_requests"`
ListenAddress string `yaml:"listen_address"`
RequestReadTimeout *time.Duration `yaml:"request_read_timeout"`
ResponseWriteTimeout *time.Duration `yaml:"response_write_timeout"`
ResponseOK *string `yaml:"response_ok"`
ResponseFailed *string `yaml:"response_failed"`
ResponseTimeout *string `yaml:"response_timeout"`
ResponseUnavailable *string `yaml:"response_unavailable"`
Timeout *time.Duration
ShutdownSignalHeader *string `yaml:"shutdown_signal_header"`
MaxHeaderBytes *int `yaml:"max_header_bytes"`
MaxConcurrentRequests *int `yaml:"max_concurrent_requests"`
ListenAddress string `yaml:"listen_address"`
RequestReadTimeout *time.Duration `yaml:"request_read_timeout"`
RequestRequiredHeaders map[string]string `yaml:"request_required_headers"`
ResponseWriteTimeout *time.Duration `yaml:"response_write_timeout"`
ResponseOK *string `yaml:"response_ok"`
ResponseFailed *string `yaml:"response_failed"`
ResponseTimeout *string `yaml:"response_timeout"`
ResponseUnavailable *string `yaml:"response_unavailable"`
ResponseInvalidRequest *string `yaml:"response_invalid_request"`
}

// ConfCheckSpec is the spec for each check configuration
Expand Down Expand Up @@ -66,20 +69,22 @@ func GetBaseConfRunner() ConfRunner {
var maxHeaderBytes int = 8 * 1024
var MaxConcurrentRequests int = 1
var respOK, respFailed, respTimeout string = "OK", "FAILED", "TIMEOUT"
var respUnavailable string = "UNAVAILABLE"
var respUnavailable, respInvalidRequest string = "UNAVAILABLE", "INVALID REQUEST"

baseConf := ConfRunner{
Timeout: &timeout,
ShutdownSignalHeader: nil,
MaxHeaderBytes: &maxHeaderBytes,
ListenAddress: "127.0.0.1:8880",
RequestReadTimeout: &readTimeout,
ResponseWriteTimeout: &writeTimout,
ResponseOK: &respOK,
ResponseFailed: &respFailed,
ResponseTimeout: &respTimeout,
ResponseUnavailable: &respUnavailable,
MaxConcurrentRequests: &MaxConcurrentRequests,
Timeout: &timeout,
ShutdownSignalHeader: nil,
MaxHeaderBytes: &maxHeaderBytes,
ListenAddress: "127.0.0.1:8880",
RequestReadTimeout: &readTimeout,
RequestRequiredHeaders: map[string]string{},
ResponseWriteTimeout: &writeTimout,
ResponseOK: &respOK,
ResponseFailed: &respFailed,
ResponseTimeout: &respTimeout,
ResponseInvalidRequest: &respInvalidRequest,
ResponseUnavailable: &respUnavailable,
MaxConcurrentRequests: &MaxConcurrentRequests,
}
return baseConf
}
Expand Down Expand Up @@ -134,6 +139,13 @@ func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner {
mergedConf.RequestReadTimeout = baseConf.RequestReadTimeout
}

// Merge the request required headers map with the baseConf
for key, value := range baseConf.RequestRequiredHeaders {
if _, exists := mergedConf.RequestRequiredHeaders[key]; !exists {
mergedConf.RequestRequiredHeaders[key] = value
}
}

if mergedConf.ResponseWriteTimeout == nil {
mergedConf.ResponseWriteTimeout = baseConf.ResponseWriteTimeout
}
Expand All @@ -158,22 +170,30 @@ func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner {
mergedConf.MaxConcurrentRequests = baseConf.MaxConcurrentRequests
}

if mergedConf.ResponseInvalidRequest == nil {
mergedConf.ResponseInvalidRequest = baseConf.ResponseInvalidRequest
}

return mergedConf
}

// CopyConfRunner returns a copy of the ConfRunner with the same values
func CopyConfRunner(conf *ConfRunner) ConfRunner {
return ConfRunner{
Timeout: conf.Timeout,
ShutdownSignalHeader: conf.ShutdownSignalHeader,
ListenAddress: conf.ListenAddress,
RequestReadTimeout: conf.RequestReadTimeout,
ResponseWriteTimeout: conf.ResponseWriteTimeout,
ResponseOK: conf.ResponseOK,
ResponseFailed: conf.ResponseFailed,
ResponseTimeout: conf.ResponseTimeout,
ResponseUnavailable: conf.ResponseUnavailable,
MaxHeaderBytes: conf.MaxHeaderBytes,
MaxConcurrentRequests: conf.MaxConcurrentRequests,
}
newConfRunner := ConfRunner{
Timeout: conf.Timeout,
ShutdownSignalHeader: conf.ShutdownSignalHeader,
ListenAddress: conf.ListenAddress,
RequestReadTimeout: conf.RequestReadTimeout,
RequestRequiredHeaders: map[string]string{},
ResponseWriteTimeout: conf.ResponseWriteTimeout,
ResponseOK: conf.ResponseOK,
ResponseFailed: conf.ResponseFailed,
ResponseTimeout: conf.ResponseTimeout,
ResponseUnavailable: conf.ResponseUnavailable,
ResponseInvalidRequest: conf.ResponseInvalidRequest,
MaxHeaderBytes: conf.MaxHeaderBytes,
MaxConcurrentRequests: conf.MaxConcurrentRequests,
}
maps.Copy(newConfRunner.RequestRequiredHeaders, conf.RequestRequiredHeaders)
return newConfRunner
}
81 changes: 40 additions & 41 deletions internal/conf_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package chkok

import (
"maps"
"testing"
"time"
)

// TestReadConfErrors tests the ReadConf function for error handling
func TestReadConfErrors(t *testing.T) {
var conf *Conf
var err error
Expand Down Expand Up @@ -147,17 +149,19 @@ func TestGetConfRunner(t *testing.T) {

runners := ConfRunners{
"default": ConfRunner{
Timeout: &fiveSecond,
ListenAddress: "localhost:8080",
ResponseWriteTimeout: &tenSecond,
Timeout: &fiveSecond,
ListenAddress: "localhost:8080",
ResponseWriteTimeout: &tenSecond,
RequestRequiredHeaders: map[string]string{"X-Test-Default": "test"},
},
"testMinimalHttpRunner": ConfRunner{},
"testHttpRunner": ConfRunner{
Timeout: &tenSecond,
ShutdownSignalHeader: &shutdownSignalHeader,
ListenAddress: "localhost:9090",
RequestReadTimeout: &fiveSecond,
ResponseWriteTimeout: &fiveSecond,
Timeout: &tenSecond,
ShutdownSignalHeader: &shutdownSignalHeader,
ListenAddress: "localhost:9090",
RequestReadTimeout: &fiveSecond,
ResponseWriteTimeout: &fiveSecond,
RequestRequiredHeaders: map[string]string{"X-Test-2": "http-test"},
},
}

Expand All @@ -173,14 +177,15 @@ func TestGetConfRunner(t *testing.T) {
name: "Existing runner",
runnerName: "testHttpRunner",
expectedRunner: ConfRunner{
Timeout: &tenSecond,
ShutdownSignalHeader: &shutdownSignalHeader,
ListenAddress: "localhost:9090",
RequestReadTimeout: &fiveSecond,
ResponseWriteTimeout: &fiveSecond,
ResponseOK: &ok,
ResponseFailed: &failed,
ResponseTimeout: &timeout,
Timeout: &tenSecond,
ShutdownSignalHeader: &shutdownSignalHeader,
ListenAddress: "localhost:9090",
RequestReadTimeout: &fiveSecond,
RequestRequiredHeaders: map[string]string{"X-Test-2": "http-test", "X-Test-Default": "test"},
ResponseWriteTimeout: &fiveSecond,
ResponseOK: &ok,
ResponseFailed: &failed,
ResponseTimeout: &timeout,
},
expectedExists: true,
},
Expand All @@ -198,42 +203,36 @@ func TestGetConfRunner(t *testing.T) {
},
}

var wantTimeout, wantReadTimeout, wantWriteTimeout time.Duration = 0, 0, 0
var wantResponseOK, wantResponseFailed, wantResponseTimeout, wantListenAddr string = "", "", "", ""

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
runner, exists := GetConfRunner(&runners, tt.runnerName)
if exists != tt.expectedExists {
t.Errorf("expected runner exists to be %v, got %v", tt.expectedExists, exists)
t.Errorf("exists want %v got %v", tt.expectedExists, exists)
}
expRunner := tt.expectedRunner
if *runner.Timeout != *expRunner.Timeout {
t.Errorf("timout want %v got %+v", *expRunner.Timeout, runner.Timeout)
}
wantTimeout = *tt.expectedRunner.Timeout
wantReadTimeout = *tt.expectedRunner.RequestReadTimeout
wantWriteTimeout = *tt.expectedRunner.ResponseWriteTimeout
wantListenAddr = tt.expectedRunner.ListenAddress
wantResponseOK = *tt.expectedRunner.ResponseOK
wantResponseFailed = *tt.expectedRunner.ResponseFailed
wantResponseTimeout = *tt.expectedRunner.ResponseTimeout
if *runner.Timeout != wantTimeout {
t.Errorf("expected runner timeout to be %+v, got %+v", wantTimeout, runner.Timeout)
if *runner.RequestReadTimeout != *expRunner.RequestReadTimeout {
t.Errorf("read timeout want %+v got %+v", *expRunner.Timeout, runner.RequestReadTimeout)
}
if *runner.RequestReadTimeout != wantReadTimeout {
t.Errorf("expected runner read timeout to be %+v, got %+v", wantReadTimeout, runner.RequestReadTimeout)
if *runner.ResponseWriteTimeout != *expRunner.ResponseWriteTimeout {
t.Errorf("write timeout want %+v got %+v", *expRunner.ResponseWriteTimeout, runner.ResponseWriteTimeout)
}
if *runner.ResponseWriteTimeout != wantWriteTimeout {
t.Errorf("expected runner write timeout to be %+v, got %+v", wantWriteTimeout, runner.ResponseWriteTimeout)
if runner.ListenAddress != expRunner.ListenAddress {
t.Errorf("listen address want %s got %s", expRunner.ListenAddress, runner.ListenAddress)
}
if runner.ListenAddress != wantListenAddr {
t.Errorf("expected runner listen address to be %s, got %s", wantListenAddr, runner.ListenAddress)
if *runner.ResponseOK != *expRunner.ResponseOK {
t.Errorf("response ok want %s got %s", *expRunner.ResponseOK, *runner.ResponseOK)
}
if *runner.ResponseOK != wantResponseOK {
t.Errorf("expected runner response ok to be %s, got %s", wantResponseOK, *runner.ResponseOK)
if *runner.ResponseFailed != *expRunner.ResponseFailed {
t.Errorf("response failed want %s got %s", *expRunner.ResponseFailed, *runner.ResponseFailed)
}
if *runner.ResponseFailed != wantResponseFailed {
t.Errorf("expected runner response failed to be %s, got %s", wantResponseFailed, *runner.ResponseFailed)
if *runner.ResponseTimeout != *expRunner.ResponseTimeout {
t.Errorf("response timeout want %s got %s", *expRunner.ResponseTimeout, *runner.ResponseTimeout)
}
if *runner.ResponseTimeout != wantResponseTimeout {
t.Errorf("expected runner response timeout to be %s, got %s", wantResponseTimeout, *runner.ResponseTimeout)
if !maps.Equal(expRunner.RequestRequiredHeaders, runner.RequestRequiredHeaders) {
t.Errorf("request headers want %+v got %+v", expRunner.RequestRequiredHeaders, runner.RequestRequiredHeaders)
}
})
}
Expand Down
Loading

0 comments on commit 786bca1

Please sign in to comment.