diff --git a/commands/constructors.go b/commands/constructors.go index c86745b..6e14898 100644 --- a/commands/constructors.go +++ b/commands/constructors.go @@ -103,3 +103,11 @@ func CreatePutAsymmetricKeyCommand(keyID uint16, label []byte, domains uint16, c return command, nil } + +func CreateCloseSessionCommand() (*CommandMessage, error) { + command := &CommandMessage{ + CommandType: CommandTypeCloseSession, + } + + return command, nil +} diff --git a/commands/response.go b/commands/response.go index 968731e..7541dd7 100644 --- a/commands/response.go +++ b/commands/response.go @@ -74,6 +74,8 @@ func ParseResponse(data []byte) (Response, error) { return parseSignDataEddsaResponse(payload) case CommandTypePutAsymmetric: return parsePutAsymmetricKeyResponse(payload) + case CommandTypeCloseSession: + return nil, nil case ErrorResponseCode: return nil, parseErrorResponse(payload) default: diff --git a/main.go b/main.go deleted file mode 100644 index a9ea1d6..0000000 --- a/main.go +++ /dev/null @@ -1,43 +0,0 @@ -package main - -import ( - "aiakos/commands" - "aiakos/connector" - "aiakos/securechannel" - "fmt" -) - -func main() { - - channel, err := securechannel.NewSecureChannel(connector.NewHTTPConnector("127.0.0.1:12345"), 1, "password") - if err != nil { - panic(err) - } - - err = channel.Authenticate() - if err != nil { - panic(err) - } - - cmd, _ := commands.CreateGenerateAsymmetricKeyCommand(2, []byte("myKey"), commands.Domain1, commands.CapabilityAsymmetricSignEddsa, commands.AlgorighmED25519) - res, err := channel.SendEncryptedCommand(cmd) - if err != nil { - fmt.Printf("%v\n", err) - } - - fmt.Printf("%v\n", res) - - cmd, _ = commands.CreateSignDataEddsaCommand(2, []byte("my test message")) - res, err = channel.SendEncryptedCommand(cmd) - if err != nil { - fmt.Printf("%v\n", err) - } - - fmt.Printf("signature: %v\n", res) - - cmd, _ = commands.CreateResetCommand() - _, err = channel.SendEncryptedCommand(cmd) - if err != nil { - panic(err) - } -} diff --git a/manager.go b/manager.go new file mode 100644 index 0000000..b9cafb0 --- /dev/null +++ b/manager.go @@ -0,0 +1,103 @@ +package main + +import ( + "aiakos/connector" + "aiakos/securechannel" + "errors" + "fmt" + "math/rand" + "sync" + "time" +) + +type ( + SessionManager struct { + sessions []*securechannel.SecureChannel + lock sync.Mutex + connector connector.Connector + authKeyID uint16 + password string + + poolSize uint + + creationWait sync.WaitGroup + } +) + +func NewSessionManager(connector connector.Connector, authKeyID uint16, password string, poolSize uint) (*SessionManager, error) { + if poolSize > 16 { + return nil, errors.New("pool size exceeds session limit") + } + + manager := &SessionManager{ + sessions: make([]*securechannel.SecureChannel, 0), + connector: connector, + authKeyID: authKeyID, + password: password, + poolSize: poolSize, + } + + manager.household() + + go func() { + for { + manager.household() + time.Sleep(5 * time.Second) + } + }() + + return manager, nil +} + +func (s *SessionManager) household() { + func() { + s.lock.Lock() + defer s.lock.Unlock() + + for i, session := range s.sessions { + if session.Counter > securechannel.MaxMessagesPerSession*0.9 { + // Remove expired session + go session.Close() + + copy(s.sessions[i:], s.sessions[i+1:]) + s.sessions[len(s.sessions)-1] = nil + s.sessions = s.sessions[:len(s.sessions)-1] + } + } + + for i := 0; i < int(s.poolSize)-len(s.sessions); i++ { + s.creationWait.Add(1) + go func() { + defer s.creationWait.Done() + + newSession, err := securechannel.NewSecureChannel(s.connector, s.authKeyID, s.password) + if err != nil { + fmt.Println(err.Error()) + return + } + + err = newSession.Authenticate() + if err != nil { + fmt.Println(err) + return + } + + s.lock.Lock() + defer s.lock.Unlock() + s.sessions = append(s.sessions, newSession) + }() + } + }() + + s.creationWait.Wait() +} + +func (s *SessionManager) GetSession() (*securechannel.SecureChannel, error) { + if len(s.sessions) == 0 { + return nil, errors.New("no sessions available") + } + + s.lock.Lock() + defer s.lock.Unlock() + return s.sessions[rand.Intn(len(s.sessions))], nil +} diff --git a/securechannel/channel.go b/securechannel/channel.go index 9f5ebd4..fb7363a 100644 --- a/securechannel/channel.go +++ b/securechannel/channel.go @@ -10,6 +10,7 @@ import ( "encoding/binary" "errors" "github.com/enceve/crypto/cmac" + "sync" ) type ( @@ -21,6 +22,8 @@ type ( authKeySlot uint16 // keyChain holds the keys generated in the authentication ceremony keyChain *KeyChain + // channelLock is used to lock encrypted communications to prevent race conditions + channelLock sync.Mutex // ID is the ID of the session with the HSM ID uint8 @@ -76,6 +79,8 @@ const ( MessageTypeCommand MessageType = 0 MessageTypeResponse MessageType = 1 + + MaxMessagesPerSession = 10000 ) // NewSecureChannel initiates a new secure channel to communicate with an HSM using the given authKey @@ -102,6 +107,12 @@ func NewSecureChannel(connector connector.Connector, authKeySlot uint16, passwor // Authenticate establishes an authenticated session with the HSM func (s *SecureChannel) Authenticate() error { + if s.SecurityLevel != SecurityLevelUnauthenticated { + return errors.New("the session is already authenticated") + } + + s.channelLock.Lock() + defer s.channelLock.Unlock() command, _ := commands.CreateCreateSessionCommand(s.authKeySlot, s.HostChallenge) response, err := s.SendCommand(command) @@ -144,7 +155,7 @@ func (s *SecureChannel) Authenticate() error { if err != nil { return err } - _, err = s.SendMACCommand(authenticateCommand) + _, err = s.sendMACCommand(authenticateCommand) if err != nil { return err } @@ -157,27 +168,6 @@ func (s *SecureChannel) Authenticate() error { return nil } -// SendMACCommand sends a MAC authenticated command to the HSM and returns a parsed response -func (s *SecureChannel) SendMACCommand(c *commands.CommandMessage) (commands.Response, error) { - - // Set command sessionID to this session - c.SessionID = &s.ID - - // Calculate MAC for the command - sum, err := s.calculateMAC(c, MessageTypeCommand) - if err != nil { - return nil, err - } - - // Update chain value - s.MACChainValue = sum - - // Set command MAC to calculated mac - c.MAC = sum[:MACLength] - - return s.SendCommand(c) -} - // SendCommand sends an unauthenticated command to the HSM and returns the parsed response func (s *SecureChannel) SendCommand(c *commands.CommandMessage) (commands.Response, error) { resp, err := s.connector.Request(c) @@ -191,6 +181,17 @@ func (s *SecureChannel) SendCommand(c *commands.CommandMessage) (commands.Respon // SendEncryptedCommand sends an encrypted & authenticated command to the HSM // and returns the decrypted and parsed response. func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (commands.Response, error) { + if s.SecurityLevel != SecurityLevelAuthenticated { + return nil, errors.New("the session is not authenticated") + } + + if s.Counter >= MaxMessagesPerSession { + return nil, errors.New("channel has reached its message limit; please recreate") + } + + // Lock the encrypted channel + s.channelLock.Lock() + defer s.channelLock.Unlock() // Create the cipher using the session encryption key block, err := aes.NewCipher(s.keyChain.EncKey) @@ -216,7 +217,7 @@ func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (comman encrypter.CryptBlocks(encryptedCommand, pad(commandData)) // Send the wrapped command in a SessionMessage - resp, err := s.SendMACCommand(&commands.CommandMessage{ + resp, err := s.sendMACCommand(&commands.CommandMessage{ CommandType: commands.CommandTypeSessionMessage, Data: encryptedCommand, }) @@ -255,6 +256,41 @@ func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (comman return commands.ParseResponse(unpad(decryptedResponse)) } +func (s *SecureChannel) Close() error { + command, err := commands.CreateCloseSessionCommand() + if err != nil { + return err + } + + _, err = s.SendEncryptedCommand(command) + if err != nil { + return err + } + + return nil +} + +// sendMACCommand sends a MAC authenticated command to the HSM and returns a parsed response +func (s *SecureChannel) sendMACCommand(c *commands.CommandMessage) (commands.Response, error) { + + // Set command sessionID to this session + c.SessionID = &s.ID + + // Calculate MAC for the command + sum, err := s.calculateMAC(c, MessageTypeCommand) + if err != nil { + return nil, err + } + + // Update chain value + s.MACChainValue = sum + + // Set command MAC to calculated mac + c.MAC = sum[:MACLength] + + return s.SendCommand(c) +} + // calculateMAC calculates the authenticated MAC for a command or response. // This is stateful since it uses the MACChainValue. func (s *SecureChannel) calculateMAC(c *commands.CommandMessage, messageType MessageType) ([]byte, error) {