[BZOJ3625] [Generating Function] [NTT] Children and Binary Trees

Posted by houssam_ballout on Tue, 08 Oct 2019 15:30:13 +0200

BZOJ3625

Considering the number of binary trees, a binary tree is composed of root nodes and left and right subtrees, and both left and right subtrees are binary trees.
According to the number of trees in the application of spanning functions, we can know that two trees can share a spanning function when they are of the same shape.
So we can use F(x)F(x)F(x) F (x) F (x) to represent the generating function of the binary tree, then we also need a G(x)G(x)G(x) G (x) to represent the generating function of a point.
Why do we need a generating function for only one point?
The weight of this point will affect the final answer, so G(x)G(x)G(x) G (x) G (x) is actually a generating function of point weight.
Specifically, g[i]=1g[i]=1g[i]=1 if and only if the weight iii appears in the sequence CCC
Then stitching can be expressed as F=F2G+1F=F^2G+1F=F2G+1
According to the root formula:
F=1±1−4G2GF=\frac{1\pm\sqrt{1-4G}}{2G}F=2G1±1−4G​​
F=21∓1−4GF=\frac{2}{1\mp\sqrt{1-4G}}F=1∓1−4G​2​
Given the subtracted root, it does not converge at 0.
So it's better to find the inverse of the polynomial square.

Code:

#include<bits/stdc++.h>
#define poly vector<int>
#define ll long long
#define mod 998244353
using namespace std;
inline int read(){
	int res=0,f=1;char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-') f=-f;ch=getchar();}
	while(isdigit(ch)) {res=(res<<1)+(res<<3)+(ch^48);ch=getchar();}
	return res*f;
}
inline int add(int x,int y){x+=y;if(x>=mod) x-=mod;return x;}
inline int dec(int x,int y){x-=y;if(x<0) x+=mod;return x;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline void Mul(int &x,int y){x=1ll*x*y%mod;}
inline void inc(int &x,int y){x+=y;if(x>=mod) x-=mod;}
inline int ksm(int a,int b){int res=1;for(;b;b>>=1,Mul(a,a)) if(b&1) Mul(res,a);return res;}
namespace Ntt{
	const int N=1e6+5;
	int *w[22],rev[N<<2];
	inline void init(int n){for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));}
	inline void init_w(){
		for(int i=1;i<=21;i++) w[i]=new int[1<<(i-1)];
		int wn=ksm(3,(mod-1)/(1<<21));
		w[21][0]=1;
		for(int i=1;i<(1<<(20));i++) w[21][i]=mul(w[21][i-1],wn);
		for(int i=20;i;i--)
			for(int j=0;j<(1<<(i-1));j++) w[i][j]=w[i+1][j<<1];
	}
	inline void ntt(poly &f,int n,int kd){
		for(int i=0;i<n;i++) if(i>rev[i]) swap(f[i],f[rev[i]]);
		for(int mid=1,l=1;mid<n;mid<<=1,l++){
			for(int i=0;i<n;i+=(mid<<1)){
				for(int j=0,a0,a1;j<mid;j++){
					a0=f[i+j],a1=mul(f[i+j+mid],w[l][j]);
					f[i+j]=add(a0,a1);f[i+j+mid]=dec(a0,a1);
				}
			}
		}
		if(kd==-1 && (reverse(f.begin()+1,f.begin()+n),1))
			for(int inv=ksm(n,mod-2),i=0;i<n;i++) Mul(f[i],inv);
	}
	inline poly operator -(poly a,int b){for(int i=0;i<a.size();i++) a[i]=dec(a[i],b);return a;}
	inline poly operator *(poly a,int b){for(int i=0;i<a.size();i++) Mul(a[i],b);return a;}
	inline poly operator +(poly a,int b){for(int i=0;i<a.size();i++) inc(a[i],b);return a;}
	inline poly operator *(poly a,poly b){
		int m=a.size()+b.size()-1,n=1;
		if(m<=128){
			poly c(m,0);
			for(int i=0;i<a.size();i++)
				for(int j=0;j<b.size();j++) inc(c[i+j],mul(a[i],b[j]));
			return c;	
		}
		while(n<m) n<<=1;
		init(n);
		a.resize(n);ntt(a,n,1);
		b.resize(n);ntt(b,n,1);
		for(int i=0;i<n;i++) Mul(a[i],b[i]);
		ntt(a,n,-1);a.resize(m);
		return a;
	}
}
using namespace Ntt;
int inv2;
inline poly Inv(poly a,int n){
	poly c,b(1,ksm(a[0],mod-2));
	for(int lim=4;lim<(n<<2);lim<<=1){
		init(lim);
		c=a;c.resize(lim>>1);
		c.resize(lim);ntt(c,lim,1);
		b.resize(lim);ntt(b,lim,1);
		for(int i=0;i<lim;i++) Mul(b[i],dec(2,mul(b[i],c[i])));
		ntt(b,lim,-1);b.resize(lim>>1);
	}
	b.resize(n);return b;
}
inline poly sqr(poly a,int n){
	poly b(1,1),c,d;
	for(int lim=4;lim<(n<<2);lim<<=1){
		c=a;c.resize(lim>>1);
		init(lim);d=Inv(b,lim>>1);
		c.resize(lim);ntt(c,lim,1);
		d.resize(lim);ntt(d,lim,1);
		for(int i=0;i<lim;i++) Mul(c[i],d[i]);
		ntt(c,lim,-1);b.resize(lim>>1);
		for(int i=0;i<(lim>>1);i++) b[i]=mul(inv2,add(b[i],c[i]));
	}
	b.resize(n);return b;
}
poly a,b,ans;
int main(){
	inv2=ksm(2,mod-2);init_w();int n=read(),m=read();a.resize(100001);
	for(int x,i=1;i<=n;i++) x=read(),a[x]=1;
	a=a*(-4);
	a[0]+=1;
	b=sqr(a,m+1);
	b[0]+=1;
	ans=Inv(b,m+1);
	ans=ans*2;
	for(int i=1;i<=m;i++) cout<<ans[i]<<"\n";
	return 0;
}