poj 3233 Matrix Power Series

$$
题意:给你一个矩阵A,要你求矩阵 S = A + A^2 + A^3 + … + A^k
$$

$$
思路:用不化简的矩阵快速幂直接求和会超时,要推导出一个数学公式
$$

$$
设 S(k) = A + A^2 + A^3 + … + A^k,
则有 S(k) = S(k-1) + A^k
$$

$$
设2×2的矩阵B满足 :B*
\left[
\begin{matrix}
S(k-1) \\
A^k
\end{matrix}
\right] = \left[
\begin{matrix}
S(k) \\
A^{k+1}
\end{matrix}
\right]
$$

$$
则可以得到B=
\left[
\begin{matrix}
E & E\\
O & A
\end{matrix}
\right]
(E为单位阵,O为零矩阵)
$$

$$
让B与自身相乘,可得到B^{k+1}=
\left[
\begin{matrix}
E & E + A + A^2 + A^3 + … + A^k\\
O & A^{k+1}
\end{matrix}
\right]
$$

$$
所以只要求出矩阵B的k+1次方,它的右上角的子矩阵减去单位阵E,即为答案。
$$

#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
ll n,k,mod;
struct node
{
    ll m[62][62];
};
node s,B;
node mul(node x,node y)
{
    for(int i=0;i<2*n;i++)
        for(int j=0;j<2*n;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<2*n;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s;
    for(int i=0;i<2*n;i++)
        for(int j=0;j<2*n;j++)
        {
            if(i==j)s.m[i][j]=1;
            else s.m[i][j]=0;
        }
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    cin>>n>>k>>mod;
    for(int i=0;i<n;i++)
        B.m[i][i]=B.m[i][i+n]=1;//B矩阵左上角和右上角的子矩阵均为E
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
        cin>>B.m[i+n][j+n];//B矩阵右下角的子矩阵为A
    s=quickpow(B,k+1);
    for(int i=0;i<n;i++)
        s.m[i][i+n]--;//右上角减去单位阵
    for(int i=0;i<n;i++)
        for(int j=n;j<2*n;j++)
        {
            s.m[i][j]=(s.m[i][j]%mod+mod)%mod;//防止出现负数
            j==2*n-1?printf("%lld\n",s.m[i][j]):printf("%lld ",s.m[i][j]);
        }
    return 0;
}

hdu 1588 Gauss Fibonacci

这题其实和上题有点联系,要求的和为 f(g(i)) for 0<=i<n,设这个和为S(n)
即S(n) = f(b) + f(k+b) + f(2*k+b) + ... + f((n-1)*k+b)
由于f(n)是斐波那契数列,则有
$$
2×2的矩阵A满足 :A*
\left[
\begin{matrix}
f(n-1) \\
f(n-2)
\end{matrix}
\right] = \left[
\begin{matrix}
f(n) \\
f (n-1)
\end{matrix}
\right] ,其中A= \left[
\begin{matrix}
1&1 \\
1&0
\end{matrix}
\right]
$$

$$
n>=2时,有A^n*
\left[
\begin{matrix}
f(1) \\
f(0)
\end{matrix}
\right] = \left[
\begin{matrix}
f(n+1) \\
f (n)
\end{matrix}
\right]
其中f(1)=1,f(0)=0
$$

$$
设A^n(2×2矩阵)的左下角元素为A.m[1][0],右下角元素为A.m[1][1]
$$

$$
根据矩阵乘法,有A.m[1][0] * f(1) + A.m[1][1] * f(0) = f(n),即f(n) = A.m[1][0]
$$

$则S(n) = A^b + A^{k+b} + A^{2k+b} + ... + A^{(n-1)k+b}(A的幂次取[1][0]位置,也就是矩阵左下角的元素)$

$S(n) = A^b * (E + A^k + A^{2k} + ... + A^{(n-1)k})$

$设B=A^k,则S(n) = A^b * (E + B + B^2 + ... + B^{n-1})$

括号内的形式就和上题poj 3233差不多了,直接用上题解法求括号内的矩阵和,之后再乘以$A^b$,得到矩阵ans的左下角元素即为答案。

#include <bits/stdc++.h>
using namespace std;
const int N=4;
typedef long long ll;
ll n,k,b,mod;
struct node
{
    ll m[N][N];
};
node t,s1,s2,s3,ans,B,A={1,1,0,0,1};//不能写A={1,1,1,0},因为此处默认A是4*4的矩阵
node mul(node x,node y,ll n)
{
    node s;
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<n;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b,ll n)
{
    node s;
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
        {
            if(i==j)s.m[i][j]=1;
            else s.m[i][j]=0;
        }
    while(b)
    {
        if(b&1){b--;s=mul(s,a,n);}
        a=mul(a,a,n);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>k>>b>>n>>mod)
    {
        s1=quickpow(A,b,2);
        s2=quickpow(A,k,2);
        memset(B.m,0,sizeof(B.m));
        for(int i=0;i<2;i++)
        B.m[i][i]=B.m[i][i+2]=1;
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++)
            B.m[i+2][j+2]=s2.m[i][j];
        s3=quickpow(B,n,4);
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++)
            t.m[i][j]=s3.m[i][j+2];
        ans=mul(s1,t,2);
        printf("%lld\n",ans.m[1][0]);
    }
    return 0;
}

