Skip to content

Commit b1dc9f0

Browse files
verify cluster name of TLS peer certificates (#52133)
1 parent 1520ba6 commit b1dc9f0

18 files changed

+607
-207
lines changed

api/types/authority_test.go

+14
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,17 @@ func TestRotationZero(t *testing.T) {
5252
require.Equal(t, tt.z, tt.r.IsZero(), tt.d)
5353
}
5454
}
55+
56+
// Test that the spec cluster name name will be set to match the resource name
57+
func TestCheckAndSetDefaults(t *testing.T) {
58+
ca := CertAuthorityV2{
59+
Metadata: Metadata{Name: "caName"},
60+
Spec: CertAuthoritySpecV2{
61+
ClusterName: "clusterName",
62+
Type: HostCA,
63+
},
64+
}
65+
err := ca.CheckAndSetDefaults()
66+
require.NoError(t, err)
67+
require.Equal(t, ca.Metadata.Name, ca.Spec.ClusterName)
68+
}

integration/assist/command_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ func newTestCredentials(t *testing.T, rc *helpers.TeleInstance, user types.User)
288288
}
289289

290290
pool := x509.NewCertPool()
291-
pool.AppendCertsFromPEM(rc.Secrets.TLSCACert)
291+
pool.AppendCertsFromPEM(rc.Secrets.TLSHostCACert)
292+
pool.AppendCertsFromPEM(rc.Secrets.TLSUserCACert)
292293

