Skip to content

Commit d544fa3

Browse files
committed
Join Kubernetes sessions in the web ui
1 parent 6119e09 commit d544fa3

File tree

12 files changed

+313
-71
lines changed

12 files changed

+313
-71
lines changed

integration/kube_integration_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -2098,7 +2098,7 @@ func kubeJoin(kubeConfig kube.ProxyConfig, tc *client.TeleportClient, meta types
20982098
return nil, trace.Wrap(err)
20992099
}
21002100

2101-
sess, err := client.NewKubeSession(context.TODO(), tc, meta, tc.KubeProxyAddr, "", mode, tlsConfig)
2101+
sess, err := client.NewKubeSession(context.TODO(), tc, meta, "", mode, tlsConfig)
21022102
if err != nil {
21032103
return nil, trace.Wrap(err)
21042104
}

lib/client/kubesession.go

+38-32
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ type KubeSession struct {
4949
}
5050

5151
// NewKubeSession joins a live kubernetes session.
52-
func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionTracker, kubeAddr string, tlsServer string, mode types.SessionParticipantMode, tlsConfig *tls.Config) (*KubeSession, error) {
52+
func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionTracker, tlsServer string, mode types.SessionParticipantMode, tlsConfig *tls.Config) (*KubeSession, error) {
5353
ctx, cancel := context.WithCancel(ctx)
54-
joinEndpoint := "wss://" + kubeAddr + "/api/v1/teleport/join/" + meta.GetSessionID()
54+
joinEndpoint := "wss://" + tc.KubeProxyAddr + "/api/v1/teleport/join/" + meta.GetSessionID()
5555

5656
if tlsServer != "" {
5757
tlsConfig.ServerName = tlsServer
5858
}
5959

6060
dialer := &websocket.Dialer{
61-
NetDialContext: kubeSessionNetDialer(ctx, tc, kubeAddr).DialContext,
61+
NetDialContext: kubeSessionNetDialer(ctx, tc).DialContext,
6262
TLSClientConfig: tlsConfig,
6363
}
6464

@@ -93,6 +93,10 @@ func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionT
9393
return nil, trace.Wrap(err)
9494
}
9595

