diff --git a/zk-token-sdk/src/encryption/aes.rs b/zk-token-sdk/src/encryption/aes.rs index ed2b679a3..e6dc704a9 100644 --- a/zk-token-sdk/src/encryption/aes.rs +++ b/zk-token-sdk/src/encryption/aes.rs @@ -1,64 +1,86 @@ -#[cfg(not(target_arch = "bpf"))] -use rand::{rngs::OsRng, Rng}; - use { - aes::{ - cipher::{BlockDecrypt, BlockEncrypt, NewBlockCipher}, - Aes128, Block, - }, - arrayref::array_ref, + ed25519_dalek::SecretKey as SigningKey, + solana_sdk::pubkey::Pubkey, + std::convert::TryInto, zeroize::Zeroize, }; +#[cfg(not(target_arch = "bpf"))] +use { + aes_gcm::{aead::Aead, Aes128Gcm, NewAead}, + rand::{CryptoRng, rngs::OsRng, Rng, RngCore}, + sha3::{Digest, Sha3_256}, +}; -pub struct AES; +struct AES; impl AES { #[cfg(not(target_arch = "bpf"))] #[allow(clippy::new_ret_no_self)] - pub fn new() -> AESKey { + fn keygen(rng: &mut T) -> AesKey { let random_bytes = OsRng.gen::<[u8; 16]>(); - AESKey(random_bytes) + AesKey(random_bytes) } #[cfg(not(target_arch = "bpf"))] - pub fn encrypt(sk: &AESKey, amount: u64) -> AESCiphertext { - let amount_bytes = amount.to_le_bytes(); + fn encrypt(sk: &AesKey, amount: u64) -> AesCiphertext { + let plaintext = amount.to_le_bytes(); + let nonce = OsRng.gen::<[u8; 12]>(); - let mut aes_block: Block = [0_u8; 16].into(); - aes_block[..8].copy_from_slice(&amount_bytes); + // TODO: it seems like encryption cannot fail, but will need to double check + let ciphertext = Aes128Gcm::new(&sk.0.into()) + .encrypt(&nonce.into(), plaintext.as_ref()).unwrap(); - Aes128::new(&sk.0.into()).encrypt_block(&mut aes_block); - AESCiphertext(aes_block.into()) + AesCiphertext { + nonce, + ciphertext: ciphertext.try_into().unwrap(), + } } #[cfg(not(target_arch = "bpf"))] - pub fn decrypt(sk: &AESKey, ct: &AESCiphertext) -> u64 { - let mut aes_block: Block = ct.0.into(); - Aes128::new(&sk.0.into()).decrypt_block(&mut aes_block); + fn decrypt(sk: &AesKey, ct: &AesCiphertext) -> Option { + let plaintext = Aes128Gcm::new(&sk.0.into()) + .decrypt(&ct.nonce.into(), ct.ciphertext.as_ref()); - let amount_bytes = array_ref![aes_block[..8], 0, 8]; - u64::from_le_bytes(*amount_bytes) + if let Ok(plaintext) = plaintext { + let amount_bytes: [u8; 8] = plaintext.try_into().unwrap(); + Some(u64::from_le_bytes(amount_bytes)) + } else { + None + } } } #[derive(Debug, Zeroize)] -pub struct AESKey([u8; 16]); -impl AESKey { - pub fn encrypt(&self, amount: u64) -> AESCiphertext { +pub struct AesKey([u8; 16]); +impl AesKey { + pub fn new(signing_key: &SigningKey, address: &Pubkey) -> Self { + let mut hashable = [0_u8; 64]; + hashable[..32].copy_from_slice(&signing_key.to_bytes()); + hashable[32..].copy_from_slice(&address.to_bytes()); + + let mut hasher = Sha3_256::new(); + hasher.update(hashable); + + let result: [u8; 16] = hasher.finalize()[..16].try_into().unwrap(); + AesKey(result) + } + + pub fn random(rng: &mut T) -> Self { + AES::keygen(&mut rng) + } + + pub fn encrypt(&self, amount: u64) -> AesCiphertext { AES::encrypt(self, amount) } } #[derive(Debug)] -pub struct AESCiphertext(pub [u8; 16]); -impl AESCiphertext { - pub fn decrypt(&self, sk: &AESKey) -> u64 { - AES::decrypt(sk, self) - } +pub struct AesCiphertext { + pub nonce: [u8; 12], + pub ciphertext: [u8; 24], } - -impl Default for AESCiphertext { - fn default() -> Self { - AESCiphertext([0_u8; 16]) +impl AesCiphertext { + pub fn decrypt(&self, key: &AesKey) -> Option { + AES::decrypt(key, self) } } @@ -68,11 +90,11 @@ mod tests { #[test] fn test_aes_encrypt_decrypt_correctness() { - let sk = AES::new(); + let key = AesKey::random(&mut OsRng); let amount = 55; - let ct = sk.encrypt(amount); - let decrypted_amount = ct.decrypt(&sk); + let ct = key.encrypt(amount); + let decrypted_amount = ct.decrypt(&key).unwrap(); assert_eq!(amount, decrypted_amount); }