There is a 01 matrix M of n × n.
The definition of cost(i,j) is: how many elements need to be changed at least to change all the i rows and j columns into 1.
pain(M) of the definition matrix is:
pain(M)=(∑i=1n∑j=1n(cost(i,j))2)mod(109+7)
It is required to find the pain value of the initial matrix and the pain value after each modification.
Input
In the first line, there are three positive integers n,k,q (2 ≤ n ≤ 2 ⋅ 105, 1 ≤ K ≤ min(n2,2 ⋅ 105), 0 ≤ Q ≤ 2 ⋅ 105). K means that there are k ones in this matrix. Q indicates the number of modification operations.
Next k rows, each row has two positive integers xi,yi (1 ≤ xi,yi ≤ n), indicating that there is a 1 in the xi row, yi column. Ensure that all (xi,yi) are different.
Next, line q, each line has two positive integers ui,vi (1 ≤ ui,vi ≤ n), indicating modifying line ui, column vi. If the location was originally 0, it will be changed to 1; if the location was originally 1, it will be changed to 0.
Output
Output the q+1 line, which is the pain value before all modifications and after each modification.
Examples
Input
3 4 9 1 1 1 2 2 3 3 1 3 3 1 2 1 3 2 2 2 2 2 1 3 1 1 1 2 3
Output
73 48 75 52 29 52 33 52 77 104
#include<bits/stdc++.h> using namespace std; typedef long long LL; typedef unsigned long long ULL; #define rep(i,a,b) for(int i=a;i<b;++i) #define per(i,a,b) for(int i=b-1;i>=a;--i) #define lowbit(x) (x&(-x)) const int mod=1e9+7; const int N=2e5+10; LL tr_r2[N],tr_c2[N],tr_r[N],tr_c[N]; LL n; void update(LL tr[],int x,LL val) { val%=mod; while(x<=n) { tr[x]=(tr[x]+val)%mod; x+=lowbit(x); } } LL query(LL tr[],LL x) { LL res=0; while(x>0) { res=(res+tr[x])%mod; if(res<0)res+=mod; x-=lowbit(x); } return res; } LL r[N],c[N]; set<int> st[N]; LL s; LL solve(LL n) { LL ans1=query(tr_r2,n),ans2=query(tr_c2,n); LL ans3=query(tr_r,n), ans4=query(tr_c,n); //printf("ans1:%lld ans2:%lld ans3:%lld ans4:%lld\n",ans1,ans2,ans3,ans4); ans1=ans1*(n-2)%mod; ans2=ans2*(n-2)%mod; ans3=ans3*ans4%mod; ans3=ans3*2%mod; ans1=(((ans1+ans2)%mod+s)%mod+ans3)%mod; return ans1; } void change(int x,int y,LL v) { update(tr_r2,x,-r[x]*r[x]); update(tr_r2,x,(r[x]+v)*(r[x]+v)); update(tr_c2,y,-c[y]*c[y]); update(tr_c2,y,(c[y]+v)*(c[y]+v)); update(tr_r,x,-r[x]); update(tr_r,x,r[x]+v); update(tr_c,y,-c[y]); update(tr_c,y,c[y]+v); r[x]=r[x]+v; //if(r[x]>=mod)r[x]-=mod; if(r[x]<=-mod)r[x]+=mod; c[y]=c[y]+v; //if(c[y]>=mod)c[y]-=mod; if(c[y]<=-mod)c[y]+=mod; // printf("y:%d c[y]:%lld\n\n",y,c[y]); s=s+v; s%=mod;//if(s>=mod)s-=mod; if(s<=-mod)s+=mod; } /* 3 4 9 1 1 1 2 2 3 3 1 3 3 1 2 1 3 2 2 2 2 2 1 3 1 1 1 2 3 */ int main() { LL K,Q; scanf("%lld %lld %lld",&n,&K,&Q); s=n*n%mod; for(int i=1; i<=n; i++) { update(tr_r2,i,n*n); update(tr_c2,i,n*n); update(tr_r,i,n); update(tr_c,i,n); c[i]=n;r[i]=n; } rep(i,0,K) { int x,y; scanf("%d %d",&x,&y); st[x].insert(y); change(x,y,-1); } // printf("s:%lld\n",s); //rep(i,1,n+1)printf("i:%d %lld %lld\n",i,r[i],c[i]); LL ans=solve(n); printf("%lld\n",ans); rep(i,0,Q) { int x,y; scanf("%d %d",&x,&y); if(st[x].count(y)){ change(x,y,1); st[x].erase(y); }else{ change(x,y,-1); st[x].insert(y); } ans=solve(n); printf("%lld\n",ans); } return 0; }