莫队算法

自从上次邀请赛之后说要学习莫队算法,一直拖到现在,💊总的来说,莫队算法是一种离线分块的算法,将总区间分成若干个块($\sqrt{n}$),然后对每一块更新查询$ans$的值。时间复杂度$O\ (n\sqrt{n})$

这里强烈安利一篇blog

SPOJ D-query

题意:给定一个长度为n序列,再给定m次查询,每次查询区间内出现元素的种类数

思路:听说可以主席树或者离线树状数组做,莫队算法的入门题,耗时270ms。将所有查询分成$\sqrt{n}$个块,再按照块排序,然后对于每个块可以利用之前算过的答案(即数量)来更新当前的答案(数量)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include<stdio.h>
#include<algorithm>
#include<math.h>
using namespace std;
const int maxn=1e6+10;
int a[maxn],n,m,cnt[maxn],ans=0,anss[maxn],block;
struct node{
int L,R,i;
}q[maxn];
bool cmp(node x,node y){
if(x.L/block != y.L/block) return x.L/block < y.L/block;
return x.R < y.R;
}
void add(int p){
cnt[a[p]]++;
if(cnt[a[p]] == 1) ans++;
}
void del(int p){
cnt[a[p]]--;
if(cnt[a[p]] == 0) ans--;
}
int main(){
scanf("%d",&n);
block=sqrt(n);
for(int i=1; i<=n; i++) scanf("%d",&a[i]);
scanf("%d",&m);
for(int i=1; i<=m; i++){
scanf("%d %d",&q[i].L,&q[i].R);
q[i].i=i;
}
sort(q+1,q+m+1,cmp);
int cL=1,cR=0;
for(int i=1; i<=m; i++){
int L=q[i].L,R=q[i].R;
while(cL < L) del(cL++);
while(cL > L) add(--cL);
while(cR < R) add(++cR);
while(cR > R) del(cR--);
anss[q[i].i]=ans;
}
for(int i=1; i<=m; i++) printf("%d\n",anss[i]);
}

codeforces-Powerful array

题意:给定n个元素序列,m次询问,要统计区间内[l,r]元素个数乘以元素的值的和$\sum_{l}^{r}cnt_i^2*i$,对于不同元素区间只计算一次。

思路:莫队算法,只需要修改del函数和add函数就行,在每次更新cnt前面,加上ans-=(1LL*a[p]*cnt[a[p]]*cnt[a[p]]);,更新之后再加上ans+=(1LL*a[p]*cnt[a[p]]*cnt[a[p]]);,注意不能把全部变量变成long long,会超时的😑。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include<stdio.h>
#include<algorithm>
#include<math.h>
using namespace std;
typedef long long ll;
const int maxn=1e6+10;
int a[maxn],n,m,cnt[maxn],block;
ll ans=0,anss[maxn];
struct node{
int L,R,i;
}q[maxn];
bool cmp(node x,node y){
if(x.L/block != y.L/block) return x.L/block < y.L/block;
return x.R < y.R;
}
void add(ll p){
ans-=(1LL*a[p]*cnt[a[p]]*cnt[a[p]]);
cnt[a[p]]++;
ans+=(1LL*a[p]*cnt[a[p]]*cnt[a[p]]);
}
void del(ll p){
ans-=(1LL*a[p]*cnt[a[p]]*cnt[a[p]]);
cnt[a[p]]--;
ans+=(1LL*a[p]*cnt[a[p]]*cnt[a[p]]);
}
int main(){
scanf("%d %d",&n,&m);
block=sqrt(n);
for(int i=1; i<=n; i++) scanf("%d",&a[i]);
for(int i=1; i<=m; i++){
scanf("%d %d",&q[i].L,&q[i].R);
q[i].i=i;
}
sort(q+1,q+m+1,cmp);
int cL=1,cR=0;
for(int i=1; i<=m; i++){
int L=q[i].L,R=q[i].R;
while(cL < L) del(cL++);
while(cL > L) add(--cL);
while(cR < R) add(++cR);
while(cR > R) del(cR--);
anss[q[i].i]=ans;
}
for(ll i=1; i<=m; i++) printf("%I64d\n",anss[i]);
}

区间求和

题意:给定$n$个元素,$m$次询问,求区间$[l\ ,\ r]$内$\sum_{i=l}^{r}{a_i * cnt_{a_i}}$

