本文只记录AC自动机的入门练习题,不再详解算法原理。

在掌握字典树(Trie树)和KMP思想的基础上,学习AC自动机算法原理,推荐阅读以下文章:
1. 洛谷日报 强势图解AC自动机 https://www.luogu.com.cn/blog/3383669u/qiang-shi-tu-xie-ac-zi-dong-ji
2. AC自动机 - 多模式匹配算法 https://blog.csdn.net/xaiojiang/article/details/52299258
3. AC 自动机 - OI Wiki https://oi-wiki.org/string/ac-automaton/

一、HDU 2222 Keywords Search

AC自动机模板题,具体细节详见代码注释。代码参考:洛谷日报 强势图解AC自动机

#include <bits/stdc++.h>
using namespace std;
const int N=1e6+10,M=26;
struct node
{
    int ch[N][M];
    int cnt[N];
    int fail[N];
    int tot;
    queue<int>q;
    void init() // 初始化
    {
        memset(cnt,0,sizeof(cnt));
        memset(ch,0,sizeof(ch));
        memset(fail,0,sizeof(fail));
        tot=0;
    }
    void ins(char s[]) // insert代码同Trie树
    {
        int u=0; // 根节点为0,从根节点开始往下走
        for(int i=0;s[i];i++)
        {
            int x=s[i]-'a'; // a~z -> 0~25
            if(!ch[u][x])ch[u][x]=++tot; // 没有节点就造节点
            u=ch[u][x]; // 向下遍历
        }
        // 此模式串终点为u,更新以u结尾的模式串个数
        cnt[u]++;
    }
    void build_fail()
    {
        fail[0]=0; // 根节点0的fail指针指向自身
        for(int i=0;i<M;i++)
        {
            if(ch[0][i])
                q.push(ch[0][i]); // 与根节点直接相连的一层节点入队
        }
        while(!q.empty()) // 树的层次遍历
        {
            int u=q.front();q.pop(); // 父节点u
            for(int i=0;i<M;i++)
            {
                int &v=ch[u][i]; // 子节点v,因为之后可能要修改,所以用&取引用
                if(v)
                {
                    fail[v]=ch[fail[u]][i];
                    q.push(v);
                }
                // else这短短一行代码,是算法优化,体现出路径压缩的思想
                // 修改Trie树,从而将Trie树改造成Trie图
                // 使得在之后的遍历过程中能更快得到fail指针
                // 具体操作是子节点不存在,就造一个子节点,指向父节点fail指针对应的子节点
                else v=ch[fail[u]][i]; 
            }
        }
    }
    int query(char s[])
    {
        int u=0,ans=0;
        for(int i=0;s[i];i++)
        {
            int x=s[i]-'a';
            u=ch[u][x];
            // if(~cnt[j]) 等价于 if(cnt[j]!=-1)
            // 不断向上找fail指针,找到cnt值存在的就更新答案
            // 遍历到根或之前遍历过的点时停止
            for(int j=u;j&&~cnt[j];j=fail[j])
            {
                ans+=cnt[j];
                cnt[j]=-1;
                // cnt在这里可以起到标记的作用,这样每个点至多被遍历一次
            }
        }
        return ans;
    }
}ac;
char t[N];
int n,T;
int main()
{
    ios::sync_with_stdio(false);
    cin>>T;
    while(T--)
    {
        ac.init();
        cin>>n;
        for(int i=1;i<=n;i++)
        {
            cin>>t;
            ac.ins(t);
        }
        ac.build_fail();
        cin>>t;
        printf("%d\n",ac.query(t));
    }
    return 0;
}
/*
1
2
she
h
she
ans:2
*/

二、HDU 2896 病毒侵袭

基本上也是模板题吧。

坑的地方在于,交了几次,数组开小了会RE,数组开大了会MLE,后来发现错误原因是用的二维vector记录模式串终点,而实际上题目说了不会出现两个相同的模式串(那么插入模式串后,每个终点都只唯一对应一个模式串),开二维的MLE,那就改成一维数组记录模式串终点对应的编号,大小开500*200就行了;

还有就是字符串中可能会有空格(空格ASCII码为32),应该用gets读入,gets之前记得写getchar()吸收回车。

