题意:给你m个数(m<=100),每个数的素因子仅来自于前t(t<=100)个素数,问这m个数的非空子集里,满足子集里的数的积为完全平方数的有多少个。
一开始就想进去里典型的dp世界观里,dp[n][mask]表示前n个数里为mask的有多少个,但显然这里t太大了。然后又YY了很多很多。像m少的时候应该用的是高消。即对每个因子列一个xor方程,然后高斯消元,其中自由元的个数就是可以随便取的,所以答案是2^(自由元个数),然后把空集的减掉,就是2^(自由元)-1,不过大数是必须的。
#include <iostream> #include <cstring> #include <string> #include <cstdio> #include <vector> #include <algorithm> using namespace std; #define ll long long #define maxn 110 int t,m; int b[maxn]; int p[1000+50]; int tot; int vis[1000+50]; void getPrime() { memset(vis,0,sizeof(vis)); tot=0; for(int i=2;i<=1000;++i){ if(!vis[i]) p[tot++]=i; for(int j=0;j<tot&&i*p[j]<=1000;++j){ vis[i*p[j]]=true; if(!(i%p[j])) break; } } } int a[maxn][maxn]; int gauss() { int row=t,col=m; int fix=0; int cur=0; int row_choose; for(int i=0;i<col&&fix<row;++i){ row_choose=-1; for(int j=cur;j<row;++j){ if(a[j][i]==1) row_choose=j; } if(row_choose==-1) { continue; } ++fix; swap(a[row_choose],a[cur]); for(int j=0;j<row;++j){ if(j==cur) continue; if(a[j][i]==1) { for(int k=i;k<col;++k){ a[j][k]^=a[cur][k]; } } } ++cur; } return col-fix; } const int base=10000; const int width=4; const int N=100; const int static ten[width]={1,10,100,1000}; struct bint { int ln; int v[N]; bint(int r=0){ for(ln=0;r>0;r/=base) v[ln++]=r%base; } bint & operator = (const bint &r){ memcpy(this,&r,(r.ln+1)*sizeof(int)); return *this; } }; bint operator + (const bint &a,const bint &b){ bint res;int i,cy=0; for(i=0;i<a.ln||i<b.ln||cy>0;i++){ if(i<a.ln) cy+=a.v[i]; if(i<b.ln) cy+=b.v[i]; res.v[i]=cy%base;cy/=base; } res.ln=i; return res; } bint operator- (const bint & a, const bint & b){ bint res; int i, cy = 0; for (res.ln = a.ln, i = 0; i < res.ln; i++) { res.v[i] = a.v[i] - cy; if (i < b.ln) res.v[i] -= b.v[i]; if (res.v[i] < 0) cy = 1, res.v[i] += base; else cy = 0; } while (res.ln > 0 && res.v[res.ln - 1] == 0) res.ln--; return res; } bint operator* (const bint & a, const bint & b){ bint res; res.ln = 0; if (0 == b.ln) { res.v[0] = 0; return res; } long long i, j, cy; for (i = 0; i < a.ln; i++) { for (j = cy = 0; j < b.ln || cy > 0; j++, cy /= base) { if (j < b.ln) cy += a.v[i] * b.v[j]; if (i + j < res.ln) cy += res.v[i + j]; if (i + j >= res.ln) res.v[res.ln++] = cy % base; else res.v[i + j] = cy % base; } } return res; } void write(const bint & v){ int i; printf("%d", v.ln == 0 ? 0 : v.v[v.ln - 1]); for (i = v.ln - 2; i >= 0; i--) printf("%04d", v.v[i]); // ! 4 == width // printf(" "); } int main() { getPrime(); while(~scanf("%d%d",&t,&m)){ memset(a,0,sizeof(a)); for(int i=0;i<m;++i){ scanf("%d",b+i); for(int j=0;j<t;++j){ int cnt=0; while(b[i]%p[j]==0){ b[i]/=p[j];cnt^=1; } a[j][i]=cnt; } } int res=gauss(); bint x(1); for(int i=0;i<res;++i){ x=x*2; } x=x-1; write(x);puts(""); } return 0; }