并查集模板(2020.10.1更新)

P3367 【模板】并查集

#include <bits/stdc++.h>
using namespace std;
const int N=1e4+10;
int n,m,fa[N];
void init()
{
    for(int i=1;i<=n;i++)
        fa[i]=i;
}
int find_fa(int x)
{
    return fa[x]=(x==fa[x]?x:find_fa(fa[x]));
}
void join(int x,int y)
{
    fa[find_fa(x)]=find_fa(y);
}
int main()
{
    ios::sync_with_stdio(false);
    cin>>n>>m;
    int x,y,opt;
    init();
    for(int i=1;i<=m;i++)
    {
        cin>>opt>>x>>y;
        if(opt==1)
        {
            join(x,y);
        }
        else
        {
            if(find_fa(x)==find_fa(y))printf("Y\n");
            else printf("N\n");
        }
    }
    return 0;
}

本次训练共7题,本文附AC代码和题目链接。

A题 hdu 1232 畅通工程

并查集模板题,掌握find(x)函数和join(a,b)函数的用法即可。
find(x)函数表示找x的祖先节点,使用了路径压缩算法,在找x的祖父节点的同时,还使从x结点搜索祖先结点的过程中所经过的所有结点都指向该祖先节点。
join(a,b)函数表示把a的祖先节点以及a下面的所有子节点全都指向b的祖先节点,这是由于find函数运用了路径压缩算法,使从a开始搜索时经过的所有节点都指向了b的祖先节点,从b开始搜索时的经过的所有节点也都指向了b的祖先节点,这样就使刚才搜索过的所有节点全都指向了b的祖先节点,这也就是完成了两个集合的合并操作。

#include <bits/stdc++.h>
using namespace std;
int n,m,a,b,sum,pre[1001];
int find(int x)
{
    return pre[x]=(x==pre[x]?x:find(pre[x]));
}
void join(int a,int b)
{
    pre[find(a)]=find(b);
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n&&n)//n个点,m条边
    {
        cin>>m;
        for(int i=1;i<=n;i++)
            pre[i]=i;
        sum=0;//sum统计已经连通的边的数量,总共要连n-1条边,则还需要连n-1-sum条边
        for(int i=1;i<=m;i++)
        {
            cin>>a>>b;
            if(find(a)!=find(b))
            {join(a,b);sum++;}
        }
        printf("%d\n",n-1-sum);
    }
    return 0;
}

B题 hdu 1272 小希的迷宫

这题注意判断集合数大于1的情况,大于1输出No。还有如果刚开始输入的a、b均为0要输出yes。

#include <bits/stdc++.h>
#define min3(a,b,c) min(min(a,b),min(b,c))//定义3个数的最小值
#define max3(a,b,c) max(max(a,b),max(b,c))//定义3个数的最大值
using namespace std;
int a,b,mn,mx,cnt,flag,pre[100001],vis[100001];
int find(int x)
{
    if(x!=pre[x])pre[x]=find(pre[x]);
    return pre[x];
}
void join(int a,int b)
{
    int a1=find(a),b1=find(b);
    if(a1!=b1){pre[a1]=b1;flag=0;}
    else flag=1;//a、b祖先相同,表示a、b已经连接
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>a>>b&&!(a==-1&&b==-1))
    {
        if(a==0&&b==0){printf("Yes\n");continue;}//注意如果刚开始输入的a、b均为0也要输出yes
        memset(vis,0,sizeof(vis));
        for(int i=1;i<=100001;i++)
            pre[i]=i;
        mn=0x3f3f3f3f;mx=cnt=flag=0;
        while(!(a==0&&b==0))
        {
            vis[a]=vis[b]=1;
            mn=min3(a,b,mn);mx=max3(a,b,mx);
            if(flag==0)join(a,b);
            cin>>a>>b;
        }
        if(flag==1){printf("No\n");continue;}
        for(int i=mn;i<=mx;i++)
        if(vis[i]&&pre[i]==i)cnt++;
        cnt==1?printf("Yes\n"):printf("No\n");//集合数大于1为No
    }
    return 0;
}

C题 nefu 209 湖南修路

