implement session manager
This commit is contained in:
parent
372febe620
commit
5afe7b0962
|
@ -103,3 +103,11 @@ func CreatePutAsymmetricKeyCommand(keyID uint16, label []byte, domains uint16, c
|
|||
|
||||
return command, nil
|
||||
}
|
||||
|
||||
func CreateCloseSessionCommand() (*CommandMessage, error) {
|
||||
command := &CommandMessage{
|
||||
CommandType: CommandTypeCloseSession,
|
||||
}
|
||||
|
||||
return command, nil
|
||||
}
|
||||
|
|
|
@ -74,6 +74,8 @@ func ParseResponse(data []byte) (Response, error) {
|
|||
return parseSignDataEddsaResponse(payload)
|
||||
case CommandTypePutAsymmetric:
|
||||
return parsePutAsymmetricKeyResponse(payload)
|
||||
case CommandTypeCloseSession:
|
||||
return nil, nil
|
||||
case ErrorResponseCode:
|
||||
return nil, parseErrorResponse(payload)
|
||||
default:
|
||||
|
|
43
main.go
43
main.go
|
@ -1,43 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"aiakos/commands"
|
||||
"aiakos/connector"
|
||||
"aiakos/securechannel"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
channel, err := securechannel.NewSecureChannel(connector.NewHTTPConnector("127.0.0.1:12345"), 1, "password")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = channel.Authenticate()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cmd, _ := commands.CreateGenerateAsymmetricKeyCommand(2, []byte("myKey"), commands.Domain1, commands.CapabilityAsymmetricSignEddsa, commands.AlgorighmED25519)
|
||||
res, err := channel.SendEncryptedCommand(cmd)
|
||||
if err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("%v\n", res)
|
||||
|
||||
cmd, _ = commands.CreateSignDataEddsaCommand(2, []byte("my test message"))
|
||||
res, err = channel.SendEncryptedCommand(cmd)
|
||||
if err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("signature: %v\n", res)
|
||||
|
||||
cmd, _ = commands.CreateResetCommand()
|
||||
_, err = channel.SendEncryptedCommand(cmd)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"aiakos/connector"
|
||||
"aiakos/securechannel"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
SessionManager struct {
|
||||
sessions []*securechannel.SecureChannel
|
||||
lock sync.Mutex
|
||||
connector connector.Connector
|
||||
authKeyID uint16
|
||||
password string
|
||||
|
||||
poolSize uint
|
||||
|
||||
creationWait sync.WaitGroup
|
||||
}
|
||||
)
|
||||
|
||||
func NewSessionManager(connector connector.Connector, authKeyID uint16, password string, poolSize uint) (*SessionManager, error) {
|
||||
if poolSize > 16 {
|
||||
return nil, errors.New("pool size exceeds session limit")
|
||||
}
|
||||
|
||||
manager := &SessionManager{
|
||||
sessions: make([]*securechannel.SecureChannel, 0),
|
||||
connector: connector,
|
||||
authKeyID: authKeyID,
|
||||
password: password,
|
||||
poolSize: poolSize,
|
||||
}
|
||||
|
||||
manager.household()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
manager.household()
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}()
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func (s *SessionManager) household() {
|
||||
func() {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
for i, session := range s.sessions {
|
||||
if session.Counter > securechannel.MaxMessagesPerSession*0.9 {
|
||||
// Remove expired session
|
||||
go session.Close()
|
||||
|
||||
copy(s.sessions[i:], s.sessions[i+1:])
|
||||
s.sessions[len(s.sessions)-1] = nil
|
||||
s.sessions = s.sessions[:len(s.sessions)-1]
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < int(s.poolSize)-len(s.sessions); i++ {
|
||||
s.creationWait.Add(1)
|
||||
go func() {
|
||||
defer s.creationWait.Done()
|
||||
|
||||
newSession, err := securechannel.NewSecureChannel(s.connector, s.authKeyID, s.password)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = newSession.Authenticate()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.sessions = append(s.sessions, newSession)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
s.creationWait.Wait()
|
||||
}
|
||||
|
||||
func (s *SessionManager) GetSession() (*securechannel.SecureChannel, error) {
|
||||
if len(s.sessions) == 0 {
|
||||
return nil, errors.New("no sessions available")
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return s.sessions[rand.Intn(len(s.sessions))], nil
|
||||
}
|
|
@ -10,6 +10,7 @@ import (
|
|||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/enceve/crypto/cmac"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -21,6 +22,8 @@ type (
|
|||
authKeySlot uint16
|
||||
// keyChain holds the keys generated in the authentication ceremony
|
||||
keyChain *KeyChain
|
||||
// channelLock is used to lock encrypted communications to prevent race conditions
|
||||
channelLock sync.Mutex
|
||||
|
||||
// ID is the ID of the session with the HSM
|
||||
ID uint8
|
||||
|
@ -76,6 +79,8 @@ const (
|
|||
|
||||
MessageTypeCommand MessageType = 0
|
||||
MessageTypeResponse MessageType = 1
|
||||
|
||||
MaxMessagesPerSession = 10000
|
||||
)
|
||||
|
||||
// NewSecureChannel initiates a new secure channel to communicate with an HSM using the given authKey
|
||||
|
@ -102,6 +107,12 @@ func NewSecureChannel(connector connector.Connector, authKeySlot uint16, passwor
|
|||
|
||||
// Authenticate establishes an authenticated session with the HSM
|
||||
func (s *SecureChannel) Authenticate() error {
|
||||
if s.SecurityLevel != SecurityLevelUnauthenticated {
|
||||
return errors.New("the session is already authenticated")
|
||||
}
|
||||
|
||||
s.channelLock.Lock()
|
||||
defer s.channelLock.Unlock()
|
||||
|
||||
command, _ := commands.CreateCreateSessionCommand(s.authKeySlot, s.HostChallenge)
|
||||
response, err := s.SendCommand(command)
|
||||
|
@ -144,7 +155,7 @@ func (s *SecureChannel) Authenticate() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.SendMACCommand(authenticateCommand)
|
||||
_, err = s.sendMACCommand(authenticateCommand)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -157,27 +168,6 @@ func (s *SecureChannel) Authenticate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SendMACCommand sends a MAC authenticated command to the HSM and returns a parsed response
|
||||
func (s *SecureChannel) SendMACCommand(c *commands.CommandMessage) (commands.Response, error) {
|
||||
|
||||
// Set command sessionID to this session
|
||||
c.SessionID = &s.ID
|
||||
|
||||
// Calculate MAC for the command
|
||||
sum, err := s.calculateMAC(c, MessageTypeCommand)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update chain value
|
||||
s.MACChainValue = sum
|
||||
|
||||
// Set command MAC to calculated mac
|
||||
c.MAC = sum[:MACLength]
|
||||
|
||||
return s.SendCommand(c)
|
||||
}
|
||||
|
||||
// SendCommand sends an unauthenticated command to the HSM and returns the parsed response
|
||||
func (s *SecureChannel) SendCommand(c *commands.CommandMessage) (commands.Response, error) {
|
||||
resp, err := s.connector.Request(c)
|
||||
|
@ -191,6 +181,17 @@ func (s *SecureChannel) SendCommand(c *commands.CommandMessage) (commands.Respon
|
|||
// SendEncryptedCommand sends an encrypted & authenticated command to the HSM
|
||||
// and returns the decrypted and parsed response.
|
||||
func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (commands.Response, error) {
|
||||
if s.SecurityLevel != SecurityLevelAuthenticated {
|
||||
return nil, errors.New("the session is not authenticated")
|
||||
}
|
||||
|
||||
if s.Counter >= MaxMessagesPerSession {
|
||||
return nil, errors.New("channel has reached its message limit; please recreate")
|
||||
}
|
||||
|
||||
// Lock the encrypted channel
|
||||
s.channelLock.Lock()
|
||||
defer s.channelLock.Unlock()
|
||||
|
||||
// Create the cipher using the session encryption key
|
||||
block, err := aes.NewCipher(s.keyChain.EncKey)
|
||||
|
@ -216,7 +217,7 @@ func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (comman
|
|||
encrypter.CryptBlocks(encryptedCommand, pad(commandData))
|
||||
|
||||
// Send the wrapped command in a SessionMessage
|
||||
resp, err := s.SendMACCommand(&commands.CommandMessage{
|
||||
resp, err := s.sendMACCommand(&commands.CommandMessage{
|
||||
CommandType: commands.CommandTypeSessionMessage,
|
||||
Data: encryptedCommand,
|
||||
})
|
||||
|
@ -255,6 +256,41 @@ func (s *SecureChannel) SendEncryptedCommand(c *commands.CommandMessage) (comman
|
|||
return commands.ParseResponse(unpad(decryptedResponse))
|
||||
}
|
||||
|
||||
func (s *SecureChannel) Close() error {
|
||||
command, err := commands.CreateCloseSessionCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.SendEncryptedCommand(command)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendMACCommand sends a MAC authenticated command to the HSM and returns a parsed response
|
||||
func (s *SecureChannel) sendMACCommand(c *commands.CommandMessage) (commands.Response, error) {
|
||||
|
||||
// Set command sessionID to this session
|
||||
c.SessionID = &s.ID
|
||||
|
||||
// Calculate MAC for the command
|
||||
sum, err := s.calculateMAC(c, MessageTypeCommand)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update chain value
|
||||
s.MACChainValue = sum
|
||||
|
||||
// Set command MAC to calculated mac
|
||||
c.MAC = sum[:MACLength]
|
||||
|
||||
return s.SendCommand(c)
|
||||
}
|
||||
|
||||
// calculateMAC calculates the authenticated MAC for a command or response.
|
||||
// This is stateful since it uses the MACChainValue.
|
||||
func (s *SecureChannel) calculateMAC(c *commands.CommandMessage, messageType MessageType) ([]byte, error) {
|
||||
|
|
Loading…
Reference in New Issue