From f0973543a02ed4afb56be0e22f670caf2b150fb1 Mon Sep 17 00:00:00 2001 From: Sam Lancia Date: Tue, 22 Oct 2019 12:36:09 +0100 Subject: [PATCH] Server able to close all connections for given user --- server.go | 42 ++++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/server.go b/server.go index cad0402..9a3a4d5 100644 --- a/server.go +++ b/server.go @@ -60,7 +60,7 @@ type Server struct { listenerWg sync.WaitGroup mu sync.Mutex listeners map[net.Listener]struct{} - conns map[*gossh.ServerConn]struct{} + userConns map[string]map[*gossh.ServerConn]struct{} connWg sync.WaitGroup doneChan chan struct{} } @@ -155,9 +155,12 @@ func (srv *Server) Close() error { defer srv.mu.Unlock() srv.closeDoneChanLocked() err := srv.closeListenersLocked() - for c := range srv.conns { - c.Close() - delete(srv.conns, c) + for user, conns := range srv.userConns { + for c := range conns { + c.Close() + delete(conns, c) + } + delete(srv.userConns, user) } return err } @@ -323,6 +326,19 @@ func (srv *Server) SetOption(option Option) error { return option(srv) } +func (srv *Server) CloseConnectionsForUser(user string) { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.userConns[user] == nil { + return + } + + for conn := range srv.userConns[user] { + conn.Close() + } +} + func (srv *Server) getDoneChan() <-chan struct{} { srv.mu.Lock() defer srv.mu.Unlock() @@ -368,7 +384,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) { if add { // If the *Server is being reused after a previous // Close or Shutdown, reset its doneChan: - if len(srv.listeners) == 0 && len(srv.conns) == 0 { + if len(srv.listeners) == 0 && len(srv.userConns) == 0 { srv.doneChan = nil } srv.listeners[ln] = struct{}{} @@ -382,14 +398,20 @@ func (srv *Server) trackListener(ln net.Listener, add bool) { func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { srv.mu.Lock() defer srv.mu.Unlock() - if srv.conns == nil { - srv.conns = make(map[*gossh.ServerConn]struct{}) + if srv.userConns == nil { + srv.userConns = make(map[string]map[*gossh.ServerConn]struct{}) + } + if srv.userConns[c.User()] == nil { + srv.userConns[c.User()] = make(map[*gossh.ServerConn]struct{}) } if add { - srv.conns[c] = struct{}{} + srv.userConns[c.User()][c] = struct{}{} srv.connWg.Add(1) } else { - delete(srv.conns, c) + delete(srv.userConns[c.User()], c) + if len(srv.userConns[c.User()]) == 0 { + delete(srv.userConns, c.User()) + } srv.connWg.Done() } -} +} \ No newline at end of file