考察回溯法的题目。
难点在于如何枚举天平结构的各种情况。
思路1:自底向上,用类似二叉树的结构储存(类似霍夫曼树,挂坠全部在叶节点),每次选择两个节点组成一个子树同时算出子树的左右臂长度,递归建树。但是这样会有较多重复的情况。
思路2:自顶向下,把集合分为左右子集(分别为左右子树所含的挂坠集合),在递归调用左右子集。枚举子集的思路用的是二进制枚举集合的思路,每个二进制数分别对应挂坠集合能组成的所有天平的左右臂长度,用vector<Node> node[MAXN]储存,[]内是二进制数。还用到了二进制&,^运算来处理集合间的关系。
思路2的代码。
#define _CRT_SECURE_NO_WARNINGS #include<iostream> #include<algorithm> #include<string> #include<sstream> #include<set> #include<vector> #include<stack> #include<map> #include<queue> #include<cassert> #include<cstdlib> #include<cstdio> #include<ctime> #include<cmath> #include<cstring> #include<functional> using namespace std; #define INF 0x3f3f3f3f #define max(a,b) ((a)>(b)?(a):(b)) #define min(a,b) ((a)<(b)?(a):(b)) using namespace std; const int N = 6; const int MAXN = (1 << N); int t, n, i, j, vis[MAXN]; double w[N], sumw[MAXN], r; struct Node { double l, r; Node() {} Node(double ll, double rr) { l = ll; r = rr; } }; vector<Node> node[MAXN]; int bitcount(int x) { //计算二进制中1的个数 if (x == 0) return 0; return bitcount(x / 2) + (x & 1); } void dfs(int s) { if (vis[s]) return;//添加了记忆数组,如果状态s已经被搜索过,直接返回 vis[s] = 1; if (bitcount(s) == 1) { //当只有一个1时,说明是叶子,天平的两臂都是0 node[s].push_back(Node(0, 0)); return; } for (int l = (s - 1)&s; l > 0; l = (l - 1)&s) { //枚举左右子集情况,(此处利用二进制枚举左右子集的方法值得学习) int r = s^l; dfs(l); dfs(r); for (int i = 0; i < node[l].size(); i++) { for (int j = 0; j < node[r].size(); j++) { double ll = min(-sumw[r] / (sumw[l] + sumw[r]) + node[l][i].l, sumw[l] / (sumw[l] + sumw[r]) + node[r][j].l);//比较 左臂+左子天平的左臂 与 右子天平的左臂-右臂 谁更小 double rr = max(sumw[l] / (sumw[l] + sumw[r]) + node[r][j].r, -sumw[r] / (sumw[l] + sumw[r]) + node[l][i].r);//比较 右臂+右子天平的右臂 与 左子天平的右臂-左臂 谁更大 node[s].push_back(Node(ll, rr));//将得到的该根节点的左右臂长度放入数组 } } } } void solve() { double ans = -1; int s = (1 << n) - 1; dfs(s); for (int i = 0; i < node[s].size(); i++) { if (node[s][i].r - node[s][i].l < r) {//s结点是根结点,存有所有二叉树的左右臂的长度,选出差值<r的最大值即可 if (node[s][i].r - node[s][i].l > ans) ans = node[s][i].r - node[s][i].l; } } if (ans == -1) printf("-1 "); else printf("%.10lf ", ans); } int main() { scanf("%d", &t); while (t--) { memset(vis, 0, sizeof(vis)); memset(node, 0, sizeof(node)); scanf("%lf%d", &r, &n); for (i = 0; i < n; i++) scanf("%lf", &w[i]); for (i = 0; i < (1 << n); i++) { sumw[i] = 0; for (j = 0; j < n; j++) { if (i&(1 << j)) sumw[i] += w[j]; } } solve(); } return 0; }