题目传送门:https://ac.nowcoder.com/acm/contest/3002/F

在这里插入图片描述在这里插入图片描述

基本思路

要求只经过一个黑点的路径数量,可以分两种情况:
1.起点和终点均是白点,其余路径中只有一个黑点。
2.起点和终点为一黑一白,其余路径中全部是白点。

把树分成若干个连通块构成,连通块有两种,一是全为白点,二是只有一黑点,其他全白点(特殊地,也可以只有一黑点,无白点)。这里的“连通块”,意义与图论中“含限制条件的最大连通分量”相同。

设dp[i][0]表示以 i 为根的子树(包括 i )全为白的连通块有多少个点。
dp[i][1]表示以 i 为根的子树(包括 i )一黑其他全白的连通块有多少个点。
或者这么认为:dp[i][0]表示i的子孙节点到i路径上无黑点的个数,dp[i][1]表示i的子孙节点到i路径上有一个黑点的个数。

我们要选出起点终点分布于两个不同种类连通块的所有情况,然后每次将答案加上这两个连通块中点的个数的乘积。

这题的关键在于怎么写递归的dfs函数,以及如何更新答案ans和dp值。

用dfs进行树的后序遍历,也就是对于每个点,先从左到右访问完它的所有子树后,再访问它自己。
先更新完所有子孙节点的dp值,再更新自己的dp值。

用u表示当前访问的节点编号,在向下递归的过程中(从根到叶子),进行dp的初始化:

if(s[u]=='W')dp[u][0]=1,dp[u][1]=0;//dp[u][0]=1,现在表示u自身为白,之后回溯再更新dp
else dp[u][0]=0,dp[u][1]=1;//dp[u][1]=1,现在表示u自身为黑,之后回溯再更新dp

在进行回溯的过程中,也就是从叶子到根,向上返回的过程中,进行ans的更新:
(u表示当前节点,v表示与u直接相连的孩子节点)

ans+=dp[u][0]*dp[v][1]+dp[u][1]*dp[v][0];

更新ans之后,进行dp的更新:(可以理解为将以v点为根的子树合并到u点上,之后u点的dp值是它和它的子树一起构成的

if(s[u]=='W')dp[u][0]+=dp[v][0],dp[u][1]+=dp[v][1];
else dp[u][0]=0,dp[u][1]+=dp[v][0];//全白加上一黑,要把全白的连通块从0开始计数,所以要把dp[u][0]变成0

完整的dfs函数代码:(用vector存图)

void dfs(int u,int fa)//u表示当前访问的节点,fa是u的父亲节点
{
    if(s[u]=='W')dp[u][0]=1,dp[u][1]=0;
    else dp[u][0]=0,dp[u][1]=1;
    for(int i=0;i<g[u].size();i++)
    {
        int v=g[u][i];//v是与u直接相连的孩子节点
        if(v==fa)continue;//防止向上递归
        dfs(v,u);
        ans+=dp[u][0]*dp[v][1]+dp[u][1]*dp[v][0];
        if(s[u]=='W')dp[u][0]+=dp[v][0],dp[u][1]+=dp[v][1];
        else dp[u][0]=0,dp[u][1]+=dp[v][0];
    }
}

AC代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10;
char s[N];
int n,x,y;
ll dp[N][2],ans;//dp[i][0]表示全白 dp[i][1]表示一黑其他全白
vector<int>g[N];
void dfs(int u,int fa)
{
    if(s[u]=='W')dp[u][0]=1,dp[u][1]=0;
    else dp[u][0]=0,dp[u][1]=1;
    //for(int i=0;i<g[u].size();i++)
    for(int v:g[u])
    {
        //int v=g[u][i];
        if(v==fa)continue;
        dfs(v,u);
        ans+=dp[u][0]*dp[v][1]+dp[u][1]*dp[v][0];
        if(s[u]=='W')dp[u][0]+=dp[v][0],dp[u][1]+=dp[v][1];
        else dp[u][0]=0,dp[u][1]+=dp[v][0];
    }
}
int main()
{
    ios::sync_with_stdio(false);
    cin>>n>>s+1;
    for(int i=1;i<=n-1;i++)
    {
        cin>>x>>y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs(1,0);
    printf("%lld\n",ans);
    return 0;
}

新收获

关于遍历vector中的所有元素,也可以这么写:

for(int v:g[u])
{
    //code...
}

等效于:

for(int i=0;i<g[u].size();i++)
{
    int v=g[u][i];
    //code...
}

(参考:C++11之for循环的新用法