题意
给你一份代码,仅包含 loop
op
break
continue
四个命令,分别表示循环若干次,执行若干单位操作,跳出此层循环,跳过此层循环剩余部分。要你算出此代码的时间复杂度。
分析
我们可以对每一个 loop
和 op
操作建一个点,点权即为该操作后面的 x
值,再按照他们的包含关系建一棵树。如果遇到 break
,将此层 loop
对应的点的点权设置为 (1),再略过此层剩余部分的代码;如果遇到 continue
,将此层剩余部分的代码略过即可。
注意点上还有保存一个标记,表示该点是否存在实际贡献(就是该点对应的操作是不是 op
),因为 loop
操作本身显然并不对答案存在贡献。
代码
#include <cstdio>
#include <iostream>
#include <string>
#include <vector>
const int maxDep = 25;
class Node {
public:
Node(Node* fa, int value, bool isLeaf) : father(fa), val(value), leaf(isLeaf) { return; }
int val; //权值
bool leaf; //是否为叶子(是否需要算入复杂度)
std::vector<Node> child; //儿子节点
Node* father; //父亲指针
};
int stringToInt(const std::string& str) {
int num = 0;
for (std::string::const_iterator i = str.begin(); i != str.end(); i++) num = (num << 3) + (num << 1) + (*i ^ 48);
return num;
}
void readUntilEnd(void) {
std::string str;
int cnt = 1;
while (cnt) {
std::cin >> str;
if (str == "loop") cnt++;
if (str == "end") cnt--;
}
return;
}
Node root(NULL, 0, false); //根节点
int answer[maxDep];
void dfs(const Node& p, int x, int y) {
if (p.leaf) {
answer[x] += y; //统计答案
return;
}
for (std::vector<Node>::const_iterator i = p.child.begin(); i != p.child.end(); i++)
(i->val == -1) ? dfs(*i, x + 1, y) : dfs(*i, x, y * i->val); //根据此子节点权值更改参数
return;
}
int main() {
std::string cmd;
Node* ptr = &root;
while (true) {
std::cin >> cmd;
if (cmd == "op") {
std::string x;
std::cin >> x;
ptr->child.push_back((Node){ptr, x == "n" ? -1 : stringToInt(x), true});
} else if (cmd == "loop") {
std::string x;
std::cin >> x;
if (x == "0") x = "n"; //数据问题,详见https://www.luogu.com.cn/discuss/show/271086
ptr->child.push_back((Node){ptr, x == "n" ? -1 : stringToInt(x), false});
ptr = &ptr->child.back();
} else if (cmd == "break") {
if (ptr == &root) continue;
ptr->val = 1; //由于循环至此必然结束,就相当于循环只运行了一遍
ptr = ptr->father;
readUntilEnd();
} else if (cmd == "continue") {
if (ptr == &root) continue;
ptr = ptr->father;
readUntilEnd(); //从这句话到结尾的所有话都可以无视
} else if (cmd == "end") {
if (ptr == &root) break;
ptr = ptr->father;
}
}
dfs(root, 0, 1);
bool first = true;
for (int i = maxDep - 1; ~i; i--)
if (answer[i]) {
if (!first) putchar('+');
first = false;
if (answer[i] != 1 || !i) printf("%d", answer[i]);
if (i > 1)
printf("n^%d", i);
else if (i == 1)
putchar('n');
}
if (first) putchar('0');
puts("");
return 0;
}