题目大意
有$m$条路径在一棵有$n$个节点的树上,问每个点恰为多少条路径起点出发$w_i$长度处
数据范围
$n,m\leqslant 3\times 10^5$
树退化成链的题解
由于链上的点依次是$1,2,3\dots n$,而路径有从编号小的点走到编号大的点,或者从编号大的点走到编号小的点
我们将这两种路径分开考虑,先考虑从编号小的点走到编号大的点$(u,v)(u \leqslant v)$
对于$(u,v)$,如果路径上的点$i(u\leqslant i\leqslant v)$要满足条件的话,必然是满足$w_i=i-u$,移项得$u=w_i-i$
对于一个确定的$i$,式子右边是一个常数
也就是说,路径上的点$ans_i+1$的充分必要条件是满足$w_i-i$等于一个定值$u$
现在问题转化成了支持两种操作:
在一段路径上的每个点上插入一个数
查询点$i$上有多少个数是$u$
如果只是在链上的做法有很多,可以打个差分标记或者暴力排序一下每条路径处理一下即可
代码
#include<cstdio>
#include<cmath>
#include<cstring>
#include<cctype>
#include<algorithm>
#define ll long long
#define For(i,l,r) for(int i=l;i<=r;++i)
#define Ford(i,r,l) for(int i=r;i>=l;--i)
inline int read(){
int x=0;char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) {x=x*10+ch-48;ch=getchar();}
return x;
}
using namespace std;
int n,m,top1,top2;
int a[100500],num[1000500],ans[100500];
struct data{int s,t;}q1[100500],q2[100500],q3[100500],q4[100500];
inline bool cmp1(data a,data b){return a.s<b.s;}
inline bool cmp2(data a,data b){return a.t<b.t;}
inline bool cmp3(data a,data b){return a.s>b.s;}
inline bool cmp4(data a,data b){return a.t>b.t;}
int main(){
n=read(),m=read();
For(i,1,n-1) read(),read();
For(i,1,n) a[i]=read();
For(i,1,m){
int s=read(),t=read();
if (s<=t) q1[++top1]=(data){s,t},q2[top1]=(data){s,t};
else q3[++top2]=(data){s,t},q4[top2]=(data){s,t};
}
sort(q1+1,q1+top1+1,cmp1);
sort(q2+1,q2+top1+1,cmp2);
sort(q3+1,q3+top2+1,cmp3);
sort(q4+1,q4+top2+1,cmp4);
int cnt1=0,cnt2=0;
For(i,1,n){
while (cnt1!=top1 && q1[cnt1+1].s==i) cnt1++,num[q1[cnt1].s]++;
ans[i]+=num[i-a[i]];
while (cnt2!=top1 && q2[cnt2+1].t==i) cnt2++,num[q2[cnt2].s]--;
}
int cnt3=0,cnt4=0;
Ford(i,n,1){
while (cnt3!=top2 && q3[cnt3+1].s==i) cnt3++,num[q3[cnt3].s]++;
ans[i]+=num[i+a[i]];
while (cnt4!=top2 && q4[cnt4+1].t==i) cnt4++,num[q4[cnt4].s]--;
}
For(i,1,n) printf("%d ",ans[i]);
}
参考:https://blog.sengxian.com/solutions/noip-2016-day1
树链剖分正解
为了利用在链上得到的结论,其实可以树链剖分试一下
其实部分分的数据已经给了提示,把一条路径在$LCA$处拆成两条路径,深度都是单调递增
处理方法和链一样的,为了优化效率还是需要打上差分标记
在每条路径的起点和终点打上一个数的差分标记(增加或者删除),表示在这段路径上的每个点上插入一个数
最后从根节点开始扫一遍,扫到每个点处理一下差分标记,再往下扫就好啦
注意,扫到每个点先处理增加一个数的标记,统计答案后,再处理删除一个数的标记
代码
#include<cstdio>
#include<cmath>
#include<cstring>
#include<cctype>
#include<algorithm>
#define N 300500
#define For(i,l,r) for(int i=l;i<=r;++i)
#define goedge(i,x) for(int i=last[x];i;i=e[i].next)
inline int read(){
int x=0;char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) {x=x*10+ch-48;ch=getchar();}
return x;
}
using namespace std;
int n,m,et,cnt,tot;
struct edge{int to,next;}e[N*2];
int last[N],deep[N],f[N],top[N],size[N],w[N];
int pos[N],real[N];
int tag[N*40][3],next[N*40],head[N];
int sum1[N*2],sum2[N*2],ans[N];
inline void addedge(int u,int v){
e[++et]=(edge){v,last[u]};last[u]=et;
e[++et]=(edge){u,last[v]};last[v]=et;
}
inline int lca(int u,int v){
while (1){
if (top[u]==top[v]) return deep[u]<deep[v]?u:v;
else if (deep[top[u]]>deep[top[v]]) u=f[top[u]];
else v=f[top[v]];
}
}
void dfs(int x){
size[x]=1;
goedge(i,x){
if (e[i].to==f[x]) continue;
deep[e[i].to]=deep[x]+1;
f[e[i].to]=x;
dfs(e[i].to);
size[x]+=size[e[i].to];
}
}
void gochain(int x,int chain){
pos[x]=++cnt;real[cnt]=x;top[x]=chain;
int k=0;
goedge(i,x)
if (e[i].to!=f[x]&&size[e[i].to]>size[k]) k=e[i].to;
if (k) gochain(k,chain);
goedge(i,x)
if (e[i].to!=f[x]&&e[i].to!=k) gochain(e[i].to,e[i].to);
}
inline void add(int data,int xpos,int c,int flag){
tag[++tot][0]=data;
tag[tot][1]=c;
tag[tot][2]=flag;
next[tot]=head[xpos];
head[xpos]=tot;
}
int main(){
n=read(),m=read();
For(i,1,n-1) addedge(read(),read());
For(i,1,n) w[i]=read();
dfs(1);
gochain(1,1);
For(i,1,m){
int x=read(),y=read(),h=lca(x,y);
int len=deep[x]+deep[y]-2*deep[h];
int sum=0;
while (top[x]!=top[h]){
add(1,pos[top[x]],sum+deep[x],1);
add(1,pos[x]+1,sum+deep[x],-1);
sum+=deep[x]-deep[f[top[x]]];
x=f[top[x]];
}
add(1,pos[h],sum+deep[x],1);
add(1,pos[x]+1,sum+deep[x],-1);
while (top[y]!=top[h]){
add(2,pos[top[y]],(len-deep[y])+n+1,1);
add(2,pos[y]+1,(len-deep[y])+n+1,-1);
len-=deep[y]-deep[f[top[y]]];
y=f[top[y]];
}
if (h!=y){
add(2,pos[h]+1,(len-deep[y])+n+1,1);
add(2,pos[y]+1,(len-deep[y])+n+1,-1);
}
}
For(i,1,n){
for(int j=head[i];j;j=next[j]){
if (tag[j][0]==1) sum1[tag[j][1]]+=tag[j][2];
else sum2[tag[j][1]]+=tag[j][2];
}
ans[real[i]]=sum1[w[real[i]]+deep[real[i]]]+sum2[(w[real[i]]-deep[real[i]])+n+1];
}
For(i,1,n) printf("%d ",ans[i]);
return 0;
}
线段树合并题解
其实可以不用树链剖分,用别的数据结构直接在树上维护
还是要把一条路径拆成两条路径,选择用倍增找$LCA$
考虑线段树合并,动态开点处理每个点上新增加的差分标记
统计的时候把子树中所有的线段树合并起来,记录答案
代码
#include<cstdio>
#include<cmath>
#include<cstring>
#include<cctype>
#include<algorithm>
#define N 300500
#define For(i,l,r) for(int i=l;i<=r;++i)
#define Ford(i,r,l) for(int i=r;i>=l;--i)
#define goedge(i,x) for(int i=last[x];i;i=e[i].next)
inline int read(){
int x=0;char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) {x=x*10+ch-48;ch=getchar();}
return x;
}
using namespace std;
int n,m,et,tot,cnt;
struct edge{int to,next;}e[N*2];
int last[N],deep[N],f[N][25],w[N],ans[N];
int tag[N*4][3],next[N*4],head[N];
struct tree{int l,r,c;}t[N*20];
int root1[N],root2[N];
inline void addedge(int u,int v){
e[++et]=(edge){v,last[u]};last[u]=et;
e[++et]=(edge){u,last[v]};last[v]=et;
}
inline int lca(int u,int v){
if (deep[u]<deep[v]) swap(u,v);
Ford(i,20,0)
if (deep[f[u][i]]>=deep[v]) u=f[u][i];
if (u==v) return u;
Ford(i,20,0)
if (f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
return f[u][0];
}
inline int getson(int x,int y){
if (deep[x]<deep[y]) swap(x,y);
Ford(i,20,0)
if (deep[f[x][i]]>deep[y]) x=f[x][i];
return x;
}
void dfs1(int x){
deep[x]=deep[f[x][0]]+1;
For(i,1,20) f[x][i]=f[f[x][i-1]][i-1];
goedge(i,x){
if (e[i].to==f[x][0]) continue;
f[e[i].to][0]=x;
dfs1(e[i].to);
}
}
void update(int data,int x,int deep,int flag){
tag[++tot][0]=data;
tag[tot][1]=deep;
tag[tot][2]=flag;
next[tot]=head[x];
head[x]=tot;
}
int merge(int x,int y){
if (!x || !y) return x+y;
t[x].c+=t[y].c;
t[x].l=merge(t[x].l,t[y].l);
t[x].r=merge(t[x].r,t[y].r);
return x;
}
void inc(int &p,int l,int r,int x,int c){
if (!p) p=++cnt;
t[p].c+=c;
if (l==r) return;
int mid=(l+r)>>1;
if (x<=mid) inc(t[p].l,l,mid,x,c);
else inc(t[p].r,mid+1,r,x,c);
}
int query(int p,int l,int r,int x){
if (l==r || !p) return t[p].c;
int mid=(l+r)>>1;
if (x<=mid) return query(t[p].l,l,mid,x);
else return query(t[p].r,mid+1,r,x);
}
void dfs2(int x){
goedge(i,x){
if (e[i].to==f[x][0]) continue;
dfs2(e[i].to);
root1[x]=merge(root1[x],root1[e[i].to]);
root2[x]=merge(root2[x],root2[e[i].to]);
}
for(int i=head[x];i;i=next[i])
if (tag[i][2]==1){
if (tag[i][0]==1) inc(root1[x],1,n*3,tag[i][1],1);
else inc(root2[x],1,n*3,tag[i][1],1);
}
ans[x]=query(root1[x],1,n*3,deep[x]+w[x]+n);
ans[x]+=query(root2[x],1,n*3,deep[x]-w[x]+n);
for(int i=head[x];i;i=next[i])
if (tag[i][2]==-1){
if (tag[i][0]==1) inc(root1[x],1,n*3,tag[i][1],-1);
else inc(root2[x],1,n*3,tag[i][1],-1);
}
}
int main(){
n=read(),m=read();
For(i,1,n-1) addedge(read(),read());
dfs1(1);
For(i,1,n) w[i]=read();
For(i,1,m){
int x=read(),y=read(),h=lca(x,y);
int len=deep[x]+deep[y]-2*deep[h];
update(1,x,deep[x]+n,1);
update(1,h,deep[x]+n,-1);
if (y!=h){
int hsy=getson(h,y);
update(2,y,deep[y]-len+n,1);
update(2,hsy,deep[y]-len+n,-1);
}
}
dfs2(1);
For(i,1,n) printf("%d ",ans[i]);
return 0;
}
参考:http://blog.csdn.net/qq_33229466/article/details/53426943