#!/usr/bin/env python3 # -*- coding: utf8 -*- import sys; assert sys.version_info[0] >= 3, "Python 3 required." from sapling_jubjub import FieldElement from utils import leos2ip q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001 p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001 qm1d2 = 0x2000000000000000000000000000000011234c7e04ca546ec623759080000000 assert (q - 1) // 2 == qm1d2 S = 32 T = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb21 assert (q - 1) == (1 << S) * T tm1d2 = 0x2000000000000000000000000000000011234c7e04ca546ec6237590 assert (T - 1) // 2 == tm1d2 # 5^T (mod q) ROOT_OF_UNITY = 0x2de6a9b8746d3f589e5c4dfd492ae26e9bb97ea3c106f049a70e2c1102b6d05f # # Field arithmetic # class Fq(FieldElement): @staticmethod def from_bytes(buf): return Fq(leos2ip(buf), strict=True) def random(rand): while True: try: return Fq(leos2ip(rand.b(32)), strict=True) except ValueError: pass def __init__(self, s, strict=False): FieldElement.__init__(self, Fq, s, q, strict=strict) def __str__(self): return 'Fq(%s)' % self.s def sgn0(self): # https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-10#section-4.1 return (self.s % 2) == 1 def sqrt(self): # Tonelli-Shank's algorithm for p mod 16 = 1 # https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5) a = self.exp(qm1d2) if a == self.ONE: # z <- c^t c = Fq(ROOT_OF_UNITY) # x <- a \omega x = self.exp(tm1d2 + 1) # b <- x \omega = a \omega^2 b = self.exp(T) y = S # 7: while b != 1 do while b != self.ONE: # 8: Find least integer k >= 0 such that b^(2^k) == 1 k = 1 b2k = b * b while b2k != self.ONE: b2k = b2k * b2k k += 1 assert k < y # 9: # w <- z^(2^(y-k-1)) for _ in range(0, y - k - 1): c = c * c # x <- xw x = x * c # z <- w^2 c = c * c # b <- bz b = b * c # y <- k y = k assert x * x == self return x elif a == self.MINUS_ONE: return None return self.ZERO class Scalar(FieldElement): def __init__(self, s, strict=False): FieldElement.__init__(self, Scalar, s, p, strict=strict) def __str__(self): return 'Scalar(%s)' % self.s @staticmethod def from_bytes(buf): return Scalar(leos2ip(buf), strict=True) def random(rand): while True: try: return Scalar(leos2ip(rand.b(32)), strict=True) except ValueError: pass for F in (Fq, Scalar): F.ZERO = F(0) F.ONE = F(1) F.MINUS_ONE = F(-1) assert F.ZERO + F.ZERO == F.ZERO assert F.ZERO + F.ONE == F.ONE assert F.ONE + F.ZERO == F.ONE assert F.ZERO - F.ONE == F.MINUS_ONE assert F.ZERO * F.ONE == F.ZERO assert F.ONE * F.ZERO == F.ZERO # # Point arithmetic # VESTA_B = Fq(5) class Point(object): @staticmethod def rand(rand): while True: data = rand.b(32) p = Point.from_bytes(data) if p is not None: return p @staticmethod def from_bytes(buf): assert len(buf) == 32 if buf == bytes([0]*32): return Point.identity() y_sign = buf[31] >> 7 buf = buf[:31] + bytes([buf[31] & 0b01111111]) try: x = Fq.from_bytes(buf) except ValueError: return None x3 = x * x * x y2 = x3 + VESTA_B y = y2.sqrt() if y is None: return None if y.s % 2 != y_sign: y = Fq.ZERO - y return Point(x, y) def __init__(self, x, y, is_identity=False): self.x = x self.y = y self.is_identity = is_identity if is_identity: assert self.x == Fq.ZERO assert self.y == Fq.ZERO else: assert self.y * self.y == self.x * self.x * self.x + VESTA_B def identity(): p = Point(Fq.ZERO, Fq.ZERO, True) return p def __neg__(self): if self.is_identity: return self else: return Point(Fq(self.x.s), -Fq(self.y.s)) def __add__(self, a): if self.is_identity: return a elif a.is_identity: return self else: (x1, y1) = (self.x, self.y) (x2, y2) = (a.x, a.y) if x1 != x2: # section 4.1 λ = (y1 - y2) / (x1 - x2) x3 = λ*λ - x1 - x2 y3 = λ*(x1 - x3) - y1 return Point(x3, y3) elif y1 == -y2: return Point.identity() else: return self.double() def checked_incomplete_add(self, a): assert self != a assert self != -a assert self != Point.identity() assert a != Point.identity() return self + a def __sub__(self, a): return (-a) + self def double(self): if self.is_identity: return self # section 4.1 λ = (Fq(3) * self.x * self.x) / (self.y + self.y) x = λ*λ - self.x - self.x y = λ*(self.x - x) - self.y return Point(x, y) def extract(self): if self.is_identity: return Fq.ZERO return self.x def __mul__(self, s): assert isinstance(s, Scalar) s = format(s.s, '0256b') ret = self.ZERO for c in s: ret = ret.double() if int(c): ret = ret + self return ret def __bytes__(self): if self.is_identity: return bytes([0] * 32) buf = bytes(self.x) if self.y.s % 2 == 1: buf = buf[:31] + bytes([buf[31] | (1 << 7)]) return buf def __eq__(self, a): if a is None: return False if not (self.is_identity or a.is_identity): return self.x == a.x and self.y == a.y else: return self.is_identity == a.is_identity def __str__(self): if self.is_identity: return 'Point(identity)' else: return 'Point(%s, %s)' % (self.x, self.y) Point.ZERO = Point.identity() Point.GENERATOR = Point(Fq.MINUS_ONE, Fq(2)) assert Point.ZERO + Point.ZERO == Point.ZERO assert Point.GENERATOR - Point.GENERATOR == Point.ZERO assert Point.GENERATOR + Point.GENERATOR + Point.GENERATOR == Point.GENERATOR * Scalar(3) assert Point.GENERATOR + Point.GENERATOR - Point.GENERATOR == Point.GENERATOR assert Point.from_bytes(bytes([0]*32)) == Point.ZERO assert Point.from_bytes(bytes(Point.GENERATOR)) == Point.GENERATOR