该题题义是求一个序列中非递减的子序列的个数,其实就是一个求逆序对的题,这里当然就是用树状数组来解决了。
首先对输入数据进行离散化,以便于在树状数组上面工作,然后利用DP公式计算ans[i] = sum{ ans[j], j < i },可以理解为在前面的所有满足要求的集合上加上这个较大的数。
参看http://www.cppblog.com/menjitianya/archive/2011/04/06/143510.aspx
代码如下:
#include <cstring> #include <cstdio> #include <algorithm> #include <cstdlib> #include <map> #define MAXN 100005 #define MOD 1000000007 typedef long long ll; using namespace std; ll c[MAXN]; int N, val[MAXN], t[MAXN]; inline int lowbit( int x ) { return x & (-x); } inline void modify( int pos, ll val ) { for (int i = pos; i <= N; i += lowbit(i) ) { c[i] += val; if (c[i] >= MOD) c[i] %= MOD; } } inline ll sum( int pos ) { int s = 0; for (int i = pos; i > 0; i -= lowbit(i)) { s += c[i]; if (s >= MOD) s %= MOD; } return s; } int main( ) { while (scanf("%d", &N) == 1) { ll ans = 0; map<int,int>mp; memset(c, 0, sizeof(c)); for (int i = 1; i <= N; ++i) { scanf("%d", &val[i]); t[i] = val[i]; } sort(t + 1, t + N + 1); int cnt = unique(t + 1, t + N + 1) - t - 1; for (int i = 1; i <= cnt; ++i) { mp[t[i]] = i; } modify(1, 1); for (int i = 1; i <= N; ++i) { int x = sum(mp[val[i]]); ans += x; if (ans >= MOD) ans %= MOD; modify(mp[val[i]], x); } printf("%lld\n", ans); } return 0; }
以下是我的错误代码,套用了公式 2^N-1 (其中N为逆序对数,求解时-1就用自身的单点集合来中和了,所以直接就用2^N来计算了) 来计算。可是对于下面的情况无法给出正确的结果。
3
2 1 2
其结果会与
3
2 1 3
相同,多算了( 2, 1, 2 ) 这类集合......
#include <cstring> #include <cstdio> #include <algorithm> #include <cstdlib> #include <map> #define MAXN 100005 #define MOD 1000000007 typedef long long ll; using namespace std; ll c[MAXN]; int rec[MAXN]; int N, val[MAXN], t[MAXN]; inline int lowbit( int x ) { return x & (-x); } inline void modify( int pos, ll val ) { for (int i = pos; i <= N; i += lowbit(i) ) { c[i] += val; } } inline ll sum( int pos ) { int s = 0; for (int i = pos; i > 0; i -= lowbit(i)) { s += c[i]; } return s; } int main( ) { rec[0] = 1; for (int i = 1; i < MAXN; ++i) { rec[i] = (rec[i-1]*2)%MOD; } while (scanf("%d", &N) == 1) { ll ans = 0; map<int,int>mp; memset(c, 0, sizeof(c)); for (int i = 1; i <= N; ++i) { scanf("%d", &val[i]); t[i] = val[i]; } sort(t + 1, t + N + 1); int cnt = unique(t + 1, t + N + 1) - t - 1; for (int i = 1; i <= cnt; ++i) { mp[t[i]] = i; } for (int i = 1; i <= N; ++i) { int x = sum(mp[val[i]]); ans += rec[x]; if (ans >= MOD) ans %= MOD; modify(mp[val[i]], 1); } printf("%lld\n", ans); } return 0; }