思路
代码
#include <cstdio>
#include <cmath>
using namespace std;
const int inf = 0x3f3f3f3f;
int n, m, ans = inf;
void dfs(int tj, int cs, int preR, int preH, int mj) {
// printf("%d %d %d %d %d
", tj, cs, preR, preH, mj);
if (tj < 0 || mj > ans) return;
if (2 * tj / preR + mj > ans) return;
if (cs == m + 1) {
if (tj == 0) ans = mj;
return;
}
for (int i = preR - 1; i > m - cs; --i) {
for (int j = preH - 1; j > m - cs; --j) {
if (cs == 1) mj = i * i;
dfs(tj - i * i *j, cs + 1, i, j, mj + 2 * i * j);
}
}
}
int main() {
scanf("%d %d", &n, &m);
dfs(n, 1, sqrt(n), n, 0);
if (ans == inf) printf("0");
else printf("%d", ans);
return 0;
}