最小生成树模板题,可用根据贪心思想和并查集的kruskal算法解决,按边权从小到大排序后再开始联通每个点,也就是贪心思想,而把点联通到各个相同或者不同的集合中,则用到了并查集的模板。

#include <bits/stdc++.h>
using namespace std;
int n,sum,pre[101];//pre记录每个点的祖父节点
struct node
{
    int a,b,val,flag;//a表示边的起点,b表示边的终点,val表示边权(连接该边的花费)
}p[101*101];//记录所有边的信息,有向图最多能连n*(n-1)/2条边,无向图最多能连n*(n-1)条边,此题为无向图
bool cmp(node x,node y)
{
    if(x.flag!=y.flag)return x.flag>y.flag;//flag=1的排在前面,表示已经联通的边
    else if(x.val!=y.val) return x.val<y.val;//花费少的排在前面
}
int find(int x)
{
    if(x!=pre[x])pre[x]=find(pre[x]);
    return pre[x];
}
void join(int a,int b)
{pre[find(a)]=find(b);}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n)
    {
        for(int i=1;i<=n;i++)//祖父节点初始化,认为每个点的祖父都是它自身
            pre[i]=i;
        for(int i=1;i<=n*(n-1)/2;i++)
        cin>>p[i].a>>p[i].b>>p[i].val>>p[i].flag;
        sort(p+1,p+n*(n-1)/2+1,cmp);
        sum=0;
        for(int i=1;i<=n*(n-1)/2;i++)
        {
            if(p[i].flag==1&&find(p[i].a)!=find(p[i].b))
                join(p[i].a,p[i].b);
            if(p[i].flag==0&&find(p[i].a)!=find(p[i].b))
            {join(p[i].a,p[i].b);sum=sum+p[i].val;}
        }
        printf("%d\n",sum);
    }
    return 0;
}

D题 nefu 205 最小树1

C题简单版本,不多说。

#include <bits/stdc++.h>
using namespace std;
int n,pre[51];
struct node
{
    int a,b;
    double val;
}p[51*51];
bool cmp(node x,node y)
{return x.val<y.val;}//花费少的排在前面
int find(int x)
{
    if(x!=pre[x])pre[x]=find(pre[x]);
    return pre[x];
}
void join(int a,int b)
{pre[find(a)]=find(b);}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n&&n)
    {
        for(int i=1;i<=n;i++)
            pre[i]=i;
        for(int i=1;i<=n*(n-1)/2;i++)
        cin>>p[i].a>>p[i].b>>p[i].val;
        sort(p+1,p+n*(n-1)/2+1,cmp);
        double sum=0;
        for(int i=1;i<=n*(n-1)/2;i++)
        {
            if(find(p[i].a)!=find(p[i].b))
            {join(p[i].a,p[i].b);sum=sum+p[i].val;}
        }
        printf("%.2lf\n",sum);
    }
    return 0;
}

E题 nefu 129 修路工程

依然是最小生成树的模板题,不多说。

#include <bits/stdc++.h>
using namespace std;
int n,m,cnt,sum,flag,pre[101];
struct node
{
    int a,b,val;
}p[101*101];
bool cmp(node x,node y)
{return x.val<y.val;}//花费少的排在前面
int find(int x)
{
    if(x!=pre[x])pre[x]=find(pre[x]);
    return pre[x];
}
void join(int a,int b)
{pre[find(a)]=find(b);}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n&&n)//n条边,m个点
    {
        cin>>m;
        for(int i=1;i<=m;i++)
            pre[i]=i;
        for(int i=1;i<=n;i++)
        cin>>p[i].a>>p[i].b>>p[i].val;
        sort(p+1,p+n+1,cmp);
        flag=cnt=sum=0;
        for(int i=1;i<=n;i++)
        {
            if(find(p[i].a)!=find(p[i].b))
            {
                cnt++;
                join(p[i].a,p[i].b);
                sum=sum+p[i].val;
                if(cnt==m-1){flag=1;break;}
            }
        }
        flag==1?printf("%d\n",sum):printf("?\n");
    }
    return 0;
}

F题 nefu 1525 一道图论一

这题有点难度,先要把题目看懂:“从 s 连接到 t 的所有路径中单边长度最大值的最小值

实际上的意思是从点1到点2有很多条路径,每一条抵达的路径都有一个最大值,要求的是这些最大值里面的最小值。

