AMM_algorithm

0x01 How come?

昨天拿到一道奇怪的RSA题目。乍一看很简单,pq都给出来了,仔细分析有问题:phi%e居然等于0!?顿时觉得有点意思。

0x02 Analysis

之前讲过,求逆的条件是互素,因此这道题显然求不了d,于是只能求解方程: 将该式化为 分别求x后,用CRT组合一下就可得到mod n内的解。

这道题特殊之处在于,。对于这一特殊类型的求模根,存在一种特殊的算法:Adleman-Manders-Miller rth Root Extraction Method


Algorithm: Adleman-Manders-Miller rth Root Extraction Method

Data: and a th residue , .

Result: A th root of .

  • Step 1: Choose uniformly at random from .

  • Step 2: If , go to Step 1.

  • Step 3: Compute such that .

    Compute the least connegative integer such that .

    Change ,

  • Step 4: For to

    compute

    if , change ,

    else (compute the discrete logarithm)

    change

    end for

  • Step 5: return


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import random
import time

# About 3 seconds to run
def AMM(o, r, q):
start = time.time()
print('\n----------------------------------------------------------------------------------')
print('Start to run Adleman-Manders-Miller Root Extraction Method')
print('Try to find one {:#x}th root of {} modulo {}'.format(r, o, q))
g = GF(q)
o = g(o)
p = g(random.randint(1, q))
while p ^ ((q-1) // r) == 1:
p = g(random.randint(1, q))
print('[+] Find p:{}'.format(p))
t = 1
s = q - 1
while s % r == 0:
t += 1
s = s // r
print('[+] Find s:{}, t:{}'.format(s, t))

k = 1
while (k * s + 1) % r != 0:
k += 1
alp = (k * s + 1) // r
print('[+] Find alp:{}'.format(alp))
a = p ^ (r**(t-1) * s)
b = o ^ (r*alp - 1)
c = p ^ s
h = 1
for i in range(1, t):
d = b ^ (r^(t-1-i))
if d == 1:
j = 0
else:
print('[+] Calculating DLP...')
j = - discrete_log(d, a)
print('[+] Finish DLP...')
b = b * (c^r)^j
h = h * c^j
c = c^r
result = o^alp * h
end = time.time()
print("Finished in {} seconds.".format(end - start))
print('Find one solution: {}'.format(result))
return result

实际上,sympy库集成了AMM算法: nthroot_mod(a,n,p)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def nthroot_mod(a, n, p, all_roots=False):
"""
Find the solutions to ``x**n = a mod p``

Parameters
==========

a : integer
n : positive integer
p : positive integer
all_roots : if False returns the smallest root, else the list of roots

Examples
========

>>> from sympy.ntheory.residue_ntheory import nthroot_mod
>>> nthroot_mod(11, 4, 19)
8
>>> nthroot_mod(11, 4, 19, True)
[8, 11]
>>> nthroot_mod(68, 3, 109)
23
"""

这个算法能求出的一个根。但对于开次方根,实际上最多会有个根。那么如何找到其他根呢?

0x03 primitive r root of 1

参考这个链接,我们不难得出结论:将AMM算法求得的根,分别乘以 的所有根,就能得到所有的个解。求得1的根的过程如下:

  • 随机选择内的数
  • 计算
  • 即是其中的一个解。
  • 重复上述过程直到解的数量到为止。

结合AMM算法,我们成功得到了满足 的所有,他们各含个解。将这两组解两两合成crt,再经过一定验证算法,就能得到正确的明文了。

0x04 exp

回顾整个过程,攻击流程如下:

  • 先用Adleman-Manders-Miller rth Root Extraction MethodGF(p)GF(q)上对ce次根,分别得到一个解。
  • 然后去找到所有的0x1336primitive nth root of 1,乘以上面那个解,得到所有的0x1337个解。
  • 再用CRTGF(p)GF(q)上的两组0x1337个解组合成mod n下的解,可以得到0x1337**2==24196561mod n的解。最后能通过check的即为flag

exp1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import random
import time

# About 3 seconds to run
def AMM(o, r, q):
start = time.time()
print('\n----------------------------------------------------------------------------------')
print('Start to run Adleman-Manders-Miller Root Extraction Method')
print('Try to find one {:#x}th root of {} modulo {}'.format(r, o, q))
g = GF(q)
o = g(o)
p = g(random.randint(1, q))
while p ^ ((q-1) // r) == 1:
p = g(random.randint(1, q))
print('[+] Find p:{}'.format(p))
t = 1
s = q - 1
while s % r == 0:
t += 1
s = s // r
print('[+] Find s:{}, t:{}'.format(s, t))

k = 1
while (k * s + 1) % r != 0:
k += 1
alp = (k * s + 1) // r
print('[+] Find alp:{}'.format(alp))
a = p ^ (r**(t-1) * s)
b = o ^ (r*alp - 1)
c = p ^ s
h = 1
for i in range(1, t):
d = b ^ (r^(t-1-i))
if d == 1:
j = 0
else:
print('[+] Calculating DLP...')
j = - discrete_log(d, a)
print('[+] Finish DLP...')
b = b * (c^r)^j
h = h * c^j
c = c^r
result = o^alp * h
end = time.time()
print("Finished in {} seconds.".format(end - start))
print('Find one solution: {}'.format(result))
return result
p= 12408795636519868275579286477747181009018504169827579387457997229774738126230652970860811085539129972962189443268046963335610845404214331426857155412988073
q= 12190036856294802286447270376342375357864587534233715766210874702670724440751066267168907565322961270655972226761426182258587581206888580394726683112820379
c= 68960610962019321576894097705679955071402844421318149418040507036722717269530195000135979777852568744281930839319120003106023209276898286482202725287026853925179071583797231099755287410760748104635674307266042492611618076506037004587354018148812584502385622631122387857218023049204722123597067641896169655595
e=65537
cp = c % p
cq = c % q
mp = int(AMM(cp, e, p))
mq = int(AMM(cq, e, q))
print('mp=',mp)
print('mq=',mq)

exp2:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

import random
import time
import gmpy2 as gp
from Crypto.Util.number import *
c = 10562302690541901187975815594605242014385201583329309191736952454310803387032252007244962585846519762051885640856082157060593829013572592812958261432327975138581784360302599265408134332094134880789013207382277849503344042487389850373487656200657856862096900860792273206447552132458430989534820256156021128891296387414689693952047302604774923411425863612316726417214819110981605912408620996068520823370069362751149060142640529571400977787330956486849449005402750224992048562898004309319577192693315658275912449198365737965570035264841782399978307388920681068646219895287752359564029778568376881425070363592696751183359
p = 199138677823743837339927520157607820029746574557746549094921488292877226509198315016018919385259781238148402833316033634968163276198999279327827901879426429664674358844084491830543271625147280950273934405879341438429171453002453838897458102128836690385604150324972907981960626767679153125735677417397078196059
q = 112213695905472142415221444515326532320352429478341683352811183503269676555434601229013679319423878238944956830244386653674413411658696751173844443394608246716053086226910581400528167848306119179879115809778793093611381764939789057524575349501163689452810148280625226541609383166347879832134495444706697124741
e = 0x1337
cp = c % p
cq = c % q
mp= 136784290804277183072700480012548374152759378257011909034887899431169564849889825231722097177532308859891247525446862463516714084156176213423817683313381104257495877452349025654729725084057737068700329218272100110079053344486570311487112533045352589682271379218190591008056852730068565778374497299633096500583
mq= 80786575232147566895029557320394831470404570012083703265978992757988977676053279231922614362269752289759581069078476928974891541910097315125063831847133444149619061499072328063992127358184272740474176891250834314811798181109024856754976329843311618683772106042618526497158475956938905731342153265365708634577


def CRT(aList, mList):
M = 1
for i in mList:
M = M * i # 计算M = ∏ mi
# print(M)
x = 0
for i in range(len(mList)):
Mi = M // mList[i] # 计算Mi
Mi_inverse = gp.invert(Mi, mList[i]) # 计算Mi的逆元
x += aList[i] * Mi * Mi_inverse # 构造x各项
x = x % M
return x

def findAllPRoot(p, e):

print("Start to find all the Primitive {:#x}th root of 1 modulo {}.".format(e, p))
start = time.time()
proot = set()
while len(proot) < e:
proot.add(pow(random.randint(2, p-1), (p-1)//e, p))

print(len(proot))
end = time.time()
print("Finished in {} seconds.".format(end - start))
return proot

def findAllSolutions(mp, proot, cp, p):

print("Start to find all the {:#x}th root of {} modulo {}.".format(e, cp, p))
start = time.time()
all_mp = set()
for root in proot:
mp2 = mp * root % p
#assert(pow(mp2, e, p) == cp)
if pow(mp2,e,p) != cp:
print('wrong pow!!')
break
exit()
all_mp.add(mp2)
end = time.time()
print("Finished in {} seconds.".format(end - start))
return all_mp


p_proot = findAllPRoot(p, e)
q_proot = findAllPRoot(q, e)
mps = findAllSolutions(mp, p_proot, cp, p)
mqs = findAllSolutions(mq, q_proot, cq, q)
print ('mps=',mps)
print('mqs=',mqs)

def check(m):
if len(str(m)) == 1:
return False
if b'NCTF' in long_to_bytes(m):
print(long_to_bytes(m))
return True
else:
return False



start = time.time()
print('Start CRT...')
for mpp in mps:
for mqq in mqs:
solution = int(CRT([int(mpp), int(mqq)], [p, q]))
if check(solution):
print(solution)
print(time.time() - start)

end = time.time()
print("Finished in {} seconds.".format(end - start))
# NCTF{T4k31ng_Ox1337_r00t_1s_n0t_th4t_34sy}

运行exp需要一定时间(主要花在求1的根以及crt上)。但不用很久就能跑出结果。

0x05 Revisit

这次不经意间的学习机会来源于同学问的一道题。事实上,我完成这道题的过程非常失败,因为经过一天的学习(其中包含漫长的读paper,调试代码过程)后,我发现一个问题:这道题完全不用花费这么多事情来搞!因为题目的特殊设置,一个sympy库就能彻底搞定。所以虽然学到了新的算法,但在比赛中更重要的是:直接挪用现成的库总比自己写要好。

当然,大佬博客中的库一定要慎用,其中的很多问题不一定是大佬算法出问题,而是版本或各种其他因素造成的,这使得现成的轮子不能正常使用。平常积累些成熟的库非常有用,比如这道题的crt函数就是我之前写的。