#include <bits/stdc++.h>
using namespace std;
const int N=1e5+10,M=130,K=1e4+10;
int n,m,sum;
char t[K];
struct node
{
    int ch[N][M];
    int tot=0;
    int cnt[N];
    int fail[N];
    queue<int>q;
    void ins(int num,char s[]) // 插入模式串s[],编号为num
    {
        int u=0;
        for(int i=0;s[i];i++)
        {
            int x=s[i]; // 直接用ASCII码
            if(!ch[u][x])ch[u][x]=++tot;
            u=ch[u][x];
        }// 遍历到的终点是u
        cnt[u]=num; // 以u结尾的模式串编号是num(唯一对应)
    }
    void build_fail()
    {
        for(int i=0;i<M;i++)
        {
            if(ch[0][i])
                q.push(ch[0][i]);
        }
        while(!q.empty())
        {
            int u=q.front();q.pop();
            for(int i=0;i<M;i++)
            {
                int &v=ch[u][i];
                int f=ch[fail[u]][i];
                if(v)
                {
                    fail[v]=f;
                    q.push(v);
                }
                else v=f;
            }
        }
    }
    set<int>ans;
    int tmp[N];
    bool query(char s[])
    {
        int u=0;
        ans.clear();
        memcpy(tmp,cnt,sizeof(cnt)); // tmp是cnt的一个拷贝,之后要修改tmp
        bool flag=0;
        for(int i=0;s[i];i++)
        {
            int x=s[i];
            u=ch[u][x];
            for(int j=u;j&&tmp[j];j=fail[j])
            {
                flag=1;
                ans.insert(tmp[j]);
                tmp[j]=0;
            }
        }
        return flag;
    }
}ac;
int main()
{
    cin>>n;
    getchar();
    for(int i=1;i<=n;i++)
    {
        gets(t); // 不是cin!
        ac.ins(i,t);
    }
    ac.build_fail();
    cin>>m;
    getchar();
    for(int i=1;i<=m;i++)
    {
        gets(t); // 不是cin!
        if(ac.query(t))
        {
            sum++;
            printf("web %d:",i);
            for(auto j:ac.ans)
            {
                printf(" %d",j);
            }
            printf("\n");
        }
    }
    printf("total: %d\n",sum);
    return 0;
}

三、HDU 3065 病毒侵袭持续中

这题和其他题不同的地方在于,在query函数,fail指针不断向上跳的过程中,跳回根节点才停止,每个点是可以被多次遍历到的,所以不需要(不能)标记访问过的点,这样才满足题目要求。虽然不能保证每个点只被遍历一次,时间复杂度会大一些,但是必须这样做,计数才不会有遗漏。

看这个样例,就很好明白了:

Input

3
A
AA
AAA
AAAA

Output

A: 4
AA: 3
AAA: 2

为什么A被遍历了4次,原因就在于,每次从目标串的当前位置向上跳fail指针的时候,都会遍历到单个的A。目标串长度为4,单个的A总共就被计数了4次。

还有坑的地方就是题目不说清楚是多组输入(出题人出来挨打!),单组输入会给你评测WA;

如果把gets写成了cin,评测结果不是WA而是TLE,非常的误导人,我先还以为是fail指针向上跳的时候会重复遍历点,时间复杂度太大,后来才发现是cin错了...

#include <bits/stdc++.h>
using namespace std;
const int N=5e4+10,M=26,K=2e6+10;
char ans[1005][55];// 模式串
char t[K]; // 目标串
map<int,int>vis; // 模式串编号对应出现的次数
struct node
{
    int ch[N][M];
    int cnt[N];
    int tot;
    int fail[N];
    queue<int>q;
    void init()
    {
        memset(ch,0,sizeof(ch));
        memset(fail,0,sizeof(fail));
        memset(cnt,0,sizeof(cnt));
        while(!q.empty())q.pop();
        tot=0;
    }
    void ins(int num,char s[])
    {
        int u=0;
        for(int i=0;s[i];i++)
        {
            int x=s[i]-'A'; // A~Z -> 0~25
            if(!ch[u][x])ch[u][x]=++tot;
            u=ch[u][x];
        }
        cnt[u]=num;
    }
    void build_fail()
    {
        for(int i=0;i<M;i++)
        {
            if(ch[0][i])
                q.push(ch[0][i]);
        }
        while(!q.empty())
        {
            int u=q.front();q.pop();
            for(int i=0;i<M;i++)
            {
                int &v=ch[u][i];
                int f=ch[fail[u]][i];
                if(v)
                {
                    fail[v]=f;
                    q.push(v);
                }
                else v=f;
            }
        }
    }
    void query(char s[])
    {
        int u=0;
        vis.clear();
        for(int i=0;s[i];i++)
        {
            if(s[i]<'A'||s[i]>'Z') // 跳回根节点
            {
                u=0;
                continue;
            }
            int x=s[i]-'A';
            u=ch[u][x];
            for(int j=u;j&&cnt[j];j=fail[j])
            {
                vis[cnt[j]]++; //j对应模式串编号为cnt[j],个数+1
                // 此处不能修改cnt[j]=0
                // 因为按照题目要求,之后可能需要再次遍历该点
            }
        }
    }
}ac;
int main()
{
    int n;
    while(cin>>n) // 多组输入(题目没说清楚是多组!)
    {
        ac.init();
        for(int i=1;i<=n;i++)
        {
            cin>>ans[i];
            ac.ins(i,ans[i]);
        }
        ac.build_fail();
        getchar();
        gets(t); // gets!含空格!
        ac.query(t);
        for(auto it:vis)
        {
            int a=it.first; // 模式串下标
            int b=it.second; // 出现次数
            printf("%s: %d\n",ans[a],b);
        }
    }
    return 0;
}
/*
3
A
AA
AAA
AAAA

A: 4
AA: 3
AAA: 2
*/