http://acm.hdu.edu.cn/showproblem.php?pid=4812
题意:有一棵树,每个点有一个权值要求找最小的一对点,路径上的乘积mod1e6+3为k
题解:点分治,挨个把子树更新,每次把子树和现有的map里找满足条件的点对,然后更新子树到map里,map维护的是每个到根的乘积值最小的点.
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast,no-stack-protector")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
//#pragma GCC optimize("unroll-loops")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define pi acos(-1.0)
#define ll long long
#define vi vector<int>
#define mod 1000003
#define C 0.5772156649
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1
#define pil pair<int,ll>
#define pli pair<ll,int>
#define pii pair<int,int>
#define cd complex<double>
#define ull unsigned long long
#define base 1000000000000000000
#define fio ios::sync_with_stdio(false);cin.tie(0)
using namespace std;
const double eps=1e-6;
const int N=100000+10,maxn=1000000+10,inf=0x3f3f3f3f,INF=0x3f3f3f3f3f3f3f3f;
struct edge{
int to,Next;
}e[N*2];
int cnt,head[N];
void add(int u,int v)
{
e[cnt].to=v;
e[cnt].Next=head[u];
head[u]=cnt++;
}
int n,k,inv[maxn];
bool vis[N];
int sz[N],zx[N],a[N];
map<int,int>ma;
int x,y;
vector<pair<ll,int> >te;
void init()
{
inv[1]=1;
for(int i=2;i<maxn;i++)
inv[i]=(mod-mod/i)*1ll*inv[mod%i]%mod;
}
void dfssz(int u,int f)
{
sz[u]=1;
for(int i=head[u];~i;i=e[i].Next)
{
int x=e[i].to;
if(x!=f && !vis[x])
{
dfssz(x,u);
sz[u]+=sz[x];
}
}
}
void dfszx(int u,int f,int root,int &ans)
{
zx[u]=sz[root]-sz[u];
for(int i=head[u];~i;i=e[i].Next)
{
int x=e[i].to;
if(x!=f && !vis[x])
{
zx[u]=max(zx[u],sz[x]);
dfszx(x,u,root,ans);
}
}
if(zx[u]<zx[ans])ans=u;
}
int findzx(int root)
{
dfssz(root,-1);
int ans=root;
zx[ans]=inf;
dfszx(root,-1,root,ans);
return ans;
}
void dfs1(int u,int f,int root,ll ans)
{
if(a[root]*ans%mod==k)
{
int xx=root,yy=u;
if(xx>yy)swap(xx,yy);
if(xx<x || (xx==x&&yy<y))x=xx,y=yy;
}
ll te=1ll*k*inv[ans*a[root]%mod]%mod;
if(ma.find(te)!=ma.end())
{
int xx=ma[te],yy=u;
if(xx>yy)swap(xx,yy);
if(xx<x || (xx==x&&yy<y))x=xx,y=yy;
}
for(int i=head[u];~i;i=e[i].Next)
{
int x=e[i].to;
if(x!=f && !vis[x])dfs1(x,u,root,ans*a[x]%mod);
}
}
void dfs2(int u,int f,ll ans)
{
if(ma.find(ans)==ma.end())ma[ans]=u;
else if(u<ma[ans])ma[ans]=u;
for(int i=head[u];~i;i=e[i].Next)
{
int x=e[i].to;
if(x!=f && !vis[x])dfs2(x,u,ans*a[x]%mod);
}
}
void solve(int root)
{
int zx=findzx(root);
ma.clear();
vis[zx]=1;
for(int i=head[zx];~i;i=e[i].Next)
{
int x=e[i].to;
if(!vis[x])
{
te.clear();
dfs1(x,zx,zx,a[x]);
dfs2(x,zx,a[x]);
}
}
for(int i=head[zx];~i;i=e[i].Next)
if(!vis[e[i].to])
solve(e[i].to);
}
int main()
{
init();
while(~scanf("%d%d",&n,&k))
{
cnt=0;
memset(head,-1,sizeof head);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
vis[i]=0;
}
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b);add(b,a);
}
x=n+1,y=n+1;
solve(1);
if(x==n+1)puts("No solution");
else printf("%d %d
",x,y);
}
return 0;
}
/********************
5 5
2 5 2 3 3
1 2
1 3
2 4
2 5
********************/