hdu 4965 Fast Matrix Calculation

按照题目的步骤来算肯定是会超时的,因为A×B最大是1000×1000,再快速幂1e6次就超时了。
可以利用矩阵乘法的结合律,先算B×A,B×A最大只有6×6,这样快速幂能省很多时间。
原式$(A * B)^{n * n} = A * (B * A)^{n * n - 1} * B$,利用这个公式计算即可。

#include <bits/stdc++.h>
using namespace std;
const int N=10,mod=6;
int n,k,ans,a[1010][10],b[10][1010],t[1010][10],s[1010][1010];
struct node
{
    int m[N][N];
};
node C,M;
node mul(node x,node y)
{
    node s;
    for(int i=0;i<k;i++)
        for(int j=0;j<k;j++)
        {
            s.m[i][j]=0;
            for(int p=0;p<k;p++)
                s.m[i][j]=(s.m[i][j]+x.m[i][p]*y.m[p][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,int b)
{
    node s;
    for(int i=0;i<k;i++)
        for(int j=0;j<k;j++)
        {
            if(i==j)s.m[i][j]=1;
            else s.m[i][j]=0;
        }
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n>>k&&!(n==0&&k==0))
    {
        for(int i=0;i<n;i++)
            for(int j=0;j<k;j++)
                cin>>a[i][j];
        for(int i=0;i<k;i++)
            for(int j=0;j<n;j++)
                cin>>b[i][j];
        memset(C.m,0,sizeof(C.m));
        for(int i=0;i<k;i++)
            for(int j=0;j<k;j++)
                for(int p=0;p<n;p++)
                C.m[i][j]=(C.m[i][j]+b[i][p]*a[p][j]%mod)%mod;
        M=quickpow(C,n*n-1);//M=(B*A)^(n*n-1)
        memset(t,0,sizeof(t));
        for(int i=0;i<n;i++)
            for(int j=0;j<k;j++)
                for(int p=0;p<k;p++)
                t[i][j]=(t[i][j]+a[i][p]*M.m[p][j]%mod)%mod;//t=A*M
        memset(s,0,sizeof(s));
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                for(int p=0;p<k;p++)
                s[i][j]=(s[i][j]+t[i][p]*b[p][j]%mod)%mod;//s=t*B
        ans=0;
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
            ans=ans+s[i][j];
        printf("%d\n",ans);
    }
    return 0;
}

hdu 4920 Matrix multiplication

普通的两矩阵相乘取模,时间优化在于取模,如果你取模的顺序写得不对,就超时了,比如以下代码:

#include <bits/stdc++.h>
using namespace std;
const int N=810,mod=3;
int n,a[N][N],b[N][N],s[N][N];
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n)
    {
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                cin>>a[i][j];
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                cin>>b[i][j];
        memset(s,0,sizeof(s));
        for(int i=0;i<n;i++)
            for(int k=0;k<n;k++)
                for(int j=0;j<n;j++)
                    s[i][j]=(s[i][j]+a[i][k]*b[k][j]%mod)%mod;
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                j==n-1?printf("%d\n",s[i][j]):printf("%d ",s[i][j]);
    }
    return 0;
}

超时原因应该是取模都写到三重循环里了,O(n^3^)取模比较费时间。~~(取模也很耗时间啊我枯了)~~
其实只要先对原矩阵a、b的每个元素取模,最后对答案矩阵的每个元素取模就AC了。

#include <bits/stdc++.h>
using namespace std;
const int N=810,mod=3;
int n,a[N][N],b[N][N],s[N][N];
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n)
    {
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
            {
                cin>>a[i][j];
                a[i][j]=a[i][j]%mod;//先对原矩阵每个元素取模
            }
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
            {
                cin>>b[i][j];
                b[i][j]=b[i][j]%mod;
            }
        memset(s,0,sizeof(s));
        for(int i=0;i<n;i++)
            for(int k=0;k<n;k++)
                for(int j=0;j<n;j++)
                    s[i][j]=s[i][j]+a[i][k]*b[k][j];
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                j==n-1?printf("%d\n",s[i][j]%mod):printf("%d ",s[i][j]%mod);//输出答案时取模
    }
    return 0;
}

