对于两个 n 位数的大整数,普通竖式乘法(逐位相乘再累加)的时间复杂度是 O(n²)。当 n 很大(比如上万位)时,平方级的开销会非常显著。
Karatsuba 算法 利用分治思想,将时间复杂度降到 O(n^{log₂3}) ≈ O(n^{1.585}),是第一个突破 O(n²) 的大数乘法算法。
核心思想
设我们要计算两个 n 位数 X 和 Y 的乘积(假设 n 是 2 的幂,不足则补零)。
将 X 和 Y 各分成两半:
X = A * 10^{m} + B
Y = C * 10^{m} + D
其中 m = n/2,A、C 是高半部分,B、D 是低半部分。那么:
X * Y = (A*10^m + B) * (C*10^m + D)
= AC * 10^{2m} + (AD + BC) * 10^{m} + BD
普通做法需要计算四次乘法:AC, AD, BC, BD。Karatsuba 观察到:
AD + BC = (A+B)(C+D) - AC - BD
这样只需计算 三次乘法:
ACBD(A+B)(C+D)
然后用加减法组合出最终结果。递归地对这三个乘积继续使用 Karatsuba 方法,直到数字足够小(例如 32 位以内)改用直接乘法。
算法步骤
- 输入两个大数
X和Y(字符串或小端序数组)。 - 若长度小于某阈值(如 32 或 64),直接用普通乘法返回。
- 将
X和Y补齐到相同长度,且长度为 2 的幂(方便分半)。 - 取分割点
m = n/2,得到 A、B、C、D。 - 递归计算:
z0 = karatsuba(A, C)z2 = karatsuba(B, D)z1 = karatsuba(A+B, C+D) - z0 - z2
- 合并结果:
z0 * 10^{2m} + z1 * 10^{m} + z2。 - 处理进位并返回。
代码实现
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <cmath>
using namespace std;
using BigInt = vector<int>;
// 普通高精度乘法(O(n^2)),用于小规模递归基础情况
BigInt mulBrute(const BigInt& a, const BigInt& b) {
BigInt c(a.size() + b.size(), 0);
for (size_t i = 0; i < a.size(); ++i) {
long long carry = 0;
for (size_t j = 0; j < b.size(); ++j) {
long long cur = c[i + j] + (long long)a[i] * b[j] + carry;
c[i + j] = cur % 10;
carry = cur / 10;
}
if (carry) c[i + b.size()] += carry; // 注意可能超过一位,但进位最多一位
}
while (c.size() > 1 && c.back() == 0) c.pop_back();
return c;
}
// 辅助函数:两个大数相加(小端序)
BigInt add(const BigInt& a, const BigInt& b) {
BigInt c;
int carry = 0;
size_t i = 0;
while (i < a.size() || i < b.size() || carry) {
int sum = carry;
if (i < a.size()) sum += a[i];
if (i < b.size()) sum += b[i];
c.push_back(sum % 10);
carry = sum / 10;
++i;
}
return c;
}
// 辅助函数:两个大数相减(要求 a >= b)
BigInt sub(const BigInt& a, const BigInt& b) {
BigInt c;
int borrow = 0;
for (size_t i = 0; i < a.size(); ++i) {
int diff = a[i] - borrow;
if (i < b.size()) diff -= b[i];
if (diff < 0) {
diff += 10;
borrow = 1;
} else {
borrow = 0;
}
c.push_back(diff);
}
while (c.size() > 1 && c.back() == 0) c.pop_back();
return c;
}
// 大数左移(乘以 10^k),即在低位补 k 个 0
BigInt shiftLeft(const BigInt& a, int k) {
if (a.size() == 1 && a[0] == 0) return a;
BigInt res(k, 0); // 低位补零
res.insert(res.end(), a.begin(), a.end());
return res;
}
// Karatsuba 乘法
BigInt karatsuba(const BigInt& a, const BigInt& b) {
// 基础情况:使用普通乘法
const int THRESHOLD = 64; // 经验值,可调节
if (a.size() < THRESHOLD || b.size() < THRESHOLD) {
return mulBrute(a, b);
}
// 确保两数长度一致,并取最大长度
size_t n = max(a.size(), b.size());
// 向上取整为偶数,方便分半(也可取 n/2 向上取整)
if (n % 2 != 0) ++n;
BigInt a2 = a, b2 = b;
a2.resize(n, 0);
b2.resize(n, 0);
size_t m = n / 2; // 分割点(低位部分长度)
// 分割: A = 高位部分 (a2[m..n-1]), B = 低位部分 (a2[0..m-1])
BigInt lowA(a2.begin(), a2.begin() + m);
BigInt highA(a2.begin() + m, a2.end());
BigInt lowB(b2.begin(), b2.begin() + m);
BigInt highB(b2.begin() + m, b2.end());
// 递归计算三个乘积
BigInt z0 = karatsuba(highA, highB); // AC
BigInt z2 = karatsuba(lowA, lowB); // BD
BigInt sumA = add(highA, lowA); // A+B
BigInt sumB = add(highB, lowB); // C+D
BigInt z1 = karatsuba(sumA, sumB); // (A+B)(C+D)
z1 = sub(z1, z0);
z1 = sub(z1, z2); // (A+B)(C+D) - AC - BD
// 组合结果: z0 * 10^{2m} + z1 * 10^{m} + z2
BigInt result = add(z2, shiftLeft(z1, m));
result = add(result, shiftLeft(z0, 2 * m));
// 去除前导零
while (result.size() > 1 && result.back() == 0)
result.pop_back();
return result;
}
// 字符串转小端序
BigInt strToVec(const string& s) {
BigInt v;
for (int i = s.size() - 1; i >= 0; --i)
v.push_back(s[i] - '0');
return v;
}
// 输出小端序
void printVec(const BigInt& v) {
for (int i = v.size() - 1; i >= 0; --i)
cout << v[i];
cout << endl;
}
int main() {
string s1, s2;
cin >> s1 >> s2;
BigInt a = strToVec(s1);
BigInt b = strToVec(s2);
BigInt c = karatsuba(a, b);
printVec(c);
return 0;
}
复杂度分析
- 时间复杂度:递推式
T(n) = 3T(n/2) + O(n),解得T(n) = O(n^{log₂3}) ≈ O(n^{1.585})。 - 空间复杂度:递归栈深度 O(log n),每层需要存储中间结果,总空间 O(n)。