diff --git a/ip_cores/common/src/rtl/common_if.sv b/ip_cores/common/src/rtl/common_if.sv index 0f3286b..53d4280 100644 --- a/ip_cores/common/src/rtl/common_if.sv +++ b/ip_cores/common/src/rtl/common_if.sv @@ -75,7 +75,7 @@ interface if_axi_stream # ( endfunction // Task to apply signals from one task to another in a clocked process - task copy_if(if_t in); + task automatic copy_if(if_t in); dat <= in.dat; val <= in.val; sop <= in.sop; @@ -86,7 +86,7 @@ interface if_axi_stream # ( endtask // Same task but for comb - task copy_if_comb(if_t in); + task automatic copy_if_comb(if_t in); dat = in.dat; val = in.val; sop = in.sop; diff --git a/ip_cores/util/src/rtl/bin_inv.sv b/ip_cores/util/src/rtl/bin_inv.sv index e27d492..118b735 100644 --- a/ip_cores/util/src/rtl/bin_inv.sv +++ b/ip_cores/util/src/rtl/bin_inv.sv @@ -1,5 +1,5 @@ /* - Calculates inversion mod P using binary gcd algorithm. + Calculates inversion mod p using binary gcd algorithm. Copyright (C) 2019 Benjamin Devlin and Zcash Foundation @@ -18,13 +18,13 @@ */ module bin_inv #( - parameter BITS, - parameter [BITS-1:0] P + parameter BITS )( input i_clk, input i_rst, input [BITS-1:0] i_dat, input i_val, + input [BITS-1:0] i_p, output logic o_rdy, output logic [BITS-1:0] o_dat, output logic o_val, @@ -32,6 +32,7 @@ module bin_inv #( ); logic [BITS:0] x1, x2, u, v; +logic [BITS-1:0] p_l; enum {IDLE, U_STATE, @@ -48,6 +49,7 @@ always_ff @ (posedge i_clk) begin o_rdy <= 0; o_val <= 0; o_dat <= 0; + p_l <= 0; state <= IDLE; end else begin o_rdy <= 0; @@ -58,7 +60,8 @@ always_ff @ (posedge i_clk) begin if (o_rdy && i_val) begin o_rdy <= 0; u <= i_dat; - v <= P; + v <= i_p; + p_l <= i_p; x1 <= 1; x2 <= 0; state <= U_STATE; @@ -72,7 +75,7 @@ always_ff @ (posedge i_clk) begin if (x1 % 2 == 0) begin x1 <= x1/2; end else begin - x1 <= (x1 + P)/2; + x1 <= (x1 + p_l)/2; end if ((u/2) % 2 == 1) begin state <= V_STATE; @@ -87,7 +90,7 @@ always_ff @ (posedge i_clk) begin if (x2 % 2 == 0) begin x2 <= x2/2; end else begin - x2 <= (x2 + P)/2; + x2 <= (x2 + p_l)/2; end if ((v/2 % 2) == 1) begin state <= UPDATE; @@ -98,13 +101,13 @@ always_ff @ (posedge i_clk) begin state <= U_STATE; if (u >= v) begin u <= u - v; - x1 <= x1 + (x1 >= x2 ? 0 : P) - x2; + x1 <= x1 + (x1 >= x2 ? 0 : p_l) - x2; if (u - v == 1 || v == 1) begin state <= FINISHED; end end else begin v <= v - u; - x2 <= x2 + (x2 >= x1 ? 0 : P) - x1; + x2 <= x2 + (x2 >= x1 ? 0 : p_l) - x1; if (v - u == 1 || u == 1) begin state <= FINISHED; end diff --git a/ip_cores/util/src/rtl/karatsuba_ofman_mult.sv b/ip_cores/util/src/rtl/karatsuba_ofman_mult.sv index 6773955..b54599b 100644 --- a/ip_cores/util/src/rtl/karatsuba_ofman_mult.sv +++ b/ip_cores/util/src/rtl/karatsuba_ofman_mult.sv @@ -39,13 +39,22 @@ module karatsuba_ofman_mult # ( localparam HBITS = BITS/2; -logic [BITS-1:0] m0, m1, m2; +logic [BITS-1:0] m0, m1, m2, dat_a, dat_b; logic [BITS*2-1:0] q; logic [HBITS-1:0] a0, a1; logic sign, sign_; logic val; logic [CTL_BITS-1:0] ctl; +always_ff @ (posedge i_clk) begin + dat_a <= i_dat_a; + dat_b <= i_dat_b; + + o_dat <= q; + o_val <= val; + o_ctl <= ctl; +end + generate always_comb begin a0 = i_dat_a[0 +: HBITS] > i_dat_a[HBITS +: HBITS] ? i_dat_a[0 +: HBITS] - i_dat_a[HBITS +: HBITS] : i_dat_a[HBITS +: HBITS] - i_dat_a[0 +: HBITS]; @@ -137,10 +146,4 @@ generate end endgenerate -always_ff @ (posedge i_clk) begin - o_dat <= q; - o_val <= val; - o_ctl <= ctl; -end - endmodule \ No newline at end of file diff --git a/ip_cores/util/src/rtl/packet_arb.sv b/ip_cores/util/src/rtl/packet_arb.sv index c0e170a..82fab0c 100644 --- a/ip_cores/util/src/rtl/packet_arb.sv +++ b/ip_cores/util/src/rtl/packet_arb.sv @@ -1,6 +1,8 @@ /* Takes in multiple streams and round robins between them. + The last $clog2(NUM_IN) bits on ctl will be overwritten with the identifier for the channel. + Copyright (C) 2019 Benjamin Devlin and Zcash Foundation This program is free software: you can redistribute it and/or modify @@ -20,7 +22,8 @@ module packet_arb # ( parameter DAT_BYTS, parameter CTL_BITS, - parameter NUM_IN + parameter NUM_IN, + parameter PIPELINE = 1 ) ( input i_clk, i_rst, @@ -42,17 +45,51 @@ logic [NUM_IN-1:0][CTL_BITS-1:0] ctl; generate genvar g; for (g = 0; g < NUM_IN; g++) begin: GEN - always_comb begin - i_axi[g].rdy = rdy[g]; - val[g] = i_axi[g].val; - eop[g] = i_axi[g].eop; - sop[g] = i_axi[g].sop; - err[g] = i_axi[g].err; - dat[g] = i_axi[g].dat; - mod[g] = i_axi[g].mod; - ctl[g] = i_axi[g].ctl; + + // Optionally pipeline the input + if (PIPELINE == 0) begin: PIPELINE_GEN + + always_comb begin + i_axi[g].rdy = rdy[g]; + val[g] = i_axi[g].val; + eop[g] = i_axi[g].eop; + sop[g] = i_axi[g].sop; + err[g] = i_axi[g].err; + dat[g] = i_axi[g].dat; + mod[g] = i_axi[g].mod; + ctl[g] = i_axi[g].ctl; + ctl[g][CTL_BITS-1 -: $clog2(NUM_IN)] = g; + end + + end else begin + + always_comb i_axi[g].rdy = ~val[g] || (val[g] && rdy[g]); + + always_ff @ (posedge i_clk) begin + if (i_rst) begin + val[g] <= 0; + eop[g] <= 0; + sop[g] <= 0; + err[g] <= 0; + dat[g] <= 0; + mod[g] <= 0; + ctl[g] <= 0; + end else begin + if (~val[g] || (val[g] && rdy[g])) begin + val[g] <= i_axi[g].val; + eop[g] <= i_axi[g].eop; + sop[g] <= i_axi[g].sop; + err[g] <= i_axi[g].err; + dat[g] <= i_axi[g].dat; + mod[g] <= i_axi[g].mod; + ctl[g] <= i_axi[g].ctl; + ctl[g][CTL_BITS-1 -: $clog2(NUM_IN)] <= g; + end + end + end + end - end + end endgenerate always_comb begin @@ -75,7 +112,7 @@ always_ff @ (posedge i_clk) begin end else begin if (~locked) begin idx <= get_next(idx); - if (val[get_next(idx)]) begin + if (val[get_next(idx)] && ~(eop[idx] && rdy[idx])) begin locked <= 1; end end else if (eop[idx] && val[idx] && rdy[idx]) begin diff --git a/ip_cores/util/src/tb/bin_inv_tb.sv b/ip_cores/util/src/tb/bin_inv_tb.sv index 9c5a00f..d50a5c3 100644 --- a/ip_cores/util/src/tb/bin_inv_tb.sv +++ b/ip_cores/util/src/tb/bin_inv_tb.sv @@ -50,13 +50,13 @@ always_ff @ (posedge clk) $error(1, "%m %t ERROR: output .err asserted", $time); bin_inv #( - .P ( secp256k1_pkg::p_eq ), - .BITS ( 256 ) + .BITS ( 256 ) ) bin_inv ( .i_clk( clk ), .i_rst( rst ), .i_dat( in_if.dat ), + .i_p ( secp256k1_pkg::p_eq ), .i_val( in_if.val ), .o_rdy( in_if.rdy ), .o_dat( out_if.dat ), diff --git a/zcash_fpga/src/rtl/secp256k1/secp256k1_pkg.sv b/zcash_fpga/src/rtl/secp256k1/secp256k1_pkg.sv index 30653d6..dae199c 100644 --- a/zcash_fpga/src/rtl/secp256k1/secp256k1_pkg.sv +++ b/zcash_fpga/src/rtl/secp256k1/secp256k1_pkg.sv @@ -19,7 +19,6 @@ package secp256k1_pkg; - // TODO might have to flip these parameter [255:0] p = 256'hFFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_FFFFFC2F; parameter [255:0] a = 256'h0; parameter [255:0] b = 256'h7; @@ -47,6 +46,8 @@ package secp256k1_pkg; logic [255:0] x, y, z; } jb_point_t; + jb_point_t G_p = {x: secp256k1_pkg::Gx, y: secp256k1_pkg::Gy, z:1}; + typedef struct packed { logic [5:0] padding; logic X_INFINITY_POINT; @@ -56,19 +57,34 @@ package secp256k1_pkg; function is_zero(jb_point_t p); is_zero = (p.x == 0 && p.y == 0 && p.z == 1); + return is_zero; endfunction // Function to double point in Jacobian coordinates (for comparison in testbench) // Here a is 0, and we also mod p the result function jb_point_t dbl_jb_point(jb_point_t p); - logic [1023:0] A, B, C, D; - A = (p.y*p.y) % p_eq; - B = (4*p.x*A) % p_eq; - C = (8*A*A) % p_eq; - D = (3*p.x*p.x) % p_eq; - dbl_jb_point.x = (D*D - 2*B) % p_eq; - dbl_jb_point.y = (D*(B-dbl_jb_point.x) - C) % p_eq; - dbl_jb_point.z = (2*p.y*p.z) % p_eq; + logic signed [512:0] I_X, I_Y, I_Z, A, B, C, D, X, Y, Z; + + I_X = p.x; + I_Y = p.y; + I_Z = p.z; + A = (I_Y*I_Y) % p_eq; + B = (((4*I_X) % p_eq)*A) % p_eq; + C = (((8*A) % p_eq)*A) % p_eq; + D = (((3*I_X)% p_eq)*I_X) % p_eq; + X = (D*D)% p_eq; + X = X + ((2*B) % p_eq > X ? p_eq : 0) - (2*B) % p_eq; + + Y = (D*((B + (X > B ? p_eq : 0)-X) % p_eq)) % p_eq; + Y = Y + (C > Y ? p_eq : 0) - C; + Z = (((2*I_Y)% p_eq)*I_Z) % p_eq; + + dbl_jb_point = {x:X, y:Y, z:Z}; + return dbl_jb_point; + endfunction + + function on_curve(jb_point_t p); + return (p.y*p.y - p.x*p.x*p.x - secp256k1_pkg::a*p.x*p.z*p.z*p.z*p.z - secp256k1_pkg::b*p.z*p.z*p.z*p.z*p.z*p.z); endfunction function print_jb_point(jb_point_t p); diff --git a/zcash_fpga/src/rtl/secp256k1/secp256k1_point_add.sv b/zcash_fpga/src/rtl/secp256k1/secp256k1_point_add.sv new file mode 100644 index 0000000..b08c2fd --- /dev/null +++ b/zcash_fpga/src/rtl/secp256k1/secp256k1_point_add.sv @@ -0,0 +1,278 @@ +/* + This performs point addition. + + Copyright (C) 2019 Benjamin Devlin and Zcash Foundation + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +*/ + +module secp256k1_point_add + import secp256k1_pkg::*; +#( +)( + input i_clk, i_rst, + // Input points + input jb_point_t i_p1, + input jb_point_t i_p2, + input logic i_val, + output logic o_rdy, + // Output point + output jb_point_t o_p, + input logic i_rdy, + output logic o_val, + output logic o_err, + // Interface to 256bit multiplier (mod p) + if_axi_stream.source o_mult_if, + if_axi_stream.source i_mult_if, + // Interface to only mod reduction block + if_axi_stream.source o_mod_if, + if_axi_stream.source i_mod_if +); + +/* + * These are the equations that need to be computed, they are issued as variables + * become valid. We have a bitmask to track what equation results are valid which + * will trigger other equations. [] show what equations must be valid before this starts. + * We reuse input points (as they are latched) when possible to reduce register usage. + * + * 0. A = i_p1.y - i_p2.y mod p + * 1. B = i_p1.x - i_p2.x mod p + * 2. o_p.z = B * i_p1.z mod p [eq1] + * 3. i_p1.z = B * B mod p [eq2] + * 4. i_p2.x = A * A mod p [eq0, eq5] + * 5. o_p.x = i_p1.x + i_p2.x mod p + * 6. o_p.x = o_p.x * i_p1.z mod p [eq5, eq3] + * 7. o_p.x = i_p2.x - o_p.x mod p[eq6, eq4] + * 8. o_p.y = i_p1.x*i_p1.z mod p [eq3] + * 9. o_p.y = o_p.y - o_p.x mod p [eq3, eq7, eq8] + * 10. o_p.y = o_p.y * A mod p [eq0, eq9] + * 11. i_p2.y = B * i_p1.z mod p [eq1, eq3, eq0] + * 12. i_p2.y = i_p2.y * i_p1.y [eq11] + * 13. o_p.y = o_p.y - i_p2.y mod p [eq12, eq10] + */ + + // We also check in the inital state if one of the inputs is "None" (.z == 0), and set the output to the other point +logic [13:0] eq_val, eq_wait; + +// Temporary variables +logic [255:0] A, B; +jb_point_t i_p1_l, i_p2_l; + +always_comb begin + o_mult_if.sop = 1; + o_mult_if.eop = 1; + o_mod_if.sop = 1; + o_mod_if.eop = 1; + o_mod_if.err = 1; + o_mod_if.mod = 0; + o_mult_if.err = 1; + o_mult_if.mod = 0; +end + +enum {IDLE, START, FINISHED} state; +always_ff @ (posedge i_clk) begin + if (i_rst) begin + o_val <= 0; + o_rdy <= 0; + o_p <= 0; + o_mult_if.val <= 0; + o_mod_if.val <= 0; + o_mult_if.dat <= 0; + o_mod_if.dat <= 0; + i_mult_if.rdy <= 0; + i_mod_if.rdy <= 0; + eq_val <= 0; + state <= IDLE; + eq_wait <= 0; + i_p1_l <= 0; + i_p2_l <= 0; + o_err <= 0; + A <= 0; + B <= 0; + end else begin + + if (o_mult_if.rdy) o_mult_if.val <= 0; + if (o_mod_if.rdy) o_mod_if.val <= 0; + + case(state) + {IDLE}: begin + o_rdy <= 1; + eq_val <= 0; + eq_wait <= 0; + o_err <= 0; + i_mult_if.rdy <= 1; + i_p1_l <= i_p1; + i_p2_l <= i_p2; + A <= 0; + B <= 0; + if (i_val && o_rdy) begin + state <= START; + o_rdy <= 0; + // If one point is at infinity + if (i_p1.z == 0 || i_p2.z == 0) begin + state <= FINISHED; + o_val <= 1; + o_p <= (i_p1.z == 0 ? i_p2 : i_p1); + end else + // If the points are opposite each other + if ((i_p1.x == i_p2.x) && (i_p1.y != i_p2.y)) begin + state <= FINISHED; + o_val <= 1; + o_p <= 0; // Return infinity + end else + // If the points are the same this module cannot be used + if ((i_p1.x == i_p2.x) && (i_p1.y == i_p2.y)) begin + state <= FINISHED; + o_err <= 1; + o_val <= 1; + end + end + end + // Just a big if tree where we issue equations if the required inputs + // are valid + {START}: begin + i_mod_if.rdy <= 1; + i_mult_if.rdy <= 1; + + // Check any results from multiplier + if (i_mod_if.val && i_mod_if.rdy) begin + eq_val[i_mod_if.ctl] <= 1; + case(i_mod_if.ctl) + 5: o_p.x <= i_mod_if.dat; + default: o_err <= 1; + endcase + end + + // Check any results from multiplier + if (i_mult_if.val && i_mult_if.rdy) begin + eq_val[i_mult_if.ctl] <= 1; + case(i_mult_if.ctl) inside + 2: o_p.z <= i_mult_if.dat; + 3: i_p1_l.z <= i_mult_if.dat; + 4: i_p2_l.x <= i_mult_if.dat; + 6: o_p.x <= i_mult_if.dat; + 8: o_p.y <= i_mult_if.dat; + 10: o_p.y <= i_mult_if.dat; + 11: i_p1_l.y <= i_mult_if.dat; + 12: i_p2_l.y <= i_mult_if.dat; + default: o_err <= 1; + endcase + end + + // Issue new multiplies + if (eq_val[1] && ~eq_wait[2]) begin // 2. o_p.z = B * i_p1.z mod p [eq1] + multiply(2, B, i_p1_l.z); + end else + if (eq_val[2] && ~eq_wait[3]) begin // 3. i_p1.z = B * B mod p [eq2] + multiply(3, B, B); + end else + if (eq_val[0] && eq_val[5] && ~eq_wait[4]) begin // 4. i_p2.x = A * A mod p [eq0, eq5] + multiply(4, A, A); + end else + if (eq_val[3] && eq_val[5] && ~eq_wait[6]) begin // 6. o_p.x = o_p.x * i_p1.z mod p [eq5, eq3] + multiply(6, o_p.x, i_p1_l.z); + end else + if (eq_val[3] && ~eq_wait[8]) begin // 8. o_p.y = i_p1.x*i_p1.z mod p [eq3] + multiply(8, i_p1_l.x, i_p1_l.z); + end else + if (eq_val[0] && eq_val[9] && ~eq_wait[10]) begin // 10. o_p.y = o_p.y * A mod p [eq0, eq9] + multiply(10, o_p.y, A); + end else + if (eq_val[0] && eq_val[1] && eq_val[3] && ~eq_wait[11]) begin // 11. i_p2.y = B * i_p1.z mod p [eq1, eq3, eq0] + multiply(11, B, i_p1_l.z); + end else + if (eq_val[11] && ~eq_wait[12]) begin // 12. i_p2.y = i_p2.y * i_p1.y [eq11] + multiply(12, i_p1_l.y, i_p2_l.y); + end + + // Issue new modulo reductions + if (~eq_wait[5]) begin // 5. o_p.x = i_p1.x + i_p2.x mod p + modulo(5, i_p1.x + i_p2.x); + end + + // Subtractions we do in-module + if (~eq_wait[0]) begin //0. A = i_p1.y - i_p2.y mod p + A <= subtract(0, i_p1_l.y, i_p2_l.y); + end + if (~eq_wait[1]) begin //1. B = i_p1.x - i_p2.x mod p + B <= subtract(1, i_p1_l.x, i_p2_l.x); + end + if (~eq_wait[7] && eq_val[6] && eq_val[4]) begin //7. o_p.x = i_p2.x - o_p.x mod p[eq6, eq4] + o_p.x <= subtract(7, i_p2_l.x, o_p.x); + end + if (~eq_wait[9] && eq_val[3] && eq_val[7] && eq_val[8]) begin //9. o_p.y = o_p.y - o_p.x mod p [eq3, eq7, eq8] + o_p.y <= subtract(9, o_p.y, o_p.x); + end + if (~eq_wait[13] && eq_val[12] && eq_val[10]) begin //13. o_p.y = o_p.y - i_p2.y mod p [eq12, eq10] + o_p.y <= subtract(13, o_p.y, i_p2_l.y); + end + + + + if (&eq_val) begin + state <= FINISHED; + o_val <= 1; + end + end + {FINISHED}: begin + if (o_val && i_rdy) begin + state <= IDLE; + o_val <= 0; + o_rdy <= 1; + end + end + endcase + + if (o_err) begin + o_val <= 1; + if (o_val && i_rdy) begin + o_err <= 0; + state <= IDLE; + end + end + + end +end + +// Task for subtractions +function logic [255:0] subtract(input int unsigned ctl, input logic [255:0] a, b); + eq_wait[ctl] <= 1; + eq_val[ctl] <= 1; + return (a + (b > a ? secp256k1_pkg::p : 0) - b); +endfunction + + +// Task for using multiplies +task multiply(input int unsigned ctl, input logic [255:0] a, b); + if (~o_mult_if.val || (o_mult_if.val && o_mult_if.rdy)) begin + o_mult_if.val <= 1; + o_mult_if.dat[0 +: 256] <= a; + o_mult_if.dat[256 +: 256] <= b; + o_mult_if.ctl <= ctl; + eq_wait[ctl] <= 1; + end +endtask + +// Task for using modulo +task modulo(input int unsigned ctl, input logic [512:0] a); + if (~o_mod_if.val || (o_mod_if.val && o_mod_if.rdy)) begin + o_mod_if.val <= 1; + o_mod_if.dat <= a; + o_mod_if.ctl <= ctl; + eq_wait[ctl] <= 1; + end +endtask + + +endmodule \ No newline at end of file diff --git a/zcash_fpga/src/rtl/secp256k1/secp256k1_point_dbl.sv b/zcash_fpga/src/rtl/secp256k1/secp256k1_point_dbl.sv index 2bec329..f96d61b 100644 --- a/zcash_fpga/src/rtl/secp256k1/secp256k1_point_dbl.sv +++ b/zcash_fpga/src/rtl/secp256k1/secp256k1_point_dbl.sv @@ -56,7 +56,7 @@ module secp256k1_point_dbl * 9. (o_p.x) = o_p.x - E mod p [eq8, eq7] * 10 (o_p.y) = B - o_p.x mod p [eq9, eq2] * 11. (o_p.y) = D*(o_p.y) [eq10, eq6] - * 12. (o_p.y) = (o_p.y) - C mod p [eq11] + * 12. (o_p.y) = (o_p.y) - C mod p [eq11, eq4] * 13. (o_p.z) = 2*(i_p.y) mod p * 14. (o_p.z) = o_p.y * i_p.z mod p [eq14] */ @@ -66,14 +66,27 @@ logic [14:0] eq_val, eq_wait; logic [255:0] A, B, C, D, E; jb_point_t i_p_l; +always_comb begin + o_mult_if.sop = 1; + o_mult_if.eop = 1; + o_mod_if.sop = 1; + o_mod_if.eop = 1; + o_mod_if.err = 1; + o_mod_if.mod = 0; + o_mult_if.err = 1; + o_mult_if.mod = 0; +end + enum {IDLE, START, FINISHED} state; always_ff @ (posedge i_clk) begin if (i_rst) begin o_val <= 0; o_rdy <= 0; o_p <= 0; - o_mult_if.reset_source(); - o_mod_if.reset_source(); + o_mult_if.val <= 0; + o_mod_if.val <= 0; + o_mult_if.dat <= 0; + o_mod_if.dat <= 0; i_mult_if.rdy <= 0; i_mod_if.rdy <= 0; eq_val <= 0; @@ -87,10 +100,10 @@ always_ff @ (posedge i_clk) begin D <= 0; E <= 0; end else begin - if (o_mult_if.rdy) - o_mult_if.val <= 0; - if (o_mod_if.rdy) - o_mod_if.val <= 0; + + if (o_mult_if.rdy) o_mult_if.val <= 0; + if (o_mod_if.rdy) o_mod_if.val <= 0; + case(state) {IDLE}: begin o_rdy <= 1; @@ -104,12 +117,14 @@ always_ff @ (posedge i_clk) begin C <= 0; D <= 0; E <= 0; + o_val <= 0; if (i_val && o_rdy) begin state <= START; o_rdy <= 0; if (i_p.z == 0) begin - o_err <= 1; - state <= IDLE; + o_p <= i_p; + o_val <= 1; + state <= FINISHED; end end end @@ -119,7 +134,7 @@ always_ff @ (posedge i_clk) begin i_mod_if.rdy <= 1; i_mult_if.rdy <= 1; - // Check any results from multiplier + // Check any results from modulo if (i_mod_if.val && i_mod_if.rdy) begin eq_val[i_mod_if.ctl] <= 1; case(i_mod_if.ctl) @@ -190,22 +205,18 @@ always_ff @ (posedge i_clk) begin // Additions / subtractions we do in-module if (eq_val[8] && eq_val[7] && ~eq_wait[9]) begin //9. (o_p.x) = o_p.x - E mod p [eq8, eq7] - eq_wait[9] <= 1; - eq_val[9] <= 1; - o_p.x <= o_p.x + (E > o_p.x ? secp256k1_pkg::p : 0) - E; + o_p.x <= subtract(9, o_p.x, E); end if (eq_val[9] && eq_val[2] && ~eq_wait[10]) begin //10. (o_p.y) = B - o_p.x mod p [eq9, eq2] eq_wait[10] <= 1; eq_val[10] <= 1; - o_p.y <= B + (o_p.x > B ? secp256k1_pkg::p : 0) - o_p.x; + o_p.y <= subtract(10, B, o_p.x); end - if (eq_val[11] && ~eq_wait[12]) begin //12. (o_p.y) = (o_p.y) - C mod p [eq11] - eq_wait[12] <= 1; - eq_val[12] <= 1; - o_p.y <= o_p.y + (C > o_p.y ? secp256k1_pkg::p : 0) - C; + if (eq_val[4] && eq_val[11] && ~eq_wait[12]) begin //12. (o_p.y) = (o_p.y) - C mod p [eq11, eq4] + o_p.y <= subtract(12, o_p.y ,C); end if (&eq_val) begin @@ -233,6 +244,13 @@ always_ff @ (posedge i_clk) begin end end +// Task for subtractions +function logic [255:0] subtract(input int unsigned ctl, input logic [255:0] a, b); + eq_wait[ctl] <= 1; + eq_val[ctl] <= 1; + return (a + (b > a ? secp256k1_pkg::p : 0) - b); +endfunction + // Task for using multiplies task multiply(input int unsigned ctl, input logic [255:0] a, b); if (~o_mult_if.val || (o_mult_if.val && o_mult_if.rdy)) begin diff --git a/zcash_fpga/src/rtl/secp256k1/secp256k1_point_mult.sv b/zcash_fpga/src/rtl/secp256k1/secp256k1_point_mult.sv index 0d7c007..0a74059 100644 --- a/zcash_fpga/src/rtl/secp256k1/secp256k1_point_mult.sv +++ b/zcash_fpga/src/rtl/secp256k1/secp256k1_point_mult.sv @@ -35,17 +35,19 @@ module secp256k1_point_mult output logic o_err ); -if_axi_stream #(.DAT_BYTS(256*2/8), .CTL_BITS(8)) mult_in_if(i_clk); -if_axi_stream #(.DAT_BYTS(256/8), .CTL_BITS(8)) mult_out_if(i_clk); - -if_axi_stream #(.DAT_BYTS(256*2/8), .CTL_BITS(8)) mod_in_if(i_clk); -if_axi_stream #(.DAT_BYTS(256/8), .CTL_BITS(8)) mod_out_if(i_clk); +// [0] is connection from/to dbl block, [1] is add block, [2] is arbitrated value +if_axi_stream #(.DAT_BYTS(256*2/8), .CTL_BITS(8)) mult_in_if [2:0] (i_clk); +if_axi_stream #(.DAT_BYTS(256/8), .CTL_BITS(8)) mult_out_if [2:0] (i_clk); +if_axi_stream #(.DAT_BYTS(256*2/8), .CTL_BITS(8)) mod_in_if [2:0] (i_clk); +if_axi_stream #(.DAT_BYTS(256/8), .CTL_BITS(8)) mod_out_if [2:0] (i_clk); logic [255:0] k_l; -jb_point_t p_n, p_q, p_dbl; -logic p_dbl_in_val, p_dbl_in_rdy, p_dbl_out_err, p_dbl_out_val, p_dbl_out_rdy; +jb_point_t p_n, p_q, p_dbl, p_add; +logic p_dbl_in_val, p_dbl_in_rdy, p_dbl_out_err, p_dbl_out_val, p_dbl_out_rdy, p_dbl_done; +logic p_add_in_val, p_add_in_rdy, p_add_out_err, p_add_out_val, p_add_out_rdy, p_add_done; +logic special_dbl; -enum {IDLE, DOUBLE, ADD, FINISHED} state; +enum {IDLE, DOUBLE_ADD, FINISHED} state; always_ff @ (posedge i_clk) begin if (i_rst) begin @@ -56,51 +58,75 @@ always_ff @ (posedge i_clk) begin p_q <= 0; p_dbl_in_val <= 0; p_dbl_out_rdy <= 0; + p_add_in_val <= 0; + p_add_out_rdy <= 0; state <= IDLE; o_p <= 0; p_n <= 0; + p_dbl_done <= 0; + p_add_done <= 0; + special_dbl <= 0; end else begin - p_dbl_in_val <= 0; p_dbl_out_rdy <= 1; + p_add_out_rdy <= 1; case (state) {IDLE}: begin + p_dbl_done <= 1; + p_add_done <= 1; + special_dbl <= 0; o_rdy <= 1; o_err <= 0; - p_q <= {x:0, y:0, z:1}; // p_q starts at 0 + p_q <= 0; // p_q starts at 0 + p_n <= i_p; + k_l <= i_k; if (o_rdy && i_val) begin - k_l <= i_k; - p_n <= i_p; - // Regardless of i_k[0] we skip the first add since it would set p_q to i_p - if (i_k[0]) begin - p_q <= i_p; - end - state <= DOUBLE; - p_dbl_in_val <= 1; + state <= DOUBLE_ADD; end end - {DOUBLE}: begin - if(p_dbl_in_val && p_dbl_in_rdy) begin - p_dbl_in_val <= 0; - end + {DOUBLE_ADD}: begin + p_dbl_in_val <= (p_dbl_in_val && p_dbl_in_rdy) ? 0 : p_dbl_in_val; + p_add_in_val <= (p_add_in_val && p_add_in_rdy) ? 0 : p_add_in_val; if (p_dbl_out_val && p_dbl_out_rdy) begin + p_dbl_done <= 1; + if (special_dbl) begin + p_q <= p_dbl; + special_dbl <= 0; + end p_n <= p_dbl; - k_l <= k_l >> 1; - if (k_l[1] == 1) begin - state <= ADD; - end else if (k_l[255:1] == 0) begin - state <= FINISHED; - o_p <= p_dbl; - o_val <= 1; - end else begin - state <= DOUBLE; - p_dbl_in_val <= 1; - end end - end - {ADD}: begin - state <= DOUBLE; - p_q <= p_n; - p_dbl_in_val <= 1; + if (p_add_out_val && p_add_out_rdy) begin + p_add_done <= 1; + p_q <= p_add; + end + + // Update variables and issue new commands + if (p_add_done && p_dbl_done) begin + p_add_done <= 0; + p_dbl_done <= 0; + k_l <= k_l >> 1; + if (k_l[0]) begin + p_add_in_val <= 1; + // Need to check for special case where the x, y point is the same + if (p_q.x == p_n.x && p_q.y == p_n.y) begin + special_dbl <= 1; + p_add_in_val <= 0; + p_add_done <= 1; + end + end else begin + p_add_done <= 1; + end + + p_dbl_in_val <= 1; + + if (k_l == 0) begin + state <= FINISHED; + o_p <= p_add; + o_val <= 1; + p_dbl_in_val <= 0; + p_add_in_val <= 0; + end + end + end {FINISHED}: begin if (i_rdy && o_val) begin @@ -110,7 +136,7 @@ always_ff @ (posedge i_clk) begin end endcase - if (p_dbl_out_err) begin + if (p_dbl_out_err || p_add_out_err) begin o_err <= 1; o_val <= 1; state <= FINISHED; @@ -132,12 +158,90 @@ secp256k1_point_dbl secp256k1_point_dbl( .i_rdy ( p_dbl_out_rdy ), .o_val ( p_dbl_out_val ), // Interfaces to shared multipliers / modulo blocks - .o_mult_if ( mult_in_if ), - .i_mult_if ( mult_out_if ), - .o_mod_if ( mod_in_if ), - .i_mod_if ( mod_out_if ) + .o_mult_if ( mult_in_if[0] ), + .i_mult_if ( mult_out_if[0] ), + .o_mod_if ( mod_in_if[0] ), + .i_mod_if ( mod_out_if[0] ) ); +secp256k1_point_add secp256k1_point_add( + .i_clk ( i_clk ), + .i_rst ( i_rst ), + // Input points + .i_p1 ( p_q ), + .i_p2 ( p_n ), + .i_val ( p_add_in_val ), + .o_rdy ( p_add_in_rdy ), + // Output point + .o_p ( p_add ), + .o_err ( p_add_out_err ), + .i_rdy ( p_add_out_rdy ), + .o_val ( p_add_out_val ), + // Interfaces to shared multipliers / modulo blocks + .o_mult_if ( mult_in_if[1] ), + .i_mult_if ( mult_out_if[1] ), + .o_mod_if ( mod_in_if[1] ), + .i_mod_if ( mod_out_if[1] ) +); + +// We add arbitrators to these to share with the point add module +packet_arb # ( + .DAT_BYTS ( 512/8 ), + .CTL_BITS ( 8 ), + .NUM_IN ( 2 ), + .PIPELINE ( 1 ) +) +packet_arb_mult ( + .i_clk ( i_clk ), + .i_rst ( i_rst ), + .i_axi ( mult_in_if[1:0] ), + .o_axi ( mult_in_if[2] ) +); + +packet_arb # ( + .DAT_BYTS ( 512/8 ), + .CTL_BITS ( 8 ), + .NUM_IN ( 2 ), + .PIPELINE ( 1 ) +) +packet_arb_mod ( + .i_clk ( i_clk ), + .i_rst ( i_rst ), + .i_axi ( mod_in_if[1:0] ), + .o_axi ( mod_in_if[2] ) +); + +always_comb begin + mod_out_if[0].copy_if_comb(mod_out_if[2].to_struct()); + mod_out_if[1].copy_if_comb(mod_out_if[2].to_struct()); + + mod_out_if[0].ctl = {1'd0, mod_out_if[2].ctl[6:0]}; + mod_out_if[1].ctl = {1'd0, mod_out_if[2].ctl[6:0]}; + + mod_out_if[1].val = mod_out_if[2].val && mod_out_if[2].ctl[7] == 1; + mod_out_if[0].val = mod_out_if[2].val && mod_out_if[2].ctl[7] == 0; + mod_out_if[2].rdy = mod_out_if[2].ctl[7] == 0 ? mod_out_if[0].rdy : mod_out_if[1].rdy; + + mod_out_if[2].sop = 1; + mod_out_if[2].eop = 1; + mod_out_if[2].mod = 0; +end + +always_comb begin + mult_out_if[0].copy_if_comb(mult_out_if[2].to_struct()); + mult_out_if[1].copy_if_comb(mult_out_if[2].to_struct()); + + mult_out_if[0].ctl = {1'd0, mult_out_if[2].ctl[6:0]}; + mult_out_if[1].ctl = {1'd0, mult_out_if[2].ctl[6:0]}; + + mult_out_if[1].val = mult_out_if[2].val && mult_out_if[2].ctl[7] == 1; + mult_out_if[0].val = mult_out_if[2].val && mult_out_if[2].ctl[7] == 0; + mult_out_if[2].rdy = mult_out_if[2].ctl[7] == 0 ? mult_out_if[0].rdy : mult_out_if[1].rdy; + + mult_out_if[2].sop = 1; + mult_out_if[2].eop = 1; + mult_out_if[2].mod = 0; +end secp256k1_mult_mod #( .CTL_BITS ( 8 ) @@ -145,17 +249,17 @@ secp256k1_mult_mod #( secp256k1_mult_mod ( .i_clk ( i_clk ), .i_rst ( i_rst ), - .i_dat_a ( mult_in_if.dat[0 +: 256] ), - .i_dat_b ( mult_in_if.dat[256 +: 256] ), - .i_val ( mult_in_if.val ), - .i_err ( mult_in_if.err ), - .i_ctl ( mult_in_if.ctl ), - .o_rdy ( mult_in_if.rdy ), - .o_dat ( mult_out_if.dat ), - .i_rdy ( mult_out_if.rdy ), - .o_val ( mult_out_if.val ), - .o_ctl ( mult_out_if.ctl ), - .o_err ( mult_out_if.err ) + .i_dat_a ( mult_in_if[2].dat[0 +: 256] ), + .i_dat_b ( mult_in_if[2].dat[256 +: 256] ), + .i_val ( mult_in_if[2].val ), + .i_err ( mult_in_if[2].err ), + .i_ctl ( mult_in_if[2].ctl ), + .o_rdy ( mult_in_if[2].rdy ), + .o_dat ( mult_out_if[2].dat ), + .i_rdy ( mult_out_if[2].rdy ), + .o_val ( mult_out_if[2].val ), + .o_ctl ( mult_out_if[2].ctl ), + .o_err ( mult_out_if[2].err ) ); secp256k1_mod #( @@ -165,16 +269,16 @@ secp256k1_mod #( secp256k1_mod ( .i_clk( i_clk ), .i_rst( i_rst ), - .i_dat( mod_in_if.dat ), - .i_val( mod_in_if.val ), - .i_err( mod_in_if.err ), - .i_ctl( mod_in_if.ctl ), - .o_rdy( mod_in_if.rdy ), - .o_dat( mod_out_if.dat ), - .o_ctl( mod_out_if.ctl ), - .o_err( mod_out_if.err ), - .i_rdy( mod_out_if.rdy ), - .o_val( mod_out_if.val ) + .i_dat( mod_in_if[2].dat ), + .i_val( mod_in_if[2].val ), + .i_err( mod_in_if[2].err ), + .i_ctl( mod_in_if[2].ctl ), + .o_rdy( mod_in_if[2].rdy ), + .o_dat( mod_out_if[2].dat ), + .o_ctl( mod_out_if[2].ctl ), + .o_err( mod_out_if[2].err ), + .i_rdy( mod_out_if[2].rdy ), + .o_val( mod_out_if[2].val ) ); endmodule \ No newline at end of file diff --git a/zcash_fpga/src/rtl/secp256k1/secp256k1_top.sv b/zcash_fpga/src/rtl/secp256k1/secp256k1_top.sv index 48bdde7..6cb341d 100644 --- a/zcash_fpga/src/rtl/secp256k1/secp256k1_top.sv +++ b/zcash_fpga/src/rtl/secp256k1/secp256k1_top.sv @@ -49,6 +49,8 @@ logic [255:0] r, w; logic [5:0] cnt; // Counter for parsing command inputs logic if_axi_mm_rd; +logic [255:0] inv_p; + always_comb begin header = if_cmd_rx.dat; end @@ -69,6 +71,7 @@ always_ff @ (posedge i_clk) begin bin_inv_in_if.reset_source(); bin_inv_out_if.rdy <= 0; secp256k1_ver <= 0; + inv_p <= secp256k1_pkg::n; end else begin register_file_a.en <= 1; @@ -80,6 +83,7 @@ always_ff @ (posedge i_clk) begin case(secp256k1_state) {IDLE}: begin + inv_p <= secp256k1_pkg::n; secp256k1_ver <= 0; if_cmd_rx.rdy <= 1; header_l <= header; @@ -190,12 +194,12 @@ bram #( // Calculate binary inverse mod n begin: BINARY_INVERSE_MOD_N bin_inv #( - .BITS ( 256 ), - .P ( secp256k1_pkg::n ) + .BITS ( 256 ) )( .i_clk ( i_clk ), .i_rst ( i_rst) , .i_dat ( bin_inv_in_if.dat ), + .i_p ( inv_p ), .i_val ( bin_inv_in_if.val ), .o_rdy ( bin_inv_in_if.rdy ), .o_dat ( bin_inv_out_if.dat ), @@ -232,6 +236,21 @@ end // Modulo p reducer (shared with arbitrator) // Modulo n reducer (output from karatsuba multiplier) +barret_mod #( + .IN_BITS ( 512 ), + .OUT_BITS ( 256 ), + .P ( secp256k1_pkg::n ) +) +barret_mod ( + .i_clk ( i_clk ), + .i_rst ( i_rst ), + .i_dat ( in_if.dat ), + .i_val ( in_if.val ), + .o_rdy ( in_if.rdy ), + .o_dat ( out_if.dat ), + .o_val ( out_if.val ), + .i_rdy ( out_if.rdy ) +); // 256 bit Karatsuba_ofman multiplier (shared with arbitrator) diff --git a/zcash_fpga/src/tb/secp256k1_point_dbl_tb.sv b/zcash_fpga/src/tb/secp256k1_point_dbl_tb.sv index 1c568c6..4910f4f 100644 --- a/zcash_fpga/src/tb/secp256k1_point_dbl_tb.sv +++ b/zcash_fpga/src/tb/secp256k1_point_dbl_tb.sv @@ -127,7 +127,9 @@ begin logic [255:0] in_a, in_b; jb_point_t p_in, p_exp, p_out; $display("Running test_0..."); - p_in = {z:1, x:2, y:3}; + //p_in = {z:1, x:4, y:2}; + //p_in = {z:10, x:64, y:23}; + p_in = secp256k1_pkg::G_p; p_exp = dbl_jb_point(p_in); fork diff --git a/zcash_fpga/src/tb/secp256k1_point_mult_tb.sv b/zcash_fpga/src/tb/secp256k1_point_mult_tb.sv index 012e756..0373721 100644 --- a/zcash_fpga/src/tb/secp256k1_point_mult_tb.sv +++ b/zcash_fpga/src/tb/secp256k1_point_mult_tb.sv @@ -20,7 +20,7 @@ module secp256k1_point_mult_tb (); import common_pkg::*; import secp256k1_pkg::*; -localparam CLK_PERIOD = 100; +localparam CLK_PERIOD = 1000; logic clk, rst; @@ -28,7 +28,7 @@ if_axi_stream #(.DAT_BYTS(256*3/8)) in_if(clk); if_axi_stream #(.DAT_BYTS(256*3/8)) out_if(clk); jb_point_t in_p, out_p; -logic [255:0] k; +logic [255:0] k_in; always_comb begin in_p = in_if.dat; @@ -42,7 +42,7 @@ end initial begin clk = 0; - forever #CLK_PERIOD clk = ~clk; + forever #(CLK_PERIOD/2) clk = ~clk; end always_comb begin @@ -64,7 +64,7 @@ secp256k1_point_mult secp256k1_point_mult ( .i_clk ( clk ), .i_rst ( rst ), .i_p ( in_if.dat ), - .i_k ( k ), + .i_k ( k_in ), .i_val ( in_if.val ), .o_rdy ( in_if.rdy ), .o_p ( out_p ), @@ -73,46 +73,51 @@ secp256k1_point_mult secp256k1_point_mult ( .o_err ( out_if.err ) ); -task test_0(); +// Test a point +task test(input logic [255:0] k, jb_point_t p_exp); begin integer signed get_len; logic [common_pkg::MAX_SIM_BYTS*8-1:0] expected, get_dat; logic [255:0] in_a, in_b; - jb_point_t p_in, p_exp, p_out; + jb_point_t p_in, p_out; $display("Running test_0..."); - p_in = {z:1, x:2, y:3}; - k = 100; - //p_exp = dbl_jb_point(p_in); - + p_in = secp256k1_pkg::G_p; + k_in = k; fork in_if.put_stream(p_in, 256*3/8); out_if.get_stream(get_dat, get_len); join - /*p_out = get_dat; + p_out = get_dat; if (p_exp != p_out) begin $display("Expected:"); print_jb_point(p_exp); $display("Was:"); print_jb_point(p_out); - $fatal(1, "%m %t ERROR: test_0 point was wrong", $time); - end */ + $fatal(1, "%m %t ERROR: test with k=%d was wrong", $time, integer'(k)); + end - $display("test_0 PASSED"); + $display("test with k=%d PASSED", integer'(k)); end endtask; -function compare_point(); - -endfunction - initial begin out_if.rdy = 0; in_if.val = 0; #(40*CLK_PERIOD); - - test_0(); + + test(1, {x:256'h79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798, + y:256'h483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8, + z:256'h1}); + + test(2, {x:256'h7d152c041ea8e1dc2191843d1fa9db55b68f88fef695e2c791d40444b365afc2, + y:256'h56915849f52cc8f76f5fd7e4bf60db4a43bf633e1b1383f85fe89164bfadcbdb, + z:256'h9075b4ee4d4788cabb49f7f81c221151fa2f68914d0aa833388fa11ff621a970}); + + test(3, {x:256'hca90ef9b06d7eb51d650e9145e3083cbd8df8759168862036f97a358f089848, + y:256'h435afe76017b8d55d04ff8a98dd60b2ba7eb6f87f6b28182ca4493d7165dd127, + z:256'h9242fa9c0b9f23a3bfea6a0eb6dbcfcbc4853fe9a25ee948105dc66a2a9b5baa}); #1us $finish(); end