矩阵快速幂

关于什么是快速幂:快速幂代码
现在回想下问题斐波那契数列:f(1)=1,f(2)=1f(1) = 1, f(2) = 1。在n>2n > 2时, f(n)=f(n1)+f(n2)f(n) = f(n - 1) + f (n - 2)。求f(n)f(n)
朴素动态规划解法:

1
2
3
4
5
6
7
8
int fib(int n) {
int dp[n];
dp[0] = 1, dp[1] = 1;
for (int i = 2; i < n; i++) {
dp[i] = dp[i - 1] + dp[i - 2];
}
return dp[n - 1];
}

现在考虑FnF_{n}是由Fn1F_{n-1}Fn2F_{n-2}线性变换得出。则有:

[Fn   Fn1]=[Fn1   Fn2]×[1110][F_{n} \ \ \ F_{n-1}] = [F_{n-1}\ \ \ F_{n-2}] \times \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}

设矩阵[1110]\begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}PP
则有

[Fn+1   Fn]=[Fn   Fn1]×[1110]=[Fn1   Fn2]×[1110]2=[Fn1   Fn2]×P2[F_{n+1}\ \ \ F_{n}] = [F_{n} \ \ \ F_{n-1}] \times \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} = [F_{n-1}\ \ \ F_{n-2}] \times\begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}^2 = [F_{n-1}\ \ \ F_{n-2}] \times P^2

因为矩阵乘法满足结合律,所以我们可以直接用矩阵PP的幂来得到FnF_n的值。而求矩阵PnP^n的时间复杂度为O(log(n))O(log(n))
C++代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
const int maxn = 105;
struct Matrix{
int n, m;
int v[maxn][maxn];
Matrix(int n, int m) : n(n), m(m) {}

void init() {
memset(v, 0, sizeof(v));
}

Matrix operator* (const Matrix B) const {
Matrix ans(n, B.m); // for ans
ans.init();
for (int i = 0; i < n; i++) {
for (int j = 0; j < B.m; j++) {
for (int k = 0; k < m; k++) {
ans.v[i][j] = ans.v[i][j] + v[i][k] * B.v[k][j];
}
}
}
return ans;
}

void print() {
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
cout << v[i][j] << " ";
}
cout << endl;
}
}
};

Matrix q_pow (Matrix& A, int b) {
Matrix ret(A.n, A.m);

ret.init();
for (int i = 0; i < ret.n; i++) { // 初始化E
ret.v[i][i] = 1;
}

while (b) {
if (b & 1) {
ret = ret * A;
}
A = A * A;
b >>= 1;
}
return ret;
}

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <iostream>  
#include <cstdio>
#include <cstring>
using namespace std;
using ll = long long;
const int maxn = 105;
struct Matrix{
int n, m;
int v[maxn][maxn];
Matrix(int n, int m) : n(n), m(m) {}

void init() {
memset(v, 0, sizeof(v));
}

Matrix operator* (const Matrix B) const {
Matrix ans(n, B.m); // for ans
ans.init();
for (int i = 0; i < n; i++) {
for (int j = 0; j < B.m; j++) {
for (int k = 0; k < m; k++) {
ans.v[i][j] = ans.v[i][j] + v[i][k] * B.v[k][j];
}
}
}
return ans;
}

void print() {
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
cout << v[i][j] << " ";
}
cout << endl;
}
}
};

Matrix q_pow (Matrix& A, int b) {
Matrix ret(A.n, A.m);

ret.init();
for (int i = 0; i < ret.n; i++) { // 初始化E
ret.v[i][i] = 1;
}

while (b) {
if (b & 1) {
ret = ret * A;
}
A = A * A;
b >>= 1;
}
return ret;
}

int fib(int n) {
int dp[n];
dp[0] = 1, dp[1] = 1;
for (int i = 2; i < n; i++) {
dp[i] = dp[i - 1] + dp[i - 2];
}
return dp[n - 1];
}

int main()
{
int n = 44;
clock_t startTime, endTime;
startTime = clock(); // cpu clock time
cout << fib(n) << endl;
endTime = clock();
cout << (endTime - startTime) << endl;
// [F(n) F(n - 1)] = [F(n - 1) F(n - 2)] * [1 1]
// [1 0]
Matrix start = Matrix(1, 2);
start.v[0][0] = 1;
start.v[0][1] = 1;

Matrix P = Matrix(2, 2);
P.v[0][0] = 1;
P.v[0][1] = 1;
P.v[1][0] = 1;
P.v[1][1] = 0;

startTime = clock(); // cpu clock time
Matrix ret = start * q_pow(P, n - 1);
endTime = clock();
cout << (endTime - startTime) << endl;
ret.print();
}

这里暂且不考虑溢出问题。当nn足够大的时候,可以看到快速幂需要的时钟周期明显缩短。