剩下的题目基本上都是一个套路,就是利用题目给你的递推方程构造矩阵A,然后求A的多少次幂乘前几项初始值构成的矩阵就能得到答案

HIT 2060 - Fibonacci Problem Again

#include <bits/stdc++.h>
using namespace std;
const int N=3,mod=1e9;
typedef long long ll;
ll a,b,ans1,ans2;
struct node
{
    ll m[N][N];
};
node s1,s2,A={1,1,1,0,1,1,0,1,0},E={1,0,0,0,1,0,0,0,1};
node mul(node x,node y)
{
    node s;
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<N;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s=E;
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>a>>b&&!(a==0&&b==0))
    {
        s1=quickpow(A,b-1);//b>=1恒成立
        ans1=s1.m[0][0]*2+s1.m[0][1]+s1.m[0][2];
        if(a==0) ans2=0;
        else if(a==1) ans2=1;
        else
        {
            s2=quickpow(A,a-2);
            ans2=s2.m[0][0]*2+s2.m[0][1]+s2.m[0][2];
        }
        printf("%lld\n",((ans1-ans2)%mod+mod)%mod);//相减后可能负数要加上mod再取模,否则会错
    }
    return 0;
}

HIT 2255 - Not Fibonacci

#include <bits/stdc++.h>
using namespace std;
const int N=3,mod=1e7;
typedef long long ll;
ll t,a,b,p,q,s,e,ans1,ans2;
struct node
{
    ll m[N][N];
};
node s1,s2,A,E={1,0,0,0,1,0,0,0,1};
node mul(node x,node y)
{
    node s;
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<N;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s=E;
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    cin>>t;
    while(t--)
    {
        cin>>a>>b>>p>>q>>s>>e;//[s,e]区间和
        A={1,p,q,0,p,q,0,1,0};
        if(e==0)ans1=a;
        else
        {
            s1=quickpow(A,e-1);
            ans1=s1.m[0][0]*(a+b)+s1.m[0][1]*b+s1.m[0][2]*a;
        }
        if(s==0) ans2=0;
        else if(s==1) ans2=a;
        else
        {
            s2=quickpow(A,s-2);
            ans2=s2.m[0][0]*(a+b)+s2.m[0][1]*b+s2.m[0][2]*a;
        }
        printf("%lld\n",((ans1-ans2)%mod+mod)%mod);
    }
    return 0;
}

hdu 3306 Another kind of Fibonacci

#include <bits/stdc++.h>
using namespace std;
const int N=4,mod=10007;
typedef long long ll;
ll n,p,q,ans;
struct node
{
    ll m[N][N];
};
node s,A,E={1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1};
node mul(node x,node y)
{
    node s;
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<N;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s=E;
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n>>p>>q)
    {
        A={1,p*p%mod,2*p*q%mod,q*q%mod,0,p*p%mod,2*p*q%mod,q*q%mod,0,p%mod,q%mod,0,0,1,0,0};
        s=quickpow(A,n-1);
        ans=s.m[0][0]*2%mod+s.m[0][1]%mod+s.m[0][2]%mod+s.m[0][3]%mod;//一定要取模,否则会错
        printf("%lld\n",ans%mod);
    }
    return 0;
}

hdu 1757 A Simple Math Problem

#include <bits/stdc++.h>
using namespace std;
const int N=10;
typedef long long ll;
ll k,mod,ans,a[N];
struct node
{
    ll m[N][N];
};
node s,A;
node mul(node x,node y)
{
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<N;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s;
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            if(i==j)s.m[i][j]=1;
            else s.m[i][j]=0;
        }
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>k>>mod)
    {
        for(int i=0;i<N;i++)
            cin>>a[i];
        if(k<10){printf("%lld\n",a[k]%mod);continue;}
        /*A={a[0],a[1],a[2],a[3],a[4],a[5],a[6],a[7],a[8],a[9],
        1,0,0,0,0,0,0,0,0,0,
        0,1,0,0,0,0,0,0,0,0,
        0,0,1,0,0,0,0,0,0,0,
        0,0,0,1,0,0,0,0,0,0,
        0,0,0,0,1,0,0,0,0,0,
        0,0,0,0,0,1,0,0,0,0,
        0,0,0,0,0,0,1,0,0,0,
        0,0,0,0,0,0,0,1,0,0,
        0,0,0,0,0,0,0,0,1,0};*/
        for(int i=0;i<N;i++)
        {
            A.m[0][i]=a[i];
            if(i>=1)A.m[i][i-1]=1;
        }
        s=quickpow(A,k-9);
        ans=0;
        for(int i=0;i<N;i++)
            ans=ans+s.m[0][i]*(9-i)%mod;
        printf("%lld\n",ans%mod);
    }
    return 0;
}