From cbe259b1eea4d9429b0d083619a980e581ac59e8 Mon Sep 17 00:00:00 2001 From: Hendrik Hofstadt Date: Sat, 20 Oct 2018 17:38:17 +0200 Subject: [PATCH] Fix race in session recycler; Closes #2 --- manager.go | 123 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 78 insertions(+), 45 deletions(-) diff --git a/manager.go b/manager.go index cc66c27..22ddb58 100644 --- a/manager.go +++ b/manager.go @@ -15,7 +15,7 @@ import ( type ( // SessionManager manages a pool of authenticated secure sessions with a YubiHSM2 SessionManager struct { - sessions []*securechannel.SecureChannel + sessions sessionList lock sync.Mutex connector connector.Connector authKeyID uint16 @@ -27,8 +27,11 @@ type ( destroyed bool // Connected indicates whether a successful connection with the HSM is established - Connected chan bool + Connected chan bool + recycleQueue chan *securechannel.SecureChannel } + + sessionList []*securechannel.SecureChannel ) var ( @@ -43,13 +46,14 @@ func NewSessionManager(connector connector.Connector, authKeyID uint16, password } manager := &SessionManager{ - sessions: make([]*securechannel.SecureChannel, 0), - connector: connector, - authKeyID: authKeyID, - password: password, - poolSize: poolSize, - destroyed: false, - Connected: make(chan bool, 1), + sessions: make([]*securechannel.SecureChannel, 0), + connector: connector, + authKeyID: authKeyID, + password: password, + poolSize: poolSize, + destroyed: false, + Connected: make(chan bool, 1), + recycleQueue: make(chan *securechannel.SecureChannel, 20), } manager.household() @@ -61,6 +65,29 @@ func NewSessionManager(connector connector.Connector, authKeyID uint16, password } }() + // Recycler function + go func() { + for channel := range manager.recycleQueue { + func() { + manager.lock.Lock() + defer manager.lock.Unlock() + + // Remove from list + pos := manager.sessions.pos(channel) + + manager.sessions[pos] = manager.sessions[len(manager.sessions)-1] + manager.sessions[len(manager.sessions)-1] = nil + manager.sessions = manager.sessions[:len(manager.sessions)-1] + }() + + channel.Close() + err := manager.createSession() + if err != nil { + fmt.Println(err.Error()) + } + } + }() + return manager, nil } @@ -69,7 +96,7 @@ func (s *SessionManager) household() { s.lock.Lock() defer s.lock.Unlock() - for i, session := range s.sessions { + for _, session := range s.sessions { // Send echo command command, _ := commands.CreateEchoCommand(echoPayload) resp, err := session.SendEncryptedCommand(command) @@ -85,47 +112,46 @@ func (s *SessionManager) household() { if session.Counter > securechannel.MaxMessagesPerSession*0.9 || err != nil { // Remove expired session - go session.Close() - - copy(s.sessions[i:], s.sessions[i+1:]) - s.sessions[len(s.sessions)-1] = nil - s.sessions = s.sessions[:len(s.sessions)-1] + s.recycleQueue <- session } } - - for i := 0; i < int(s.poolSize)-len(s.sessions); i++ { - s.creationWait.Add(1) - go func() { - defer s.creationWait.Done() - - newSession, err := securechannel.NewSecureChannel(s.connector, s.authKeyID, s.password) - if err != nil { - fmt.Println(err.Error()) - return - } - - err = newSession.Authenticate() - if err != nil { - fmt.Println(err) - return - } - - s.lock.Lock() - defer s.lock.Unlock() - s.sessions = append(s.sessions, newSession) - select { - case s.Connected <- true: - default: - } - }() - } }() - s.creationWait.Wait() + for i := 0; i < int(s.poolSize)-len(s.sessions); i++ { + err := s.createSession() + if err != nil { + fmt.Println(err.Error()) + } + } +} + +func (s *SessionManager) createSession() error { + newSession, err := securechannel.NewSecureChannel(s.connector, s.authKeyID, s.password) + if err != nil { + return err + } + + err = newSession.Authenticate() + if err != nil { + return err + } + + s.lock.Lock() + defer s.lock.Unlock() + s.sessions = append(s.sessions, newSession) + select { + case s.Connected <- true: + default: + } + + return nil } // GetSession returns a secure authenticated session with the HSM from the pool on which commands can be executed func (s *SessionManager) GetSession() (*securechannel.SecureChannel, error) { + s.lock.Lock() + defer s.lock.Unlock() + if s.destroyed { return nil, errors.New("sessionmanager has already been destroyed") } @@ -133,8 +159,6 @@ func (s *SessionManager) GetSession() (*securechannel.SecureChannel, error) { return nil, errors.New("no sessions available") } - s.lock.Lock() - defer s.lock.Unlock() return s.sessions[rand.Intn(len(s.sessions))], nil } @@ -149,3 +173,12 @@ func (s *SessionManager) Destroy() { } s.destroyed = true } + +func (slice sessionList) pos(value *securechannel.SecureChannel) int { + for p, v := range slice { + if v == value { + return p + } + } + return -1 +}