orz qyc
看很多写法都是 (mathcal{O}(n^4)) 的,其实稍微预处理下就能做到 (mathcal{O}(n^3)) 的了。
并不那么显然地看出是个区间 dp,后面就很好做了。
发现平方聚在一起是更优的,则区间 dp 应该是枚举一列让它尽可能的多选。
基于这个贪心的思路,设 (f_{l,r}) 为仅考虑左右端点均在 ([l,r]) 内的段的答案,转移就枚举哪一行是尽可能填的。设 (s_{k,l,r}) 为代表左右端点都在 ([l,r]) 内且经过 (k) 列的区间个数。
枚举尽可能多选的那一列是第 (k) 列,则有:
[f_{l,r}=max{f_{l,k-1}+{s_{k,l,r}}^2+f_{k+1,r}}
]
然后考虑怎样算 (s)。
对于给出的一个段 (x,y),实际上是对所有的 (lleq x,rgeq y,xleq kleq y),(s_{k,l,r}) 都加上 (1)。
也许读者注意到一个细节,(s) 的状态设计中,我将 (k) 放在第一维,实际上是方便对 (mathcal{O}(n^3)) 预处理 (s) 作铺垫。考虑对每个 (k) 建立一个平面直角坐标系,横坐标为 (l),纵坐标为 (r),((l,r)) 上的值为 (s_{k,l,r})。
那么所有满足 (lleq x,rgeq y) 的 ((l,r)),构成了一个左上角上的矩形,矩形加最后单点查,可以二维差分解决。
但实际上都是左上角的矩形,在 ((l,r)) 打个 tag 然后做二维前缀和就完全足够。
时间复杂度 (mathcal{O(n^3)}).
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
typedef long long ll;
const ll mod = 1000000007;
template <typename T> T Add(T x, T y) { return (x + y >= mod) ? (x + y - mod) : (x + y); }
template <typename T> T cAdd(T x, T y) { return x = (x + y >= mod) ? (x + y - mod) : (x + y); }
template <typename T> T Mul(T x, T y) { return x * y % mod; }
template <typename T> T Mod(T x) { return x < 0 ? (x + mod) : x; }
template <typename T> T Max(T x, T y) { return x > y ? x : y; }
template <typename T> T Min(T x, T y) { return x < y ? x : y; }
template <typename T> T Abs(T x) { return x < 0 ? -x : x; }
template <typename T> T chkmax(T &x, T y) { return x = x > y ? x : y; }
template <typename T>
T &read(T &r) {
r = 0; bool w = 0; char ch = getchar();
while(ch < '0' || ch > '9') w = ch == '-' ? 1 : 0, ch = getchar();
while(ch >= '0' && ch <= '9') r = (r << 3) + (r <<1) + (ch ^ 48), ch = getchar();
return r = w ? -r : r;
}
ll qpow(ll x, ll y) {
ll sumq = 1;
while(y) {
if(y & 1) sumq = sumq * x % mod;
x = x * x % mod;
y >>= 1;
}
return sumq;
}
const int N = 110;
int n, m;
int len[N], f[N][N], s[N][N][N];
signed main() { //freopen("in.txt", "r", stdin); freopen("out.txt", "w", stdout);
read(n); read(m);
for(int i = 1; i <= n; ++i) {
read(len[i]);
for(int j = 1; j <= len[i]; ++j) {
int x, y; read(x); read(y);
for(int k = x; k <= y; ++k)
++s[k][x][y];
}
}
for(int k = 1; k <= m; ++k) {
for(int l = m; l; --l)
for(int r = 1; r <= m; ++r)
s[k][l][r] += s[k][l][r-1];
for(int l = m; l; --l)
for(int r = 1; r <= m; ++r)
s[k][l][r] += s[k][l+1][r];
}
for(int k = 1; k <= m; ++k) f[k][k] = s[k][k][k] * s[k][k][k];
for(int len = 1; len < m; ++len)
for(int i = 1; i + len <= m; ++i) {
int j = i + len;
for(int k = i; k <= j; ++k) f[i][j] = Max(f[i][j], f[i][k-1] + s[k][i][j] * s[k][i][j] + f[k+1][j]);
}
printf("%d
", f[1][m]);
return 0;
}