implement session manager
This commit is contained in:
parent
372febe620
commit
5afe7b0962
|
@ -103,3 +103,11 @@ func CreatePutAsymmetricKeyCommand(keyID uint16, label []byte, domains uint16, c
|
||||||
|
|
||||||
return command, nil
|
return command, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateCloseSessionCommand() (*CommandMessage, error) {
|
||||||
|
command := &CommandMessage{
|
||||||
|
CommandType: CommandTypeCloseSession,
|
||||||
|
}
|
||||||
|
|
||||||
|
return command, nil
|
||||||
|
}
|
||||||
|
|
|
@ -74,6 +74,8 @@ func ParseResponse(data []byte) (Response, error) {
|
||||||
return parseSignDataEddsaResponse(payload)
|
return parseSignDataEddsaResponse(payload)
|
||||||
case CommandTypePutAsymmetric:
|
case CommandTypePutAsymmetric:
|
||||||
return parsePutAsymmetricKeyResponse(payload)
|
return parsePutAsymmetricKeyResponse(payload)
|
||||||
|
case CommandTypeCloseSession:
|
||||||
|
return nil, nil
|
||||||
case ErrorResponseCode:
|
case ErrorResponseCode:
|
||||||
return nil, parseErrorResponse(payload)
|
return nil, parseErrorResponse(payload)
|
||||||
default:
|
default:
|
||||||
|
|
43
main.go
43
main.go
|
@ -1,43 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"aiakos/commands"
|
|
||||||
"aiakos/connector"
|
|
||||||
"aiakos/securechannel"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
channel, err := securechannel.NewSecureChannel(connector.NewHTTPConnector("127.0.0.1:12345"), 1, "password")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = channel.Authenticate()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd, _ := commands.CreateGenerateAsymmetricKeyCommand(2, []byte("myKey"), commands.Domain1, commands.CapabilityAsymmetricSignEddsa, commands.AlgorighmED25519)
|
|
||||||
res, err := channel.SendEncryptedCommand(cmd)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("%v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("%v\n", res)
|
|
||||||
|
|
||||||
cmd, _ = commands.CreateSignDataEddsaCommand(2, []byte("my test message"))
|
|
||||||
res, err = channel.SendEncryptedCommand(cmd)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("%v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("signature: %v\n", res)
|
|
||||||
|
|
||||||
cmd, _ = commands.CreateResetCommand()
|
|
||||||
_, err = channel.SendEncryptedCommand(cmd)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"aiakos/connector"
|
||||||
|
"aiakos/securechannel"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
SessionManager struct {
|
||||||
|
sessions []*securechannel.SecureChannel
|
||||||
|
lock sync.Mutex
|
||||||
|
connector connector.Connector
|
||||||
|
authKeyID uint16
|
||||||
|
password string
|
||||||
|
|
||||||
|
poolSize uint
|
||||||
|
|
||||||
|
creationWait sync.WaitGroup
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewSessionManager(connector connector.Connector, authKeyID uint16, password string, poolSize uint) (*SessionManager, error) {
|
||||||
|
if poolSize > 16 {
|
||||||
|
return nil, errors.New("pool size exceeds session limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &SessionManager{
|
||||||
|
sessions: make([]*securechannel.SecureChannel, 0),
|
||||||
|
connector: connector,
|
||||||
|
authKeyID: authKeyID,
|
||||||
|
password: password,
|
||||||
|
poolSize: poolSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.household()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
manager.household()
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return manager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionManager) household() {
|
||||||
|
func() {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
for i, session := range s.sessions {
|
||||||
|
if session.Counter > securechannel.MaxMessagesPerSession*0.9 {
|
||||||
|
// Remove expired session
|
||||||
|
go session.Close()
|
||||||
|
|
||||||
|
copy(s.sessions[i:], s.sessions[i+1:])
|
||||||
|
s.sessions[len(s.sessions)-1] = nil
|
||||||
|
s.sessions = s.sessions[:len(s.sessions)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < int(s.poolSize)-len(s.sessions); i++ {
|
||||||
|
s.creationWait.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.creationWait.Done()
|
||||||
|
|
||||||
|
newSession, err := securechannel.NewSecureChannel(s.connector, s.authKeyID, s.password)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = newSession.Authenticate()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
s.sessions = append(s.sessions, newSession)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.creationWait.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionManager) GetSession() (*securechannel.SecureChannel, error) {
|
||||||
|
if len(s.sessions) == 0 {
|
||||||
|
return nil, errors.New("no sessions available")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
return s.sessions[rand.Intn(len(s.sessions))], nil
|
||||||
|
}
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/enceve/crypto/cmac"
|
"github.com/enceve/crypto/cmac"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -21,6 +22,8 @@ type (
|
||||||
authKeySlot uint16
|
authKeySlot uint16
|
||||||
// keyChain holds the keys generated in the authentication ceremony
|
// keyChain holds the keys generated in the authentication ceremony
|
||||||
keyChain *KeyChain
|
keyChain *KeyChain
|
||||||
|
// channelLock is used to lock encrypted communications to prevent race conditions
|
||||||
|
channelLock sync.Mutex
|
||||||
|
|
||||||
// ID is the ID of the session with the HSM
|
// ID is the ID of the session with the HSM
|
||||||
ID uint8
|
ID uint8
|
||||||
|
@ -76,6 +79,8 @@ const (
|
||||||
|
|
||||||
MessageTypeCommand MessageType = 0
|
MessageTypeCommand MessageType = 0
|
||||||
MessageTypeResponse MessageType = 1
|
MessageTypeResponse MessageType = 1
|
||||||
|
|
||||||
|
MaxMessagesPerSession = 10000
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewSecureChannel initiates a new secure channel to communicate with an HSM using the given authKey
|
// NewSecureChannel initiates a new secure channel to communicate with an HSM using the given authKey
|
||||||
|
@ -102,6 +107,12 @@ func NewSecureChannel(connector connector.Connector, authKeySlot uint16, passwor
|
||||||
|
|
||||||
// Authenticate establishes an authenticated session with the HSM
|
// Authenticate establishes an authenticated session with the HSM
|
||||||
func (s *SecureChannel) Authenticate() error {
|
func (s *SecureChannel) Authenticate() error {
|
||||||
|
if s.SecurityLevel != SecurityLevelUnauthenticated {
|
||||||
|
return errors.New("the session is already authenticated")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.channelLock.Lock()
|
||||||
|
defer s.channelLock.Unlock()
|
||||||
|
|
||||||
command, _ := commands.CreateCreateSessionCommand(s.authKeySlot, s.HostChallenge)
|
command, _ := commands.CreateCreateSessionCommand(s.authKeySlot, s.HostChallenge)
|
||||||
response, err := s.SendCommand(command)
|
response, err := s.SendCommand(command)
|
||||||
|
@ -144,7 +155,7 @@ func (s *SecureChannel) Authenticate() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = s.SendMACCommand(authenticateCommand)
|
_, err = s.sendMACCommand(authenticateCommand)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -157,27 +168,6 @@ func (s *SecureChannel) Authenticate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMACCommand sends a MAC authenticated command to the HSM and returns a parsed response
|
|
||||||
func (s *SecureChannel) SendMACCommand(c *commands.CommandMessage) (commands.Response, error) {
|
|
||||||
|
|
||||||
// Set command sessionID to this session
|
|
||||||
c.SessionID = &s.ID
|
|
||||||
|
|
||||||
// Calculate MAC for the command
|
|
||||||
sum, err := s.calculateMAC(c, MessageTypeCommand)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update chain value
|
|
||||||
s.MACChainValue = sum
|
|
||||||
|
|
||||||
// Set command MAC to calculated mac
|
|
||||||
c.MAC = sum[:MACLength]
|
|
||||||
|
|
||||||
return s.SendCommand(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendCommand sends an unauthenticated command to the HSM and returns the parsed response
|
// SendCommand sends an unauthenticated command to the HSM and returns the parsed response
|
||||||
func (s *SecureChannel) SendCommand(c *commands.CommandMessage) (commands.Response, error) {
|
func (s *SecureChannel) SendCommand(c *commands.CommandMessage) (commands.Response, error) {
|
||||||
resp, err := s.connector.Request(c)
|
resp, err := s.connector.Request(c)
|
||||||
|
@ -191,6 +181,17 @@ func (s *SecureChannel) SendCommand(c *commands.CommandMessage) (commands.Respon
|
||||||
// SendEncryptedCommand sends an encrypted & authenticated command to the HSM
|
// SendEncryptedCommand sends an encrypted & authenticated command to the HSM
|
||||||
// and returns the decrypted and parsed response.
|
// and returns the decrypted and parsed response.
|
||||||
func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (commands.Response, error) {
|
func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (commands.Response, error) {
|
||||||
|
if s.SecurityLevel != SecurityLevelAuthenticated {
|
||||||
|
return nil, errors.New("the session is not authenticated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Counter >= MaxMessagesPerSession {
|
||||||
|
return nil, errors.New("channel has reached its message limit; please recreate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock the encrypted channel
|
||||||
|
s.channelLock.Lock()
|
||||||
|
defer s.channelLock.Unlock()
|
||||||
|
|
||||||
// Create the cipher using the session encryption key
|
// Create the cipher using the session encryption key
|
||||||
block, err := aes.NewCipher(s.keyChain.EncKey)
|
block, err := aes.NewCipher(s.keyChain.EncKey)
|
||||||
|
@ -216,7 +217,7 @@ func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (comman
|
||||||
encrypter.CryptBlocks(encryptedCommand, pad(commandData))
|
encrypter.CryptBlocks(encryptedCommand, pad(commandData))
|
||||||
|
|
||||||
// Send the wrapped command in a SessionMessage
|
// Send the wrapped command in a SessionMessage
|
||||||
resp, err := s.SendMACCommand(&commands.CommandMessage{
|
resp, err := s.sendMACCommand(&commands.CommandMessage{
|
||||||
CommandType: commands.CommandTypeSessionMessage,
|
CommandType: commands.CommandTypeSessionMessage,
|
||||||
Data: encryptedCommand,
|
Data: encryptedCommand,
|
||||||
})
|
})
|
||||||
|
@ -255,6 +256,41 @@ func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (comman
|
||||||
return commands.ParseResponse(unpad(decryptedResponse))
|
return commands.ParseResponse(unpad(decryptedResponse))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SecureChannel) Close() error {
|
||||||
|
command, err := commands.CreateCloseSessionCommand()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.SendEncryptedCommand(command)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMACCommand sends a MAC authenticated command to the HSM and returns a parsed response
|
||||||
|
func (s *SecureChannel) sendMACCommand(c *commands.CommandMessage) (commands.Response, error) {
|
||||||
|
|
||||||
|
// Set command sessionID to this session
|
||||||
|
c.SessionID = &s.ID
|
||||||
|
|
||||||
|
// Calculate MAC for the command
|
||||||
|
sum, err := s.calculateMAC(c, MessageTypeCommand)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update chain value
|
||||||
|
s.MACChainValue = sum
|
||||||
|
|
||||||
|
// Set command MAC to calculated mac
|
||||||
|
c.MAC = sum[:MACLength]
|
||||||
|
|
||||||
|
return s.SendCommand(c)
|
||||||
|
}
|
||||||
|
|
||||||
// calculateMAC calculates the authenticated MAC for a command or response.
|
// calculateMAC calculates the authenticated MAC for a command or response.
|
||||||
// This is stateful since it uses the MACChainValue.
|
// This is stateful since it uses the MACChainValue.
|
||||||
func (s *SecureChannel) calculateMAC(c *commands.CommandMessage, messageType MessageType) ([]byte, error) {
|
func (s *SecureChannel) calculateMAC(c *commands.CommandMessage, messageType MessageType) ([]byte, error) {
|
||||||
|
|
Loading…
Reference in New Issue