题面
有一个 (X)、(Y) 轴坐标范围为 (1sim n) 的范围的方阵,每个点上有块黄金。一阵风来 ((x,y)) 上的黄金到了 ((f(x),f(y))),(f(x)) 为 (x) 各位上数字的乘积,如果黄金飘出方阵就没了。求在 (k) 个格子上采集黄金最多可以采集的黄金数。
数据范围:(1le nle 10^{12}),(kle min(n^2,10^5))。
蒟蒻语
蒟蒻跟着 (it srf) 大师的日报来做这题,然后发现自己的裸代码跑得比题解都快,方法也比较神奇,于是来跟巨佬们讲讲。
蒟蒻解
首先众所周知,对于 (1le ile 10^{12}),(f(i)) 只有 (8282) 种,所以可以先找出这 (8282) 种 (f(i)),蒟蒻有三种方法:set、枚举质因数和数位 dp。
为了优化可以用数位 dp(这是第一次数位 dp),正好求出 (1sim n) 的所有 (f(i) ot=0):
int dn,d[13],cnt=0;
ll t[N]; unordered_map<ll,int> nt;
bool vis[13][N];
void init(int w,ll now,bool ava){ //w:位 now:乘积 ava=true:可以自由选择数字
if(!~w){if(!nt.count(now)) nt[t[cnt]=now]=cnt,cnt++;return;}
if(vis[w][nt[now]]) return;
vis[w][nt[now]]=true;
int up=ava?9:d[w];
for(int i=1;i<=up;i++) init(w-1,now*i,true||i<up);
}
void INIT(){
while(n) d[dn++]=n%10,n/=10;
for(int i=0;i<=dn;i++) init(i-1,1,i<dn);
}
然后是最重要的部分:求每种 (f(i)) 有多少个 (i)。
蒟蒻原来的做法是对每种 (f(i)) 来一次数位 dp,瞬间被 TLE 打脸,于是蒟蒻想出了一个不同于别的巨佬的做法:
记 (f_{w,now}) 表示到第 (w) 位,剩下 (w) 位乘积为 (now) 的方案数。
用记忆化搜索转移,dp 中会用到除法。
ll f[13][N];
ll dp(int w,int now,bool ava){ //w:位 now:剩下期望乘积 ava=true:可以自由选择数字
if(!~w) return t[now]==1;
if(ava&&~f[w][now]) return f[w][now];
int up=ava?9:d[w]; ll res=0;
for(int i=1;i<=up;i++)if(t[now]%i==0)
res+=dp(w-1,nt[t[now]/i],ava||i<up);
if(ava) f[w][now]=res;
return res;
}
void DP(){
memset(f,-1,sizeof f);
for(int i=0;i<cnt;i++)
for(int j=1;j<=dn;j++) a[i]+=dp(j-1,i,j<dn); // ai 是每种 fi 出现次数
}
然后给 (a_i) 排个序,用堆维护找最大 (k) 乘积即可,考虑到相乘可能会爆 long long,蒟蒻用了除法。
时间复杂度 (Theta(8282*12*10))。
代码
很明显蒟蒻的题解都是废话,只好放代码了。
#include <bits/stdc++.h>
using namespace std;
//Start
typedef long long ll;
typedef double db;
#define mp(a,b) make_pair(a,b)
#define x first
#define y second
#define be(a) a.begin()
#define en(a) a.end()
#define sz(a) int((a).size())
#define pb(a) push_back(a)
const int inf=0x3f3f3f3f;
const ll INF=0x3f3f3f3f3f3f3f3f;
//Data
const int N=8282;
const int mod=1e9+7;
ll n,a[N];
int k,nex[N],ans;
struct cmp{
bool operator()(pair<int,int> p,pair<int,int> q){
return 1.*a[p.x]/a[q.y]<1.*a[q.x]/a[p.y];
}
};
priority_queue<pair<int,int>,vector<pair<int,int>>,cmp> q;
//Digitdp
int dn,d[13],cnt=0;
ll t[N]; unordered_map<ll,int> nt;
bool vis[13][N];
void init(int w,ll now,bool ava){
if(!~w){if(!nt.count(now)) nt[t[cnt]=now]=cnt,cnt++;return;}
if(vis[w][nt[now]]) return;
vis[w][nt[now]]=true;
int up=ava?9:d[w];
for(int i=1;i<=up;i++) init(w-1,now*i,true||i<up);
}
void INIT(){
while(n) d[dn++]=n%10,n/=10;
for(int i=0;i<=dn;i++) init(i-1,1,i<dn);
}
ll f[13][N];
ll dp(int w,int now,bool ava){
if(!~w) return t[now]==1;
if(ava&&~f[w][now]) return f[w][now];
int up=ava?9:d[w]; ll res=0;
for(int i=1;i<=up;i++)if(t[now]%i==0)
res+=dp(w-1,nt[t[now]/i],ava||i<up);
if(ava) f[w][now]=res;
return res;
}
void DP(){
memset(f,-1,sizeof f);
for(int i=0;i<cnt;i++)
for(int j=1;j<=dn;j++) a[i]+=dp(j-1,i,j<dn);
}
//Main
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n>>k,INIT(),DP();
sort(a+0,a+cnt,[&](ll p,ll q){return p>q;});
for(int i=0;i<cnt;i++) nex[i]=0,q.push(mp(i,nex[i]++));
while(k--&&sz(q)){
pair<int,int> u=q.top(); q.pop();
(ans+=(a[u.x]%mod)*(a[u.y]%mod)%mod)%=mod;
if(nex[u.x]<cnt) u.y=nex[u.x]++,q.push(u);
}
cout<<ans<<'
';
return 0;
}
祝大家学习愉快!