Skip to content

Commit

Permalink
Split certificate reloading logic into a separate package
Browse files Browse the repository at this point in the history
  • Loading branch information
csstaub committed Jun 18, 2018
1 parent 3d206c4 commit 77e0fa6
Show file tree
Hide file tree
Showing 22 changed files with 938 additions and 360 deletions.
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SOURCE_FILES := $(shell find . \( -name '*.go' -not -path './vendor*' \))
SOURCE_FILES := $(shell find . \( -name '*.go' -not -path './vendor/*' \))
INTEGRATION_TESTS := $(shell find tests -name 'test-*.py' -exec basename {} .py \;)
VERSION := $(shell git describe --always --dirty)

Expand All @@ -8,7 +8,7 @@ ghostunnel: $(SOURCE_FILES)

# Test binary with coverage instrumentation
ghostunnel.test: $(SOURCE_FILES)
go test -c -covermode=count -coverpkg .
go test -c -covermode=count -coverpkg .,./auth,./certloader

# Clean build output
clean:
Expand All @@ -25,6 +25,7 @@ test: unit $(INTEGRATION_TESTS)
unit:
go test -v -covermode=count -coverprofile=coverage-unit-test-base.out .
go test -v -covermode=count -coverprofile=coverage-unit-test-auth.out ./auth
go test -v -covermode=count -coverprofile=coverage-unit-test-certloader.out ./certloader
.PHONY: unit

# Run integration tests
Expand Down
75 changes: 75 additions & 0 deletions certloader/certigo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*-
* Copyright 2018 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package certloader

import (
"crypto/x509"
"encoding/pem"
"fmt"
"os"

certigo "github.com/square/certigo/lib"
)

func readPEM(path, password, format string) ([]*pem.Block, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}

var pemBlocks []*pem.Block
err = certigo.ReadAsPEMFromFiles(
[]*os.File{file},
format,
func(prompt string) string { return password },
func(block *pem.Block) { pemBlocks = append(pemBlocks, block) })
if err != nil {
return nil, fmt.Errorf("error reading file '%s': %s", path, err)
}
if len(pemBlocks) == 0 {
return nil, fmt.Errorf("error reading file '%s', no certificates found", path)
}

return pemBlocks, nil
}

func readX509(path string) ([]*x509.Certificate, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}

errs := []error{}
out := []*x509.Certificate{}

err = certigo.ReadAsX509FromFiles(
[]*os.File{file}, "PEM", nil,
func(cert *x509.Certificate, err error) {
if err != nil {
errs = append(errs, err)
return
}
out = append(out, cert)
})
if err != nil || len(errs) > 0 {
return nil, fmt.Errorf("error reading file '%s'", path)
}
if len(out) == 0 {
return nil, fmt.Errorf("no certificates found in file '%s'", path)
}
return out, nil
}
121 changes: 121 additions & 0 deletions certloader/certigo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*-
* Copyright 2018 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package certloader

