AlpacaHack Round 5 (Crypto) writeup

AlpacaHack Round 5 (Crypto) に参加した。日本時間で 2024-10-12 12:00 から 2024-10-12 18:00 まで。

結果は 9/247 位

解法集

XorshiftStream

(key を hex string にしたもの) + (key と FLAG を xor したもの) を Xorshift で作ったストリームによって暗号化する (XOR)。

前半部分は平文のバイトが [0x30,0x39] または [0x61,0x66] の範囲に収まる。ここから、

  • 第 7 ビットが常に 0
  • 第 5 ビットが常に 1
  • 第 4 ビット xor 第 6 ビットが常に 1

あたりが言えるので、乱数ストリームの該当する位置のビットがわかる。
Xorshift はビットごとに見ると GF(2) の上の線型写像になっているので、1 << 0 から 1 << 63 までの 64 通りの seed によって得られる乱数ストリームのいくつかの xor になっているはず。それを線型代数で求める。

https://doc.sagemath.org/html/ja/tutorial/afterword.html#sagepython
にあるように、xor の記号が ^^ だと知った時かなり嫌な気持ちになった。(時間を無駄にした。)

実装は Sage でやった。

output = bytes.fromhex(open("output.txt").read().strip())
K = GF(2)


class XorshiftStream:
    def __init__(self, key: int):
        self.state = key % 2**64

    def _next(self):
        self.state = (self.state ^^ (self.state << 13)) % 2**64
        self.state = (self.state ^^ (self.state >> 7)) % 2**64
        self.state = (self.state ^^ (self.state << 17)) % 2**64
        return self.state

    def encrypt(self, data: bytes):
        ct = b""
        for i in range(0, len(data), 8):
            pt_block = data[i : i + 8]
            ct += (int.from_bytes(pt_block, "little") ^^ self._next()).to_bytes(
                8, "little"
            )[: len(pt_block)]
        return ct


def next(x: int) -> int:
    x = (x ^ (x << 13))
    x = (x ^ (x >> 7))
    x = (x ^ (x << 17))
    return x


def get_synd(dat: bytes, i: int) -> list[int]:
    u = []
    d = dat[i]
    # 0x3? or 0x6?
    u.append(d >> 7 & 1)
    c3a = d >> 5 & 1
    u.append(c3a ^^ 1)
    u.append((d >> 4 ^^ d >> 6 ^^ 1) & 1)
    return u


def main() -> None:
    keylen = len(output) // 3
    print(f'# {keylen=}')

    # Collect constraints
    units = []
    for h in range(64):
        tmp = XorshiftStream(1 << h)
        dat = tmp.encrypt(b'\x00' * (keylen * 2 + 7))
        u = []
        for i in range(keylen * 2):
            u += get_synd(dat, i)
        units.append(u)
    synd = []
    for i in range(keylen * 2):
        synd += get_synd(output, i)

    # Solve the equation
    A = matrix(K, units)
    b = matrix(K, [synd])
    print(f'# {A=}')
    print(f'# {b=}')
    x = A.solve_left(b)
    print(f'# {x=}')

    # Decrypt the flag
    seed = 0
    for i in range(64):
        seed |= int(x[0, i]) << i
    xss = XorshiftStream(seed)
    decrypted = xss.encrypt(output)
    key = bytes.fromhex(decrypted[:keylen*2].decode())
    flag = [decrypted[keylen*2+i] ^^ key[i] for i in range(keylen)]
    print(bytes(flag).decode())


if __name__ == "__main__":
    main()

NNNN

「n[0] = p * q, n[i] = (p + d[i]) * (q + d[i]) (1 <= i <= 3) が与えられる。ただし p, q は 768 ビットで d[i] は 192 ビット。このとき p, q, d[i] を求めよ。」という問題。

n[i] - n[0] = d[i] * (p + q) + d[i]^2 なので、Approximate GCD を使って p + q を求めれば良い。しかし単純にやると以下のような問題が発生する。

  • p + q が 770 ビットになる (期待される値より 2 倍くらい大きい)
  • Approximate GCD において、どの値を 0 番目として使うかによって p + q の値が異なる

