diff --git a/lib/src/main/java/cash/z/ecc/android/bip39/Mnemonics.kt b/lib/src/main/java/cash/z/ecc/android/bip39/Mnemonics.kt index a49e50a..da725c9 100644 --- a/lib/src/main/java/cash/z/ecc/android/bip39/Mnemonics.kt +++ b/lib/src/main/java/cash/z/ecc/android/bip39/Mnemonics.kt @@ -128,58 +128,73 @@ object Mnemonics { } } - // verify: checksum (this function contains a checksum validation) - toEntropy() + // verify: checksum + validateChecksum() } /** - * Convert this mnemonic word list to its original entropy value. + * Convenience method for validating the checksum of this MnemonicCode. Since validation + * requires deriving the original entropy, this function is the same as calling [toEntropy]. + */ + fun validateChecksum() = toEntropy() + + /** + * Get the original entropy that was used to create this MnemonicCode. This call will fail + * if the words have an invalid length or checksum. + * + * @throws WordCountException when the word count is zero or not a multiple of 3. + * @throws ChecksumException if the checksum does not match the expected value. */ fun toEntropy(): ByteArray { - wordCount.let { wordCount -> - if (wordCount % 3 > 0) throw WordCountException(wordCount) - } - if (isEmpty()) throw RuntimeException("Word list is empty.") + wordCount.let { if (it <= 0 || it % 3 > 0) throw WordCountException(wordCount) } // Look up all the words in the list and construct the // concatenation of the original entropy and the checksum. // - val concatLenBits = wordCount * 11 - val concatBits = BooleanArray(concatLenBits) - var wordindex = 0 + val totalLengthBits = wordCount * 11 + val checksumLengthBits = totalLengthBits / 33 + val entropy = ByteArray((totalLengthBits - checksumLengthBits) / 8) + val checksumBits = mutableListOf() - // TODO: iterate by characters instead of by words, for a little added security - forEach { word -> - // Find the words index in the wordlist. - val ndx = getCachedWords(languageCode).binarySearch(word) - if (ndx < 0) throw InvalidWordException(word) - - // Set the next 11 bits to the value of the index. - for (ii in 0..10) concatBits[wordindex * 11 + ii] = - ndx and (1 shl 10 - ii) != 0 - ++wordindex - } - val checksumLengthBits = concatLenBits / 33 - val entropyLengthBits = concatLenBits - checksumLengthBits - - // Extract original entropy as bytes. - val entropy = ByteArray(entropyLengthBits / 8) - for (ii in entropy.indices) - for (jj in 0..7) - if (concatBits[ii * 8 + jj]) { - entropy[ii] = entropy[ii] or (1 shl 7 - jj).toByte() + val words = getCachedWords(languageCode) + var bitsProcessed = 0 + var nextByte = 0.toByte() + this.forEach { + words.binarySearch(it).let { phraseIndex -> + // fail if the word was not found on the list + if (phraseIndex < 0) throw InvalidWordException(it) + // for each of the 11 bits of the phraseIndex + (10 downTo 0).forEach { i -> + // isolate the next bit (starting from the big end) + val bit = phraseIndex and (1 shl i) != 0 + // if the bit is set, then update the corresponding bit in the nextByte + if (bit) nextByte = nextByte or (1 shl 7 - (bitsProcessed).rem(8)).toByte() + val entropyIndex = ((++bitsProcessed) - 1) / 8 + // if we're at a byte boundary (excluding the extra checksum bits) + if (bitsProcessed.rem(8) == 0 && entropyIndex < entropy.size) { + // then set the byte and prepare to process the next byte + entropy[entropyIndex] = nextByte + nextByte = 0.toByte() + // if we're now processing checksum bits, then track them for later + } else if (entropyIndex >= entropy.size) { + checksumBits.add(bit) + } } + } + } - // Take the digest of the entropy. - val hash: ByteArray = entropy.toSha256() - val hashBits = hash.toBits() + // Check each required checksum bit, against the first byte of the sha256 of entropy + entropy.toSha256()[0].toBits().let { hashFirstByteBits -> + repeat(checksumLengthBits) { i -> + // failure means that each word was valid BUT they were in the wrong order + if (hashFirstByteBits[i] != checksumBits[i]) throw ChecksumException + } + } - // Check all the checksum bits. - for (i in 0 until checksumLengthBits) - if (concatBits[entropyLengthBits + i] != hashBits[i]) throw ChecksumException return entropy } + companion object { /** @@ -317,6 +332,9 @@ fun MnemonicCode.toSeed( passphrase: CharArray = charArrayOf(), validate: Boolean = true ): ByteArray { + // we can skip validation when we know for sure that the code is valid + // such as when it was just generated from new/correct entropy (common case for new seeds) + if (validate) validate() return (DEFAULT_PASSPHRASE.toCharArray() + passphrase).toBytes().let { salt -> PBEKeySpec(chars, salt, INTERATION_COUNT, KEY_SIZE).let { pbeKeySpec -> SecretKeyFactory.getInstance(PBE_ALGORITHM).generateSecret(pbeKeySpec).encoded.also { @@ -337,9 +355,9 @@ fun WordCount.toEntropy(): ByteArray = ByteArray(bitLength / 8).apply { private fun ByteArray?.toSha256() = MessageDigest.getInstance("SHA-256").digest(this) -private fun ByteArray.toBits(): List { - return flatMap { b -> (7 downTo 0).map { (b.toInt() and (1 shl it)) != 0 } } -} +private fun ByteArray.toBits(): List = flatMap { it.toBits() } + +private fun Byte.toBits(): List = (7 downTo 0).map { (toInt() and (1 shl it)) != 0 } private fun CharArray.toBytes(): ByteArray { val byteBuffer = CharBuffer.wrap(this).let { Charset.forName("UTF-8").encode(it) } diff --git a/lib/src/test/java/cash/z/ecc/android/bip39/MnemonicsTest.kt b/lib/src/test/java/cash/z/ecc/android/bip39/MnemonicsTest.kt index faaf823..6c84aeb 100644 --- a/lib/src/test/java/cash/z/ecc/android/bip39/MnemonicsTest.kt +++ b/lib/src/test/java/cash/z/ecc/android/bip39/MnemonicsTest.kt @@ -6,10 +6,8 @@ import com.squareup.moshi.JsonClass import com.squareup.moshi.Moshi import com.squareup.moshi.kotlin.reflect.KotlinJsonAdapterFactory import io.kotest.assertions.asClue -import io.kotest.assertions.fail -import io.kotest.assertions.forEachAsClue +import io.kotest.assertions.throwables.shouldNotThrowAny import io.kotest.assertions.throwables.shouldThrow -import io.kotest.assertions.withClue import io.kotest.core.spec.style.BehaviorSpec import io.kotest.data.forAll import io.kotest.data.row @@ -139,8 +137,7 @@ class MnemonicsTest : BehaviorSpec({ val mnemonic = it[1].toCharArray() val seed = it[2] val passphrase = "TREZOR".toCharArray() - val language = Locale.ENGLISH.language - MnemonicCode(mnemonic, language).toSeed(passphrase).toHex() shouldBe seed + MnemonicCode(mnemonic, lang).toSeed(passphrase).toHex() shouldBe seed } } } @@ -149,25 +146,54 @@ class MnemonicsTest : BehaviorSpec({ Given("an invalid mnemonic") { When("it was created by swapping two words in a valid mnemonic") { // swapped "trend" and "flight" - val mnemonicPhrase = validPhrase.swap(4, 5) - Then("it fails with a checksum error") { - mnemonicPhrase.asClue { + validPhrase.swap(4, 5).asClue { mnemonicPhrase -> + Then("validate() fails with a checksum error") { shouldThrow { MnemonicCode(mnemonicPhrase).validate() } } + Then("toEntropy() fails with a checksum error") { + shouldThrow { + MnemonicCode(mnemonicPhrase).toEntropy() + } + } + Then("toSeed() fails with a checksum error") { + shouldThrow { + MnemonicCode(mnemonicPhrase).toSeed() + } + } + Then("toSeed(validate=false) succeeds!!") { + shouldNotThrowAny { + MnemonicCode(mnemonicPhrase).toSeed(validate = false) + } + } } } When("it contains an invalid word") { val mnemonicPhrase = validPhrase.split(' ').let { words -> validPhrase.replace(words[23], "convincee") } - Then("it fails with a word validation error") { - mnemonicPhrase.asClue { + mnemonicPhrase.asClue { + Then("validate() fails with a word validation error") { shouldThrow { MnemonicCode(mnemonicPhrase).validate() } } + Then("toEntropy() fails with a word validation error") { + shouldThrow { + MnemonicCode(mnemonicPhrase).toEntropy() + } + } + Then("toSeed() fails with a word validation error") { + shouldThrow { + MnemonicCode(mnemonicPhrase).toSeed() + } + } + Then("toSeed(validate=false) succeeds!!") { + shouldNotThrowAny { + MnemonicCode(mnemonicPhrase).toSeed(validate = false) + } + } } } When("it contains an unsupported number of words") { @@ -176,6 +202,15 @@ class MnemonicsTest : BehaviorSpec({ shouldThrow { MnemonicCode(mnemonicPhrase).validate() } + shouldThrow { + MnemonicCode(mnemonicPhrase).toEntropy() + } + shouldThrow { + MnemonicCode(mnemonicPhrase).toSeed() + } + shouldNotThrowAny { + MnemonicCode(mnemonicPhrase).toSeed(validate = false) + } } } }