矩阵乘法的Strassen算法
2019-11-25 21:33阅读:
目录
一、矩阵相乘
二、矩阵相乘的朴素算法
三、Strassen其人及其算法
四、 Strassen的理想
五、 附Strassen的一种代码实现
六、 参考
一、矩阵相乘
矩阵相乘,是指一个m*n的矩阵和一个n*k的矩阵相乘而得到一个m*k矩阵的一种运算。矩阵相乘可用于线性变换和矩阵分解[1]。
二、矩阵相乘的朴素算法
矩阵相乘的一般步骤是第一个矩阵的第i行*第二个矩阵的第j列,得到第三个矩阵的第i行第
j列元素
因此其素朴算法是用三层循环的出计算结果
for i ← 1
to n
do for j ← 1 to n
do c[i][j] ← 0
for k ← 1 to n
do c[i][j] ← c[i][j] + a[i][k]⋅
b[k][j]
显然其时间复杂度为O(n^3)
三、Strassen其人及其算法
矩阵中的元素排列整齐,很容易分块,因此采用分治思想,将其划分为块,对各块进行相乘,分而治之,这样就会降低乘法运算的规模。
我们可以把一个n*n的矩阵划分为4个n/2*n/2的子矩阵进行运算。
这样递归进行预算,其形式如下,其中A11-A22代表a-d,B11-B22代表e-h:[5]
由于划分的规模缩小为n/2,总共划分成了8块,且各块相加的时间复杂度总和为O(n^2),故其时间复杂度可以近似表示为T(n)=8T(n/2)+O(n^2)
根据用主方法(the master method)求解递归式的方法O(n^2),因此此分治方法的时间复杂度仍为O(n^3)。
Volker Strassen是一位出生于1936年的德国数学家,他因为在概率论上的工作而广为人知,但是在计算机科学和算法领域,他却因为矩阵相乘算法而被大部分人认识,这个算法目前仍然是比通用矩阵相乘算法性能好的主要算法之一。[2]
Strassen在1969年第一次发表关于这个算法的文章,并证明了复杂度为n^3的算法并不是最优算法。[2]
他做了一个巧妙的组合计算:
先计算如何组合
再计算如下
最后得出
这时中间的分治计算只需P1-P7这七个乘法,因此时间复杂度公式也随之而变为:
T(n)=7T(n/2)+O(n^2)
同样据用主方法求解递归式的方法O(n^2)<</span>O(n^2.807),因此此分治方法的时间复杂度约为O(n^2.807)。
虽然,Strassen给出的解决方案只是好了一点点,但是,他的贡献却是相当巨大的,就是因为这导致了矩阵相乘领域更多的研究,产生了更快的算法,比如复杂度为O(n^2.3737)的Coppersmith-Winograd算法。
四、Strassen的理想
对于研究时间复杂度理论而言,Strassen算法贡献巨大,这个算法鼓励我们要朝着完美一步步接近,这是非常值得肯定的,理想一定要有。
然而,对于Strassen算法的实际测试我们也要进行关注,压力实测来源于网络:[3]
数据取600位上界,即超过10分钟跳出。可以看到使用Strassen算法时,耗时不但没有减少,反而剧烈增多,在n=700时计算时间就无法忍受。
造成如此结果的原因根据网上查阅资料,现罗列如下:[4]
1)采用Strassen算法作递归运算,需要创建大量的动态二维数组,其中分配堆内存空间将占用大量计算时间,从而掩盖了Strassen算法的优势
2)于是对Strassen算法做出改进,设定一个界限。当n<</font>界限时,使用普通法计算矩阵,而不继续分治递归。需要合理设置界限,不同环境(硬件配置)下界限不同
3)矩阵乘法一般意义上还是选择的是朴素的方法,只有当矩阵变稠密,而且矩阵的阶数很大时,才会考虑使用Strassen算法。
改进策略为:设定一个界限。当n<界限时,使用普通法计算矩阵,而不继续分治递归。
改进后算法优势明显,就算时间大幅下降。之后,针对不同大小的界限进行试验。在初步试验中发现,当数据规模小于1000时,下界S法的差别不大,规模大于1000以后,n取值越大,消耗时间下降。最优的界限值在32~128之间。
五、附Strassen的一种代码实现[6]
#include
#define N 4
//matrix + matrix
void plus( int t[N/2][N/2], int r[N/2][N/2], int
s[N/2][N/2] ) {
int i, j;
for( i = 0; i < N / 2; i++ )
{
for( j = 0; j < N / 2; j++ )
{
t[i][j] = r[i][j] + s[i][j];
}
}
}
//matrix - matrix
void minus( int t[N/2][N/2], int r[N/2][N/2], int
s[N/2][N/2] ) {
int i, j;
for( i = 0; i < N / 2; i++ )
{
for( j = 0; j < N / 2; j++ )
{
t[i][j] = r[i][j] - s[i][j];
}
}
}
//matrix * matrix
void mul( int t[N/2][N/2], int r[N/2][N/2], int
s[N/2][N/2] ) {
int i, j, k;
for( i = 0; i < N / 2; i++ )
{
for( j = 0; j < N / 2; j++ )
{
t[i][j] = 0;
for( k = 0; k < N / 2; k++ )
{
t[i][j] += r[i][k] * s[k][j];
}
}
}
}
int main() {
int i, j, k;
int mat[N][N];
int m1[N][N];
int m2[N][N];
int
a[N/2][N/2],b[N/2][N/2],c[N/2][N/2],d[N/2][N/2];
int
e[N/2][N/2],f[N/2][N/2],g[N/2][N/2],h[N/2][N/2];
int
p1[N/2][N/2],p2[N/2][N/2],p3[N/2][N/2],p4[N/2][N/2];
int
p5[N/2][N/2],p6[N/2][N/2],p7[N/2][N/2];
int r[N/2][N/2], s[N/2][N/2],
t[N/2][N/2], u[N/2][N/2], t1[N/2][N/2], t2[N/2][N/2];
printf('Input the first
matrix...:');
for( i = 0; i < N; i++ )
{
for( j = 0; j < N; j++ )
{
scanf('%d', &m1[i][j]);
}
}
printf('Input the second
matrix...:');
for( i = 0; i < N; i++ )
{
for( j = 0; j < N; j++ )
{
scanf('%d', &m2[i][j]);
}
}
// a b c d e f g h
for( i = 0; i < N / 2; i++ )
{
for( j = 0; j < N / 2; j++ )
{
a[i][j] = m1[i][j];
b[i][j] = m1[i][j + N / 2];
c[i][j] = m1[i + N / 2][j];
d[i][j] = m1[i + N / 2][j + N / 2];
e[i][j] = m2[i][j];
f[i][j] = m2[i][j + N / 2];
g[i][j] = m2[i + N / 2][j];
h[i][j] = m2[i + N / 2][j + N / 2];
}
}
//p1
minus( r, f, h );
mul( p1, a, r );
//p2
plus( r, a, b );
mul( p2, r, h );
//p3
plus( r, c, d );
mul( p3, r, e );
//p4
minus( r, g, e );
mul( p4, d, r );
//p5
plus( r, a, d );
plus( s, e, f );
mul( p5, r, s );
//p6
minus( r, b, d );
plus( s, g, h );
mul( p6, r, s );
//p7
minus( r, a, c );
plus( s, e, f );
mul( p7, r, s );
//r = p5 + p4 - p2 + p6
plus( t1, p5, p4 );
minus( t2, t1, p2 );
plus( r, t2, p6 );
//s = p1 + p2
plus( s, p1, p2 );
//t = p3 + p4
plus( t, p3, p4 );
//u = p5 + p1 - p3 - p7 = p5 + p1 -
( p3 + p7 )
plus( t1, p5, p1 );
plus( t2, p3, p7 );
minus( u, t1, t2 );
for( i = 0; i < N / 2; i++ )
{
for( j = 0; j < N / 2; j++ )
{
mat[i][j] = r[i][j];
mat[i][j + N / 2] = s[i][j];
mat[i + N / 2][j] = t[i][j];
mat[i + N / 2][j + N / 2] = u[i][j];
}
}
printf('下面是strassen算法处理结果:');
for( i = 0; i < N; i++ )
{
for( j = 0; j < N; j++ )
{
printf('%d ', mat[i][j]);
}
printf('');
}
//下面是朴素算法处理
printf('下面是朴素算法处理结果:');
for( i = 0; i < N; i++ )
{
for( j = 0; j < N; j++ )
{
mat[i][j] = 0;
for( k = 0; k < N; k++ )
{
mat[i][j] += m1[i][j] * m2[i][j];
}
}
}
for( i = 0; i < N; i++ )
{
for( j = 0; j < N; j++ )
{
printf('%d ', mat[i][j]);
}
printf('');
}
return 0;
}
六、参考
1、https://www.zhihu.com/question/21351965?sort=created
2、https://yq.aliyun.com/articles/3591
3、https://blog.csdn.net/handawnc/article/details/7987107
4、https://www.cnblogs.com/zhoutaotao/p/3963048.html
5、《算法导论》第三版
6、https://www.2cto.com/kf/201303/197291.html