これらの原因を調査したところ、Approximate GCD の値として得られる値が d[i] の 1/2 の値だった。
原因は、d[i] が常に偶数であることと、Approximate GCD が GCD としてなるべく大きい値を得ようとする (その結果戻り値が小さくなる) ことであった。

実装は Sage でやった。

from Crypto.Util.number import long_to_bytes


for line in open('output.txt').readlines():
    exec(line, globals())
ns = [n0, n1, n2, n3]
cs = [c0, c1, c2, c3]


def approx_gcd(d: list[int], approx_error: int) -> int:
    """
    Returns q where d[0] ~= qx and d[i]'s are close to multiples of x.
    The caller must find d[0] // q if they want to find x.
    """
    M = Matrix(ZZ, 3, 4)
    M[0, 0] = approx_error
    M[0, 1] = d[1]
    M[0, 2] = d[2]
    M[1, 1] = -d[0]
    M[2, 2] = -d[0]
    L = M.LLL()
    for row in L:
        if row[0] != 0:
            quot = abs(row[0] // approx_error)
            return quot


def main() -> None:
    # Find p and q
    d = [n1 - n0, n2 - n0, n3 - n0]
    k = 2 ** 400
    quot = approx_gcd(d, k) * 2
    rest = d[0] - quot * quot
    assert rest % quot == 0
    p_plus_q = rest // quot
    print(f'# {p_plus_q = }')
    d = p_plus_q^2 - 4 * n0
    sqrtd = d.sqrt()
    assert sqrtd^2 == d
    p = (p_plus_q + sqrtd) // 2
    q = (p_plus_q - sqrtd) // 2
    assert p * q == n0

    # Decrypt
    factors = []
    for val in ns:
        quot = (val - n0) // p_plus_q
        assert val == n0 + quot * quot + quot * p_plus_q
        factors.append((p + quot, q + quot))
    for i in range(4):
        (pp, qq) = factors[i]
        m = pow(cs[i], pow(65537, -1, (pp - 1) * (qq - 1)), ns[i])
        print(long_to_bytes(m).decode('ascii'), end='')
    print()


if __name__ == "__main__":
    main()

SchnorrLCG

「Schnorr 署名方式で署名と認証を行うサーバーがある。特定のメッセージの署名を偽造して受理せしめよ。」という問題。

x を秘密鍵とする。乱数 k が線形合同法 k[i+1] = a * k[i] + b (mod q) で生成される。このことを利用して、s[i] = k[i] + x * e[i] (mod q) であることから
s_{i+1} - as_i \equiv b + xe_{i+1} - xae_i \pmod q
という関係式ができる。これを LLL で解くことになる。(ECDSA に対する同じような攻撃を参考にする。)

詳細は実装に譲るが、注意点は以下。

  • 最終的に得られるベクトルの各要素が big = 2^1024 程度になるようにする
    • x, a は 384 ビットで xa は 768 ビットであるため、それが現れる位置の大きさが big 程度になるように係数で調整する

実装は Sage でやった。

# https://stackoverflow.com/questions/65579133/every-time-i-run-my-script-it-returns-curses-error-must-call-setupterm-firs
import os
os.environ['TERM'] = 'linux'
os.environ['PWNLIB_NOTERM'] = '1'


import sys
import time
import subprocess
from pwn import process, remote
from Crypto.Hash import SHA256
from Crypto.Util.number import long_to_bytes


local = len(sys.argv) == 1
io = process(["sh", "./run.sh"]) if local else remote(sys.argv[1], int(sys.argv[2]))


def get_hashcash(cmd: str) -> str:
    out = subprocess.check_output(cmd.split()).decode().strip()
    return out


def fetch_sign(msg: bytes) -> tuple[int, int]:
    io.recvuntil(b'option> ')
    io.sendline(b'1')
    io.recvuntil(b'message(in hex)> ')
    io.sendline(msg.hex().encode())
    io.recvuntil(b'e=')
    e = int(io.recvline().strip().decode())
    io.recvuntil(b's=')
    s = int(io.recvline().strip().decode())
    return e, s


def find_x(es: list[tuple[int, int]], q: int) -> int:
    count = len(es)
    big = 2 ** 1024
    M = Matrix(ZZ, count + 5, count + 5)
    for i in range(count - 1):
        (e, s) = es[i]
        (en, sn) = es[i + 1]
        M[count, i] = -sn * big
        M[count + 1, i] = s * big
        M[count + 2, i] = big
        M[count + 3, i] = en * big
        M[count + 4, i] = -e * big
    M[count, count] = big
    M[count + 1, count + 1] = big // (2 ** 384)
    M[count + 2, count + 2] = big // (2 ** 384)
    M[count + 3, count + 3] = big // (2 ** 384)
    M[count + 4, count + 4] = big // (2 ** 768)
    for i in range(count):
        M[i, i] = q * big
    L = M.LLL()
    for row in L:
        if abs(row[count]) != big:
            continue
        coef = row[count] // big
        x = row[count + 3] // M[count + 3, count + 3] // coef
        print(f'# {x = }')
        print(f'# {x.bit_length() = }')
        break
    else:
        raise ValueError('x not found')
    return x


def _hash(message: bytes, r: int, q: int):
    hash_res = SHA256.new(message + long_to_bytes(r))
    return int(hash_res.hexdigest(), 16) % q


def forge_sign(message: bytes, x: int, g: int, p: int) -> tuple[int, int]:
    k = 1
    q = (p - 1) // 2
    r = pow(g, k, p)  # r = g^k mod p
    e = _hash(message, r, q)  # e = H(m || r)
    s = (k + x * e) % q  # s = (k + x * e) mod q
    return (e, s)


def main() -> None:
    start = time.time()
    io.recvuntil(b'running the following command:')
    io.recvline()
    cmd = io.recvline().strip().decode()
    io.recvuntil(b'hashcash token: ')
    io.sendline(get_hashcash(cmd).encode())
    print(f'# ({time.time() - start:.2f}s) hashcash token sent')

    io.recvuntil(b'p=')
    p = int(io.recvline().strip().decode())
    io.recvuntil(b'g=')
    g = int(io.recvline().strip().decode())
    io.recvuntil(b'pub_key=')
    pub_key = int(io.recvline().strip().decode())
    q = (p - 1) // 2

    count = 5

    # collect
    es = []
    for _ in range(count):
        (e, s) = fetch_sign(b'koba')
        es.append((e, s))
    print(f'# ({time.time() - start:.2f}s) signatures collected')

    # solve
    x = find_x(es, q)
    print(f'# ({time.time() - start:.2f}s) x found')
    assert pow(g, x, p) == pub_key

    # forge + submit
    target_msg = b'give me flag'
    (e, s) = forge_sign(target_msg, x, g, p)
    io.recvuntil(b'option> ')
    io.sendline(b'2')
    io.recvuntil(b'message(in hex)> ')
    io.sendline(target_msg.hex().encode())
    io.recvuntil(b'e> ')
    io.sendline(str(e).encode())
    io.recvuntil(b's> ')
    io.sendline(str(s).encode())
    io.recvline()
    io.recvuntil(b'Here is your flag: ')
    print(io.recvline().decode().strip())


if __name__ == "__main__":
    main()

まとめ

反省点は
(i) 基本的な道具 (線型代数、Approximate GCD) に対する理解不十分
(ii) NTRU に対するリサーチ不足
(iii) 実装力の衰え
あたりだと思われる。

単純に典型知識を適用するだけでは解けず、中身の理解を要求するという点で、問題の質はかなり良かったと思われる。まさに実装力不足で全完できなかったのが悔やまれる。

あと Sage 祭り、LLL 祭りだった気がする