読者です 読者をやめる 読者になる 読者になる

解いた問題のソースコードと解説など。


POJ 3233 Matrix Power Series

問題

nxnの行列Aが与えられる。S=A+A^2+A^3+...+A^kを計算し、各要素をmで割ったあまりを求めよ。

やりかた

行列の累乗は繰り返し二乗でできるとしても単純にやったら確実にTLEする。なのでうまく分割統治しつつメモ化しておくことで通すことができる。

数列を左右に二分割する。すると
A + A^2 + A^3 + A^4 + A^5 + A^6 + A^7 + A^8 = A(E + A + A^2 + A^3) + A^4(E + A + A^2 + A^3)
となり、(E + A + A^2 + A^3)という構造が左右に現れる。この構造はさらにE(E + A) + A^2(E + A)と分割できる。これを繰り返して、要素数が一つ(=E)になるまで分割してから統合していくことで計算ができる。分割後の要素数に対するメモ化を行うことで無駄な計算を省く。

最初何度もTLEしたので、要素数が奇数の場合でも左右の分割数を同じにしてメモ化に引っかかりやすくしてみたり、行列の計算結果を戻り値で返さずにグローバル変数に入れて参照する、などしてようやく2100MS程度で通せた。。

以下きたないソース。

typedef vector<int> vec;
typedef vector<vec> mat;

mat add_res;

int MOD;

inline mat identity(int n){
  mat E(n, vec(n, 0));
  for(int i = 0; i < n; i++) E[i][i] = 1;
  return E;
}

inline void add(const mat &A, const mat &B){
  int N = A.size(), M = A[0].size();
  for(int n = 0; n < N; n++)
    for(int m = 0; m < M; m++)
      add_res[n][m] = ((A[n][m] + B[n][m]) % MOD);
}

inline mat mult(const mat &A, const mat &B){
  int N = A.size(), M = B.size(), L = B[0].size();
  mat C(N, vec(L, 0));
  for(int n = 0; n < N; n++)
    for(int l = 0; l < L; l++)
      for(int m = 0; m < M; m++)
	C[n][l] = ((C[n][l] + A[n][m] * B[m][l]) % MOD);
  return C;
}

inline mat pow(const mat &A, int p){
  mat R = identity(A.size());
  mat B = A;
  while(p){
    if(p & 1) R = mult(R, B);
    B = mult(B, B);
    p >>= 1;
  }
  return R;
}

map<int, mat> memo;

mat rec(const mat &A, int w){
  if(w == 1) return identity(A.size());
  if(memo.count(w)) return memo[w];

  mat L = rec(A, w / 2);
  mat R = L;
  mat B = pow(A, w / 2);
  R = mult(B, R);
  
  add(L, R);

  //要素数が奇数の場合は(0,1,..,w/2-1)(w/2,...,w-2), w という分割にする
  //ここでは最後のw番目の行列A^(w-1)の計算を個別に行っている
  if(w % 2) add(add_res, pow(A, w - 1));
  return memo[w] = add_res;
}

int main(int argc, char **argv){
  int n, k, m;
  scanf("%d %d %d", &n, &k, &m);
  MOD = m;
  mat A = identity(n);
  for(int i = 0; i < n; i++)
    for(int j = 0; j < n; j++)
      scanf("%ld", &A[i][j]);
  
  add_res = identity(n);

  mat ans = rec(A, k);
  ans = mult(A, ans);
  for(int i = 0; i < n; i++){
    for(int j = 0; j < n; j++){
      printf("%ld ", ans[i][j] % MOD);
    }
    printf("\n");
  }
  return 0;
}

ここまで書いたが、蟻本によれば
 
\left(
\begin{array}{c}
A^k \\
\hline
S^k
\end{array}
\right)
=
\left(
\begin{array}{c|c}
A & 0 \\
\hline
I & I
\end{array}
\right)
\left(
\begin{array}{c}
A^{k-1} \\
\hline
S^{k-1}
\end{array}
\right)
という漸化式が成り立つのでもっと高速に計算できる。なるほど~。

しょげないでよBaby 眠れば治る