题解:
dp[ l ] [ r ] [ lc ] [ rc ]
代表的是第在区间[ l , r] 的情况下 左端点颜色是lc, 右端点颜色是rc的方案数是多少。
然后记忆化DP。
将一个序列拆成一个个匹配的序列。
为了防止一开始序列不匹配,所以从2个虚拟的地方开始计算。
代码:
#include<bits/stdc++.h> using namespace std; #define Fopen freopen("_in.txt","r",stdin); freopen("_out.txt","w",stdout); #define LL long long #define ULL unsigned LL #define fi first #define se second #define pb push_back #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 #define lch(x) tr[x].son[0] #define rch(x) tr[x].son[1] #define max3(a,b,c) max(a,max(b,c)) #define min3(a,b,c) min(a,min(b,c)) typedef pair<int,int> pll; const int inf = 0x3f3f3f3f; const int _inf = 0xc0c0c0c0; const LL INF = 0x3f3f3f3f3f3f3f3f; const LL _INF = 0xc0c0c0c0c0c0c0c0; const LL mod = (int)1e9+7; const int N = 1000; char s[N]; int dp[N][N][3][3]; int link[N]; stack<int> sta; int solve(int l, int r, int cl, int cr){ if(~dp[l][r][cl][cr]) return dp[l][r][cl][cr]; if(l+1 == r){ dp[l][r][cl][cr] = 1; return 1; } int now = 0, pre = 1; LL t[2][3]; for(int i = 0; i < 3; ++i) t[0][i] = t[1][i] = 0; t[0][cl] = 1; for(int ll = l + 1, rr = link[ll]; ll < r; ll = rr + 1, rr = link[ll]){ swap(now, pre); for(int i = 0; i < 3; ++i) t[now][i] = 0; t[now][0] = t[pre][0] * solve(ll, rr, 1, 0) + t[pre][0] * solve(ll, rr, 2, 0) + t[pre][1] * solve(ll, rr, 2, 0) + t[pre][2] * solve(ll, rr, 1, 0); t[now][0] %= mod; t[now][1] = t[pre][0] * solve(ll, rr, 0, 1) + t[pre][1] * solve(ll, rr, 0, 1) + t[pre][2] * solve(ll, rr, 0, 1); t[now][1] %= mod; t[now][2] = t[pre][0] * solve(ll, rr, 0, 2) + t[pre][1] * solve(ll, rr, 0, 2) + t[pre][2] * solve(ll, rr, 0, 2); t[now][2] %= mod; } if(cr == 0) dp[l][r][cl][cr] = (t[now][0] + t[now][1] + t[now][2]) % mod; else if(cr == 1) dp[l][r][cl][cr] = (t[now][0] + t[now][2]) % mod; else dp[l][r][cl][cr] = (t[now][0] + t[now][1]) % mod; return dp[l][r][cl][cr]; } int main(){ memset(dp, -1, sizeof dp); scanf("%s", s+1); int n = strlen(s+1); for(int i = 1; i <= n; ++i){ if(s[i] == '(') sta.push(i); else { int j = sta.top(); link[i] = j; link[j] = i; sta.pop(); } } int ans = 1ll * solve(0, n+1, 0, 0); //int ans = (1ll * solve(1, n, 0, 1) + solve(1, n, 1, 0) + solve(1, n, 2, 0) + solve(1, n, 0, 2)) % mod; printf("%d ", ans); return 0; }