题目大意
给定一棵有$n$个节点的无根树和$m$个操作,操作有2类:
将节点$a$到节点$b$路径上所有点都染成颜色$c$
询问节点$a$到节点$b$路径上的颜色段数量$($连续相同颜色被认为是同一段,如$“112221”$由$3$段组成$:“11”,“222”$和$“1”)$
数据范围
$n\leqslant 10^5,m\leqslant 10^5$,所有颜色$c$均为整数且$0\leqslant c\leqslant 10^9$
题解
树链剖分之后用线段树维护区间颜色,注意区间合并时颜色总数的变化
程序
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define For(i,l,r) for(int i=l;i<=r;i++)
#define goedge(i,x) for(int i=last[x];i;i=e[i].next)
using namespace std;
struct tree{int lc,rc,s,tag;}t[400050];
struct edge{int to,next;}e[200050];
int last[100050],size[100050],deep[100050],f[100050];
int top[100050],pos[100050],v[100050];
int n,m,et,cnt;
inline int read(){
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9') {x=x*10+ch-48;ch=getchar();}
return x*f;
}
inline void addedge(int u,int v){
e[++et]=(edge){v,last[u]};last[u]=et;
e[++et]=(edge){u,last[v]};last[v]=et;
}
inline int lca(int u,int v){
while(1){
if (top[u]==top[v]) return deep[u]<=deep[v]?u:v;
else if (deep[top[u]]>=deep[top[v]]) u=f[top[u]];
else v=f[top[v]];
}
}
inline void pushup(int p){
t[p].lc=t[p<<1].lc;t[p].rc=t[p<<1|1].rc;
if(t[p<<1].rc^t[p<<1|1].lc) t[p].s=t[p<<1].s+t[p<<1|1].s;
else t[p].s=t[p<<1].s+t[p<<1|1].s-1;
}
inline void pushdown(int p){
int tmp=t[p].tag;t[p].tag=-1;
if (tmp==-1)return;
t[p<<1].s=t[p<<1|1].s=1;
t[p<<1].tag=t[p<<1|1].tag=tmp;
t[p<<1].lc=t[p<<1].rc=tmp;
t[p<<1|1].lc=t[p<<1|1].rc=tmp;
}
void dfs(int x){
size[x]=1;
goedge(i,x)
if (e[i].to!=f[x]){
f[e[i].to]=x;
deep[e[i].to]=deep[x]+1;
dfs(e[i].to);
size[x]+=size[e[i].to];
}
}
void gochain(int x,int chain){
top[x]=chain;pos[x]=++cnt;
int k=0;
goedge(i,x)
if (e[i].to!=f[x]&&size[e[i].to]>size[k]) k=e[i].to;
if (k) gochain(k,chain);
goedge(i,x)
if (e[i].to!=f[x]&&e[i].to!=k) gochain(e[i].to,e[i].to);
}
void build(int p,int wl,int wr){
t[p].s=1;t[p].tag=-1;
if (wl==wr) return;
int mid=(wl+wr)>>1;
build(p<<1,wl,mid);build(p<<1|1,mid+1,wr);
}
void change(int p,int wl,int wr,int l,int r,int c){
if (wl^wr) pushdown(p);
if (wl==l&&wr==r) {t[p].lc=t[p].rc=t[p].tag=c;t[p].s=1;return;}
int mid=(wl+wr)>>1;
if (l<=mid) change(p<<1,wl,mid,l,min(mid,r),c);
if (r>=mid+1) change(p<<1|1,mid+1,wr,max(mid+1,l),r,c);
pushup(p);
}
int query(int p,int wl,int wr,int l,int r){
if (wl^wr) pushdown(p);
if (wl==l&&wr==r) return t[p].s;
int mid=(wl+wr)>>1,sum=0;
if (l<=mid) sum+=query(p<<1,wl,mid,l,min(mid,r));
if (r>=mid+1) sum+=query(p<<1|1,mid+1,wr,max(mid+1,l),r);
if (l<=mid&&r>=mid+1&&t[p<<1].rc==t[p<<1|1].lc) sum--;
return sum;
}
int colour(int p,int wl,int wr,int x){
if (wl^wr) pushdown(p);
else return t[p].lc;
int mid=(wl+wr)>>1;
if (x<=mid) return colour(p<<1,wl,mid,x);
else return colour(p<<1|1,mid+1,wr,x);
}
inline int getquery(int low,int high){
int sum=0;
while (top[low]!=top[high]){
sum+=query(1,1,n,pos[top[low]],pos[low]);
if (colour(1,1,n,pos[top[low]])==colour(1,1,n,pos[f[top[low]]])) sum--;
low=f[top[low]];
}
sum+=query(1,1,n,pos[high],pos[low]);
return sum;
}
inline void getchange(int low,int high,int c){
while (top[low]!=top[high]){
change(1,1,n,pos[top[low]],pos[low],c);
low=f[top[low]];
}
change(1,1,n,pos[high],pos[low],c);
}
int main(){
n=read(),m=read();
For(i,1,n) v[i]=read();
For(i,1,n-1){int u=read(),v=read();addedge(u,v);}
dfs(1);gochain(1,1);build(1,1,n);
For(i,1,n) change(1,1,n,pos[i],pos[i],v[i]);
For(i,1,m){
char ch[10];
scanf("%s",ch+1);
if (ch[1]=='Q'){
int u=read(),v=read(),h=lca(u,v);
printf("%d\n",getquery(u,h)+getquery(v,h)-1);
}
else{
int u=read(),v=read(),c=read(),h=lca(u,v);
getchange(u,h,c);getchange(v,h,c);
}
}
return 0;
}