AlpacaHack Round 5 (Crypto) に参加した。日本時間で 2024-10-12 12:00 から 2024-10-12 18:00 まで。
結果は 9/247 位。
解法集
XorshiftStream
(key を hex string にしたもの) + (key と FLAG を xor したもの) を Xorshift で作ったストリームによって暗号化する (XOR)。
前半部分は平文のバイトが [0x30,0x39] または [0x61,0x66] の範囲に収まる。ここから、
- 第 7 ビットが常に 0
- 第 5 ビットが常に 1
- 第 4 ビット xor 第 6 ビットが常に 1
あたりが言えるので、乱数ストリームの該当する位置のビットがわかる。
Xorshift はビットごとに見ると GF(2) の上の線型写像になっているので、1 << 0 から 1 << 63 までの 64 通りの seed によって得られる乱数ストリームのいくつかの xor になっているはず。それを線型代数で求める。
https://doc.sagemath.org/html/ja/tutorial/afterword.html#sagepython
にあるように、xor の記号が ^^ だと知った時かなり嫌な気持ちになった。(時間を無駄にした。)
実装は Sage でやった。
output = bytes.fromhex(open("output.txt").read().strip()) K = GF(2) class XorshiftStream: def __init__(self, key: int): self.state = key % 2**64 def _next(self): self.state = (self.state ^^ (self.state << 13)) % 2**64 self.state = (self.state ^^ (self.state >> 7)) % 2**64 self.state = (self.state ^^ (self.state << 17)) % 2**64 return self.state def encrypt(self, data: bytes): ct = b"" for i in range(0, len(data), 8): pt_block = data[i : i + 8] ct += (int.from_bytes(pt_block, "little") ^^ self._next()).to_bytes( 8, "little" )[: len(pt_block)] return ct def next(x: int) -> int: x = (x ^ (x << 13)) x = (x ^ (x >> 7)) x = (x ^ (x << 17)) return x def get_synd(dat: bytes, i: int) -> list[int]: u = [] d = dat[i] # 0x3? or 0x6? u.append(d >> 7 & 1) c3a = d >> 5 & 1 u.append(c3a ^^ 1) u.append((d >> 4 ^^ d >> 6 ^^ 1) & 1) return u def main() -> None: keylen = len(output) // 3 print(f'# {keylen=}') # Collect constraints units = [] for h in range(64): tmp = XorshiftStream(1 << h) dat = tmp.encrypt(b'\x00' * (keylen * 2 + 7)) u = [] for i in range(keylen * 2): u += get_synd(dat, i) units.append(u) synd = [] for i in range(keylen * 2): synd += get_synd(output, i) # Solve the equation A = matrix(K, units) b = matrix(K, [synd]) print(f'# {A=}') print(f'# {b=}') x = A.solve_left(b) print(f'# {x=}') # Decrypt the flag seed = 0 for i in range(64): seed |= int(x[0, i]) << i xss = XorshiftStream(seed) decrypted = xss.encrypt(output) key = bytes.fromhex(decrypted[:keylen*2].decode()) flag = [decrypted[keylen*2+i] ^^ key[i] for i in range(keylen)] print(bytes(flag).decode()) if __name__ == "__main__": main()
NNNN
「n[0] = p * q, n[i] = (p + d[i]) * (q + d[i]) (1 <= i <= 3) が与えられる。ただし p, q は 768 ビットで d[i] は 192 ビット。このとき p, q, d[i] を求めよ。」という問題。
n[i] - n[0] = d[i] * (p + q) + d[i]^2 なので、Approximate GCD を使って p + q を求めれば良い。しかし単純にやると以下のような問題が発生する。
- p + q が 770 ビットになる (期待される値より 2 倍くらい大きい)
- Approximate GCD において、どの値を 0 番目として使うかによって p + q の値が異なる
これらの原因を調査したところ、Approximate GCD の値として得られる値が d[i] の 1/2 の値だった。
原因は、d[i] が常に偶数であることと、Approximate GCD が GCD としてなるべく大きい値を得ようとする (その結果戻り値が小さくなる) ことであった。
実装は Sage でやった。
from Crypto.Util.number import long_to_bytes for line in open('output.txt').readlines(): exec(line, globals()) ns = [n0, n1, n2, n3] cs = [c0, c1, c2, c3] def approx_gcd(d: list[int], approx_error: int) -> int: """ Returns q where d[0] ~= qx and d[i]'s are close to multiples of x. The caller must find d[0] // q if they want to find x. """ M = Matrix(ZZ, 3, 4) M[0, 0] = approx_error M[0, 1] = d[1] M[0, 2] = d[2] M[1, 1] = -d[0] M[2, 2] = -d[0] L = M.LLL() for row in L: if row[0] != 0: quot = abs(row[0] // approx_error) return quot def main() -> None: # Find p and q d = [n1 - n0, n2 - n0, n3 - n0] k = 2 ** 400 quot = approx_gcd(d, k) * 2 rest = d[0] - quot * quot assert rest % quot == 0 p_plus_q = rest // quot print(f'# {p_plus_q = }') d = p_plus_q^2 - 4 * n0 sqrtd = d.sqrt() assert sqrtd^2 == d p = (p_plus_q + sqrtd) // 2 q = (p_plus_q - sqrtd) // 2 assert p * q == n0 # Decrypt factors = [] for val in ns: quot = (val - n0) // p_plus_q assert val == n0 + quot * quot + quot * p_plus_q factors.append((p + quot, q + quot)) for i in range(4): (pp, qq) = factors[i] m = pow(cs[i], pow(65537, -1, (pp - 1) * (qq - 1)), ns[i]) print(long_to_bytes(m).decode('ascii'), end='') print() if __name__ == "__main__": main()
SchnorrLCG
「Schnorr 署名方式で署名と認証を行うサーバーがある。特定のメッセージの署名を偽造して受理せしめよ。」という問題。
x を秘密鍵とする。乱数 k が線形合同法 k[i+1] = a * k[i] + b (mod q) で生成される。このことを利用して、s[i] = k[i] + x * e[i] (mod q) であることから
という関係式ができる。これを LLL で解くことになる。(ECDSA に対する同じような攻撃を参考にする。)
詳細は実装に譲るが、注意点は以下。
- 最終的に得られるベクトルの各要素が big = 2^1024 程度になるようにする
- x, a は 384 ビットで xa は 768 ビットであるため、それが現れる位置の大きさが big 程度になるように係数で調整する
実装は Sage でやった。
# https://stackoverflow.com/questions/65579133/every-time-i-run-my-script-it-returns-curses-error-must-call-setupterm-firs import os os.environ['TERM'] = 'linux' os.environ['PWNLIB_NOTERM'] = '1' import sys import time import subprocess from pwn import process, remote from Crypto.Hash import SHA256 from Crypto.Util.number import long_to_bytes local = len(sys.argv) == 1 io = process(["sh", "./run.sh"]) if local else remote(sys.argv[1], int(sys.argv[2])) def get_hashcash(cmd: str) -> str: out = subprocess.check_output(cmd.split()).decode().strip() return out def fetch_sign(msg: bytes) -> tuple[int, int]: io.recvuntil(b'option> ') io.sendline(b'1') io.recvuntil(b'message(in hex)> ') io.sendline(msg.hex().encode()) io.recvuntil(b'e=') e = int(io.recvline().strip().decode()) io.recvuntil(b's=') s = int(io.recvline().strip().decode()) return e, s def find_x(es: list[tuple[int, int]], q: int) -> int: count = len(es) big = 2 ** 1024 M = Matrix(ZZ, count + 5, count + 5) for i in range(count - 1): (e, s) = es[i] (en, sn) = es[i + 1] M[count, i] = -sn * big M[count + 1, i] = s * big M[count + 2, i] = big M[count + 3, i] = en * big M[count + 4, i] = -e * big M[count, count] = big M[count + 1, count + 1] = big // (2 ** 384) M[count + 2, count + 2] = big // (2 ** 384) M[count + 3, count + 3] = big // (2 ** 384) M[count + 4, count + 4] = big // (2 ** 768) for i in range(count): M[i, i] = q * big L = M.LLL() for row in L: if abs(row[count]) != big: continue coef = row[count] // big x = row[count + 3] // M[count + 3, count + 3] // coef print(f'# {x = }') print(f'# {x.bit_length() = }') break else: raise ValueError('x not found') return x def _hash(message: bytes, r: int, q: int): hash_res = SHA256.new(message + long_to_bytes(r)) return int(hash_res.hexdigest(), 16) % q def forge_sign(message: bytes, x: int, g: int, p: int) -> tuple[int, int]: k = 1 q = (p - 1) // 2 r = pow(g, k, p) # r = g^k mod p e = _hash(message, r, q) # e = H(m || r) s = (k + x * e) % q # s = (k + x * e) mod q return (e, s) def main() -> None: start = time.time() io.recvuntil(b'running the following command:') io.recvline() cmd = io.recvline().strip().decode() io.recvuntil(b'hashcash token: ') io.sendline(get_hashcash(cmd).encode()) print(f'# ({time.time() - start:.2f}s) hashcash token sent') io.recvuntil(b'p=') p = int(io.recvline().strip().decode()) io.recvuntil(b'g=') g = int(io.recvline().strip().decode()) io.recvuntil(b'pub_key=') pub_key = int(io.recvline().strip().decode()) q = (p - 1) // 2 count = 5 # collect es = [] for _ in range(count): (e, s) = fetch_sign(b'koba') es.append((e, s)) print(f'# ({time.time() - start:.2f}s) signatures collected') # solve x = find_x(es, q) print(f'# ({time.time() - start:.2f}s) x found') assert pow(g, x, p) == pub_key # forge + submit target_msg = b'give me flag' (e, s) = forge_sign(target_msg, x, g, p) io.recvuntil(b'option> ') io.sendline(b'2') io.recvuntil(b'message(in hex)> ') io.sendline(target_msg.hex().encode()) io.recvuntil(b'e> ') io.sendline(str(e).encode()) io.recvuntil(b's> ') io.sendline(str(s).encode()) io.recvline() io.recvuntil(b'Here is your flag: ') print(io.recvline().decode().strip()) if __name__ == "__main__": main()
まとめ
反省点は
(i) 基本的な道具 (線型代数、Approximate GCD) に対する理解不十分
(ii) NTRU に対するリサーチ不足
(iii) 実装力の衰え
あたりだと思われる。
単純に典型知識を適用するだけでは解けず、中身の理解を要求するという点で、問題の質はかなり良かったと思われる。まさに実装力不足で全完できなかったのが悔やまれる。
あと Sage 祭り、LLL 祭りだった気がする