293294
tlsConf := &tls.Config{
294295
Certificates: []tls.Certificate{cert},

integration/helpers/instance.go

+67-44
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ package helpers
1717
import (
1818
"bytes"
1919
"context"
20-
"crypto/rsa"
2120
"crypto/tls"
2221
"crypto/x509/pkix"
2322
"encoding/json"
@@ -44,6 +43,7 @@ import (
4443
"github.com/gravitational/teleport/api/breaker"
4544
clientproto "github.com/gravitational/teleport/api/client/proto"
4645
"github.com/gravitational/teleport/api/types"
46+
"github.com/gravitational/teleport/api/utils/keys"
4747
"github.com/gravitational/teleport/lib/auth/authclient"
4848
"github.com/gravitational/teleport/lib/auth/keygen"
4949
"github.com/gravitational/teleport/lib/auth/state"
@@ -97,11 +97,15 @@ type InstanceSecrets struct {
9797
// PrivKey is instance private key
9898
PrivKey []byte `json:"priv"`
9999
// Cert is SSH host certificate
100-
Cert []byte `json:"cert"`
101-
// TLSCACert is the certificate of the trusted certificate authority
102-
TLSCACert []byte `json:"tls_ca_cert"`
103-
// TLSCert is client TLS X509 certificate
104-
TLSCert []byte `json:"tls_cert"`
100+
SSHHostCert []byte `json:"cert"`
101+
// TLSHostCACert is the certificate of the trusted host certificate authority
102+
TLSHostCACert []byte `json:"tls_host_ca_cert"`
103+
// TLSCert is client TLS host X509 certificate
104+
TLSHostCert []byte `json:"tls_host_cert"`
105+
// TLSUserCACert is the certificate of the trusted user certificate authority
106+
TLSUserCACert []byte `json:"tls_user_ca_cert"`
107+
// TLSUserCert is client TLS user X509 certificate
108+
TLSUserCert []byte `json:"tls_user_cert"`
105109
// TunnelAddr is a reverse tunnel listening port, allowing
106110
// other sites to connect to i instance. Set to empty
107111
// string if i instance is not allowing incoming tunnels
@@ -132,9 +136,7 @@ func (s *InstanceSecrets) GetRoles(t *testing.T) []types.Role {
132136
return roles
133137
}
134138

135-
// GetCAs return an array of CAs stored by the secrets object. In i
136-
// case we always return hard-coded userCA + hostCA (and they share keys
137-
// for simplicity)
139+
// GetCAs return an array of CAs stored by the secrets object
138140
func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
139141
hostCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{
140142
Type: types.HostCA,
@@ -148,7 +150,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
148150
TLS: []*types.TLSKeyPair{{
149151
Key: s.PrivKey,
150152
KeyType: types.PrivateKeyType_RAW,
151-
Cert: s.TLSCACert,
153+
Cert: s.TLSHostCACert,
152154
}},
153155
},
154156
})
@@ -168,7 +170,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
168170
TLS: []*types.TLSKeyPair{{
169171
Key: s.PrivKey,
170172
KeyType: types.PrivateKeyType_RAW,
171-
Cert: s.TLSCACert,
173+
Cert: s.TLSUserCACert,
172174
}},
173175
},
174176
Roles: []string{services.RoleNameForCertAuthority(s.SiteName)},
@@ -184,7 +186,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
184186
TLS: []*types.TLSKeyPair{{
185187
Key: s.PrivKey,
186188
KeyType: types.PrivateKeyType_RAW,
187-
Cert: s.TLSCACert,
189+
Cert: s.TLSHostCACert,
188190
}},
189191
},
190192
})
@@ -199,7 +201,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
199201
TLS: []*types.TLSKeyPair{{
200202
Key: s.PrivKey,
201203
KeyType: types.PrivateKeyType_RAW,
202-
Cert: s.TLSCACert,
204+
Cert: s.TLSHostCACert,
203205
}},
204206
},
205207
})
@@ -256,9 +258,9 @@ func (s *InstanceSecrets) AsSlice() []*InstanceSecrets {
256258

257259
func (s *InstanceSecrets) GetIdentity() *state.Identity {
258260
i, err := state.ReadIdentityFromKeyPair(s.PrivKey, &clientproto.Certs{
259-
SSH: s.Cert,
260-
TLS: s.TLSCert,
261-
TLSCACerts: [][]byte{s.TLSCACert},
261+
SSH: s.SSHHostCert,
262+
TLS: s.TLSHostCert,
263+
TLSCACerts: [][]byte{s.TLSHostCACert},
262264
})
263265
fatalIf(err)
264266
return i
@@ -338,20 +340,14 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
338340
if cfg.Priv == nil || cfg.Pub == nil {
339341
cfg.Priv, cfg.Pub, _ = keygen.GenerateKeyPair()
340342
}
341-
rsaKey, err := ssh.ParseRawPrivateKey(cfg.Priv)
343+
key, err := keys.ParsePrivateKey(cfg.Priv)
342344
fatalIf(err)
343345

344-
tlsCACert, err := tlsca.GenerateSelfSignedCAWithSigner(rsaKey.(*rsa.PrivateKey), pkix.Name{
345-
CommonName: cfg.ClusterName,
346-
Organization: []string{cfg.ClusterName},
347-
}, nil, defaults.CATTL)
348-
fatalIf(err)
349-
350-
signer, err := ssh.ParsePrivateKey(cfg.Priv)
346+
sshSigner, err := ssh.NewSignerFromSigner(key)
351347
fatalIf(err)
352348

353-
cert, err := keygen.GenerateHostCert(services.HostCertParams{
354-
CASigner: signer,
349+
hostCert, err := keygen.GenerateHostCert(services.HostCertParams{
350+
CASigner: sshSigner,
355351
PublicHostKey: cfg.Pub,
356352
HostID: cfg.HostID,
357353
NodeName: cfg.NodeName,
@@ -360,23 +356,48 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
360356
TTL: 24 * time.Hour,
361357
})
362358
fatalIf(err)
363-
tlsCA, err := tlsca.FromKeys(tlsCACert, cfg.Priv)
364-
fatalIf(err)
365-
cryptoPubKey, err := sshutils.CryptoPublicKey(cfg.Pub)
366-
fatalIf(err)
367-
identity := tlsca.Identity{
368-
Username: fmt.Sprintf("%v.%v", cfg.HostID, cfg.ClusterName),
369-
Groups: []string{string(types.RoleAdmin)},
370-
}
359+
371360
clock := cfg.Clock
372361
if clock == nil {
373362
clock = clockwork.NewRealClock()
374363
}
364+
365+
identity := tlsca.Identity{
366+
Username: fmt.Sprintf("%v.%v", cfg.HostID, cfg.ClusterName),
367+
Groups: []string{string(types.RoleAdmin)},
368+
}
375369
subject, err := identity.Subject()
376370
fatalIf(err)
377-
tlsCert, err := tlsCA.GenerateCertificate(tlsca.CertificateRequest{
371+
372+
tlsCAHostCert, err := tlsca.GenerateSelfSignedCAWithSigner(key, pkix.Name{
373+
CommonName: cfg.ClusterName,
374+
Organization: []string{cfg.ClusterName},
375+
}, nil, defaults.CATTL)
376+
fatalIf(err)
377+
tlsHostCA, err := tlsca.FromKeys(tlsCAHostCert, cfg.Priv)
378+
fatalIf(err)
379+
hostCryptoPubKey, err := sshutils.CryptoPublicKey(cfg.Pub)
380+
fatalIf(err)
381+
tlsHostCert, err := tlsHostCA.GenerateCertificate(tlsca.CertificateRequest{
382+
Clock: clock,
383+
PublicKey: hostCryptoPubKey,
384+
Subject: subject,
385+
NotAfter: clock.Now().UTC().Add(time.Hour * 24),
386+
})
387+
fatalIf(err)
388+
389+
tlsCAUserCert, err := tlsca.GenerateSelfSignedCAWithSigner(key, pkix.Name{
390+
CommonName: cfg.ClusterName,
391+
Organization: []string{cfg.ClusterName},
392+
}, nil, defaults.CATTL)
393+
fatalIf(err)
394+
tlsUserCA, err := tlsca.FromKeys(tlsCAHostCert, cfg.Priv)
395+
fatalIf(err)
396+
userCryptoPubKey, err := sshutils.CryptoPublicKey(cfg.Pub)
397+
fatalIf(err)
398+
tlsUserCert, err := tlsUserCA.GenerateCertificate(tlsca.CertificateRequest{
378399
Clock: clock,
379-
PublicKey: cryptoPubKey,
400+
PublicKey: userCryptoPubKey,
380401
Subject: subject,
381402
NotAfter: clock.Now().UTC().Add(time.Hour * 24),
382403
})
@@ -391,14 +412,16 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
391412
}
392413

393414
secrets := InstanceSecrets{
394-
SiteName: cfg.ClusterName,
395-
PrivKey: cfg.Priv,
396-
PubKey: cfg.Pub,
397-
Cert: cert,
398-
TLSCACert: tlsCACert,
399-
TLSCert: tlsCert,
400-
TunnelAddr: i.ReverseTunnel,
401-
Users: make(map[string]*User),
415+
SiteName: cfg.ClusterName,
416+
PrivKey: cfg.Priv,
417+
PubKey: cfg.Pub,
418+
SSHHostCert: hostCert,
419+
TLSHostCACert: tlsCAHostCert,
420+
TLSHostCert: tlsHostCert,
421+
TLSUserCACert: tlsCAUserCert,
422+
TLSUserCert: tlsUserCert,
423+
TunnelAddr: i.ReverseTunnel,
424+
Users: make(map[string]*User),
402425
}
403426

404427
i.Secrets = secrets

lib/auth/authclient/tls.go

+81-27
Original file line numberDiff line numberDiff line change
@@ -39,59 +39,113 @@ type CAGetter interface {
3939
GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error)
4040
}
4141

42-
// ClientCertPool returns trusted x509 certificate authority pool with CAs provided as caTypes.
42+
// HostAndUserCAInfo is a map of CA raw subjects and type info for Host
43+
// and User CAs. The key is the RawSubject of the X.509 certificate authority
44+
// (so it's ASN.1 data, not printable).
45+
type HostAndUserCAInfo = map[string]CATypeInfo
46+
47+
// CATypeInfo indicates whether the CA is a host or user CA, or both.
48+
type CATypeInfo struct {
49+
IsHostCA bool
50+
IsUserCA bool
51+
}
52+
53+
// ClientCertPool returns trusted x509 certificate authority pool with CAs provided as caType.
4354
// In addition, it returns the total length of all subjects added to the cert pool, allowing
4455
// the caller to validate that the pool doesn't exceed the maximum 2-byte length prefix before
4556
// using it.
46-
func ClientCertPool(ctx context.Context, client CAGetter, clusterName string, caTypes ...types.CertAuthType) (*x509.CertPool, int64, error) {
47-
if len(caTypes) == 0 {
48-
return nil, 0, trace.BadParameter("at least one CA type is required")
57+
func ClientCertPool(ctx context.Context, client CAGetter, clusterName string, caType types.CertAuthType) (*x509.CertPool, int64, error) {
58+
authorities, err := getCACerts(ctx, client, clusterName, caType)
59+
if err != nil {
60+
return nil, 0, trace.Wrap(err)
4961
}
5062

5163
pool := x509.NewCertPool()
52-
var authorities []types.CertAuthority
53-
if clusterName == "" {
54-
for _, caType := range caTypes {
55-
cas, err := client.GetCertAuthorities(ctx, caType, false)
56-
if err != nil {
57-
return nil, 0, trace.Wrap(err)
58-
}
59-
authorities = append(authorities, cas...)
60-
}
61-
} else {
62-
for _, caType := range caTypes {
63-
ca, err := client.GetCertAuthority(
64-
ctx,
65-
types.CertAuthID{Type: caType, DomainName: clusterName},
66-
false)
64+
var totalSubjectsLen int64
65+
for _, auth := range authorities {
66+
for _, keyPair := range auth.GetTrustedTLSKeyPairs() {
67+
cert, err := tlsca.ParseCertificatePEM(keyPair.Cert)
6768
if err != nil {
6869
return nil, 0, trace.Wrap(err)
6970
}
71+
pool.AddCert(cert)
7072

71-
authorities = append(authorities, ca)
73+
// Each subject in the list gets a separate 2-byte length prefix.
74+
totalSubjectsLen += 2
75+
totalSubjectsLen += int64(len(cert.RawSubject))
7276
}
7377
}
78+
return pool, totalSubjectsLen, nil
79+
}
80+
81+
// DefaultClientCertPool returns default trusted x509 certificate authority pool.
82+
func DefaultClientCertPool(ctx context.Context, client CAGetter, clusterName string) (*x509.CertPool, HostAndUserCAInfo, int64, error) {
83+
authorities, err := getCACerts(ctx, client, clusterName, types.HostCA, types.UserCA)
84+
if err != nil {
85+
return nil, nil, 0, trace.Wrap(err)
86+
}
7487

88+
pool := x509.NewCertPool()
89+
caInfos := make(HostAndUserCAInfo, len(authorities))
7590
var totalSubjectsLen int64
7691
for _, auth := range authorities {
7792
for _, keyPair := range auth.GetTrustedTLSKeyPairs() {
7893
cert, err := tlsca.ParseCertificatePEM(keyPair.Cert)
7994
if err != nil {
80-
return nil, 0, trace.Wrap(err)
95+
return nil, nil, 0, trace.Wrap(err)
8196
}
8297
pool.AddCert(cert)
8398

99+
caType := auth.GetType()
100+
caInfo := caInfos[string(cert.RawSubject)]
101+
switch caType {
102+
case types.HostCA:
103+
caInfo.IsHostCA = true
104+
case types.UserCA:
105+
caInfo.IsUserCA = true
106+
default:
107+
return nil, nil, 0, trace.BadParameter("unexpected CA type %q", caType)
108+
}
109+
caInfos[string(cert.RawSubject)] = caInfo
110+
84111
// Each subject in the list gets a separate 2-byte length prefix.
85112
totalSubjectsLen += 2
86113
totalSubjectsLen += int64(len(cert.RawSubject))
87114
}
88115
}
89-
return pool, totalSubjectsLen, nil
116+
117+
return pool, caInfos, totalSubjectsLen, nil
90118
}
91119

92-
// DefaultClientCertPool returns default trusted x509 certificate authority pool.
93-
func DefaultClientCertPool(ctx context.Context, client CAGetter, clusterName string) (*x509.CertPool, int64, error) {
94-
return ClientCertPool(ctx, client, clusterName, types.HostCA, types.UserCA)
120+
func getCACerts(ctx context.Context, client CAGetter, clusterName string, caTypes ...types.CertAuthType) ([]types.CertAuthority, error) {
121+
if len(caTypes) == 0 {
122+
return nil, trace.BadParameter("at least one CA type is required")
123+
}
124+
125+
var authorities []types.CertAuthority
126+
if clusterName == "" {
127+
for _, caType := range caTypes {
128+
cas, err := client.GetCertAuthorities(ctx, caType, false)
129+
if err != nil {
130+
return nil, trace.Wrap(err)
131+
}
132+
authorities = append(authorities, cas...)
133+
}
134+
} else {
135+
for _, caType := range caTypes {
136+
ca, err := client.GetCertAuthority(
137+
ctx,
138+
types.CertAuthID{Type: caType, DomainName: clusterName},
139+
false)
140+
if err != nil {
141+
return nil, trace.Wrap(err)
142+
}
143+
144+
authorities = append(authorities, ca)
145+
}
146+
}
147+
148+
return authorities, nil
95149
}
96150

97151
// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
@@ -110,7 +164,7 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
110164
}
111165
}
112166
}
113-
pool, totalSubjectsLen, err := DefaultClientCertPool(info.Context(), ap, clusterName)
167+
pool, _, totalSubjectsLen, err := DefaultClientCertPool(info.Context(), ap, clusterName)
114168
if err != nil {
115169
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", clusterName)
116170
// this falls back to the default config
@@ -132,7 +186,7 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
132186
if totalSubjectsLen >= int64(math.MaxUint16) {
133187
log.Debugf("Number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate.")
134188

135-
pool, _, err = DefaultClientCertPool(info.Context(), ap, currentClusterName)
189+
pool, _, _, err = DefaultClientCertPool(info.Context(), ap, currentClusterName)
136190
if err != nil {
137191
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", currentClusterName)
138192
// this falls back to the default config

0 commit comments

Comments
 (0)