ShuffledPlaylist (SRM443 Div1 Hard)

だいぶ前にtopoderの王者petr氏のブログで言及されていた問題を勉強した。
http://petr-mitrichev.blogspot.com/2009/06/srm-443.html

The hard problem for the round is a reasonably complex example how matrix multiplication can chime in to speed up DP. If you want to learn the more advanced DP techniques, this problem might give you a level-up :)

こう書かれてはやらずにはおれない、と読んだときに思ったはずなのに早1年以上、ついに昨日やってみた。

問題は、
「色々な曲(最大数百曲くらい)が与えられていて、それぞれの曲にジャンル(最大9通り)と曲の長さ(1〜9)が設定されている。
これらの曲をそれぞれどういう順番で何回使ってもよいという条件のもと、minLength(1..10億)以上、maxLength(minLength..10億)以下の長さのアルバムを作る方法は何通りあるか?
ただしジャンル間には連続してはいけないものがあり、それは(最大)9x9のテーブルで与えられている」

minLength,maxLenghtが最大10億という条件から、行列の累乗に帰着というのは常識。
A^n = (A^(n/2))^2という再帰によりlog(n)のオーダーになるから。
しかしどんな行列?
あっさりあきらめて解説読んだ。

まず、グラフを考える。
各ノードは、ジャンルと現在演奏中の曲の残り時間(g,l)
これが81通り。
それにsourceとsinkを加える。
エッジをたどるのが1秒消費に相当する。
求めるものは、sourceからスタートしてこのグラフをぐるぐるとめぐりめぐってL秒以下でsinkまで行く場合の数。
隣接行列のL乗がノード間をL歩でいく場合の数になる、という事実を知っていると(解説を読んで知りました^^;)、上のグラフのエッジに適切な値を設定してL乗すればいいことが分かる。
エッジはどう設定するか。
・まず、sourceから各曲のジャンルg、長さlの(g,l)を接続。
 複数曲が同じ(g,l)に行く場合があるので、その場合は重みをつけて。
・l>1の場合は(g,l)から(g,l-1)へは重み1のエッジ。
・(g,1)から各曲の(g',l)へこれも曲数の重みつきエッジを設定。ただし、g=>g'が接続してよいジャンルの場合だけ。
・全ての(g,1)からsinkへ重み1のエッジ。

これで83x83の行列が出来たので、L乗すればよいようだが、それだとちょうどL秒のアルバムの総数になる。
いや、最初のsourceからスタートする最初の1秒は本当はないものなのでL-1秒。
いずれにしても、「ちょうど・・秒」で「・・秒以下」にならない。
これはもう一つエッジを加えることで簡単に解決する。
・sinkからsinkへのエッジ
うまいなあ。

以下に私が書いたコードを載せておく。
行列の演算ライブラリと引数の処理以外は行列の設定がほとんど全て。
ちなみに、行列の積で、tmp > 1000000000000000000LL と書いているのは余分なmodを減らすため。
これ入れないと最悪ケースで2秒以内で終わるか微妙なところだった。


#include
#include
#include
#include
#include
#include
using namespace std;

#define MOD 600921647
typedef long long ll;

template
class MyMatrix {
vector< vector > elements;
int mod;

public:
MyMatrix( int n, int mod_ ) : elements( n, vector( n ) ), mod(mod_) {}
static MyMatrix identity( int n, int mod ) {
MyMatrix m( n, mod );
for ( int i = 0; i < n; i ++ )
m( i, i ) = 1;
return m;
}

T operator()( int i, int j ) const {
return elements[i][j];
}
T& operator()( int i, int j ) {
return elements[i][j];
}

MyMatrix operator*( const MyMatrix& m ) const {
MyMatrix res( elements.size(), mod );
for ( int i = 0; i < (int)elements.size(); i ++ ) {
for ( int j = 0; j < (int)elements.size(); j ++ ) {
T tmp = 0;
for ( int k = 0; k < (int)elements.size(); k ++ ) {
tmp += elements[i][k] * m( k, j );
if ( mod && tmp > 1000000000000000000LL ) {
tmp %= mod;
}
}
if ( mod ) tmp %= mod;
res( i, j ) = tmp;
}
}
return res;
}

MyMatrix pow( int n ) {
if ( n == 0 )
return MyMatrix::identity( elements.size(), mod );
if ( n == 1 )
return *this;
if ( n & 1 )
return (*this) * pow( n - 1 );
MyMatrix m = pow( n / 2 );
return m * m;
}
};

int node( int g, int l )
{
return g * 9 + l + 1;
}

class ShuffledPlaylist {
public:
int count(vector songs, vector transitions, int minLength, int maxLength)
{
int nGenres = transitions.size();
string s = accumulate( songs.begin(), songs.end(), string("") );
for ( int i = 0; i < s.size(); i ++ ) if ( s[i] == ',' ) s[i] = ' ';
stringstream ss( s );
vector genres, lengths;
for ( int t; ss >> t; ) {
int u; ss >> u;
genres.push_back( t );
lengths.push_back( u - 1 );
}

int nSongs = genres.size();
int weight[83] = { 0 };
MyMatrix A( 83, MOD );

for ( int i = 0; i < nSongs; i ++ )
weight[node( genres[i], lengths[i] )] ++;
for ( int i = 1; i < 82; i ++ )
A( 0, i ) = weight[i];
for ( int g = 0; g < nGenres; g ++ )
for ( int l = 1; l < 9; l ++ )
A( node( g, l ), node( g, l - 1 ) ) = 1;
for ( int g = 0; g < nGenres; g ++ )
for ( int i = 0; i < nSongs; i ++ ) {
if ( transitions[g][genres[i]] == 'Y' )
A( node( g, 0 ), node( genres[i], lengths[i] ) ) = weight[node( genres[i], lengths[i] )];
}
for ( int g = 0; g < nGenres; g ++ )
A( node( g, 0 ), 82 ) = 1;
A( 82, 82 ) = 1;

MyMatrix mh = A.pow( maxLength + 1 );
MyMatrix ml = A.pow( minLength );

return ( ( mh( 0, 82 ) - ml( 0, 82 ) ) % MOD + MOD ) % MOD;
}
};