diff --git a/src/ringct/bulletproofs.cc b/src/ringct/bulletproofs.cc index 381f50872..0e5b3b55f 100644 --- a/src/ringct/bulletproofs.cc +++ b/src/ringct/bulletproofs.cc @@ -127,15 +127,6 @@ static void sub_acc_p3(ge_p3 *acc_p3, const rct::key &point) ge_p1p1_to_p3(acc_p3, &p1); } -static rct::key scalarmultKey(const ge_p3 &P, const rct::key &a) -{ - ge_p2 R; - ge_scalarmult(&R, a.bytes, &P); - rct::key aP; - ge_tobytes(aP.bytes, &R); - return aP; -} - static rct::key get_exponent(const rct::key &base, size_t idx) { static const std::string salt("bulletproof"); @@ -193,23 +184,28 @@ static rct::key vector_exponent(const rct::keyV &a, const rct::keyV &b) } /* Compute a custom vector-scalar commitment */ -static rct::key vector_exponent_custom(const rct::keyV &A, const rct::keyV &B, const rct::keyV &a, const rct::keyV &b) +static rct::key cross_vector_exponent8(size_t size, const std::vector &A, size_t Ao, const std::vector &B, size_t Bo, const rct::keyV &a, size_t ao, const rct::keyV &b, size_t bo, const ge_p3 *extra_point, const rct::key *extra_scalar) { - CHECK_AND_ASSERT_THROW_MES(A.size() == B.size(), "Incompatible sizes of A and B"); - CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b"); - CHECK_AND_ASSERT_THROW_MES(a.size() == A.size(), "Incompatible sizes of a and A"); - CHECK_AND_ASSERT_THROW_MES(a.size() <= maxN*maxM, "Incompatible sizes of a and maxN"); + CHECK_AND_ASSERT_THROW_MES(size + Ao <= A.size(), "Incompatible size for A"); + CHECK_AND_ASSERT_THROW_MES(size + Bo <= B.size(), "Incompatible size for B"); + CHECK_AND_ASSERT_THROW_MES(size + ao <= a.size(), "Incompatible size for a"); + CHECK_AND_ASSERT_THROW_MES(size + bo <= b.size(), "Incompatible size for b"); + CHECK_AND_ASSERT_THROW_MES(size <= maxN*maxM, "size is too large"); + CHECK_AND_ASSERT_THROW_MES(!!extra_point == !!extra_scalar, "only one of extra point/scalar present"); std::vector multiexp_data; - multiexp_data.reserve(a.size()*2); - for (size_t i = 0; i < a.size(); ++i) + multiexp_data.resize(size*2 + (!!extra_point)); + for (size_t i = 0; i < size; ++i) { - multiexp_data.resize(multiexp_data.size() + 1); - multiexp_data.back().scalar = a[i]; - CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, A[i].bytes) == 0, "ge_frombytes_vartime failed"); - multiexp_data.resize(multiexp_data.size() + 1); - multiexp_data.back().scalar = b[i]; - CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, B[i].bytes) == 0, "ge_frombytes_vartime failed"); + sc_mul(multiexp_data[i*2].scalar.bytes, a[ao+i].bytes, INV_EIGHT.bytes);; + multiexp_data[i*2].point = A[Ao+i]; + sc_mul(multiexp_data[i*2+1].scalar.bytes, b[bo+i].bytes, INV_EIGHT.bytes); + multiexp_data[i*2+1].point = B[Bo+i]; + } + if (extra_point) + { + sc_mul(multiexp_data.back().scalar.bytes, extra_scalar->bytes, INV_EIGHT.bytes); + multiexp_data.back().point = *extra_point; } return multiexp(multiexp_data, false); } @@ -273,16 +269,19 @@ static rct::keyV hadamard(const rct::keyV &a, const rct::keyV &b) return res; } -/* Given two curvepoint arrays, construct the Hadamard product */ -static rct::keyV hadamard2(const rct::keyV &a, const rct::keyV &b) +/* folds a curvepoint array using a two way scaled Hadamard product */ +static void hadamard_fold(std::vector &v, const rct::key &a, const rct::key &b) { - CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b"); - rct::keyV res(a.size()); - for (size_t i = 0; i < a.size(); ++i) + CHECK_AND_ASSERT_THROW_MES((v.size() & 1) == 0, "Vector size should be even"); + const size_t sz = v.size() / 2; + for (size_t n = 0; n < sz; ++n) { - rct::addKeys(res[i], a[i], b[i]); + ge_dsmp c[2]; + ge_dsm_precomp(c[0], &v[n]); + ge_dsm_precomp(c[1], &v[sz + n]); + ge_double_scalarmult_precomp_vartime2_p3(&v[n], a.bytes, c[0], b.bytes, c[1]); } - return res; + v.resize(sz); } /* Add two vectors */ @@ -326,17 +325,6 @@ static rct::keyV vector_dup(const rct::key &x, size_t N) return rct::keyV(N, x); } -/* Exponentiate a curve vector by a scalar */ -static rct::keyV vector_scalar2(const rct::keyV &a, const rct::key &x) -{ - rct::keyV res(a.size()); - for (size_t i = 0; i < a.size(); ++i) - { - rct::scalarmultKey(res[i], a[i], x); - } - return res; -} - /* Get the sum of a vector's elements */ static rct::key vector_sum(const rct::keyV &a) { @@ -620,16 +608,16 @@ try_again: // These are used in the inner product rounds size_t nprime = N; - rct::keyV Gprime(N); - rct::keyV Hprime(N); + std::vector Gprime(N); + std::vector Hprime(N); rct::keyV aprime(N); rct::keyV bprime(N); const rct::key yinv = invert(y); rct::key yinvpow = rct::identity(); for (size_t i = 0; i < N; ++i) { - Gprime[i] = Gi[i]; - Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow); + Gprime[i] = Gi_p3[i]; + ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]); sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); aprime[i] = l[i]; bprime[i] = r[i]; @@ -652,14 +640,10 @@ try_again: rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); // PAPER LINES 18-19 - L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size())); sc_mul(tmp.bytes, cL.bytes, x_ip.bytes); - rct::addKeys(L[round], L[round], rct::scalarmultH(tmp)); - L[round] = rct::scalarmultKey(L[round], INV_EIGHT); - R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); + L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp); sc_mul(tmp.bytes, cR.bytes, x_ip.bytes); - rct::addKeys(R[round], R[round], rct::scalarmultH(tmp)); - R[round] = rct::scalarmultKey(R[round], INV_EIGHT); + R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp); // PAPER LINES 21-22 w[round] = hash_cache_mash(hash_cache, L[round], R[round]); @@ -672,8 +656,11 @@ try_again: // PAPER LINES 24-25 const rct::key winv = invert(w[round]); - Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round])); - Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv)); + if (nprime > 1) + { + hadamard_fold(Gprime, winv, w[round]); + hadamard_fold(Hprime, w[round], winv); + } // PAPER LINES 28-29 aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv)); @@ -914,16 +901,16 @@ try_again: // These are used in the inner product rounds size_t nprime = MN; - rct::keyV Gprime(MN); - rct::keyV Hprime(MN); + std::vector Gprime(MN); + std::vector Hprime(MN); rct::keyV aprime(MN); rct::keyV bprime(MN); const rct::key yinv = invert(y); rct::key yinvpow = rct::identity(); for (size_t i = 0; i < MN; ++i) { - Gprime[i] = Gi[i]; - Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow); + Gprime[i] = Gi_p3[i]; + ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]); sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); aprime[i] = l[i]; bprime[i] = r[i]; @@ -942,18 +929,18 @@ try_again: nprime /= 2; // PAPER LINES 16-17 + PERF_TIMER_START_BP(PROVE_inner_product); rct::key cL = inner_product(slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size())); rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); + PERF_TIMER_STOP(PROVE_inner_product); // PAPER LINES 18-19 - L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size())); + PERF_TIMER_START_BP(PROVE_LR); sc_mul(tmp.bytes, cL.bytes, x_ip.bytes); - rct::addKeys(L[round], L[round], rct::scalarmultH(tmp)); - L[round] = rct::scalarmultKey(L[round], INV_EIGHT); - R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); + L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp); sc_mul(tmp.bytes, cR.bytes, x_ip.bytes); - rct::addKeys(R[round], R[round], rct::scalarmultH(tmp)); - R[round] = rct::scalarmultKey(R[round], INV_EIGHT); + R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp); + PERF_TIMER_STOP(PROVE_LR); // PAPER LINES 21-22 w[round] = hash_cache_mash(hash_cache, L[round], R[round]); @@ -966,12 +953,19 @@ try_again: // PAPER LINES 24-25 const rct::key winv = invert(w[round]); - Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round])); - Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv)); + if (nprime > 1) + { + PERF_TIMER_START_BP(PROVE_hadamard2); + hadamard_fold(Gprime, winv, w[round]); + hadamard_fold(Hprime, w[round], winv); + PERF_TIMER_STOP(PROVE_hadamard2); + } // PAPER LINES 28-29 + PERF_TIMER_START_BP(PROVE_prime); aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv)); bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round])); + PERF_TIMER_STOP(PROVE_prime); ++round; }