• 【题解】矩阵乘法


    题目描述

            有一个N行M列的二维数组,共N×M个格子,每个格子都有一个价值。第1行格子的价值从左往右是1至M,第2行格子的价值从左往右是M+1至M+M,最后1行格子的价值从左往右价值是(N-1)×M+1至N×M。例如N=3,M=4,那么该二维数组的价值如下所示:

    Failed to load picture

            现在有K个操作,每个操作是如下的两种情况之一:

            1、格式是R X Y,表示行操作,第X行的每个格子价值都乘上一个非负整数Y。

            2、格式是S X Y,表示列操作,第X列的每个格子价值都乘上一个非负整数Y。

            当进行完K次操作之后,你要输出所有格子的价值总和,由于答案可能很大,所以答案要模1000000007。

     

    输入格式

            第一行,三个整数,N,M,K。

            接下来是K行,每行的格式如题目所述。

    输出格式

            一个整数。

    输入样例

    3 4 4

    R 2 4

    S 4 1

    R 3 2

    R 2 0

    输出样例

    94

    数据规模

            对于50%的数据,1≤N,M≤1000,1≤K≤1000;0≤Y≤1000000000;

            对于100%的数据,1≤N,M≤1000000,1≤K≤1000;0≤Y≤1000000000。

    题解

            显然$O(mn) $会超时,所以我们要直接枚举操作。

            我们设第一种操作为行操作,第二种操作为列操作。

            首先是预处理,我们将相同行(列)的操作合并起来。同时我们可以将没有操作时的总价值求出来,用$ans$记录。

            我们先枚举行操作,把每行增加的价值也累加到$ans$里(等差数列求和应该都会吧)。对于每一个行操作$i$,我们同时枚举列操作$j$,加上$i$和$j$相交的格子进行操作$j$后再增加的价值。

            然后再从头枚举列操作$i$,对于每一个$i$,同时枚举行操作$j$,因为$i$,$j$相交的格子之前已经加过了,所以这里再把多加的减回去即可。

            注意:如果操作为$ imes 0$,则实际上是减去刚开始预处理多加的价值。

            这样就可以做到$O(k^{2})$了。具体细节可以看下面的代码。

    #include <iostream>
    #include <algorithm>
    #include <cstring>
    
    #define MAX_K 1000
    
    using namespace std;
    
    typedef long long ll;
    const ll mod = 1000000007LL;
    ll n, m;
    int k;
    struct Fish
    {
        ll x;
        ll y;
        inline bool operator < (const Fish & a) const {return x < a.x;}
    };
    Fish r[MAX_K + 5], s[MAX_K + 5];
    int cr, cs;
    
    inline ll S(ll, ll, ll);
    void Initiation();
    void Solution();
    
    int main()
    {
        Initiation();
        Solution();
        return 0;
    }
    
    void Initiation()
    {
        cin >> n >> m >> k;
        Fish tr[MAX_K + 5], ts[MAX_K + 5];
        memset(tr, 0, sizeof tr);
        memset(ts, 0, sizeof ts);
        int ctr = 0, cts = 0;
        char ch;
        ll tx;
        ll ty;
        for(register int i = 1; i <= k; ++i)
        {
            cin >> ch >> tx >> ty;
            if(ch == 'R') tr[++ctr].x = tx, tr[ctr].y = ty % mod;
            else ts[++cts].x = tx, ts[cts].y = ty % mod;
        }
        sort(tr + 1, tr + ctr + 1);
        sort(ts + 1, ts + cts + 1);
        for(register int i = 1; i <= ctr; ++i)
        {
            if(tr[i].x == tr[i + 1].x) tr[i + 1].y = tr[i + 1].y * tr[i].y % mod;
            else r[++cr] = tr[i];
        }
        for(register int i = 1; i <= cts; ++i)
        {
            if(ts[i].x == ts[i + 1].x) ts[i + 1].y = ts[i + 1].y * ts[i].y % mod;
            else s[++cs] = ts[i];
        }
        return;
    }
    
    void Solution()
    {
        ll ans = S(1, 1, n * m);
        for(register int i = 1; i <= cr; ++i)
        {
            if(r[i].y)
            {
                ans += S((r[i].x - 1) * m + 1, 1, m) * (r[i].y - 1) % mod;
                if(ans >= mod) ans -= mod;
                for(register int j = 1; j <= cs; ++j)
                {
                    if(s[j].y)
                    {
                        ans += ((r[i].x - 1) * m + s[j].x) % mod * r[i].y % mod * (s[j].y - 1) % mod;
                        if(ans >= mod) ans -= mod;
                    }
                    else
                    {
                        ans -= ((r[i].x - 1) * m + s[j].x) % mod * r[i].y % mod;
                        if(ans < 0) ans += mod;
                    }
                }
            }
            else 
            {
                ans -= S((r[i].x - 1) * m + 1, 1, m);
                if(ans < 0) ans += mod;
            }
        }
        for(register int i = 1; i <= cs; ++i)
        {
            if(s[i].y)
            {
                ans += S(s[i].x, m, n) * (s[i].y - 1) % mod;
                if(ans >= mod) ans -= mod;
                for(register int j = 1; j <= cr; ++j)
                {
                    ans -= ((r[j].x - 1) * m + s[i].x) % mod * (s[i].y - 1) % mod;
                    if(ans < 0) ans += mod;
                }
            }
            else
            {
                ans -= S(s[i].x, m, n);
                if(ans < 0) ans += mod;
                for(register int j = 1; j <= cr; ++j)
                {
                    ans += ((r[j].x - 1) * m + s[i].x) % mod;
                    if(ans >= mod) ans -= mod; 
                }
            }
        }
        cout << ans;
        return;
    }
    
    inline ll S(ll a, ll d, ll cnt)
    {
        if(cnt & 1) return ((a % mod) * (cnt % mod) % mod + ((cnt - 1 >> 1) % mod) * (cnt % mod) % mod * (d % mod) % mod) % mod; 
        return ((a % mod) * (cnt % mod) % mod + ((cnt >> 1) % mod) * ((cnt - 1) % mod) % mod * (d % mod) % mod) % mod; 
    }
    参考程序
  • 相关阅读:
    SQL Server如何使用表变量
    Msys/MinGW与Cygwin/GCC(转)
    内存段划分:代码段、数据段、堆、栈
    Codeblocks+MinGW+wxWidgets搭建方法(转)
    Java GUI图形界面开发工具
    MinGW离线安装方法汇总(转)
    Linux系统的启动过程(转)
    详解VOLATILE在C++中的作用(转)
    C++虚函数与纯虚函数用法与区别(转)
    C++中值传递、指针传递和引用传递的比较 (转)
  • 原文地址:https://www.cnblogs.com/kcn999/p/11224183.html
Copyright © 2020-2023  润新知