与上题不同的是,每个元素都要算到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include<stdio.h>
#include<algorithm>
#include<math.h>
using namespace std;
typedef long long ll;
const ll maxn=1e6+10;
ll a[maxn],n,m,cnt[maxn],block;
ll ans=0,anss[maxn];
struct node{
ll L,R,i;
}q[maxn];
bool cmp(node x,node y){
if(x.L/block != y.L/block) return x.L/block < y.L/block;
return x.R < y.R;
}
void add(ll p){
ans+=(a[p]*cnt[a[p]]);
cnt[a[p]]++;
ans+=(a[p]*cnt[a[p]]);
}
void del(ll p){
ans-=(a[p]*cnt[a[p]]);
cnt[a[p]]--;
ans-=(a[p]*cnt[a[p]]);
}
int main(){
scanf("%lld %lld",&n,&m);
block=sqrt(n);
for(ll i=1; i<=n; i++) scanf("%lld",&a[i]);
for(ll i=1; i<=m; i++){
scanf("%lld %lld",&q[i].L,&q[i].R);
q[i].i=i;
}
sort(q+1,q+m+1,cmp);
ll cL=1,cR=0;
for(ll i=1; i<=m; i++){
ll L=q[i].L,R=q[i].R;
while(cL < L) del(cL++);
while(cL > L) add(--cL);
while(cR < R) add(++cR);
while(cR > R) del(cR--);
anss[q[i].i]=ans;
}
for(ll i=1; i<=m; i++) printf("%lld\n",anss[i]);
}

HDU-6534.Chika and Friendly Pairs

湘潭邀请赛C题,当时卡的题,现在补掉

题意:给定一个长度为$n$的序列,$m$次询问,一个数$k$。询问你在区间$[l\ ,\ r]$中,符合$|a_i-a_j|\leq k$并且$i < j$有多少对这样的数。

思路:莫队+树状数组+离散化,首先对于每次查询$a_i$,我只要查询区间内$[a_i-k\ ,\ a_i+k]$的种数即可,利用树状数组就行,然后由于序列中的数最大有$10^9$,而数的个数只有$27000$,就可以通过离散化将其缩小,再用莫队算法。

对于新加进来的数$a_i$,ans+=(query(up[l])-query(lowa[l])),然后再add(mida[l],1),仔细想想就会发现,如果你先写add函数,就会把本身的值也算进去。

对于去掉的数$a_i$,则需要先add(mida[l],-1),然后再更新ans,ans-=(query(up[l])-query(lowa[l])),因为query查询的时候会把mida[l]也计算进来。

离散化的模版

1
2
3
4
5
6
7
8
9
10
//n 原数组大小 num 原数组中的元素 lsh 离散化的数组 cnt 离散化后的数组大小 
int lsh[MAXN],cnt,num[MAXN],n;
for(int i=1; i<=n; i++) {
scanf("%d",&num[i]);
lsh[i]=num[i];
}
sort(lsh+1,lsh+n+1);
cnt = unique(lsh+1,lsh+n+1)-(lsh+1);
for(int i=1; i<=n; i++)
num[i]=lower_bound(lsh+1,lsh+cnt+1,num[i])-lsh;

该题ac的代码(390MS)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<bits/stdc++.h>
using namespace std;
const int maxn=27010;
int mp[maxn*3],a[maxn],upa[maxn],lowa[maxn],mida[maxn];
int tree[maxn*3],ans=0,anss[maxn];
struct node{
int l,r;
int i;
}q[maxn];
int n,m,block,k;
bool cmp(node i,node j){
if(i.l/block != j.l/block) return i.l/block < j.l/block;
return i.r < j.r;
}
int lowbit(int x){
return x&(-x);
}
void add(int x,int val){
for(int i=x; i<maxn*3; i+=lowbit(i)) tree[i]+=val;
}
int query(int x){
int ans=0;
for(int i=x; i; i-=lowbit(i)) ans+=tree[i];
return ans;
}
int main(){
scanf("%d %d %d",&n,&m,&k);
int sum=1;
block=sqrt(n);
for(int i=1; i<=n; i++){
scanf("%d",&a[i]);
mp[sum++]=a[i];
mp[sum++]=a[i]+k;
mp[sum++]=a[i]-k-1;
}
sort(mp+1,mp+sum);
int num=unique(mp+1,mp+sum)-(mp+1);
int pos;
for(int i=1; i<=n; i++){
pos=lower_bound(mp+1,mp+num+1,a[i]+k)-mp;
upa[i]=pos;
pos=lower_bound(mp+1,mp+num+1,a[i]-k-1)-mp;
lowa[i]=pos;
pos=lower_bound(mp+1,mp+num+1,a[i])-mp;
mida[i]=pos;
}
for(int i=1; i<=m; i++){
scanf("%d %d",&q[i].l,&q[i].r);
q[i].i=i;
}
sort(q+1,q+m+1,cmp);
int cL=1,cR=0;
for(int i=1; i<=m; i++){
int L=q[i].l;
int R=q[i].r;
while(cR<R){
cR++;
ans+=(query(upa[cR])-query(lowa[cR]));
add(mida[cR],1);
}
while(cR>R){
add(mida[cR],-1);
ans-=(query(upa[cR])-query(lowa[cR]));
cR--;
}
while(cL<L){
add(mida[cL],-1);
ans-=(query(upa[cL])-query(lowa[cL]));
cL++;
}
while(cL>L){
cL--;
ans+=(query(upa[cL])-query(lowa[cL]));
add(mida[cL],1);
}
anss[q[i].i]=ans;
}
for(int i=1; i<=m; i++) printf("%d\n",anss[i]);
}
thanks!