举个例子:点1到点2,有两条路径可以到达,一条是1->3->2,一条是1->4->2。
然后,1->3的值是5,3->2的值是3,那这条路径的最大值是5。
之后1->4的距离是4,4->2的距离是3,那这条到达的路径最大值是4。
所以要求的最大值的最小值为4。

那么如何求这个最大值的最小值呢?我们可以先按边权val从小到大排序所有边,之后在每次查询的过程中进行边的遍历,用join函数把点依次联通,直到输入的a0、b0属于同一个集合时,记录此时的边权val,这就是我们要求的所有路径中边权最大值的最小值。

#include <bits/stdc++.h>
using namespace std;
const int maxn=1001,maxm=100001;
int n,m,k,a0,b0,flag,pre[maxn];
long long ans;
struct node
{
    int a,b,val;
}p[maxm<<1];
bool cmp(node x,node y)
{return x.val<y.val;}
int find(int x)
{
    if(x!=pre[x])pre[x]=find(pre[x]);
    return pre[x];
}
void join(int a,int b)
{pre[find(a)]=find(b);}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n>>m>>k)//n个点,m条边,k次询问
    {
        for(int i=1;i<=m;i++)
        cin>>p[i].a>>p[i].b>>p[i].val;
        sort(p+1,p+m+1,cmp);
        while(k--)
        {
            cin>>a0>>b0;
            for(int i=1;i<=n;i++)
                pre[i]=i;//每次查询都要把点进行初始化
            flag=0;
            for(int i=1;i<=m;i++)
            {
                if(find(p[i].a)!=find(p[i].b))
                    join(p[i].a,p[i].b);
                if(find(a0)==find(b0))//如果要查询的a0、b0为一个集合,则此时的val就是答案,结束循环
                   {ans=p[i].val;flag=1;break;}
            }
            flag==1?printf("%lld\n",ans):printf("-1\n");
        }
    }
    return 0;
}

G题 nefu 1791 藤原千花的星星图

(我永远喜欢藤原千花.jpg)
在这里插入图片描述

咳咳,说正事 ~~(差点就暴露出我喜欢二次元了)~~

这题是E题的加强版,数据比较大,需要用到快读模板,同时最好也把join函数优化为按秩合并的算法,
其他的都和求最小生成树的算法一样,时限500ms也可以不超时AC了~

#include <bits/stdc++.h>
using namespace std;
const int maxn=1e6+5;
int n,m,cnt,flag,pre[maxn],rk[maxn];
long long sum;
struct node
{
    int a,b,val;
}p[maxn<<1];
bool cmp(node x,node y)
{return x.val<y.val;}//花费少的排在前面
int find(int x)
{
    if(x!=pre[x])pre[x]=find(pre[x]);
    return pre[x];
}
void join(int a,int b)
{
    int a1=find(a),b1=find(b);
    if(rk[a1]>rk[b1])swap(a1,b1);//按秩合并,秩小的向秩大的合并,也就是子节点少的祖父节点向子节点多的祖父节点合并
    pre[a1]=b1;
    if(rk[a1]==rk[b1])rk[b1]++;
}
inline int read()//快读模板,在循环中输入int类型,提速效果很好,使用方法:int x=read();
{
    register int x=0,f=1;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1;c=getchar();}
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return x*f;
}
int main()
{
    while(scanf("%d%d",&n,&m)!=-1)//n个点,m条边
    {
        for(int i=1;i<=n;i++)
        {pre[i]=i;rk[i]=0;}
        for(int i=1;i<=m;i++)
        {p[i].a=read();p[i].b=read();p[i].val=read();}//scanf("%d%d%d",&p[i].a,&p[i].b,&p[i].val)会超时
        sort(p+1,p+m+1,cmp);
        flag=cnt=sum=0;
        for(int i=1;i<=m;i++)
        {
            if(find(p[i].a)!=find(p[i].b))
            {
                cnt++;
                join(p[i].a,p[i].b);
                sum=sum+p[i].val;
                if(cnt==n-1){flag=1;break;}
            }
        }
        flag==1?printf("%lld\n",sum):printf("-1\n");
    }
    return 0;
}

那么这篇文章就写到这里了,我们下次再见~