ISITDTU CTF QUALS 2024 writeup

ISITDTU CTF QUALS 2024 にチーム ZK cha で参加した。日本時間で 2024-10-26 11:00 から 2024-10-27 19:00 まで。

結果は 95/315 位。
SpookyCTF 2024 で散々な目にあったから、SpookyCTF 2024 の競技中にもかかわらずこちらに移行した。

解法集

crypto

ShareMixer1

素数 p が決められる。flag を 1 つ含み他はランダムな係数を持つ、次数 32 の多項式 cs をジャッジが決めるので、長さ 256 以下の数列 xs を送れ。それの各値を代入した結果をシャッフルして返す。」という問題。

xs の中の頻度を調節する (たとえばある数は 1 個だけ入れ、ある数は 2 個入れ、…) と、特定の数を代入した結果が分かる。
これを利用して頻度を [(1, 3), (2, 3), (3, 2), (4, 2), (5, 2), (6, 2), (7, 2), (8, 2), (9, 2), (10, 2), (11, 2), (12, 2), (13, 1), (14, 1), (15, 1), (16, 1), (17, 1), (18, 1)] にすることで、組み合わせを 9216 通り試せば良くなる。
(この列は (頻度, その頻度をもつ数の個数) の列で、同じ頻度をもつ数は全ての順列を試す必要があるので、3! * 3! * 2^10 = 9216 通り)

なお、この問題では PoW を実行することを求められるが、配布されるソースコードにはそれが記載されておらず、リモートサーバーに接続した時に初めて分かる仕様になっていた。

pow.py は以下の通り。

#!/usr/bin/env python3
"""
Copied and modified from https://github.com/balsn/proof-of-work/blob/master/solver/python3.py
"""
import hashlib
import sys


difficulty = 24
zeros = '0' * difficulty


def is_valid(digest):
    if sys.version_info.major == 2:
        digest = [ord(i) for i in digest]
    bits = ''.join(bin(i)[2:].zfill(8) for i in digest)
    return bits[:difficulty] == zeros


def find(prefix: str) -> str:
    i = 0
    while True:
        i += 1
        s = prefix + str(i)
        if is_valid(hashlib.sha256(s.encode()).digest()):
            return str(i)

また solve.sage は以下の通り。

# https://stackoverflow.com/questions/65579133/every-time-i-run-my-script-it-returns-curses-error-must-call-setupterm-firs
import os
os.environ['TERM'] = 'linux'
os.environ['PWNLIB_NOTERM'] = '1'


import sys
import time
import itertools
import pow
from pwn import remote, process
from Crypto.Util.number import long_to_bytes, getPrime


local = len(sys.argv) == 1
io = process(["python3", "chall.py"]) if local else remote(sys.argv[1], int(sys.argv[2]))


l = 32
getPrime(256)


def decouple_multiset(multiset: list[int]) -> dict[int, list[int]]:
    """
    freq -> [val1, val2, ...]
    """
    decoupled_freqs = {}
    for val in multiset:
        if val not in decoupled_freqs:
            decoupled_freqs[val] = 0
        decoupled_freqs[val] += 1
    ret = {}
    for k in decoupled_freqs:
        v = decoupled_freqs[k]
        if v not in ret:
            ret[v] = []
        ret[v].append(k)
    return ret


def solve_lagrange(p: int, assoc: list[tuple[int, int]]) -> list[int]:
    K = GF(p)
    R = PolynomialRing(K, 'x')
    assert len(assoc) == l
    f = R.lagrange_polynomial(assoc)
    return f.list()


def dfs(keys: list[int], start: float, count: list[int], assoc: list[tuple[int, int]], p: int, xs: dict[int, list[int]], shares: dict[int, list[int]]) -> bytes | None:
    count[0] += 1
    if count[0] % 10000 == 0:
        print(f'# ({time.time() - start:.2f}s) count: {count[0]}')
    if len(keys) == 0:
        res = solve_lagrange(p, assoc)
        for v in res:
            bs = long_to_bytes(int(v))
            if bs.startswith(b"ISITDTU{"):
                return bs
        return
    xlist = xs[keys[0]]
    sharelist = shares[keys[0]]
    assert len(xlist) == len(sharelist)
    seq = range(0, len(xlist))
    for perm in itertools.permutations(seq):
        for i, index in enumerate(perm):
            assoc.append((xlist[i], sharelist[index]))
        ret = dfs(keys[1:], start, count, assoc, p, xs, shares)
        if ret is not None:
            return ret
        for _ in seq:
            assoc.pop()
    return None


