我们可以考虑更一般的问题:
给定一个随机数生成器randN
, 这个方法可以生成 [1,N] 范围内的随机整数,现在我们想要使用这个生成器,来得到另一个随机数生成器randM
,即能够生成 [1,M] 范围内的随机整数
预处理
首先,我们可以假设 N\geq M , 如果 N<M 的话,那么我们就可以投 k=\lceil \log_NM\rceil 次来得到一个 [1, N^k] 的随机数生成器,这样问题就变成了N\geq M的形式。
我们记[1, N^k]的随机数生成器为randNk
, 注意到[1, N^k]由区间[0, N^k-1]平移一位得到, 而区间[1, N^k]可以视作一个 k 位 N 进制,这样我们就可以将randN
转换为randNk
, 代码如下:
def randN(N):
return random.randint(1, N)
def randNk(N: int, k: int):
x = 0
for i in range(k):
x += (randN(N) - 1) * N**i
return x + 1
接下来,我们就可以使用拒绝采样(reject sampling)由randNk
生成randM
.
算法1
最简单的拒绝采样算法就是
- 从
randN
中采样k=\lceil \log_NM\rceil次,将结果相乘得到 x - 如果 x\leq M , 那么采样完成;否则,跳到第一步重新采样
写成python代码就是:
def randM_from_randN1(N, M):
k = math.ceil(math.log(M, N))
while True:
x = randNk(N, k)
if x <= M:
return x
我们可以来分析一下调用的次数, 记调用次数为 C , 注意到C是一个随机变量,其期望为
\mathbb{E}[C] = k*1+k\frac{N^k-M}{N^k}+\cdots+k\left(\frac{N^k-M}{N^k}\right)^i=k\sum_{i=0}^{\infty}\left(\frac{N^k-M}{N^k}\right)^i=\frac{kN^k}{M}
这里 k*1 代表我们至少要调用randN
k 次来得到一个结果,第二项代表如果我们第一次失败了(以概率 (N^k-M)/N^k ),则需要再调用 k 次。。。
当 N=6 , M=7 时,我们平均需要 \lceil \log_67\rceil*6^2/7\approx 10.28 次才能成功,这显然是不能接受的。
算法2
算法1的问题是,我们抛弃了太多的值,当 N=6 , M=7 时,我们每轮拒绝的比例是 (6^2-7)/6^2=29/36 , 因此我们需要改进.
注意到区间 [1,kM] 的随机数可以变成区间 [1,M] 的随机数,这只要进行一个简单的模运算就可以了:对 x\in [1,kM] , 我们有 x\ \mathrm{mod}\ M\in[0, M-1] . 这样我们就利用了更多的数,还是以 , M=7为例,对于 [1, 28] 中的数,我们都接受,然后使用模运算得到 [0, 6] 中的均匀随机数,这样就减少了平均调用次数,代码如下:
def randM_from_randN2(N, M):
k = math.ceil(math.log(M, N))
r = N**k // M
while True:
x = randNk(N, k)
if x <= r * M:
return (x - 1) % M + 1
与上面的分析一样,但是现在失败的概率变成了 (N^k-rM)/N^k , 其中 r=\lfloor \frac{N^k}{M} \rfloor . 这样
\mathbb{E}[C] =\frac{kN^k}{rM}
当 N=6 , M=7 时,我们平均需要 \lceil \log_67\rceil * 6^2/(7*5)\approx 2.05 次成功,这相当于算法1显然是非常大的提升。这时,我们每轮拒绝的比例为 (6^2-7*5)/6^2=1/36 .
最后我们可以验证一下算法的正确性:
import random
import math
from collections import defaultdict
def randN(N):
return random.randint(1, N)
def randNk(N: int, k: int) -> int:
x = 0
for i in range(k):
x += (randN(N) - 1) * N**i
return x + 1
def randM_from_randN1(N, M):
k = math.ceil(math.log(M, N))
count = 0
while True:
x = 0
count += 1
x = randNk(N, k)
if x <= M:
return x, count * k
def randM_from_randN2(N, M):
k = math.ceil(math.log(M, N))
r = N**k // M
count = 0
while True:
x = 0
count += 1
x = randNk(N, k)
if x <= r * M:
return (x - 1) % M + 1, count * k
iters = 1000000
counts = []
result = defaultdict(int)
for _ in range(iters):
x, count = randM_from_randN1(6, 7)
# x, count = randM_from_randN2(6, 7)
counts.append(count)
result[x] += 1
print(sum(counts) / iters)
print(result)
分别运行算法1和算法2 1000000 次得到的结果如下:
# randM_from_randN1
10.296348
defaultdict(<class 'int'>, {4: 142374, 7: 142717, 5: 142928, 2: 142686, 6: 143172, 1: 143148, 3: 142975})
# randM_from_randN2
2.05696
defaultdict(<class 'int'>, {7: 142504, 2: 143050, 3: 142237, 4: 142711, 1: 143311, 5: 143193, 6: 142994})