2021年度训练联盟热身训练赛第八场 B Gene Tree

题意很简单:求树中所有叶结点之间的距离平方和(若树根的度为1也看作叶结点)。

对于树根root,取两个叶结点x,y,设其距离为dis(x,y),x,y最近公共祖先为lca(x,y),则有:
$$
dis(x,y)=dis(root,x)+dis(root,y)-2*dis(root,lca(x,y))
$$我们进行一次dfs即可求得dis(root,x)和dis(root,y),实现$O(n)$预处理。

dis(root,x)简写为$d_x$,dis(root,y)简写为$d_y$,dis(root,lca(x,y))简写为$d_{lca}$,化简$[dis(x,y)]^2$如下:
$$[dis(x,y)]^2=(d_x)^2+(d_y)^2-4d_{lca}(d_x+d_y)+4(d_{lca})^2+2d_x*d_y
$$

维护三个值sz、sum1、sum2,放入结构体tr[i]中表示(用 $j∈son(i)$ 表示ji的子孙结点):
- tr[i].sz表示点i的子树中叶结点个数
- tr[i].sum1表示点i的子树中所有点j到子树根i距离和,即$\sum_{j∈son(i)} d_j$。
- tr[i].sum2表示点i的子树中所有点j到子树根i距离平方和,即$\sum_{j∈son(i)} (d_j)^2$。

那么要求$[dis(x,y)]^2$就转化成了维护四部分相加。设x是u的子孙结点,y是v的子孙结点,lca(x,y)为点u,点u的直接孩子为点v,那么写成代码如下:
1. $p_1=\sum_{x∈son(u)}\sum_{y∈son(v)}\ [(d_x)^2+(d_y)^2]$
p1=tr[v].sz*tr[u].sum2+tr[u].sz*tr[v].sum2;
2. $p_2=\sum_{x∈son(u)}\sum_{y∈son(v)}\ [-4d_{lca}(d_x+d_y)]$
p2=(-4)*dis[u]*(tr[v].sz*tr[u].sum1+tr[u].sz*tr[v].sum1);
3. $p_3=\sum_{x∈son(u)}\sum_{y∈son(v)}\ [4(d_{lca})^2]$
p3=4*dis[u]*dis[u]*tr[u].sz*tr[v].sz;
3. $p_4=\sum_{x∈son(u)}\sum_{y∈son(v)}\ [2
d_x*d_y]$
p4=2*tr[u].sum1*tr[v].sum1;

说明一下第二次dfs每层的具体操作。搜索点$u$时,假设其孩子有$v_1$,$v_2$,...,$v_m$,当搜索完孩子$v_i$,返回本层时,$v_1$,...,$v_{i-1}$这些子树都已经被搜索完了,它们的叶结点距离和也已经更新到答案中,它们的各个维护值也已经加到点$u$上。然后把点$u$(这时点$u$集合了$v_1$,...,$v_{i-1}$的所有叶结点信息)和$v_i$的各个维护值(集合了$v_i$子树的所有叶结点信息)进行计算求和,更新答案。最后再把$v_i$加到点$u$上,以便下一次搜索到$v_{i+1}$时能与点$u$(这时点$u$集合了$v_1$,...,$v_{i}$的所有叶结点信息)的各维护值进行计算求和。