96+
context.AfterFunc(ctx, func() {
97+
_ = stream.Close()
98+
})
99+
96100
term, err := terminal.New(tc.Stdin, tc.Stdout, tc.Stderr)
97101
if err != nil {
98102
cancel()
@@ -108,26 +112,25 @@ func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionT
108112
stdout := utils.NewSyncWriter(term.Stdout())
109113

110114
go handleOutgoingResizeEvents(ctx, stream, term)
111-
go handleIncomingResizeEvents(stream, term)
115+
go handleIncomingResizeEvents(ctx, stream, term)
112116

113117
s := &KubeSession{stream, term, ctx, cancel, meta, sync.WaitGroup{}}
114-
err = s.handleMFA(ctx, tc, mode, stdout)
115-
if err != nil {
118+
if err := s.handleMFA(ctx, tc, mode, stdout); err != nil {
116119
return nil, trace.Wrap(err)
117120
}
118121

119-
s.pipeInOut(stdout, tc.EnableEscapeSequences, mode)
122+
s.pipeInOut(ctx, stdout, tc.EnableEscapeSequences, mode)
120123
return s, nil
121124
}
122125

123-
func kubeSessionNetDialer(ctx context.Context, tc *TeleportClient, kubeAddr string) client.ContextDialer {
126+
func kubeSessionNetDialer(ctx context.Context, tc *TeleportClient) client.ContextDialer {
124127
dialOpts := []client.DialOption{
125128
client.WithInsecureSkipVerify(tc.InsecureSkipVerify),
126129
}
127130

128131
// Add options for ALPN connection upgrade only if kube is served at Proxy
129132
// web address.
130-
if tc.WebProxyAddr == kubeAddr && tc.TLSRoutingConnUpgradeRequired {
133+
if tc.WebProxyAddr == tc.KubeProxyAddr && tc.TLSRoutingConnUpgradeRequired {
131134
dialOpts = append(dialOpts,
132135
client.WithALPNConnUpgrade(tc.TLSRoutingConnUpgradeRequired),
133136
client.WithALPNConnUpgradePing(true), // Use Ping protocol for long-lived connections.
@@ -157,27 +160,30 @@ func handleOutgoingResizeEvents(ctx context.Context, stream *streamproto.Session
157160
}
158161
}
159162

160-
func handleIncomingResizeEvents(stream *streamproto.SessionStream, term *terminal.Terminal) {
163+
func handleIncomingResizeEvents(ctx context.Context, stream *streamproto.SessionStream, term *terminal.Terminal) {
161164
events := term.Subscribe()
162165

163166
for {
164-
event, more := <-events
165-
_, ok := event.(terminal.ResizeEvent)
166-
if ok {
167-
w, h, err := term.Size()
168-
if err != nil {
169-
fmt.Printf("Error attempting to fetch terminal size: %v\n\r", err)
170-
}
167+
select {
168+
case <-ctx.Done():
169+
return
170+
case event, more := <-events:
171+
_, ok := event.(terminal.ResizeEvent)
172+
if ok {
173+
w, h, err := term.Size()
174+
if err != nil {
175+
fmt.Printf("Error attempting to fetch terminal size: %v\n\r", err)
176+
}
171177

172-
size := remotecommand.TerminalSize{Width: uint16(w), Height: uint16(h)}
173-
err = stream.Resize(&size)
174-
if err != nil {
175-
fmt.Printf("Error attempting to resize terminal: %v\n\r", err)
178+
size := remotecommand.TerminalSize{Width: uint16(w), Height: uint16(h)}
179+
if err := stream.Resize(&size); err != nil {
180+
fmt.Printf("Error attempting to resize terminal: %v\n\r", err)
181+
}
176182
}
177-
}
178183

179-
if !more {
180-
break
184+
if !more {
185+
return
186+
}
181187
}
182188
}
183189
}
@@ -205,14 +211,15 @@ func (s *KubeSession) handleMFA(ctx context.Context, tc *TeleportClient, mode ty
205211
}
206212

207213
// pipeInOut starts background tasks that copy input to and from the terminal.
208-
func (s *KubeSession) pipeInOut(stdout io.Writer, enableEscapeSequences bool, mode types.SessionParticipantMode) {
214+
func (s *KubeSession) pipeInOut(ctx context.Context, stdout io.Writer, enableEscapeSequences bool, mode types.SessionParticipantMode) {
209215
// wait for the session to copy everything
210216
s.wg.Add(1)
211217
go func() {
212-
defer s.wg.Done()
213-
defer s.cancel()
214-
_, err := io.Copy(stdout, s.stream)
215-
if err != nil {
218+
defer func() {
219+
s.wg.Done()
220+
s.cancel()
221+
}()
222+
if _, err := io.Copy(stdout, s.stream); err != nil {
216223
fmt.Printf("Error while reading remote stream: %v\n\r", err.Error())
217224
}
218225
}()
@@ -225,9 +232,8 @@ func (s *KubeSession) pipeInOut(stdout io.Writer, enableEscapeSequences bool, mo
225232
handlePeerControls(s.term, enableEscapeSequences, s.stream)
226233
default:
227234
handleNonPeerControls(mode, s.term, func() {
228-
err := s.stream.ForceTerminate()
229-
if err != nil {
230-
log.DebugContext(context.Background(), "Error sending force termination request", "error", err)
235+
if err := s.stream.ForceTerminate(); err != nil {
236+
log.DebugContext(ctx, "Error sending force termination request", "error", err)
231237
fmt.Print("\n\rError while sending force termination request\n\r")
232238
}
233239
})

lib/web/apiserver.go

+26-20
Original file line numberDiff line numberDiff line change
@@ -3724,6 +3724,10 @@ func (h *Handler) setDefaultConnectorHandle(w http.ResponseWriter, r *http.Reque
37243724
type podConnectParams struct {
37253725
// Term is the initial PTY size.
37263726
Term session.TerminalParams `json:"term"`
3727+
// SessionID is a Teleport session ID to join as.
3728+
SessionID session.ID `json:"sid"`
3729+
// ParticipantMode is the mode that determines what you can do when you join an active session.
3730+
ParticipantMode types.SessionParticipantMode `json:"mode"`
37273731
}
37283732

37293733
func (h *Handler) podConnect(
@@ -3743,6 +3747,20 @@ func (h *Handler) podConnect(
37433747
return nil, trace.Wrap(err)
37443748
}
37453749

3750+
// If a session is provided, then join an existing session
3751+
// instead of creating a new one.
3752+
if !params.SessionID.IsZero() {
3753+
return nil, trace.Wrap(h.joinKubernetesSession(
3754+
r.Context(),
3755+
params.SessionID.String(),
3756+
params.ParticipantMode,
3757+
sctx,
3758+
site,
3759+
ws,
3760+
))
3761+
}
3762+
3763+
// Wait for the user to supply the pod information.
37463764
execReq, err := readPodExecRequestFromWS(ws)
37473765
if err != nil {
37483766
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || terminal.IsOKWebsocketCloseError(trace.Unwrap(err)) {
@@ -3761,26 +3779,11 @@ func (h *Handler) podConnect(
37613779
return nil, trace.Wrap(err)
37623780
}
37633781

3764-
clt, err := sctx.GetUserClient(r.Context(), site)
3765-
if err != nil {
3766-
return nil, trace.Wrap(err)
3767-
}
3768-
3769-
clusterName := site.GetName()
3770-
3771-
accessChecker, err := sctx.GetUserAccessChecker()
3772-
if err != nil {
3773-
return session.Session{}, trace.Wrap(err)
3774-
}
3775-
policySets := accessChecker.SessionPolicySets()
3776-
accessEvaluator := auth.NewSessionAccessEvaluator(policySets, types.KubernetesSessionKind, sctx.GetUser())
3777-
37783782
sess := session.Session{
37793783
Kind: types.KubernetesSessionKind,
37803784
Login: "root",
3781-
ClusterName: clusterName,
3785+
ClusterName: site.GetName(),
37823786
KubernetesClusterName: execReq.KubeCluster,
3783-
Moderated: accessEvaluator.IsModerated(),
37843787
ID: session.NewID(),
37853788
Created: h.clock.Now().UTC(),
37863789
LastActive: h.clock.Now().UTC(),
@@ -3807,8 +3810,6 @@ func (h *Handler) podConnect(
38073810
return nil, trace.Wrap(err)
38083811
}
38093812

3810-
keepAliveInterval := netConfig.GetKeepAliveInterval()
3811-
38123813
serverAddr, tlsServerName, err := h.getKubeExecClusterData(netConfig)
38133814
if err != nil {
38143815
return nil, trace.Wrap(err)
@@ -3822,13 +3823,18 @@ func (h *Handler) podConnect(
38223823
return nil, trace.Wrap(err)
38233824
}
38243825

3825-
ph := podHandler{
3826+
clt, err := sctx.GetUserClient(r.Context(), site)
3827+
if err != nil {
3828+
return nil, trace.Wrap(err)
3829+
}
3830+
3831+
ph := podExecHandler{
38263832
req: execReq,
38273833
sess: sess,
38283834
sctx: sctx,
38293835
teleportCluster: site.GetName(),
38303836
ws: ws,
3831-
keepAliveInterval: keepAliveInterval,
3837+
keepAliveInterval: netConfig.GetKeepAliveInterval(),
38323838
logger: h.logger.With(teleport.ComponentKey, "pod"),
38333839
userClient: clt,
38343840
localCA: hostCA,

0 commit comments

Comments
 (0)