Skip to content

Commit a5279ee

Browse files
committed
fix: e2e test
1 parent d544fa3 commit a5279ee

File tree

2 files changed

+55
-37
lines changed

2 files changed

+55
-37
lines changed

integration/kube_integration_test.go

+42-35
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import (
3434
"os/user"
3535
"strconv"
3636
"strings"
37-
"sync"
3837
"testing"
3938
"time"
4039

@@ -2121,7 +2120,9 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
21212120

21222121
// fooey
21232122
hostUsername := suite.me.Username
2124-
participantUsername := suite.me.Username + "-participant"
2123+
peerUsername := suite.me.Username + "-peer"
2124+
observer1Username := suite.me.Username + "-observer1"
2125+
observer2Username := suite.me.Username + "-observer2"
21252126
kubeGroups := []string{kube.TestImpersonationGroup}
21262127
kubeUsers := []string{"alice@example.com"}
21272128
role, err := types.NewRole("kubemaster", types.RoleSpecV6{
@@ -2152,7 +2153,9 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
21522153
})
21532154
require.NoError(t, err)
21542155
teleport.AddUserWithRole(hostUsername, role)
2155-
teleport.AddUserWithRole(participantUsername, joinRole)
2156+
teleport.AddUserWithRole(peerUsername, joinRole)
2157+
teleport.AddUserWithRole(observer1Username, joinRole)
2158+
teleport.AddUserWithRole(observer2Username, joinRole)
21562159

21572160
err = teleport.CreateEx(t, nil, tconf)
21582161
require.NoError(t, err)
@@ -2200,30 +2203,32 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
22002203
// We need to wait for the exec request to be handled here for the session to be
22012204
// created. Sadly though the k8s API doesn't give us much indication of when that is.
22022205
var session types.SessionTracker
2203-
require.Eventually(t, func() bool {
2206+
require.EventuallyWithT(t, func(t *assert.CollectT) {
22042207
// We need to wait for the session to be created here. We can't use the
22052208
// session manager's WaitUntilExists method because it doesn't work for
22062209
// kubernetes sessions.
22072210
sessions, err := teleport.Process.GetAuthServer().GetActiveSessionTrackers(context.Background())
2208-
if err != nil || len(sessions) == 0 {
2209-
return false
2211+
assert.NoError(t, err)
2212+
if assert.Len(t, sessions, 1) {
2213+
session = sessions[0]
22102214
}
2211-
2212-
session = sessions[0]
2213-
return true
22142215
}, 10*time.Second, time.Second)
22152216

22162217
participantStdinR, participantStdinW, err := os.Pipe()
22172218
require.NoError(t, err)
22182219
participantStdoutR, participantStdoutW, err := os.Pipe()
22192220
require.NoError(t, err)
2220-
streamsMu := &sync.Mutex{}
2221-
streams := make([]*client.KubeSession, 0, 3)
2222-
observerCaptures := make([]*bytes.Buffer, 0, 2)
2221+
2222+
observerCaptures := make([]*bytes.Buffer, 2)
22232223
albProxy := helpers.MustStartMockALBProxy(t, teleport.Config.Proxy.WebAddr.Addr)
22242224

22252225
// join peer by KubeProxyAddr
22262226
group.Go(func() error {
2227+
defer func() {
2228+
// close participant stdout so that we can read it after till EOF
2229+
participantStdoutW.Close()
2230+
}()
2231+
22272232
tc, err := teleport.NewClient(helpers.ClientConfig{
22282233
Login: hostUsername,
22292234
Cluster: helpers.Site,
@@ -2238,50 +2243,52 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
22382243

22392244
stream, err := kubeJoin(kube.ProxyConfig{
22402245
T: teleport,
2241-
Username: participantUsername,
2246+
Username: peerUsername,
22422247
KubeUsers: kubeUsers,
22432248
KubeGroups: kubeGroups,
22442249
}, tc, session, types.SessionPeerMode)
22452250
if err != nil {
22462251
return trace.Wrap(err)
22472252
}
2248-
streamsMu.Lock()
2249-
streams = append(streams, stream)
2250-
streamsMu.Unlock()
2253+
22512254
stream.Wait()
2252-
// close participant stdout so that we can read it after till EOF
2253-
participantStdoutW.Close()
2255+
2256+
t.Cleanup(func() { _ = stream.Close() })
2257+
22542258
return nil
22552259
})
22562260

22572261
// join observer by WebProxyAddr
22582262
group.Go(func() error {
2259-
stream, capture := kubeJoinByWebAddr(t, teleport, participantUsername, kubeUsers, kubeGroups)
2260-
streamsMu.Lock()
2261-
streams = append(streams, stream)
2262-
observerCaptures = append(observerCaptures, capture)
2263-
streamsMu.Unlock()
2263+
stream, capture := kubeJoinByWebAddr(t, teleport, observer1Username, kubeUsers, kubeGroups)
2264+
observerCaptures[0] = capture
22642265
stream.Wait()
2266+
2267+
t.Cleanup(func() { _ = stream.Close() })
22652268
return nil
22662269
})
22672270

22682271
// join observer with ALPN conn upgrade
22692272
group.Go(func() error {
2270-
stream, capture := kubeJoinByALBAddr(t, teleport, participantUsername, kubeUsers, kubeGroups, albProxy.Addr().String())
2271-
streamsMu.Lock()
2272-
streams = append(streams, stream)
2273-
observerCaptures = append(observerCaptures, capture)
2274-
streamsMu.Unlock()
2273+
stream, capture := kubeJoinByALBAddr(t, teleport, observer2Username, kubeUsers, kubeGroups, albProxy.Addr().String())
2274+
observerCaptures[1] = capture
22752275
stream.Wait()
2276+
2277+
t.Cleanup(func() { _ = stream.Close() })
22762278
return nil
22772279
})
22782280

2279-
// We wait again for the second user to finish joining the session.
2280-
// We allow a bit of time to pass here to give the session manager time to recognize the
2281-
// new IO streams of the second client.
2282-
time.Sleep(time.Second * 5)
2281+
// Wait for all users to finish joining the session.
2282+
require.EventuallyWithT(t, func(t *assert.CollectT) {
2283+
session, err := teleport.Process.GetAuthServer().GetSessionTracker(context.Background(), session.GetName())
2284+
if !assert.NoError(t, err) {
2285+
return
2286+
}
2287+
2288+
assert.Len(t, session.GetParticipants(), 4)
2289+
}, 30*time.Second, 500*time.Millisecond)
22832290

2284-
// sent a test message from the participant
2291+
// send a test message from the participant
22852292
participantStdinW.Write([]byte("\ahi from peer\n\r"))
22862293

22872294
// lets type "echo hi" followed by "enter" and then "exit" + "enter":
@@ -2306,8 +2313,8 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
23062313

23072314
// Verify observers.
23082315
for _, capture := range observerCaptures {
2309-
require.Contains(t, capture.String(), "hi from peer")
2310-
require.Contains(t, capture.String(), "hi from term")
2316+
assert.Contains(t, capture.String(), "hi from peer")
2317+
assert.Contains(t, capture.String(), "hi from term")
23112318
}
23122319
}
23132320

lib/client/kubesession.go

+13-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type KubeSession struct {
4949
}
5050

5151
// NewKubeSession joins a live kubernetes session.
52-
func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionTracker, tlsServer string, mode types.SessionParticipantMode, tlsConfig *tls.Config) (*KubeSession, error) {
52+
func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionTracker, tlsServer string, mode types.SessionParticipantMode, tlsConfig *tls.Config) (_ *KubeSession, err error) {
5353
ctx, cancel := context.WithCancel(ctx)
5454
joinEndpoint := "wss://" + tc.KubeProxyAddr + "/api/v1/teleport/join/" + meta.GetSessionID()
5555

@@ -86,6 +86,15 @@ func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionT
8686

8787
return nil, trace.BadParameter("failed to decode remote error: %v", string(body))
8888
}
89+
defer func() {
90+
if err == nil {
91+
return
92+
}
93+
94+
if err := ws.Close(); err != nil {
95+
log.DebugContext(ctx, "Close stream in response to context termination", "error", err)
96+
}
97+
}()
8998

9099
stream, err := streamproto.NewSessionStream(ws, streamproto.ClientHandshake{Mode: mode})
91100
if err != nil {
@@ -94,7 +103,9 @@ func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionT
94103
}
95104

96105
context.AfterFunc(ctx, func() {
97-
_ = stream.Close()
106+
log.DebugContext(ctx, "Closing stream in response to context termination")
107+
err := stream.Close()
108+
log.DebugContext(ctx, "Closed stream in response to context termination", "error", err)
98109
})
99110

100111
term, err := terminal.New(tc.Stdin, tc.Stdout, tc.Stderr)

0 commit comments

Comments
 (0)