fortran66のブログ

fortran について書きます。

行列積を Cannon の algorithm で

昔、コネクションマシンの本を読んだときに、各並列プロセッサは 1bit 演算で、浮動小数点数も 1bit 演算を繰り返して計算すると知って、これは駄目ですわと思ったのですが、グスタフソンの動的可変長浮動小数点数フォーマットのことを思うと、任意長に出来る 1bit 演算もありかなと思ってパラパラと斜め読みしなおしてみました。


日本語版には訳者がまとめた CM-2 の資料がありますが、その中に Fortran8x に基づく CM-FORTRAN のサンプルプログラムがあったので、今の Fortran で動かしてみました。CM-FORTRAN はおおむね Fortran-95 水準なので、ほぼそのまま動きます。(1bit 演算とは関係ありませんw)

コネクションマシン―65,536台のプロセッサから構成される超並列コンピュータ

コネクションマシン―65,536台のプロセッサから構成される超並列コンピュータ

サンプルプログラムは、サブルーチンで、Cannon のアルゴリズムによる行列積の並列計算です。元々二次元アレイプロセッサ用のものなので、あまり意味はありません。CSHIFT/EOSHIFT 関数が 並列計算用組み込み関数に分類されるのは、二次元トーラス上のアレイプロセッサに配列要素を割り当てる前提があったためです。


Cannon のアルゴリズムの説明動画を貼っておきます。

Cannon's algorithm for matrix multiplication

実行結果

乱数で生成した行列 A,B の積 C=A*B を、組み込み関数の MATMUL と Cannon 法の二通りで求めて、各行列要素の差の絶対値の最大値を求めています。また、各ステップごとの経過時間を秒単位で書き出しています。

Intel Fortran で自動並列化をかけましたが、たいして効果はありません、MATMUL は MKL のルーチンへの自動置換オプションを選んだので並列化が効いています。

行列要素ではなく、ブロック行列に適用しなければいけなかろうかと思います。

  80610.66     :start
 2.1000000E-02 :matmul
  3.666000     :Cannon
 0.0000000E+00 :maxval
max difference  5.6457520E-04
続行するには何かキーを押してください . . .

ソース

サブルーチンは、原文そのままを書こうと思ったのですが、コメント文と CSHIFT の分を直してあります。CSHIFT は引数の順番の定義が一部変わっていました。

    module m_mod
    contains
      SUBROUTINE CANNON(M1, M2, R, N)
      INTEGER N
      REAL M1(N,N), M2(N,N), R(N,N)
!C    Skew argument arrays.
      M1 = CSHIFT(M1, (/0:N-1/), 2)
      M2 = CSHIFT(M2, (/0:N-1/), 1)
!C    Multiply matrices
      R = 0.0
      DO I = 1,N
           R = R + M1*M2
           M1 = CSHIFT(M1,1,2)
           M2 = CSHIFT(M2,1,1)
      END DO
!C    Restore argument arrays.
      M1 = CSHIFT(M1, -(/0:N-1/), 2)
      M2 = CSHIFT(M2, -(/0:N-1/), 1)
      RETURN
      END 
    end module m_mod

    program CM
      use m_mod
      implicit none
      integer, parameter :: n = 1000 
      real :: a(n, n), b(n, n), c(n, n), d(n, n)
      call RANDOM_NUMBER(a)
      call RANDOM_NUMBER(b)
      call stamp('start')
      d = matmul(a, b)
      call stamp('matmul')
      call CANNON(a, b, c, n)
      call stamp('Cannon')
      call stamp('maxval')
      print *, 'max difference', maxval(abs(c - d))
    contains
      subroutine stamp(text)
        character(*), intent(in) :: text
        integer :: ic0 = 0, ic, irate
        call SYSTEM_CLOCK(ic, irate)
        print *, (ic - ic0) / real(irate), ':', text
        ic0 = ic
      end subroutine stamp  
    end program CM