SECCON CTF 13 Quals writeup

SECCON CTF 13 Quals にチーム ZK Lovers で参加した。日本時間で 2024-11-23 14:00 から 2024-11-24 14:00 まで。

結果は 35/653 位。日本のチーム内では 8/303 位だったため、決勝に進出した。

解法集

crypto

xiyi

解説は https://zenn.dev/sigma425/articles/826180135a39cb でやってもらったのでコードだけ。
コードは以下の通り。以下の点が記事とは違うことに注意。

  • 離散対数を取るところで、\log_{-2}(s^yg^r) を計算している。
  • s は s\equiv 1 \pmod{p^2}, s\equiv -2 \pmod{q} を満たす。これによって、 mod p^2 で出る結果が -r に、mod q で出る結果が y-r になる。

presolve.py (パラメーターの算出)

from Crypto.Util.number import isPrime


def order_prim(g: int, p: int, n: int) -> int:
    g = pow(g, (n - 1) // p, n)
    if g == 1:
        return 1
    return p


def order_log(g: int, n: int, prod_list: list[int]):
    ans = 1
    for p in prod_list:
        ans *= order_prim(g, p, n)
    return ans


def main() -> None:
    prod = 1
    prod_list = []

    for i in range(2, 374):
        if not isPrime(i):
            continue
        prod *= i
        prod_list.append(i)

    ps = []
    for i in range(1, 1000000):
        p = prod * i + 1
        if not isPrime(i):
            continue
        if p.bit_length() != 518:
            continue
        if isPrime(p):
            ps.append((p, i))
            if len(ps) == 2:
                break

    for index in range(2):
        p, i = ps[index]
        tmp = prod_list[:]
        tmp.append(i)
        print(f'p{index} = ' + ' * '.join(map(str, tmp)) + ' + 1')
        print(f'prod_list{index} =', tmp)
        print(f'assert p{index}.bit_length() == 518')
        print(f'assert isPrime(p{index})')
        o = order_log(p - 2, p, tmp)
        print(f'o{index} = {o}')

    print('gcd_o = math.gcd(o0, o1)')
    print('assert gcd_o.bit_length() >= 256')
    print('''
gcd_o_factors = []
for p in prod_list0:
    if gcd_o % p == 0:
        gcd_o_factors.append(p)
assert functools.reduce(lambda x, y: x * y, gcd_o_factors) == gcd_o''')

if __name__ == '__main__':
    main()

solve.py

import math
import time
import functools
import json
import sys
from Crypto.Util.number import isPrime
from pwn import remote, process
from params import L


p0 = 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 29 * 31 * 37 * 41 * 43 * 47 * 53 * 59 * 61 * 67 * 71 * 73 * 79 * 83 * 89 * 97 * 101 * 103 * 107 * 109 * 113 * 127 * 131 * 137 * 139 * 149 * 151 * 157 * 163 * 167 * 173 * 179 * 181 * 191 * 193 * 197 * 199 * 211 * 223 * 227 * 229 * 233 * 239 * 241 * 251 * 257 * 263 * 269 * 271 * 277 * 281 * 283 * 293 * 307 * 311 * 313 * 317 * 331 * 337 * 347 * 349 * 353 * 359 * 367 * 373 * 95111 + 1
prod_list0 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 95111]
assert p0.bit_length() == 518
assert isPrime(p0)
o0 = 36661539024568114673630848676502229358837161250611654971980007790086057602635278619098631925083991929859953297191541221964453055358660192576524314518690
p1 = 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 29 * 31 * 37 * 41 * 43 * 47 * 53 * 59 * 61 * 67 * 71 * 73 * 79 * 83 * 89 * 97 * 101 * 103 * 107 * 109 * 113 * 127 * 131 * 137 * 139 * 149 * 151 * 157 * 163 * 167 * 173 * 179 * 181 * 191 * 193 * 197 * 199 * 211 * 223 * 227 * 229 * 233 * 239 * 241 * 251 * 257 * 263 * 269 * 271 * 277 * 281 * 283 * 293 * 307 * 311 * 313 * 317 * 331 * 337 * 347 * 349 * 353 * 359 * 367 * 373 * 95471 + 1
prod_list1 = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 95471]
assert p1.bit_length() == 518
assert isPrime(p1)
o1 = 433176388095566017443503977303018864439997415658581630574274185147818012525806889798787919014260344688923985302972314688075228224036485629700737413953932390
gcd_o = math.gcd(o0, o1)
assert gcd_o.bit_length() >= 256