def share_mixer_decipher(p: int, start: float, xs: dict[int, list[int]], shares: dict[int, list[int]]) -> bytes:
    keys = list(xs)
    count = [0]
    res = dfs(keys, start, count, [], p, xs, shares)
    if res is None:
        raise ValueError("Failed to find the flag")
    return res


def main() -> None:
    start = time.time()
    if not local:
        io.recvuntil(b'Send a suffix that:')
        io.recvline()
        problem = io.recvline().strip().decode()
        prefix = problem.split('"')[1]
        print(f'# ({time.time() - start:.2f}s) {prefix = }')
        suffix = pow.find(prefix)
        print(f'# ({time.time() - start:.2f}s) {suffix = }')
        io.recvuntil(b"Suffix: ")
        io.sendline(suffix.encode())
    io.recvuntil(b"p = ")
    p = int(io.recvline().strip().decode())
    io.recvuntil(b"Gib me the queries: ")
    xs = [1, 2, 3, 4, 4, 5, 5, 6, 6] + sum(([x] * max((x - 1) // 2, 1) for x in range(7, 33)), []) + [28, 29, 30, 30, 31, 31, 32, 32, 32]
    print(f'# {len(xs) = }')
    io.sendline(" ".join(map(str, xs)).encode())
    io.recvuntil(b"shares = ")
    shares = io.recvline().strip().decode()[1:-1].split(", ")
    shares = list(map(int, shares))
    xs = decouple_multiset(xs)
    shares = decouple_multiset(shares)
    print(f'# ({time.time() - start:.2f}s) shares obtained')
    print(f'# length_distrib: {[(v, len(xs[v])) for v in xs]}')
    print(f'# #combinations: {reduce(lambda x, y: x * y, (len(xs[v]) for v in xs))}')
    print(f'# {p = }')
    flag = share_mixer_decipher(p, start, xs, shares)
    print(flag.decode())


if __name__ == "__main__":
    main()
ShareMixer2

素数 p が決められる。flag を 1 つ含み他はランダムな係数を持つ、次数 32 の多項式 cs をジャッジが決めるので、長さ 32 以下の数列 xs を送れ。それの各値を代入した結果をシャッフルして返す。」という問題。
ShareMixer1 に比べて xs の長さ制限が短くなった代わりに、PoW を要求されなくなった。

p-1 が 32 の倍数であれば mod p で 1 の 32 乗根が存在するようになり、さらに xs としてそれらを与えると戻ってきた値の合計が 32 * cs[0] になる。
何回も接続して 1 の 32 乗根が存在するようになるまで (32 | p-1 が成立するまで) 待ち、さらに 1 の 32 乗根を xs として与えて cs[0] を取得して、flag が cs[0] に来るまでガチャを回す。試行回数の期待値は 32 * 32 = 1024 回。

# https://stackoverflow.com/questions/65579133/every-time-i-run-my-script-it-returns-curses-error-must-call-setupterm-firs
import os
os.environ['TERM'] = 'linux'
os.environ['PWNLIB_NOTERM'] = '1'


import sys
import time
from pwn import remote, process, context
from Crypto.Util.number import long_to_bytes


context.log_level = 'error'
local = len(sys.argv) == 1
def get_io():
    return process(["python3", "chall.py"]) if local else remote(sys.argv[1], int(sys.argv[2]))


l = 32


def try_one(start: float) -> None:
    io = get_io()
    while True:
        io.recvuntil(b"p = ")
        p = int(io.recvline().strip().decode())
        if (p - 1) % 32 == 0:
            break
        io.close()
        io = get_io()
    K = GF(p)
    g = K.multiplicative_generator()
    base_l = g ** ((p - 1) // l)
    io.recvuntil(b"Gib me the queries: ")
    xs = [int(base_l ** i) for i in range(l)]
    print(f'# {len(xs) = }')
    io.sendline(" ".join(map(str, xs)).encode())
    io.recvuntil(b"shares = ")
    shares = io.recvline().strip().decode()[1:-1].split(", ")
    shares = list(map(int, shares))
    print(f'# ({time.time() - start:.2f}s) shares obtained')
    io.close()
    cs0 = sum(shares) * pow(l, -1, p) % p
    flag = long_to_bytes(cs0)
    if flag.startswith(b'ISITDTU{'):
        print(flag.decode())
        return flag.decode()


def main() -> None:
    start = time.time()
    count = 0
    while True:
        print(f'# ({time.time() - start:.2f}s) trial {count}')
        res = try_one(start)
        if res is not None:
            print(res)
            return
        count += 1


if __name__ == "__main__":
    main()
Sign

競技終了後に解いた。
「n が未知な状態で PKCS#1 v1.5 形式でランダムなデータの署名を返すオラクルと、フラグに対して pow(flag, d, n) を返すオラクルが与えられる。flag を特定せよ。」という問題。

PKCS#1 v1.5 形式の署名は 0x1ff....ffXXXX (XXXX は対象のハッシュ値など) の形の整数を d 乗したものになることに注意すると、署名の e 乗の差分はほとんど n の倍数 (差は高々 257 ビット整数程度) であることに着目する。
Approximate GCD をやれば良い。

# https://stackoverflow.com/questions/65579133/every-time-i-run-my-script-it-returns-curses-error-must-call-setupterm-firs
import os
os.environ['TERM'] = 'linux'
os.environ['PWNLIB_NOTERM'] = '1'


import sys
import time
from pwn import remote, process
from Crypto.Util.number import long_to_bytes


local = len(sys.argv) == 1
io = process(["python3", "chall.py"]) if local else remote(sys.argv[1], int(sys.argv[2]))


def approx_gcd(d: list[int], approx_error: int) -> int:
    """
    Returns q where d[0] ~= qx and d[i]'s are close to multiples of x.
    The caller must find (d[0] + q // 2) // q if they want to find x.
    """
    l = len(d)
    M = Matrix(ZZ, l, l)
    M[0, 0] = approx_error
    for i in range(1, l):
        M[0, i] = d[i]
        M[i, i] = -d[0]
    L = M.LLL()
    for row in L:
        if row[0] != 0:
            quot = abs(row[0] // approx_error)
            return quot


def get_random_sig() -> int:
    io.recvuntil(b'> ')
    io.sendline(b'1')
    io.recvuntil(b'sig = ')
    return int(io.recvline().strip().decode(), 16)


def get_flag_sig() -> int:
    io.recvuntil(b'> ')
    io.sendline(b'2')
    io.recvuntil(b'sig = ')
    return int(io.recvline().strip().decode(), 16)


def main() -> None:
    start = time.time()
    e = 11
    count = 14
    sigs: list[int] = []
    for _ in range(count):
        sigs.append(get_random_sig())
    print(f'# ({time.time() - start:.2f}s) {sigs[0].bit_length() = }')
    diff = [abs(sigs[i]**e - sigs[i - 1]**e) for i in range(1, count)]
    q = approx_gcd(diff, 2**256)
    n = (diff[0] + q // 2) // q
    print(f'# ({time.time() - start:.2f}s) {n.bit_length() = }')
    print(f'# ({time.time() - start:.2f}s) {hex(pow(sigs[0], e, n)) = }')
    fs = get_flag_sig()
    flag = long_to_bytes(pow(fs, e, n))
    index = flag.find(b'ISITDTU')
    print(flag[index:].decode())


if __name__ == "__main__":
    main()

(2024-11-01 23:36 修正: n を求めるところで四捨五入の代わりに間違えて切り捨てをしてしまっていたので、修正した。)

まとめ

SpookyCTF 2024 よりはマシだったがこれもちょっと不親切なところがあった。