Extract common perm() helper from poseidon_fp

This commit is contained in:
therealyingtong 2022-05-19 12:59:38 +08:00
parent bbec810925
commit 3aca9f5b80
2 changed files with 8 additions and 46 deletions

View File

@ -18,7 +18,7 @@ R_P = 56
# Width
t = 3
def perm(input_words):
def perm(input_words, ROUND_CONSTANTS, MDS_MATRIX):
R_f = int(R_F / 2)
round_constants_counter = 0
@ -157,7 +157,7 @@ CAPACITY_ELEMENT = Fp(1 << 65)
def hash(x, y):
assert isinstance(x, Fp)
assert isinstance(y, Fp)
return perm([x, y, CAPACITY_ELEMENT])[0]
return perm([x, y, CAPACITY_ELEMENT], ROUND_CONSTANTS, MDS_MATRIX)[0]
def main():
@ -173,7 +173,7 @@ def main():
Fp(0x0a49c868c6976544256fcd597984561af7cfdfe1bda42c7b359029a1d34e9ddd),
]
assert perm(fixed_test_input) == fixed_test_output
assert perm(fixed_test_input, ROUND_CONSTANTS, MDS_MATRIX) == fixed_test_output
test_vectors = [fixed_test_input]
@ -203,7 +203,7 @@ def main():
),
[{
'initial_state': list(map(bytes, input)),
'final_state': list(map(bytes, perm(input))),
'final_state': list(map(bytes, perm(input, ROUND_CONSTANTS, MDS_MATRIX))),
} for input in test_vectors],
)

View File

@ -5,6 +5,7 @@ import numpy as np
from itertools import chain
from .vesta import Fq
from .poseidon_fp import perm
from ..utils import leos2ip
from ..output import render_args, render_tv
@ -18,45 +19,6 @@ R_P = 56
# Width
t = 3
def perm(input_words):
R_f = int(R_F / 2)
round_constants_counter = 0
state_words = list(input_words)
assert len(state_words) == t
# First full rounds
for r in range(0, R_f):
# Round constants, nonlinear layer, matrix multiplication
for i in range(0, t):
state_words[i] = state_words[i] + ROUND_CONSTANTS[round_constants_counter]
round_constants_counter += 1
for i in range(0, t):
state_words[i] = (state_words[i]).exp(5)
state_words = list(np.array(MDS_MATRIX).dot(np.array(state_words, dtype=object)))
# Middle partial rounds
for r in range(0, R_P):
# Round constants, nonlinear layer, matrix multiplication
for i in range(0, t):
state_words[i] = state_words[i] + ROUND_CONSTANTS[round_constants_counter]
round_constants_counter += 1
state_words[0] = (state_words[0]).exp(5)
state_words = list(np.array(MDS_MATRIX).dot(np.array(state_words, dtype=object)))
# Last full rounds
for r in range(0, R_f):
# Round constants, nonlinear layer, matrix multiplication
for i in range(0, t):
state_words[i] = state_words[i] + ROUND_CONSTANTS[round_constants_counter]
round_constants_counter += 1
for i in range(0, t):
state_words[i] = (state_words[i]).exp(5)
state_words = list(np.array(MDS_MATRIX).dot(np.array(state_words, dtype=object)))
return state_words
# Round constants generated by the reference implementation script, commit 659de89
# https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/generate_parameters_grain.sage
#
@ -157,7 +119,7 @@ CAPACITY_ELEMENT = Fq(1 << 65)
def hash(x, y):
assert isinstance(x, Fq)
assert isinstance(y, Fq)
return perm([x, y, CAPACITY_ELEMENT])[0]
return perm([x, y, CAPACITY_ELEMENT], ROUND_CONSTANTS, MDS_MATRIX)[0]
def main():
@ -173,7 +135,7 @@ def main():
Fq(0x25ab8aece9537168117fdb2420d8ea605019bfd4e0423fa014d542372a7ba0d9),
]
assert perm(fixed_test_input) == fixed_test_output
assert perm(fixed_test_input, ROUND_CONSTANTS, MDS_MATRIX) == fixed_test_output
test_vectors = [fixed_test_input]
@ -203,7 +165,7 @@ def main():
),
[{
'initial_state': list(map(bytes, input)),
'final_state': list(map(bytes, perm(input))),
'final_state': list(map(bytes, perm(input, ROUND_CONSTANTS, MDS_MATRIX))),
} for input in test_vectors],
)