diff --git a/README.md b/README.md index e8877494b0e..1625013d975 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ control. **Certificate hotswapping**: Ghostunnel can reload certificates at runtime without dropping existing connections. To trigger a reload, simply send `SIGUSR1` to the process (or set a time-based reloading interval). This will -cause ghostunnel to reload the keystore files. Once successful, the reloaded +cause ghostunnel to reload the certificate/key. Once successful, the reloaded certificate will be used for new connections going forward. **Monitoring and metrics**: Ghostunnel has a built-in status feature that diff --git a/main.go b/main.go index d2d1618324d..2b947ad8b88 100644 --- a/main.go +++ b/main.go @@ -121,7 +121,6 @@ var exitFunc = os.Exit // Context groups listening context data together type Context struct { - watcher chan bool status *statusHandler statusHTTP *http.Server shutdownTimeout time.Duration @@ -305,12 +304,6 @@ func run(args []string) error { } metrics := sqmetrics.NewMetrics(*metricsURL, *metricsPrefix, client, *metricsInterval, metrics.DefaultRegistry, logger) - // Set up file watchers (if requested) - watcher := make(chan bool, 1) - if *timedReload > 0 { - go watchFiles([]string{*keystorePath}, *timedReload, watcher) - } - cert, err := buildCertificate(*keystorePath, *keystorePass) if err != nil { fmt.Fprintf(os.Stderr, "error: unable to load certificates: %s\n", err) @@ -332,7 +325,8 @@ func run(args []string) error { logger.Printf("using target address %s", *serverForwardAddress) status := newStatusHandler(dial) - context := &Context{watcher, status, nil, *shutdownTimeout, dial, metrics, cert} + context := &Context{status, nil, *shutdownTimeout, dial, metrics, cert} + go context.reloadHandler(*timedReload) // Start listening err = serverListen(context) @@ -361,7 +355,8 @@ func run(args []string) error { } status := newStatusHandler(dial) - context := &Context{watcher, status, nil, *shutdownTimeout, dial, metrics, cert} + context := &Context{status, nil, *shutdownTimeout, dial, metrics, cert} + go context.reloadHandler(*timedReload) // Start listening err = clientListen(context) diff --git a/signals.go b/signals.go index e029fb0454b..b3637b94a34 100644 --- a/signals.go +++ b/signals.go @@ -72,12 +72,19 @@ func (context *Context) signalHandler(p *proxy.Proxy) { logger.Printf("received %s, reloading certificates", sig.String()) context.reload() - case <-context.watcher: - context.reload() } } } +func (context *Context) reloadHandler(interval time.Duration) { + if interval == 0 { + return + } + for range time.Tick(interval) { + context.reload() + } +} + func (context *Context) reload() { context.status.Reloading() err := context.cert.Reload() diff --git a/watcher.go b/watcher.go deleted file mode 100644 index 61e39bbd505..00000000000 --- a/watcher.go +++ /dev/null @@ -1,100 +0,0 @@ -/*- - * Copyright 2015 Square Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package main - -import ( - "bytes" - "crypto/sha256" - "encoding/hex" - "io/ioutil" - "path" - "time" -) - -// Watch files with a periodic timer, for filesystems that don't do -// inotify correctly (e.g. some fuse filesystems or other custom stuff). -func watchFiles(files []string, duration time.Duration, notify chan bool) { - hashes := hashFiles(files) - ticker := time.Tick(duration) - - for { - <-ticker - logger.Printf("running timed reload (timer fired)") - - change := false - for _, file := range files { - if fileChanged(hashes, file) { - logger.Printf("detected change on %s, reloading", path.Base(file)) - change = true - } - } - - if change { - notify <- true - } else { - logger.Printf("nothing changed, not reloading") - } - } -} - -// Hash initial state of files we're watching -func hashFiles(files []string) map[string][32]byte { - hashes := make(map[string][32]byte) - - for _, file := range files { - hash, err := hashFile(file) - if err != nil { - logger.Printf("error reading file: %s", err) - continue - } - - name := path.Base(file) - logger.Printf("sha256(%s) = %s", name, hex.EncodeToString(hash[:])) - hashes[name] = hash - } - - return hashes -} - -// Read & hash a single file -func hashFile(file string) ([32]byte, error) { - data, err := ioutil.ReadFile(file) - if err != nil { - return [32]byte{}, err - } - - return sha256.Sum256(data), nil -} - -// Check if a file has changed contents, update hash -func fileChanged(hashes map[string][32]byte, file string) bool { - newHash, err := hashFile(file) - if err != nil { - logger.Printf("error reading file: %s", err) - return false - } - - name := path.Base(file) - oldHash := hashes[name] - if !bytes.Equal(oldHash[:], newHash[:]) { - logger.Printf("sha256(%s) = %s", name, hex.EncodeToString(newHash[:])) - hashes[name] = newHash - return true - } - - return false -} diff --git a/watcher_test.go b/watcher_test.go deleted file mode 100644 index 1dd7f9475f8..00000000000 --- a/watcher_test.go +++ /dev/null @@ -1,83 +0,0 @@ -/*- - * Copyright 2015 Square Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package main - -import ( - "io/ioutil" - "os" - "testing" - "time" -) - -func TestWatchers(t *testing.T) { - // Setup - tmpDir, err := ioutil.TempDir("", "ghostunnel-test") - panicOnError(err) - - tmpFile, err := ioutil.TempFile(tmpDir, "") - panicOnError(err) - - tmpFile.WriteString("test") - tmpFile.Sync() - defer os.Remove(tmpFile.Name()) - - // Start watching - watcher := make(chan bool, 1) - - go watchFiles([]string{tmpFile.Name()}, time.Duration(100)*time.Millisecond, watcher) - - time.Sleep(time.Duration(1) * time.Second) - - // Must detect new writes - tmpFile.WriteString("new") - tmpFile.Sync() - tmpFile.Close() - - select { - case <-watcher: - case <-time.Tick(time.Duration(1) * time.Second): - t.Fatalf("timeout, no notification on changed file") - } - - // Must detect file being replaced - os.Remove(tmpFile.Name()) - tmpFile, err = os.Create(tmpFile.Name()) - panicOnError(err) - - tmpFile.WriteString("blubb") - tmpFile.Sync() - tmpFile.Close() - - select { - case <-watcher: - case <-time.Tick(time.Duration(1) * time.Second): - t.Fatalf("timeout, no notification on changed file") - } -} - -func TestHashFilesNonExistent(t *testing.T) { - res := hashFiles([]string{"./does-not-exist"}) - if len(res) > 0 { - t.Error("hash files generated hash for non-existent file") - } -} - -func TestFileChangedNonExistent(t *testing.T) { - if fileChanged(map[string][32]byte{}, "./does-not-exist") { - t.Error("hash files generated hash for non-existent file") - } -}