gcd_o_factors = []
for p in prod_list0:
    if gcd_o % p == 0:
        gcd_o_factors.append(p)
assert functools.reduce(lambda x, y: x * y, gcd_o_factors) == gcd_o


def crt(a0: int, mo0: int, a1: int, mo1: int) -> int:
    return (a0 * mo1 * pow(mo1, -1, mo0) + a1 * mo0 * pow(mo0, -1, mo1)) % (mo0 * mo1)


def disc_log_prim(g: int, h: int, factor: int, order: int, n: int) -> int | None:
    g = pow(g, order // factor, n)
    h = pow(h, order // factor, n)
    cur = 1
    for i in range(factor):
        if cur == h:
            return i
        cur = cur * g % n
    return None

def disc_log(g: int, h: int, n: int, prod_list: list[int]) -> tuple[int, int]:
    order = functools.reduce(lambda x, y: x * y, prod_list)
    assert pow(g, order, n) == 1
    assert order == gcd_o
    a = 0
    mo = 1
    for factor in prod_list:
        res = disc_log_prim(g, h, factor, order, n)
        assert res is not None, f'{g = }, {h = }, {factor = }, {order = }, {n = }'
        a = crt(a, mo, res, factor)
        mo *= factor
    return (a, mo)


local = len(sys.argv) == 1
io = process(['python3', 'server.py']) if local else remote(sys.argv[1], int(sys.argv[2]))


def main() -> None:
    start = time.time()
    # initialize
    n = p0 * p0 * p1
    x = crt(1, p0 * p0, p1 - 2, p1)
    assert x % (p0 * p0) == 1
    assert x % p1 == p1 - 2
    enc_xs = [x] * L

    # 1: (client) --- n, enc_xs ---> (server)
    io.sendlineafter(b"> ", json.dumps({"n": n, "enc_xs": enc_xs}).encode())

    # 3: (server) --- enc_alphas, beta_sum_mod_n ---> (client)
    params = json.loads(io.recvline().strip().decode())
    enc_alphas = params["enc_alphas"]
    ys = []
    base_0 = -2
    exp_0 = 1
    base_1 = -2
    exp_1 = 1
    for p in prod_list0:
        if p in gcd_o_factors:
            continue
        assert (p0 - 1) % p == 0
        if pow(-2, (p0 - 1) // p, p0) != 1:
            base_0 = pow(base_0, p, p0)
            exp_0 *= p
    base_0 = pow(base_0, p0, p0 * p0)
    exp_0 *= p0
    for p in prod_list1:
        if p in gcd_o_factors:
            continue
        if pow(-2, (p1 - 1) // p, p1) != 1:
            base_1 = pow(base_1, p, p1)
            exp_1 *= p
    assert pow(base_0, gcd_o, p0) == 1
    assert pow(base_0, gcd_o, p0 * p0) == 1
    assert pow(base_1, gcd_o, p1) == 1
    for p in gcd_o_factors:
        assert pow(base_0, gcd_o // p, p0) != 1
        assert pow(base_0, gcd_o // p, p0 * p0) != 1
        assert pow(base_1, gcd_o // p, p1) != 1

    for i in range(L):
        disc_0 = disc_log(base_0, pow(enc_alphas[i], exp_0, p0 * p0), p0 * p0, gcd_o_factors)  # -r
        disc_1 = disc_log(base_1, pow(enc_alphas[i], exp_1, p1), p1, gcd_o_factors) # y - r
        print(f'# ({time.time() - start:.2f}s) {i = }, {disc_0 = }, {disc_1 = }')
        y = (disc_1[0] - disc_0[0]) % gcd_o
        ys.append(y)

    # If, by any chance, you can guess ys, send it for the flag!
    print(f'{ys = }')
    io.sendlineafter(b"> ", json.dumps({"ys": ys, "p": p0, "q": p1}).encode())
    print(io.recvline().strip().decode())  # Congratz! or Wrong...
    print(io.recvline().strip().decode())  # flag or ys


if __name__ == "__main__":
    main()

まとめ

xiyi を解いてギャンブルに大勝ち。