diff --git a/Makefile b/Makefile index 7219706489a..86683ce3806 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -SOURCE_FILES := $(shell find . \( -name '*.go' -not -path './vendor*' \)) +SOURCE_FILES := $(shell find . \( -name '*.go' -not -path './vendor/*' \)) INTEGRATION_TESTS := $(shell find tests -name 'test-*.py' -exec basename {} .py \;) VERSION := $(shell git describe --always --dirty) @@ -8,7 +8,7 @@ ghostunnel: $(SOURCE_FILES) # Test binary with coverage instrumentation ghostunnel.test: $(SOURCE_FILES) - go test -c -covermode=count -coverpkg . + go test -c -covermode=count -coverpkg .,./auth,./certloader # Clean build output clean: @@ -25,6 +25,7 @@ test: unit $(INTEGRATION_TESTS) unit: go test -v -covermode=count -coverprofile=coverage-unit-test-base.out . go test -v -covermode=count -coverprofile=coverage-unit-test-auth.out ./auth + go test -v -covermode=count -coverprofile=coverage-unit-test-certloader.out ./certloader .PHONY: unit # Run integration tests diff --git a/certloader/certigo.go b/certloader/certigo.go new file mode 100644 index 00000000000..e997a008f69 --- /dev/null +++ b/certloader/certigo.go @@ -0,0 +1,75 @@ +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "os" + + certigo "github.com/square/certigo/lib" +) + +func readPEM(path, password, format string) ([]*pem.Block, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + + var pemBlocks []*pem.Block + err = certigo.ReadAsPEMFromFiles( + []*os.File{file}, + format, + func(prompt string) string { return password }, + func(block *pem.Block) { pemBlocks = append(pemBlocks, block) }) + if err != nil { + return nil, fmt.Errorf("error reading file '%s': %s", path, err) + } + if len(pemBlocks) == 0 { + return nil, fmt.Errorf("error reading file '%s', no certificates found", path) + } + + return pemBlocks, nil +} + +func readX509(path string) ([]*x509.Certificate, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + + errs := []error{} + out := []*x509.Certificate{} + + err = certigo.ReadAsX509FromFiles( + []*os.File{file}, "PEM", nil, + func(cert *x509.Certificate, err error) { + if err != nil { + errs = append(errs, err) + return + } + out = append(out, cert) + }) + if err != nil || len(errs) > 0 { + return nil, fmt.Errorf("error reading file '%s'", path) + } + if len(out) == 0 { + return nil, fmt.Errorf("no certificates found in file '%s'", path) + } + return out, nil +} diff --git a/certloader/certigo_test.go b/certloader/certigo_test.go new file mode 100644 index 00000000000..1ad012985a5 --- /dev/null +++ b/certloader/certigo_test.go @@ -0,0 +1,121 @@ +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +const testCertificate = ` +-----BEGIN CERTIFICATE----- +MIIDKDCCAhCgAwIBAgIJAPjKcAKZMSkUMA0GCSqGSIb3DQEBCwUAMCMxEjAQBgNV +BAMTCWxvY2FsaG9zdDENMAsGA1UECxMEdGVzdDAeFw0xNTEwMDcxODExNTlaFw0x +NjEwMDYxODExNTlaMCMxEjAQBgNVBAMTCWxvY2FsaG9zdDENMAsGA1UECxMEdGVz +dDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAK4EbZf3EMb/ciW5nGlN +yrf5Pcfz3ZnjWRy1kvBriuPD6NQSZaTWTPmJnbdS/Q5FH0p/6ZjdZKXf6f7WNnAz +JwW0XK7NT3N2DrWfgQqrrVvLAYlfqgHnC7Fxqq7FCpgWjf7L8wcQXfdIYkhdsE4n +osLmCRvx7qS+wuasb6nLzBtg7b99ZvO8K/sezrDIjwzemBWA1Vovztw/vGD4J4/h +D0hiOOqFGWstwFxB9oG4d/QJ45VttLMGuiZCY+A4IyBgPCxphrEec6zf8H4u/ceQ +bB8i1IMmD1VTsq9afeVhMKuoSn2Bs3VRB6c9FpL41/ftN5mYpZCteZH+qQ/DhK/y +Dz0CAwEAAaNfMF0wDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAqwwHQYDVR0lBBYw +FAYIKwYBBQUHAwIGCCsGAQUFBwMBMCEGA1UdEQQaMBiHBH8AAAGHEAAAAAAAAAAA +AAAAAAAAAAEwDQYJKoZIhvcNAQELBQADggEBABuBe5cuyZy6StCYebI3FLN3CEla +/3Hreul6i5giqkF90X6M+9eERZCqSqm2whBMSF4vG+1B6GX1K6S29PUOmTDWyasW +B0WlBgRiZld3JfFBuJu6xk1a8+XwwlGOgEsggepjkrAXbjbqnUMAKOJkjFIyIPvk +5p97SYDJYiOh7MmjyXUIzyNdqpL5WiUgKPTxXL+1tNzxH1jjxfVdjaNaNcOJuu20 +9tsMqDZyTm2yZWOBUXbtqlaMQHrs5Ksz5EKk5/U5KfJehKss8oba2npg/6echTJU +nkOOZ6U4eEju7H1S46qlN9ZmUmSrrjwec3H7CnvxQ0ncEyZXlEiTlbO2JQI= +-----END CERTIFICATE-----` + +const testCertificateBad = ` +-----BEGIN CERTIFICATE----- +//////////////////////////////////////////////////////////////// +-----END CERTIFICATE-----` + +func TestReadPEMValid(t *testing.T) { + cert, err := ioutil.TempFile("", "ghostunnel-test") + assert.Nil(t, err, "temp file error") + defer os.Remove(cert.Name()) + + _, err = cert.Write([]byte(testCertificate)) + assert.Nil(t, err, "temp file error") + + blocks, err := readPEM(cert.Name(), "", "PEM") + assert.Nil(t, err, "should read PEM file") + assert.Len(t, blocks, 1, "should find one PEM block") +} + +func TestReadPEMInvalid(t *testing.T) { + cert, err := ioutil.TempFile("", "ghostunnel-test") + assert.Nil(t, err, "temp file error") + defer os.Remove(cert.Name()) + + _, err = cert.Write([]byte("invalid")) + assert.Nil(t, err, "temp file error") + + blocks, err := readPEM(cert.Name(), "", "PEM") + assert.NotNil(t, err, "should not parse invalid file") + assert.Len(t, blocks, 0, "should not return PEM blocks") + + blocks, err = readPEM("does-not-exist", "", "PEM") + assert.NotNil(t, err, "should not parse invalid file") + assert.Len(t, blocks, 0, "should not return PEM blocks") +} + +func TestReadX509Valid(t *testing.T) { + cert, err := ioutil.TempFile("", "ghostunnel-test") + assert.Nil(t, err, "temp file error") + defer os.Remove(cert.Name()) + + _, err = cert.Write([]byte(testCertificate)) + assert.Nil(t, err, "temp file error") + + certs, err := readX509(cert.Name()) + assert.Nil(t, err, "should parse certificate from PEM file") + assert.Len(t, certs, 1, "should find one certificate") +} + +func TestReadX509Invalid(t *testing.T) { + cert0, err := ioutil.TempFile("", "ghostunnel-test") + assert.Nil(t, err, "temp file error") + defer os.Remove(cert0.Name()) + + cert1, err := ioutil.TempFile("", "ghostunnel-test") + assert.Nil(t, err, "temp file error") + defer os.Remove(cert1.Name()) + + _, err = cert0.Write([]byte("invalid")) + assert.Nil(t, err, "temp file error") + _, err = cert1.Write([]byte(testCertificateBad)) + assert.Nil(t, err, "temp file error") + + certs, err := readX509(cert0.Name()) + assert.NotNil(t, err, "should not parse invalid file") + assert.Len(t, certs, 0, "should not parse invalid file") + + certs, err = readX509(cert1.Name()) + assert.NotNil(t, err, "should not parse invalid file") + assert.Len(t, certs, 0, "should not parse invalid file") + + certs, err = readX509("does-not-exist") + assert.NotNil(t, err, "should not parse invalid file") + assert.Len(t, certs, 0, "should not parse invalid file") +} diff --git a/certloader/certstore_disabled.go b/certloader/certstore_disabled.go new file mode 100644 index 00000000000..47caed7065e --- /dev/null +++ b/certloader/certstore_disabled.go @@ -0,0 +1,33 @@ +// +build !certstore + +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import "errors" + +// SupportsKeychain returns true or false, depending on whether the +// binary was built with Certstore/Keychain support or not (requires CGO, recent +// Darwin to build). +func SupportsKeychain() bool { + return false +} + +// CertificateFromKeychainIdentity creates a reloadable certificate from a system keychain identity. +func CertificateFromKeychainIdentity(commonName string) (Certificate, error) { + return nil, errors.New("not supported") +} diff --git a/certloader/certstore_enabled.go b/certloader/certstore_enabled.go new file mode 100644 index 00000000000..4b37fe090ce --- /dev/null +++ b/certloader/certstore_enabled.go @@ -0,0 +1,114 @@ +// +build certstore + +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "sync/atomic" + "unsafe" + + "github.com/mastahyeti/certstore" +) + +type certstoreCertificate struct { + // Common name of keychain identity + commonName string + // Cached *tls.Certificate + cached unsafe.Pointer +} + +// SupportsKeychain returns true or false, depending on whether the +// binary was built with Certstore/Keychain support or not (requires CGO, recent +// Darwin to build). +func SupportsKeychain() bool { + return true +} + +// CertificateFromKeychainIdentity creates a reloadable certificate from a system keychain identity. +func CertificateFromKeychainIdentity(commonName string) (Certificate, error) { + c := certstoreCertificate{ + commonName: commonName, + } + err := c.Reload() + if err != nil { + return nil, err + } + return &c, nil +} + +// Reload transparently reloads the certificate. +func (c *certstoreCertificate) Reload() error { + store, err := certstore.Open() + if err != nil { + return err + } + + identitites, err := store.Identities() + if err != nil { + return err + } + + var certAndKey *tls.Certificate + for _, identity := range identitites { + chain, err := identity.CertificateChain() + if err != nil { + continue + } + + signer, err := identity.Signer() + if err != nil { + continue + } + + if chain[0].Subject.CommonName == c.commonName { + certAndKey = &tls.Certificate{ + Certificate: serializeChain(chain), + PrivateKey: signer, + } + break + } + } + + if certAndKey != nil { + atomic.StorePointer(&c.cached, unsafe.Pointer(certAndKey)) + return nil + } + + return fmt.Errorf("unable to find identity with common name '%s' in keychain", c.commonName) +} + +// GetCertificate retrieves the actual underlying tls.Certificate. +func (c *certstoreCertificate) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return (*tls.Certificate)(atomic.LoadPointer(&c.cached)), nil +} + +// GetClientCertificate retrieves the actual underlying tls.Certificate. +func (c *certstoreCertificate) GetClientCertificate(certInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return (*tls.Certificate)(atomic.LoadPointer(&c.cached)), nil +} + +func serializeChain(chain []*x509.Certificate) [][]byte { + out := [][]byte{} + for _, cert := range chain { + out = append(out, cert.Raw) + } + return out +} diff --git a/tls_no_cgo.go b/certloader/certstore_test.go similarity index 72% rename from tls_no_cgo.go rename to certloader/certstore_test.go index 89f5ceaa7dd..0efc9b4879f 100644 --- a/tls_no_cgo.go +++ b/certloader/certstore_test.go @@ -1,5 +1,3 @@ -// +build !cgo - /*- * Copyright 2018 Square Inc. * @@ -16,17 +14,15 @@ * limitations under the License. */ -package main +package certloader import ( - "crypto" - "errors" -) + "testing" -func newPKCS11(pubkey crypto.PublicKey) (crypto.PrivateKey, error) { - panic(errors.New("PKCS11 unavailable when compiled without CGO support")) -} + "github.com/stretchr/testify/assert" +) -func hasPKCS11() bool { - return false +func TestInvalidKeychainIdentity(t *testing.T) { + _, err := CertificateFromKeychainIdentity("") + assert.NotNil(t, err, "should not load invalid identity") } diff --git a/certloader/dialer.go b/certloader/dialer.go new file mode 100644 index 00000000000..2e254bdd5dd --- /dev/null +++ b/certloader/dialer.go @@ -0,0 +1,88 @@ +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "crypto/tls" + "net" + "time" +) + +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// Dialer is an interface for dialers. Can be a net.Dialer, http_dialer.HttpTunnel, or a dialer from this package. +type Dialer interface { + Dial(network, address string) (net.Conn, error) +} + +type mtlsDialer struct { + cert Certificate + config *tls.Config + timeout time.Duration + dialer Dialer +} + +// DialerWithCertificate creates a dialer that reloads its certificate (if set) before dialing new connections. +// If the certificate is nil, the dialer will still work, but it won't supply client certificates on connections. +func DialerWithCertificate(cert Certificate, config *tls.Config, timeout time.Duration, dialer Dialer) Dialer { + d := mtlsDialer{ + cert: cert, + config: config, + timeout: timeout, + dialer: dialer, + } + if cert != nil && config.GetClientCertificate == nil { + config.GetClientCertificate = cert.GetClientCertificate + } + return &d +} + +func (d *mtlsDialer) Dial(network, address string) (net.Conn, error) { + return dialWithDialer(d.dialer, d.timeout, network, address, d.config) +} + +// Internal copy of tls.DialWithDialer, adapted so it can work with HTTP CONNECT dialers. +// See https://golang.org/pkg/crypto/tls/#DialWithDialer for original implementation. +func dialWithDialer(dialer Dialer, timeout time.Duration, network, addr string, config *tls.Config) (*tls.Conn, error) { + errChannel := make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- timeoutError{} + }) + + rawConn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + conn := tls.Client(rawConn, config) + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + + if err != nil { + rawConn.Close() + return nil, err + } + + return conn, nil +} diff --git a/certloader/dialer_test.go b/certloader/dialer_test.go new file mode 100644 index 00000000000..15146c5d376 --- /dev/null +++ b/certloader/dialer_test.go @@ -0,0 +1,30 @@ +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTimeoutError(t *testing.T) { + err := timeoutError{} + assert.False(t, err.Error() == "", "Timeout error should have message") + assert.True(t, err.Timeout(), "Timeout error should have Timeout() == true") + assert.True(t, err.Temporary(), "Timeout error should have Temporary() == true") +} diff --git a/certloader/doc.go b/certloader/doc.go new file mode 100644 index 00000000000..77230f435d0 --- /dev/null +++ b/certloader/doc.go @@ -0,0 +1,5 @@ +// Package certloader provides abstractions over certificates that can be used +// for clients and servers to make runtime reloading easier. It supports reading +// certificates from PEM files, PKCS#12 keystores, PKCS#11 hardware modules and +// from the macOS keychain. +package certloader diff --git a/certloader/generic_test.go b/certloader/generic_test.go new file mode 100644 index 00000000000..26dd6852907 --- /dev/null +++ b/certloader/generic_test.go @@ -0,0 +1,25 @@ +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import "testing" + +func TestBuildFlags(t *testing.T) { + // Return true/false based on flags, but should never panic. + SupportsPKCS11() + SupportsKeychain() +} diff --git a/certloader/keystore.go b/certloader/keystore.go new file mode 100644 index 00000000000..a786ef7e644 --- /dev/null +++ b/certloader/keystore.go @@ -0,0 +1,120 @@ +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "sync/atomic" + "unsafe" +) + +// Certificate wraps a TLS certificate and supports reloading at runtime. +type Certificate interface { + // Reload will reload the certificate and private key. Subsequent calls + // to GetCertificate/GetClientCertificate will return the newly loaded + // certificate, if reloading was successful. If reloading failed, the old + // state is kept. + Reload() error + + // GetCertificate returns the current underlying certificate. + // Can be used for tls.Config's GetCertificate callback. + GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) + + // GetClientCertificate returns the current underlying certificate. + // Can be used for tls.Config's GetClientCertificate callback. + GetClientCertificate(certInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) +} + +type keystoreCertificate struct { + // Keystore or PEM files path + keystorePaths []string + // Password for keystore (may be empty) + keystorePassword string + // File format as an indicator for certigo/lib + format string + // Cached *tls.Certificate + cached unsafe.Pointer +} + +// CertificateFromPEMFiles creates a reloadable certificate from a set of PEM files. +func CertificateFromPEMFiles(certificatePath, keyPath string) (Certificate, error) { + c := keystoreCertificate{ + keystorePaths: []string{certificatePath, keyPath}, + format: "PEM", + } + err := c.Reload() + if err != nil { + return nil, err + } + return &c, nil +} + +// CertificateFromKeystore creates a reloadable certificate from a PKCS#12 keystore. +func CertificateFromKeystore(keystorePath, keystorePassword string) (Certificate, error) { + c := keystoreCertificate{ + keystorePaths: []string{keystorePath}, + keystorePassword: keystorePassword, + format: "PKCS12", + } + err := c.Reload() + if err != nil { + return nil, err + } + return &c, nil +} + +// Reload transparently reloads the certificate. +func (c *keystoreCertificate) Reload() error { + var pemBlocks []*pem.Block + for _, path := range c.keystorePaths { + blocks, err := readPEM(path, c.keystorePassword, c.format) + if err != nil { + return err + } + pemBlocks = append(pemBlocks, blocks...) + } + + var pemBytes []byte + for _, block := range pemBlocks { + pemBytes = append(pemBytes, pem.EncodeToMemory(block)...) + } + + certAndKey, err := tls.X509KeyPair(pemBytes, pemBytes) + if err != nil { + return err + } + + certAndKey.Leaf, err = x509.ParseCertificate(certAndKey.Certificate[0]) + if err != nil { + return err + } + + atomic.StorePointer(&c.cached, unsafe.Pointer(&certAndKey)) + return nil +} + +// GetCertificate retrieves the actual underlying tls.Certificate. +func (c *keystoreCertificate) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return (*tls.Certificate)(atomic.LoadPointer(&c.cached)), nil +} + +// GetClientCertificate retrieves the actual underlying tls.Certificate. +func (c *keystoreCertificate) GetClientCertificate(certInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return (*tls.Certificate)(atomic.LoadPointer(&c.cached)), nil +} diff --git a/certloader/keystore_test.go b/certloader/keystore_test.go new file mode 100644 index 00000000000..1ed5ab15614 --- /dev/null +++ b/certloader/keystore_test.go @@ -0,0 +1,111 @@ +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +const testCombinedCertificateAndKey = ` +-----BEGIN CERTIFICATE----- +MIIC6DCCAdCgAwIBAgIJAK56Q73Kb2tfMA0GCSqGSIb3DQEBCwUAMA8xDTALBgNV +BAMMBHJvb3QwHhcNMTgwNTI0MTg0MjAwWhcNMzIwMTMxMTg0MjAwWjARMQ8wDQYD +VQQDDAZzZXJ2ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC/bkEe +7lxh6H/XkGK/00GR+XNZHRgYagpbbATNnEt7zXJ3Ot6Fu2SJvUpjRhfJ2GCakjLq +2+YFlH8heN3DEYFFxuLOtnHrNzZW8TzyWlV9LWK+jO/YjEoH6wGwvZ/XiDwYTg/B +yzUphvuUdYMrWWdvV2PcLTspfjSNuUM8QjhKHchUJzddqaEWsTUM7tWIPpRZiDQH +BNmoEKklBrgwKyQZe/IJ/VL3Vntbdpp1eycHk6uh7hAWZ897Hidv8YwOP8Fusr0c +AMj2vEzS2HHED16ha8TAN+5lycAPPJ9b8bOeSv5K90w73Szjxf8fHkmgFmdI4Q3e +N9S2bVpUx3f+lNMvAgMBAAGjRTBDMBMGA1UdJQQMMAoGCCsGAQUFBwMBMCwGA1Ud +EQQlMCOHBH8AAAGHEAAAAAAAAAAAAAAAAAAAAAGCCWxvY2FsaG9zdDANBgkqhkiG +9w0BAQsFAAOCAQEAZq3KX0T8BvKwTTtCrzV7wkdruyfUFxNER2GAzynmm9rIHcTE +UiMoRZk/s5CcqJZFNS1N9ObqCXzNDzQreBOVcPk/YnCwiVviuzDfZxPPchrB3prp +1B9b813dhUknjy2nU40Bi/djx8Fp8H59EpGM+OWFt368zxb7NWxK8PFPKJDyHvbA +QDU7QP3y99EoYugQKPmjiav6gzDFegYilBt3bBKUwRqqMOv08wia4oycaCqZW+ay +qkfXo0Io2kEp2nkbQfPhAZASq1Il7x6ytr6NyIBCxsKvgPYF2YdDqfs2a/cwxU7A +zIo7sqovg5zVX3IUCJNbnC5g6wGYRoCUXzeExg== +-----END CERTIFICATE----- +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAv25BHu5cYeh/15Biv9NBkflzWR0YGGoKW2wEzZxLe81ydzre +hbtkib1KY0YXydhgmpIy6tvmBZR/IXjdwxGBRcbizrZx6zc2VvE88lpVfS1ivozv +2IxKB+sBsL2f14g8GE4Pwcs1KYb7lHWDK1lnb1dj3C07KX40jblDPEI4Sh3IVCc3 +XamhFrE1DO7ViD6UWYg0BwTZqBCpJQa4MCskGXvyCf1S91Z7W3aadXsnB5Oroe4Q +FmfPex4nb/GMDj/BbrK9HADI9rxM0thxxA9eoWvEwDfuZcnADzyfW/Gznkr+SvdM +O90s48X/Hx5JoBZnSOEN3jfUtm1aVMd3/pTTLwIDAQABAoIBAAFkrwqnl3qK86tA +/McCpZ6HX0SNxqge1XZ24c4RTidXhnbBse7tPz0VaJ4yW2f3sDRPzhkRgqoVu5sl +ww9xaCx21x3EDC43F6koVlY5PBgOJYLXicNcugk2t6tupeQutKlEoC676OYlel1J +QawmGW+hBjQLDDwwE/fYGlos7TX04EzAfDDORZ5WQBnSFlGDXFNIV+pTXAwy5KHr +OQsbJEKoqF8KcXSS4yDZ0ZEKFJrC6pZUXAkDhUZd201UQtMIgOReGyBDmmGZZkNe +t1uBiZqwidYvMHFPT56S3R5nhL+4zeQ4SjvGluXei3c+uCkTtT7l7AKs84OC+DGg +NZDU4OECgYEA80nmBy0XmRR6CA6yg9Vy30s0Z/jIEk01USQmnHMOvIwOevi5TKiS +CgDZiHzM6b9PBTuux4/FA1vRIXJnSsMjamHIQbDHdjEi5ZuYF0SSGydEtNhCx/FF +eW5ZKKLBb/M+sfu06CN0Tts6OyKkQuWLSZ8RdjcYxwgrM1gDNktffNECgYEAyW67 +DBNXTfzrRJG5Su+dGiULxlRlXd1Nv69gZaAH0FBKBw/BV+AGnYC/MbZSHsalmnXW ++FvwdlARoP5PrHkXA2V4cHyLFxdLHuMAokt8qu+cUHUXssyqay9jXEgkX/vKfjVm +pHZszJz4iIbXuqDaX1nBJqCznUO8I3KfH1SDT/8CgYBD6lI7mJvo0O2MCEZPRSvP +J9hWWf3IFiOXJiddL0Vi3xo/u+VGgBxcjIYtcuFlM1Gk3VdaQEk4Oc50rtIk7bqa +PPfBVs8nsGnUfQ4FGNBMojas4V4rILBLSMG89UpYrSfIWcLTtuoGBW8JCQ+f2SJ8 +B9rBDHpvPVmJ+LzU0E+0sQKBgQCSAcFzL1HJJdsjCL3Wo3isys2OJP6U2yTQHL8y +6py/UnNWSwVKPQiOghQUZKOBy1ueamw3+eyC1ebxW2VFD0CvJY33e08WnbvF16VN +/omPHb+gUj+rSs78gozzBxfWuxw7/1k3POAAMIe17ofQr2eaVeS7qyCGjeKBj0Pn +4cqM4QKBgCxn5c5kskJcuSEKrCvuuSRYBbYY7FxBH2ksnFECl9VnsDl8pYMaTf0E +9kNvJK3/1WjJOaXy4cEPx/BMbHcrh01K/IM3Te2VCrp7tkA5H1V2YGQD4/aqmajA +plW93GyQzhwY+Cc1Of2ktdBwOHNn1xWyl3lgjAaW+da1nEhq6Anc +-----END RSA PRIVATE KEY-----` + +func TestCertificateFromPEMFilesValid(t *testing.T) { + file, err := ioutil.TempFile("", "ghostunnel-test") + assert.Nil(t, err, "temp file error") + defer os.Remove(file.Name()) + + _, err = file.Write([]byte(testCombinedCertificateAndKey)) + assert.Nil(t, err, "temp file error") + + cert, err := CertificateFromPEMFiles(file.Name(), file.Name()) + assert.Nil(t, err, "should read PEM file with certificate & private key") + + c0, err := cert.GetCertificate(nil) + assert.Nil(t, err, "should have a valid tls.Certificate on GetCertificate call") + + c1, err := cert.GetClientCertificate(nil) + assert.Nil(t, err, "should have a valid tls.Certificate on GetCertificate call") + + assert.Equal(t, c0.Leaf.Subject.CommonName, "server", "should have the right cert") + assert.Equal(t, c1.Leaf.Subject.CommonName, "server", "should have the right cert") + assert.Nil(t, cert.Reload(), "should be able to reload") + + // Remove file & test reload failure + os.Remove(file.Name()) + assert.NotNil(t, cert.Reload(), "should not be able to reload") +} + +func TestCertificateFromPEMFilesInvalid(t *testing.T) { + file, err := ioutil.TempFile("", "ghostunnel-test") + assert.Nil(t, err, "temp file error") + defer os.Remove(file.Name()) + + _, err = file.Write([]byte("invalid")) + assert.Nil(t, err, "temp file error") + + cert, err := CertificateFromPEMFiles(file.Name(), file.Name()) + assert.Nil(t, cert, "should not return certificate on error") + assert.NotNil(t, err, "should read PEM file with certificate & private key") +} diff --git a/certloader/pkcs11_disabled.go b/certloader/pkcs11_disabled.go new file mode 100644 index 00000000000..6e4d705e2b5 --- /dev/null +++ b/certloader/pkcs11_disabled.go @@ -0,0 +1,32 @@ +// +build !cgo + +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import "errors" + +// SupportsPKCS11 returns true or false, depending on whether the binary +// was built with PKCS11 support or not (requires CGO to build). +func SupportsPKCS11() bool { + return false +} + +// CertificateFromPKCS11Module creates a reloadable certificate from a PKCS#11 module. +func CertificateFromPKCS11Module(certificatePath, modulePath, tokenLabel, pin string) (Certificate, error) { + return nil, errors.New("not supported") +} diff --git a/certloader/pkcs11_enabled.go b/certloader/pkcs11_enabled.go new file mode 100644 index 00000000000..b5b12658183 --- /dev/null +++ b/certloader/pkcs11_enabled.go @@ -0,0 +1,101 @@ +// +build cgo + +/*- + * Copyright 2018 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "crypto/tls" + "sync/atomic" + "unsafe" + + "github.com/letsencrypt/pkcs11key" +) + +type pkcs11Certificate struct { + // Certificate chain corresponding to key + certificatePath string + // Params for loading key from a PKCS#11 module + modulePath, tokenLabel, pin string + // Cached *tls.Certificate + cached unsafe.Pointer +} + +// SupportsPKCS11 returns true or false, depending on whether the binary +// was built with PKCS11 support or not (requires CGO to build). +func SupportsPKCS11() bool { + return true +} + +// CertificateFromPKCS11Module creates a reloadable certificate from a PKCS#11 module. +func CertificateFromPKCS11Module(certificatePath, modulePath, tokenLabel, pin string) (Certificate, error) { + c := pkcs11Certificate{ + certificatePath: certificatePath, + modulePath: modulePath, + tokenLabel: tokenLabel, + pin: pin, + } + err := c.Reload() + if err != nil { + return nil, err + } + return &c, nil +} + +// Reload transparently reloads the certificate. +func (c *pkcs11Certificate) Reload() error { + // Expecting certificate file to only have certificate chain, + // with the (fixed) private key being in an HSM/PKCS11 module. + certs, err := readX509(c.certificatePath) + if err != nil { + return err + } + + certAndKey := tls.Certificate{ + Leaf: certs[0], + } + for _, cert := range certs { + certAndKey.Certificate = append(certAndKey.Certificate, cert.Raw) + } + + // Reuse previously loaded PKCS11 private key if we already have it. + // We want to avoid reloading the key every time the cert reloads, as it's + // a potentially expensive operation that calls out into a shared library. + if c.cached != nil { + old, _ := c.GetCertificate(nil) + certAndKey.PrivateKey = old.PrivateKey + } else { + privateKey, err := pkcs11key.New(c.modulePath, c.tokenLabel, c.pin, certAndKey.Leaf.PublicKey) + if err != nil { + return err + } + certAndKey.PrivateKey = privateKey + } + + atomic.StorePointer(&c.cached, unsafe.Pointer(&certAndKey)) + return nil +} + +// GetCertificate retrieves the actual underlying tls.Certificate. +func (c *pkcs11Certificate) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return (*tls.Certificate)(atomic.LoadPointer(&c.cached)), nil +} + +// GetClientCertificate retrieves the actual underlying tls.Certificate. +func (c *pkcs11Certificate) GetClientCertificate(certInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return (*tls.Certificate)(atomic.LoadPointer(&c.cached)), nil +} diff --git a/certstore_disabled.go b/certloader/pkcs11_test.go similarity index 62% rename from certstore_disabled.go rename to certloader/pkcs11_test.go index 83a12170029..7605b17cc39 100644 --- a/certstore_disabled.go +++ b/certloader/pkcs11_test.go @@ -1,5 +1,3 @@ -// +build !certstore - /*- * Copyright 2018 Square Inc. * @@ -16,17 +14,15 @@ * limitations under the License. */ -package main +package certloader -import "fmt" +import ( + "testing" -func validateKeystoreOrIdentity() error { - if *keystorePath == "" { - return fmt.Errorf("--keystore flag (or --disable-authentication in client mode) is required, try --help") - } - return nil -} + "github.com/stretchr/testify/assert" +) -func buildCertificateFromKeystoreOrIdentity() (*certificate, error) { - return buildCertificate(*keystorePath, *keystorePass) +func TestInvalidPKCS11Module(t *testing.T) { + _, err := CertificateFromPKCS11Module("", "", "", "") + assert.NotNil(t, err, "should not load invalid PKCS11 certificate/key") } diff --git a/certstore_enabled.go b/certstore_enabled.go deleted file mode 100644 index 8f59d18f4e2..00000000000 --- a/certstore_enabled.go +++ /dev/null @@ -1,100 +0,0 @@ -// +build certstore - -/*- - * Copyright 2018 Square Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package main - -import ( - "crypto/tls" - "crypto/x509" - "fmt" - "unsafe" - - "github.com/mastahyeti/certstore" -) - -var ( - keychainIdentity = app.Flag("keychain-identity", "Use local keychain identity with given common name (instead of keystore file).").PlaceHolder("CN").String() -) - -func validateKeystoreOrIdentity() error { - if (*keystorePath == "") && (*keychainIdentity == "") { - return fmt.Errorf("one of --keystore or --keychain-identity (or --disable-authentication in client mode) flags is required, try --help") - } - if (*keystorePath != "") && (*keychainIdentity != "") { - return fmt.Errorf("--keystore and --keychain-identity flags are mutually exclusive") - } - return nil -} - -func buildCertificateFromKeystoreOrIdentity() (*certificate, error) { - if *keystorePath != "" { - return buildCertificate(*keystorePath, *keystorePass) - } - - if *keychainIdentity != "" { - cert, err := loadIdentity(*keychainIdentity) - if err != nil { - return nil, fmt.Errorf("unable to load identity from keychain: %s", err) - } - - return &certificate{false, "", "", unsafe.Pointer(&cert)}, nil - } - - return &certificate{}, nil -} - -func loadIdentity(commonName string) (tls.Certificate, error) { - store, err := certstore.Open() - if err != nil { - return tls.Certificate{}, err - } - - identitites, err := store.Identities() - if err != nil { - return tls.Certificate{}, err - } - - for _, identity := range identitites { - chain, err := identity.CertificateChain() - if err != nil { - continue - } - - signer, err := identity.Signer() - if err != nil { - continue - } - - if chain[0].Subject.CommonName == commonName { - return tls.Certificate{ - Certificate: serializeChain(chain), - PrivateKey: signer, - }, nil - } - } - - return tls.Certificate{}, fmt.Errorf("no identity with name '%s' found", commonName) -} - -func serializeChain(chain []*x509.Certificate) [][]byte { - out := [][]byte{} - for _, cert := range chain { - out = append(out, cert.Raw) - } - return out -} diff --git a/main.go b/main.go index ee878422323..6275f417461 100644 --- a/main.go +++ b/main.go @@ -37,6 +37,7 @@ import ( "github.com/mwitkow/go-http-dialer" "github.com/rcrowley/go-metrics" "github.com/square/ghostunnel/auth" + "github.com/square/ghostunnel/certloader" "github.com/square/go-sq-metrics" "gopkg.in/alecthomas/kingpin.v2" ) @@ -46,6 +47,15 @@ var ( defaultMetricsPrefix = "ghostunnel" ) +// Optional flags (enabled conditionally based on build) +var ( + keychainIdentity *string + pkcs11Module *string + pkcs11TokenLabel *string + pkcs11PIN *string +) + +// Main flags (always supported) var ( app = kingpin.New("ghostunnel", "A simple SSL/TLS proxy with mutual authentication for securing non-TLS services.") @@ -97,6 +107,17 @@ var ( enableProf = app.Flag("enable-pprof", "Enable serving /debug/pprof endpoints alongside /_status (for profiling).").Bool() ) +func init() { + if certloader.SupportsKeychain() { + keychainIdentity = app.Flag("keychain-identity", "Use local keychain identity with given common name (instead of keystore file).").PlaceHolder("CN").String() + } + if certloader.SupportsPKCS11() { + pkcs11Module = app.Flag("pkcs11-module", "Path to PKCS11 module (SO) file (optional).").Envar("PKCS11_MODULE").PlaceHolder("PATH").ExistingFile() + pkcs11TokenLabel = app.Flag("pkcs11-token-label", "Token label for slot/key in PKCS11 module (optional).").Envar("PKCS11_TOKEN_LABEL").PlaceHolder("LABEL").String() + pkcs11PIN = app.Flag("pkcs11-pin", "PIN code for slot/key in PKCS11 module (optional).").Envar("PKCS11_PIN").PlaceHolder("PIN").String() + } +} + var exitFunc = os.Exit // Context groups listening context data together @@ -107,7 +128,7 @@ type Context struct { shutdownTimeout time.Duration dial func() (net.Conn, error) metrics *sqmetrics.SquareMetrics - cert *certificate + cert certloader.Certificate } // Dialer is an interface for dialers (either net.Dialer, or http_dialer.HttpTunnel) @@ -179,20 +200,23 @@ func serverValidateFlags() error { len(*serverAllowedIPs) > 0 || len(*serverAllowedURIs) > 0 - if err := validateKeystoreOrIdentity(); err != nil { - return err + if *keystorePath == "" && !hasKeychainIdentity() { + return errors.New("at least one of --keystore or --keychain-identity (if supported) flags is required") + } + if *keystorePath != "" && hasKeychainIdentity() { + return errors.New("--keystore and --keychain-identity flags are mutually exclusive") } if !(*serverDisableAuth) && !(*serverAllowAll) && !hasAccessFlags { - return fmt.Errorf("at least one access control flag (--allow-{all,cn,ou,dns-san,ip-san,uri-san} or --disable-authentication) is required") + return errors.New("at least one access control flag (--allow-{all,cn,ou,dns-san,ip-san,uri-san} or --disable-authentication) is required") } if !(*serverDisableAuth) && *serverAllowAll && hasAccessFlags { - return fmt.Errorf("--allow-all is mutually exclusive with other access control flags") + return errors.New("--allow-all is mutually exclusive with other access control flags") } if *serverDisableAuth && (*serverAllowAll || hasAccessFlags) { - return fmt.Errorf("--disable-authentication is mutually exclusive with other access control flags") + return errors.New("--disable-authentication is mutually exclusive with other access control flags") } if !*serverUnsafeTarget && !validateUnixOrLocalhost(*serverForwardAddress) { - return fmt.Errorf("--target must be unix:PATH, localhost:PORT, 127.0.0.1:PORT or [::1]:PORT (unless --unsafe-target is set)") + return errors.New("--target must be unix:PATH, localhost:PORT, 127.0.0.1:PORT or [::1]:PORT (unless --unsafe-target is set)") } for _, suite := range strings.Split(*enabledCipherSuites, ",") { @@ -206,8 +230,11 @@ func serverValidateFlags() error { // Validate flags for client mode func clientValidateFlags() error { - if err := validateKeystoreOrIdentity(); err != nil && !(*clientDisableAuth) { - return err + if *keystorePath == "" && !hasKeychainIdentity() && !*clientDisableAuth { + return errors.New("at least one of --keystore, --keychain-identity (if supported), or --disable-authentication flags is required") + } + if *keystorePath != "" && hasKeychainIdentity() && !*clientDisableAuth { + return errors.New("--keystore, --keychain-identity, and --disable-authentication flags are mutually exclusive") } if !*clientUnsafeListen && !validateUnixOrLocalhost(*clientListenAddress) { return fmt.Errorf("--listen must be unix:PATH, localhost:PORT, 127.0.0.1:PORT or [::1]:PORT (unless --unsafe-listen is set)") @@ -283,7 +310,7 @@ func run(args []string) error { go watchFiles([]string{*keystorePath}, *timedReload, watcher) } - cert, err := buildCertificateFromKeystoreOrIdentity() + cert, err := buildCertificate(*keystorePath, *keystorePass) if err != nil { fmt.Fprintf(os.Stderr, "error: unable to load certificates: %s\n", err) return err @@ -368,7 +395,7 @@ func serverListen(context *Context) error { Logger: logger, } - config.GetCertificate = context.cert.getCertificate + config.GetCertificate = context.cert.GetCertificate config.VerifyPeerCertificate = serverACL.VerifyPeerCertificateServer if *serverDisableAuth { config.ClientAuth = tls.NoClientCert @@ -472,7 +499,9 @@ func (context *Context) serveStatus() error { return err } config.ClientAuth = tls.NoClientCert - config.GetCertificate = context.cert.getCertificate + if context.cert != nil { + config.GetCertificate = context.cert.GetCertificate + } network, address, _, err := parseUnixOrTCPAddress(*statusAddress) if err != nil { @@ -524,7 +553,7 @@ func serverBackendDialer() (func() (net.Conn, error), error) { } // Get backend dialer function in client mode (connecting to a TLS port) -func clientBackendDialer(cert *certificate, network, address, host string) (func() (net.Conn, error), error) { +func clientBackendDialer(cert certloader.Certificate, network, address, host string) (func() (net.Conn, error), error) { config, err := buildConfig(*enabledCipherSuites, *caBundlePath) if err != nil { return nil, err @@ -565,12 +594,6 @@ func clientBackendDialer(cert *certificate, network, address, host string) (func http_dialer.WithTls(proxyConfig)) } - return func() (net.Conn, error) { - if !(*clientDisableAuth) { - // Fetch latest cached certificate before initiating new connection - crt, _ := cert.getCertificate(nil) - config.Certificates = []tls.Certificate{*crt} - } - return dialWithDialer(dialer, *timeoutDuration, network, address, config) - }, nil + d := certloader.DialerWithCertificate(cert, config, *timeoutDuration, dialer) + return func() (net.Conn, error) { return d.Dial(network, address) }, nil } diff --git a/signals.go b/signals.go index f6aa0d75955..a2cd77affc2 100644 --- a/signals.go +++ b/signals.go @@ -84,7 +84,7 @@ func (context *Context) signalHandler(proxy *proxy, closeables []io.Closer) { func (context *Context) reload() { context.status.Reloading() - err := context.cert.reload() + err := context.cert.Reload() if err != nil { logger.Printf("error reloading certificates: %s", err) } diff --git a/status.go b/status.go index 9bfc1841456..6168ba5b150 100644 --- a/status.go +++ b/status.go @@ -115,8 +115,5 @@ func (s *statusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) } - _, err = w.Write(out) - if err != nil { - logger.Printf("error writing status response: %s", err) - } + _, _ = w.Write(out) } diff --git a/tls.go b/tls.go index 28b7d311bd4..2e81bf4055b 100644 --- a/tls.go +++ b/tls.go @@ -19,17 +19,12 @@ package main import ( "crypto/tls" "crypto/x509" - "encoding/pem" "errors" "fmt" "io/ioutil" - "os" "strings" - "sync/atomic" - "time" - "unsafe" - certigo "github.com/square/certigo/lib" + "github.com/square/ghostunnel/certloader" ) var cipherSuites = map[string][]uint16{ @@ -45,144 +40,34 @@ var cipherSuites = map[string][]uint16{ }, } -type timeoutError struct{} - -func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } -func (timeoutError) Timeout() bool { return true } -func (timeoutError) Temporary() bool { return true } - -// certificate wraps a TLS certificate in a reloadable way -type certificate struct { - reloadable bool - keystorePath, keystorePass string - cached unsafe.Pointer -} - // Build reloadable certificate -func buildCertificate(keystorePath, keystorePass string) (*certificate, error) { - if keystorePath == "" { - return &certificate{}, nil +func buildCertificate(keystorePath, keystorePass string) (certloader.Certificate, error) { + if hasPKCS11() { + return buildCertificateFromPKCS11(keystorePath) } - cert := &certificate{true, keystorePath, keystorePass, nil} - err := cert.reload() - if err != nil { - return nil, err + if hasKeychainIdentity() { + return buildCertificateFromCertstore() } - return cert, nil + if keystorePath != "" { + return certloader.CertificateFromKeystore(keystorePath, keystorePass) + } + return nil, nil } -// Retrieve actual certificate -func (c *certificate) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - return (*tls.Certificate)(atomic.LoadPointer(&c.cached)), nil +func buildCertificateFromPKCS11(certificatePath string) (certloader.Certificate, error) { + return certloader.CertificateFromPKCS11Module(certificatePath, *pkcs11Module, *pkcs11TokenLabel, *pkcs11PIN) } -// Reload certificate -func (c *certificate) reload() error { - if !c.reloadable { - logger.Printf("certificate not reloadable, skipping") - return nil - } - - var err error - if hasPKCS11() { - err = c.reloadFromPKCS11() - } else { - err = c.reloadFromPEM() - } - - if err == nil { - cert, _ := c.getCertificate(nil) - logger.Printf("loaded certificate with common name '%s'", cert.Leaf.Subject.CommonName) - } - return err +func hasPKCS11() bool { + return pkcs11Module != nil && *pkcs11Module != "" } -func (c *certificate) reloadFromPEM() error { - keystore, err := os.Open(c.keystorePath) - if err != nil { - return err - } - - var pemBlocks []*pem.Block - err = certigo.ReadAsPEMFromFiles( - []*os.File{keystore}, - "", - func(prompt string) string { - return c.keystorePass - }, - func(block *pem.Block) { - pemBlocks = append(pemBlocks, block) - }) - if err != nil { - return fmt.Errorf("error during keystore read (%s)", err) - } - if len(pemBlocks) == 0 { - return errors.New("no certificates or private key found in keystore") - } - - var pemBytes []byte - for _, block := range pemBlocks { - pemBytes = append(pemBytes, pem.EncodeToMemory(block)...) - } - - certAndKey, err := tls.X509KeyPair(pemBytes, pemBytes) - if err != nil { - return err - } - - certAndKey.Leaf, err = x509.ParseCertificate(certAndKey.Certificate[0]) - if err != nil { - return err - } - - atomic.StorePointer(&c.cached, unsafe.Pointer(&certAndKey)) - return nil +func buildCertificateFromCertstore() (certloader.Certificate, error) { + return certloader.CertificateFromKeychainIdentity(*keychainIdentity) } -func (c *certificate) reloadFromPKCS11() error { - // Expecting keystore file to only have certificate, - // with the private key being in an HSM/PKCS11 module. - keystore, err := os.Open(c.keystorePath) - if err != nil { - return err - } - - certAndKey := tls.Certificate{} - err = certigo.ReadAsX509FromFiles( - []*os.File{keystore}, "", nil, - func(cert *x509.Certificate, err error) { - if err != nil { - logger.Printf("error during keystore read: %s", err) - return - } - if certAndKey.Leaf == nil { - certAndKey.Leaf = cert - } - certAndKey.Certificate = append(certAndKey.Certificate, cert.Raw) - }) - if err != nil { - return fmt.Errorf("error during keystore read (%s)", err) - } - if certAndKey.Leaf == nil { - return errors.New("no certificates found in keystore") - } - - // Reuse previously loaded PKCS11 private key if we already have it. We want to - // avoid reloading the key every time the cert reloads, as it's a potentially - // expensive operation that calls out into a shared library. - if c.cached != nil { - old, _ := c.getCertificate(nil) - certAndKey.PrivateKey = old.PrivateKey - } else { - privateKey, err := newPKCS11(certAndKey.Leaf.PublicKey) - if err != nil { - return err - } - certAndKey.PrivateKey = privateKey - } - - atomic.StorePointer(&c.cached, unsafe.Pointer(&certAndKey)) - return nil +func hasKeychainIdentity() bool { + return keychainIdentity != nil && *keychainIdentity != "" } func caBundle(caBundlePath string) (*x509.CertPool, error) { @@ -204,34 +89,6 @@ func caBundle(caBundlePath string) (*x509.CertPool, error) { return bundle, nil } -// Internal copy of tls.DialWithDialer, adapter so it can work with HTTP CONNECT dialers. -// See: https://golang.org/pkg/crypto/tls/#DialWithDialer -func dialWithDialer(dialer Dialer, timeout time.Duration, network, addr string, config *tls.Config) (*tls.Conn, error) { - errChannel := make(chan error, 2) - time.AfterFunc(timeout, func() { - errChannel <- timeoutError{} - }) - - rawConn, err := dialer.Dial(network, addr) - if err != nil { - return nil, err - } - - conn := tls.Client(rawConn, config) - go func() { - errChannel <- conn.Handshake() - }() - - err = <-errChannel - - if err != nil { - rawConn.Close() - return nil, err - } - - return conn, nil -} - // buildConfig reads command-line options and builds a tls.Config func buildConfig(enabledCipherSuites string, caBundlePath string) (*tls.Config, error) { ca, err := caBundle(caBundlePath) diff --git a/tls_cgo.go b/tls_cgo.go deleted file mode 100644 index 4fec979aa83..00000000000 --- a/tls_cgo.go +++ /dev/null @@ -1,39 +0,0 @@ -// +build cgo - -/*- - * Copyright 2018 Square Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package main - -import ( - "crypto" - - "github.com/letsencrypt/pkcs11key" -) - -var ( - pkcs11Module = app.Flag("pkcs11-module", "Path to PKCS11 module (SO) file (optional).").Envar("PKCS11_MODULE").PlaceHolder("PATH").ExistingFile() - pkcs11TokenLabel = app.Flag("pkcs11-token-label", "Token label for slot/key in PKCS11 module (optional).").Envar("PKCS11_TOKEN_LABEL").PlaceHolder("LABEL").String() - pkcs11PIN = app.Flag("pkcs11-pin", "PIN code for slot/key in PKCS11 module (optional).").Envar("PKCS11_PIN").PlaceHolder("PIN").String() -) - -func newPKCS11(pubkey crypto.PublicKey) (crypto.PrivateKey, error) { - return pkcs11key.New(*pkcs11Module, *pkcs11TokenLabel, *pkcs11PIN, pubkey) -} - -func hasPKCS11() bool { - return *pkcs11Module != "" -} diff --git a/tls_test.go b/tls_test.go index 0c3f9abfc11..6206191d01c 100644 --- a/tls_test.go +++ b/tls_test.go @@ -165,7 +165,6 @@ func TestBuildConfig(t *testing.T) { assert.NotNil(t, err, "should reject invalid CA cert bundle") cert, err := buildCertificate("", "") - assert.NotNil(t, cert, "empty keystorePath should lead to empty certificate") assert.Nil(t, err, "empty keystorePath should not raise an error") cert, err = buildCertificate(tmpKeystore.Name(), "totes invalid") @@ -222,7 +221,7 @@ func TestReload(t *testing.T) { c, err := buildCertificate(tmpKeystore.Name(), testKeystorePassword) assert.Nil(t, err, "should be able to build certificate") - c.reload() + c.Reload() } func TestBuildConfigSystemRoots(t *testing.T) { @@ -237,10 +236,3 @@ func TestBuildConfigSystemRoots(t *testing.T) { assert.NotNil(t, conf.ClientCAs, "config must have CA certs") assert.True(t, conf.MinVersion == tls.VersionTLS12, "must have correct TLS min version") } - -func TestTimeoutError(t *testing.T) { - err := timeoutError{} - assert.False(t, err.Error() == "", "Timeout error should have message") - assert.True(t, err.Timeout(), "Timeout error should have Timeout() == true") - assert.True(t, err.Temporary(), "Timeout error should have Temporary() == true") -}