#include <bits/stdc++.h>
using namespace std;
const int N=1e5+10,inf=0x3f3f3f3f;
typedef long long ll;
int head[N],cnt,d[N]; // 度数d
ll ans,dis[N]; // 根结点到i的路径长度dis
struct edge
{
    int to,next;
    ll w;
}e[N<<1];
void add(int x,int y,ll z)
{
    e[cnt].to=y;
    e[cnt].w=z;
    e[cnt].next=head[x];
    head[x]=cnt++;
}
struct node
{
    ll sz,sum1,sum2;
}tr[N];
void dfs1(int u,int fa,ll di) // 预处理得到dis数组
{
    dis[u]=di;
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa)continue;
        dfs1(v,u,di+e[i].w);
    }
}
void dfs2(int u,int fa)
{
    tr[u].sz=0;
    tr[u].sum1=0;
    tr[u].sum2=0;
    if(d[u]==1&&u!=fa) // 叶结点
    {
        tr[u].sz=1;
        tr[u].sum1=dis[u];
        tr[u].sum2=dis[u]*dis[u];
        ans+=dis[u]*dis[u]; // ans加上所有叶结点到根结点的距离
        return;
    }
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa)continue;
        dfs2(v,u);
        ll p1=tr[v].sz*tr[u].sum2+tr[u].sz*tr[v].sum2;
        ll p2=(-4)*dis[u]*(tr[v].sz*tr[u].sum1+tr[u].sz*tr[v].sum1);
        ll p3=4*dis[u]*dis[u]*tr[u].sz*tr[v].sz;
        ll p4=2*tr[u].sum1*tr[v].sum1;
        ans+=p1+p2+p3+p4;
        tr[u].sum1+=tr[v].sum1;
        tr[u].sum2+=tr[v].sum2;
        tr[u].sz+=tr[v].sz;
    }
}
int main()
{
    ios::sync_with_stdio(false);
    int n,x,y,z,root;
    cin>>n;
    memset(head,-1,sizeof(head));
    for(int i=1;i<n;i++)
    {
        cin>>x>>y>>z;
        d[x]++,d[y]++; // 度数
        add(x,y,z);
        add(y,x,z);
    }
    for(int i=1;i<=n;i++)
        if(d[i]==1){root=i;break;} // 根结点度数为1
    dfs1(root,-1,0);
    dfs2(root,root);
    printf("%lld\n",ans);
    return 0;
}

附赠一个相似的题,求树中所有点(不仅仅包括叶结点)之间的距离平方和

牛客练习赛55 E题 树

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N=1e6+10,mod=998244353;
int head[N],cnt,tot,sz[N];
ll ans,dis[N]; // 根结点到i的路径长度dis
struct edge
{
    int to,next;
    ll w;
}e[N<<1];
void add(int x,int y,ll z)
{
    e[cnt].to=y;
    e[cnt].w=z;
    e[cnt].next=head[x];
    head[x]=cnt++;
}
struct node
{
    ll sz,sum1,sum2;
}tr[N];
void dfs1(int u,int fa,ll di) // 预处理得到dis数组
{
    dis[u]=di;
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa)continue;
        dfs1(v,u,di+e[i].w);
    }
}
void dfs2(int u,int fa)
{
    tr[u].sz=1;
    tr[u].sum1=dis[u]%mod;
    tr[u].sum2=dis[u]*dis[u]%mod;
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa)continue;
        dfs2(v,u);
        ll p1=(tr[v].sz*tr[u].sum2%mod+tr[u].sz*tr[v].sum2%mod)%mod;
        ll p2=(-4)*dis[u]*(tr[v].sz*tr[u].sum1%mod+tr[u].sz*tr[v].sum1%mod)%mod;
        ll p3=4*dis[u]*dis[u]%mod*tr[u].sz%mod*tr[v].sz%mod;
        ll p4=2*tr[u].sum1*tr[v].sum1%mod;
        ans=(ans+p1+p2+p3+p4+mod)%mod;
        tr[u].sum1=(tr[u].sum1+tr[v].sum1)%mod;
        tr[u].sum2=(tr[u].sum2+tr[v].sum2)%mod;
        tr[u].sz=tr[u].sz+tr[v].sz;
    }
}
int main()
{
    ios::sync_with_stdio(false);
    int n,x,y;
    cin>>n;
    memset(head,-1,sizeof(head));
    for(int i=1;i<n;i++)
    {
        cin>>x>>y;
        add(x,y,1);
        add(y,x,1);
    }
    dfs1(1,-1,0);
    dfs2(1,1);
    printf("%lld\n",ans*2%mod); // 一定要记得%mod
    return 0;
}