题面
给定一张有向无环图,定义重要节点x为:对于图中每一个点y,都可以从y出发到x或者从x出发到y.
次重要节点为删除某个节点后能满足上面条件的节点(不包括重要节点),求重要节点和次重要节点一共有多少个
(2<=n, m<= 300000)
题解
我们考虑这样两个性质:
- 当我们在拓扑排序的时候,同一时刻出现在队列中的点互不可达。
- 除去已经排序完的点和当前在队列的点,剩下的点全都是直接or间接由当前队列中的点拓展出来的。
我们记f[i]为点i可以到达or可以到达点i的点数之和。那么题目要求的点就是满足f[i] >= n - 2的点。
当队列里点数大于等于3的时候,由于互不可达,因此他们必然不是我们要求的点
当队列里点数只有1的时候,由性质2可以得出,剩下的点都可以加入f[x]
当队列里点数为2时,假设点为x和y,那么如果后续有一个点只能被x到达,那么它显然不能对y产生贡献。打个标记就行
一个细节:
如果x在y前面,那么我们应该用y来判x,因为当前在更新x的f值.而当x被更新完后,只能被x到达的点已经加入队列了。
考虑如何用y来判x:
我们定义内部点为y可达的所有点,内部边为两个端点都是内部点的边
如果y的某个邻居入度不是1,那么说明x有可能直接or间接指向它。
如果x不能指向它,说明这个邻居是被内部点指向了,那么我们沿着内部边反向走,一定可以走到一个(也可能是多个)点,满足这个点没有内部点连向它,不然的话就出现了环,不符合题意。那么对于这样的点,如果x不能连向它,说明它只有一个入度且入度为y(不然的话xy都不连向它,内部点也不连向它,它早入度为0进队了。),那么y会根据这个点,给x打上标记。
#include<bits/stdc++.h>
using namespace std;
#define R register int
#define AC 301000
#define ac 601000
int n, m, ans;
int f[AC], in[AC];
int Head[AC], Next[ac], date[ac], tot;
int q[AC], head, tail;
bool z[AC];
struct node{
int f, w;
}way[AC];
inline int read()
{
int x = 0;char c = getchar();
while(c > '9' || c < '0') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x;
}
inline void add(int f, int w){
date[++ tot] = w, Next[tot] = Head[f], Head[f] = tot, in[w] ++;
}
void pre()
{
n = read(), m = read();
for(R i = 1; i <= m; i ++)
{
way[i].f = read(), way[i].w = read();
add(way[i].f, way[i].w);
}
}
void get(int cnt)//tmp == 2时,前者若不合法,在get中被筛出
{
int x = q[head - 1], y = q[head];
bool flag = 0;
for(R i = Head[y]; i; i = Next[i])
if(in[date[i]] == 1) {flag = 1; break;}
if(flag) z[x] = 1;
else f[x] += n - cnt;
}
void t_sort()
{
int cnt = 0;
head = 1, tail = 0;
for(R i = 1; i <= n; i ++)
if(!in[i]) q[++ tail] = i, ++ cnt;
while(head <= tail)
{
int x = q[head ++];
int tmp = tail - head + 1 + 1;
if(tmp == 1) f[x] += n - cnt;
if(tmp == 2) get(cnt);
for(R i = Head[x]; i; i = Next[i])
{
int now = date[i];
if(!(-- in[now])) q[++ tail] = now, ++ cnt;//tmp == 2时,后者若不合法,在这里保证他的f值正确
}
}
}
void work()
{
t_sort();
tot = 0, memset(Head, 0, (n + 2) * 4), memset(in, 0, (n + 2) * 4);
for(R i = 1; i <= m; i ++) add(way[i].w, way[i].f);
t_sort();
for(R i = 1; i <= n; i ++)
if(!z[i] && f[i] >= n - 2) ++ ans;
printf("%d
", ans);
}
int main()
{
// freopen("in.in", "r", stdin);
pre();
work();
return 0;
}