import (
"io/ioutil"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

const testCertificate = `
-----BEGIN CERTIFICATE-----
MIIDKDCCAhCgAwIBAgIJAPjKcAKZMSkUMA0GCSqGSIb3DQEBCwUAMCMxEjAQBgNV
BAMTCWxvY2FsaG9zdDENMAsGA1UECxMEdGVzdDAeFw0xNTEwMDcxODExNTlaFw0x
NjEwMDYxODExNTlaMCMxEjAQBgNVBAMTCWxvY2FsaG9zdDENMAsGA1UECxMEdGVz
dDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAK4EbZf3EMb/ciW5nGlN
yrf5Pcfz3ZnjWRy1kvBriuPD6NQSZaTWTPmJnbdS/Q5FH0p/6ZjdZKXf6f7WNnAz
JwW0XK7NT3N2DrWfgQqrrVvLAYlfqgHnC7Fxqq7FCpgWjf7L8wcQXfdIYkhdsE4n
osLmCRvx7qS+wuasb6nLzBtg7b99ZvO8K/sezrDIjwzemBWA1Vovztw/vGD4J4/h
D0hiOOqFGWstwFxB9oG4d/QJ45VttLMGuiZCY+A4IyBgPCxphrEec6zf8H4u/ceQ
bB8i1IMmD1VTsq9afeVhMKuoSn2Bs3VRB6c9FpL41/ftN5mYpZCteZH+qQ/DhK/y
Dz0CAwEAAaNfMF0wDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAqwwHQYDVR0lBBYw
FAYIKwYBBQUHAwIGCCsGAQUFBwMBMCEGA1UdEQQaMBiHBH8AAAGHEAAAAAAAAAAA
AAAAAAAAAAEwDQYJKoZIhvcNAQELBQADggEBABuBe5cuyZy6StCYebI3FLN3CEla
/3Hreul6i5giqkF90X6M+9eERZCqSqm2whBMSF4vG+1B6GX1K6S29PUOmTDWyasW
B0WlBgRiZld3JfFBuJu6xk1a8+XwwlGOgEsggepjkrAXbjbqnUMAKOJkjFIyIPvk
5p97SYDJYiOh7MmjyXUIzyNdqpL5WiUgKPTxXL+1tNzxH1jjxfVdjaNaNcOJuu20
9tsMqDZyTm2yZWOBUXbtqlaMQHrs5Ksz5EKk5/U5KfJehKss8oba2npg/6echTJU
nkOOZ6U4eEju7H1S46qlN9ZmUmSrrjwec3H7CnvxQ0ncEyZXlEiTlbO2JQI=
-----END CERTIFICATE-----`

const testCertificateBad = `
-----BEGIN CERTIFICATE-----
////////////////////////////////////////////////////////////////
-----END CERTIFICATE-----`

func TestReadPEMValid(t *testing.T) {
cert, err := ioutil.TempFile("", "ghostunnel-test")
assert.Nil(t, err, "temp file error")
defer os.Remove(cert.Name())

_, err = cert.Write([]byte(testCertificate))
assert.Nil(t, err, "temp file error")

blocks, err := readPEM(cert.Name(), "", "PEM")
assert.Nil(t, err, "should read PEM file")
assert.Len(t, blocks, 1, "should find one PEM block")
}

func TestReadPEMInvalid(t *testing.T) {
cert, err := ioutil.TempFile("", "ghostunnel-test")
assert.Nil(t, err, "temp file error")
defer os.Remove(cert.Name())

_, err = cert.Write([]byte("invalid"))
assert.Nil(t, err, "temp file error")

blocks, err := readPEM(cert.Name(), "", "PEM")
assert.NotNil(t, err, "should not parse invalid file")
assert.Len(t, blocks, 0, "should not return PEM blocks")

blocks, err = readPEM("does-not-exist", "", "PEM")
assert.NotNil(t, err, "should not parse invalid file")
assert.Len(t, blocks, 0, "should not return PEM blocks")
}

func TestReadX509Valid(t *testing.T) {
cert, err := ioutil.TempFile("", "ghostunnel-test")
assert.Nil(t, err, "temp file error")
defer os.Remove(cert.Name())

_, err = cert.Write([]byte(testCertificate))
assert.Nil(t, err, "temp file error")

certs, err := readX509(cert.Name())
assert.Nil(t, err, "should parse certificate from PEM file")
assert.Len(t, certs, 1, "should find one certificate")
}

func TestReadX509Invalid(t *testing.T) {
cert0, err := ioutil.TempFile("", "ghostunnel-test")
assert.Nil(t, err, "temp file error")
defer os.Remove(cert0.Name())

cert1, err := ioutil.TempFile("", "ghostunnel-test")
assert.Nil(t, err, "temp file error")
defer os.Remove(cert1.Name())

_, err = cert0.Write([]byte("invalid"))
assert.Nil(t, err, "temp file error")
_, err = cert1.Write([]byte(testCertificateBad))
assert.Nil(t, err, "temp file error")

certs, err := readX509(cert0.Name())
assert.NotNil(t, err, "should not parse invalid file")
assert.Len(t, certs, 0, "should not parse invalid file")

certs, err = readX509(cert1.Name())
assert.NotNil(t, err, "should not parse invalid file")
assert.Len(t, certs, 0, "should not parse invalid file")

certs, err = readX509("does-not-exist")
assert.NotNil(t, err, "should not parse invalid file")
assert.Len(t, certs, 0, "should not parse invalid file")
}
33 changes: 33 additions & 0 deletions certloader/certstore_disabled.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// +build !certstore

/*-
* Copyright 2018 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package certloader

import "errors"

// SupportsKeychain returns true or false, depending on whether the
// binary was built with Certstore/Keychain support or not (requires CGO, recent
// Darwin to build).
func SupportsKeychain() bool {
return false
}

// CertificateFromKeychainIdentity creates a reloadable certificate from a system keychain identity.
func CertificateFromKeychainIdentity(commonName string) (Certificate, error) {
return nil, errors.New("not supported")
}
114 changes: 114 additions & 0 deletions certloader/certstore_enabled.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// +build certstore

/*-
* Copyright 2018 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package certloader

import (
"crypto/tls"
"crypto/x509"
"fmt"
"sync/atomic"
"unsafe"

"github.com/mastahyeti/certstore"
)

type certstoreCertificate struct {
// Common name of keychain identity
commonName string
// Cached *tls.Certificate
cached unsafe.Pointer
}

// SupportsKeychain returns true or false, depending on whether the
// binary was built with Certstore/Keychain support or not (requires CGO, recent
// Darwin to build).
func SupportsKeychain() bool {
return true
}

// CertificateFromKeychainIdentity creates a reloadable certificate from a system keychain identity.
func CertificateFromKeychainIdentity(commonName string) (Certificate, error) {
c := certstoreCertificate{
commonName: commonName,
}
err := c.Reload()
if err != nil {
return nil, err
}
return &c, nil
}

// Reload transparently reloads the certificate.
func (c *certstoreCertificate) Reload() error {
store, err := certstore.Open()
if err != nil {
return err
}

identitites, err := store.Identities()
if err != nil {
return err
}

var certAndKey *tls.Certificate
for _, identity := range identitites {
chain, err := identity.CertificateChain()
if err != nil {
continue
}

signer, err := identity.Signer()
if err != nil {
continue
}

if chain[0].Subject.CommonName == c.commonName {
certAndKey = &tls.Certificate{
Certificate: serializeChain(chain),
PrivateKey: signer,
}
break
}
}

if certAndKey != nil {
atomic.StorePointer(&c.cached, unsafe.Pointer(certAndKey))
return nil
}

return fmt.Errorf("unable to find identity with common name '%s' in keychain", c.commonName)
}

// GetCertificate retrieves the actual underlying tls.Certificate.
func (c *certstoreCertificate) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return (*tls.Certificate)(atomic.LoadPointer(&c.cached)), nil
}

// GetClientCertificate retrieves the actual underlying tls.Certificate.
func (c *certstoreCertificate) GetClientCertificate(certInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return (*tls.Certificate)(atomic.LoadPointer(&c.cached)), nil
}

func serializeChain(chain []*x509.Certificate) [][]byte {
out := [][]byte{}
for _, cert := range chain {
out = append(out, cert.Raw)
}
return out
}
Loading

0 comments on commit 77e0fa6

Please sign in to comment.