Skip to content

Commit 5dee41c

Browse files
authored
Merge pull request #193 from hashicorp/fix-server-automtls
automtls: fix bidirectional communication when AutoMTLS is enabled
2 parents e4102ee + afb1659 commit 5dee41c

File tree

3 files changed

+77
-4
lines changed

3 files changed

+77
-4
lines changed

client.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ func (c *Client) Start() (addr net.Addr, err error) {
574574

575575
c.config.TLSConfig = &tls.Config{
576576
Certificates: []tls.Certificate{cert},
577+
ClientAuth: tls.RequireAndVerifyClientCert,
578+
MinVersion: tls.VersionTLS12,
577579
ServerName: "localhost",
578580
}
579581
}
@@ -776,7 +778,7 @@ func (c *Client) Start() (addr net.Addr, err error) {
776778
}
777779

778780
// loadServerCert is used by AutoMTLS to read an x.509 cert returned by the
779-
// server, and load it as the RootCA for the client TLSConfig.
781+
// server, and load it as the RootCA and ClientCA for the client TLSConfig.
780782
func (c *Client) loadServerCert(cert string) error {
781783
certPool := x509.NewCertPool()
782784

@@ -793,6 +795,7 @@ func (c *Client) loadServerCert(cert string) error {
793795
certPool.AddCert(x509Cert)
794796

795797
c.config.TLSConfig.RootCAs = certPool
798+
c.config.TLSConfig.ClientCAs = certPool
796799
return nil
797800
}
798801

server.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,13 @@ func Serve(opts *ServeConfig) {
304304

305305
certPEM, keyPEM, err := generateCert()
306306
if err != nil {
307-
logger.Error("failed to generate client certificate", "error", err)
307+
logger.Error("failed to generate server certificate", "error", err)
308308
panic(err)
309309
}
310310

311311
cert, err := tls.X509KeyPair(certPEM, keyPEM)
312312
if err != nil {
313-
logger.Error("failed to parse client certificate", "error", err)
313+
logger.Error("failed to parse server certificate", "error", err)
314314
panic(err)
315315
}
316316

@@ -319,6 +319,8 @@ func Serve(opts *ServeConfig) {
319319
ClientAuth: tls.RequireAndVerifyClientCert,
320320
ClientCAs: clientCertPool,
321321
MinVersion: tls.VersionTLS12,
322+
RootCAs: clientCertPool,
323+
ServerName: "localhost",
322324
}
323325

324326
// We send back the raw leaf cert data for the client rather than the

server_test.go

+69-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,75 @@ func TestServer_testMode(t *testing.T) {
8282
t.Logf("HELLO")
8383
}
8484

85+
func TestServer_testMode_AutoMTLS(t *testing.T) {
86+
ctx, cancel := context.WithCancel(context.Background())
87+
defer cancel()
88+
89+
closeCh := make(chan struct{})
90+
go Serve(&ServeConfig{
91+
HandshakeConfig: testVersionedHandshake,
92+
VersionedPlugins: map[int]PluginSet{
93+
2: testGRPCPluginMap,
94+
},
95+
GRPCServer: DefaultGRPCServer,
96+
Logger: hclog.NewNullLogger(),
97+
Test: &ServeTestConfig{
98+
Context: ctx,
99+
ReattachConfigCh: nil,
100+
CloseCh: closeCh,
101+
},
102+
})
103+
104+
// Connect!
105+
process := helperProcess("test-mtls")
106+
c := NewClient(&ClientConfig{
107+
Cmd: process,
108+
HandshakeConfig: testVersionedHandshake,
109+
VersionedPlugins: map[int]PluginSet{
110+
2: testGRPCPluginMap,
111+
},
112+
AllowedProtocols: []Protocol{ProtocolGRPC},
113+
AutoMTLS: true,
114+
})
115+
client, err := c.Client()
116+
if err != nil {
117+
t.Fatalf("err: %s", err)
118+
}
119+
120+
// Pinging should work
121+
if err := client.Ping(); err != nil {
122+
t.Fatalf("should not err: %s", err)
123+
}
124+
125+
// Grab the impl
126+
raw, err := client.Dispense("test")
127+
if err != nil {
128+
t.Fatalf("err should be nil, got %s", err)
129+
}
130+
131+
tester, ok := raw.(testInterface)
132+
if !ok {
133+
t.Fatalf("bad: %#v", raw)
134+
}
135+
136+
n := tester.Double(3)
137+
if n != 6 {
138+
t.Fatal("invalid response", n)
139+
}
140+
141+
// ensure we can make use of bidirectional communication with AutoMTLS
142+
// enabled
143+
err = tester.Bidirectional()
144+
if err != nil {
145+
t.Fatal("invalid response", err)
146+
}
147+
148+
c.Kill()
149+
// Canceling should cause an exit
150+
cancel()
151+
<-closeCh
152+
}
153+
85154
func TestRmListener_impl(t *testing.T) {
86155
var _ net.Listener = new(rmListener)
87156
}
@@ -145,7 +214,6 @@ func TestProtocolSelection_no_server(t *testing.T) {
145214
if protocol != ProtocolNetRPC {
146215
t.Fatalf("bad protocol %s", protocol)
147216
}
148-
149217
}
150218

151219
func TestServer_testStdLogger(t *testing.T) {

0 commit comments

Comments
 (0)