xiyi
解説は https://zenn.dev/sigma425/articles/826180135a39cb でやってもらったのでコードだけ。
コードは以下の通り。以下の点が記事とは違うことに注意。
- 離散対数を取るところで、
を計算している。
- s は
を満たす。これによって、 mod p^2 で出る結果が -r に、mod q で出る結果が y-r になる。
presolve.py (パラメーターの算出)
from Crypto.Util.number import isPrime
def order_prim(g: int, p: int, n: int) -> int:
g = pow(g, (n - 1) // p, n)
if g == 1:
return 1
return p
def order_log(g: int, n: int, prod_list: list[int]):
ans = 1
for p in prod_list:
ans *= order_prim(g, p, n)
return ans
def main() -> None:
prod = 1
prod_list = []
for i in range(2, 374):
if not isPrime(i):
continue
prod *= i
prod_list.append(i)
ps = []
for i in range(1, 1000000):
p = prod * i + 1
if not isPrime(i):
continue
if p.bit_length() != 518:
continue
if isPrime(p):
ps.append((p, i))
if len(ps) == 2:
break
for index in range(2):
p, i = ps[index]
tmp = prod_list[:]
tmp.append(i)
print(f'p{index} = ' + ' * '.join(map(str, tmp)) + ' + 1')
print(f'prod_list{index} =', tmp)
print(f'assert p{index}.bit_length() == 518')
print(f'assert isPrime(p{index})')
o = order_log(p - 2, p, tmp)
print(f'o{index} = {o}')
print('gcd_o = math.gcd(o0, o1)')
print('assert gcd_o.bit_length() >= 256')
print('''
gcd_o_factors = []
for p in prod_list0:
if gcd_o % p == 0:
gcd_o_factors.append(p)
assert functools.reduce(lambda x, y: x * y, gcd_o_factors) == gcd_o''')
if __name__ == '__main__':
main()
solve.py
import math
import time
import functools
import json
import sys
from Crypto.Util.number import isPrime
from pwn import remote, process
from params import L
p0 = 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 29 * 31 * 37 * 41 * 43 * 47 * 53 * 59 * 61 * 67 * 71 * 73 * 79 * 83 * 89 * 97 * 101 * 103 * 107 * 109 * 113 * 127 * 131 * 137 * 139 * 149 * 151 * 157 * 163 * 167 * 173 * 179 * 181 * 191 * 193 * 197 * 199 * 211 * 223 * 227 * 229 * 233 * 239 * 241 * 251 * 257 * 263 * 269 * 271 * 277 * 281 * 283 * 293 * 307 * 311 * 313 * 317 * 331 * 337 * 347 * 349 * 353 * 359 * 367 * 373 * 95111 + 1
prod_list0 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 95111]
assert p0.bit_length() == 518
assert isPrime(p0)
o0 = 36661539024568114673630848676502229358837161250611654971980007790086057602635278619098631925083991929859953297191541221964453055358660192576524314518690
p1 = 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 29 * 31 * 37 * 41 * 43 * 47 * 53 * 59 * 61 * 67 * 71 * 73 * 79 * 83 * 89 * 97 * 101 * 103 * 107 * 109 * 113 * 127 * 131 * 137 * 139 * 149 * 151 * 157 * 163 * 167 * 173 * 179 * 181 * 191 * 193 * 197 * 199 * 211 * 223 * 227 * 229 * 233 * 239 * 241 * 251 * 257 * 263 * 269 * 271 * 277 * 281 * 283 * 293 * 307 * 311 * 313 * 317 * 331 * 337 * 347 * 349 * 353 * 359 * 367 * 373 * 95471 + 1
prod_list1 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 95471]
assert p1.bit_length() == 518
assert isPrime(p1)
o1 = 433176388095566017443503977303018864439997415658581630574274185147818012525806889798787919014260344688923985302972314688075228224036485629700737413953932390
gcd_o = math.gcd(o0, o1)
assert gcd_o.bit_length() >= 256
gcd_o_factors = []
for p in prod_list0:
if gcd_o % p == 0:
gcd_o_factors.append(p)
assert functools.reduce(lambda x, y: x * y, gcd_o_factors) == gcd_o
def crt(a0: int, mo0: int, a1: int, mo1: int) -> int:
return (a0 * mo1 * pow(mo1, -1, mo0) + a1 * mo0 * pow(mo0, -1, mo1)) % (mo0 * mo1)
def disc_log_prim(g: int, h: int, factor: int, order: int, n: int) -> int | None:
g = pow(g, order // factor, n)
h = pow(h, order // factor, n)
cur = 1
for i in range(factor):
if cur == h:
return i
cur = cur * g % n
return None
def disc_log(g: int, h: int, n: int, prod_list: list[int]) -> tuple[int, int]:
order = functools.reduce(lambda x, y: x * y, prod_list)
assert pow(g, order, n) == 1
assert order == gcd_o
a = 0
mo = 1
for factor in prod_list:
res = disc_log_prim(g, h, factor, order, n)
assert res is not None, f'{g = }, {h = }, {factor = }, {order = }, {n = }'
a = crt(a, mo, res, factor)
mo *= factor
return (a, mo)
local = len(sys.argv) == 1
io = process(['python3', 'server.py']) if local else remote(sys.argv[1], int(sys.argv[2]))
def main() -> None:
start = time.time()
n = p0 * p0 * p1
x = crt(1, p0 * p0, p1 - 2, p1)
assert x % (p0 * p0) == 1
assert x % p1 == p1 - 2
enc_xs = [x] * L
io.sendlineafter(b"> ", json.dumps({"n": n, "enc_xs": enc_xs}).encode())
params = json.loads(io.recvline().strip().decode())
enc_alphas = params["enc_alphas"]
ys = []
base_0 = -2
exp_0 = 1
base_1 = -2
exp_1 = 1
for p in prod_list0:
if p in gcd_o_factors:
continue
assert (p0 - 1) % p == 0
if pow(-2, (p0 - 1) // p, p0) != 1:
base_0 = pow(base_0, p, p0)
exp_0 *= p
base_0 = pow(base_0, p0, p0 * p0)
exp_0 *= p0
for p in prod_list1:
if p in gcd_o_factors:
continue
if pow(-2, (p1 - 1) // p, p1) != 1:
base_1 = pow(base_1, p, p1)
exp_1 *= p
assert pow(base_0, gcd_o, p0) == 1
assert pow(base_0, gcd_o, p0 * p0) == 1
assert pow(base_1, gcd_o, p1) == 1
for p in gcd_o_factors:
assert pow(base_0, gcd_o // p, p0) != 1
assert pow(base_0, gcd_o // p, p0 * p0) != 1
assert pow(base_1, gcd_o // p, p1) != 1
for i in range(L):
disc_0 = disc_log(base_0, pow(enc_alphas[i], exp_0, p0 * p0), p0 * p0, gcd_o_factors)
disc_1 = disc_log(base_1, pow(enc_alphas[i], exp_1, p1), p1, gcd_o_factors)
print(f'# ({time.time() - start:.2f}s) {i = }, {disc_0 = }, {disc_1 = }')
y = (disc_1[0] - disc_0[0]) % gcd_o
ys.append(y)
print(f'{ys = }')
io.sendlineafter(b"> ", json.dumps({"ys": ys, "p": p0, "q": p1}).encode())
print(io.recvline().strip().decode())
print(io.recvline().strip().decode())
if __name__ == "__main__":
main()