fortran66のブログ

fortran について書きます。

【メモ帳】Fortran から Julia 呼び出し その2

Julia from Fortran

前回の続きで、関数呼び出し、配列受け渡しなどの例題をやってみました。

一応、配列構造体を定義してみました。しかしコンパイル時の定数で構造が変わるようなので、よく分かりません。

しかし構造体などをきちんと定義していないので、ちゃんとしたものにはなっていません。 fortran66.hatenablog.com

配列は、Fortran の allocatable 型にしろ、Julia の Array にしろ、データの前に次元やサイズの情報のヘッダーがついているので、そこを何とかしなければならないのですが、めんどくさいのでちゃんとしていません。

また、C 言語の例題では Julia のガベージ・コレクション (GC) の動作のコントロールをマクロを利用してやっているのですが、そこは難しいので、気になる場合は Fortran で Julia 側で確保したメモリー領域をいじる場合には GC を停止させればいいのではないかと思います。インターフェースは作っておきました。

Fortran ソース

module julia_mod
    use, intrinsic :: iso_fortran_env
    use, intrinsic :: iso_c_binding
    implicit none
!
! types (struct)
!
    type, bind(c) :: jl_array_t
        type(c_ptr) :: dat
        ! integer(c_size_t) :: length
        type(c_ptr) :: flags
        integer(c_int16_t) :: elsize
        integer(c_int32_t) :: offset
        integer(c_size_t) :: nrows
        integer(c_size_t) :: ncols ! union maxsize
    end type jl_array_t
! !
! ! julia-1.5.2/include/julia/julia.h
! !
!
!JL_EXTENSION typedef struct {
!    JL_DATA_TYPE
!    void *data;
!#ifdef STORE_ARRAY_LEN
!    size_t length;
!#endif
!    jl_array_flags_t flags;
!    uint16_t elsize;  // element size including alignment (dim 1 memory stride)
!    uint32_t offset;  // for 1-d only. does not need to get big.
!    size_t nrows;
!    union {
!        // 1d
!        size_t maxsize;
!        // Nd
!        size_t ncols;
!    };
!    // other dim sizes go here for ndims > 2
!
!    // followed by alignment padding and inline data, or owner pointer
!} jl_array_t;
!

!
! global variables
!
    type(c_ptr) :: jl_base_module
    bind(c, name = 'jl_base_module') :: jl_base_module

    type(c_ptr) :: jl_int8_type
    bind(c, name = 'jl_int8_type') :: jl_int8_type

    type(c_ptr) :: jl_int16_type
    bind(c, name = 'jl_int16_type') :: jl_int16_type

    type(c_ptr) :: jl_int32_type
    bind(c, name = 'jl_int32_type') :: jl_int32_type

    type(c_ptr) :: jl_int64_type
    bind(c, name = 'jl_int64_type') :: jl_int64_type

    type(c_ptr) :: jl_float16_type
    bind(c, name = 'jl_float16_type') :: jl_float16_type

    type(c_ptr) :: jl_float32_type
    bind(c, name = 'jl_float32_type') :: jl_float32_type

    type(c_ptr) :: jl_float64_type
    bind(c, name = 'jl_float64_type') :: jl_float64_type
!
! interfaces
!
    interface
        subroutine jl_init() bind(c, name = 'jl_init__threading')
        end subroutine jl_init

        subroutine jl_atexit_hook(status) bind(c, name = 'jl_atexit_hook')
            import
            integer(c_int), value :: status
        end subroutine jl_atexit_hook

        type(c_ptr) function jl_eval_string(pstr) bind(c, name = 'jl_eval_string')
            import
            type(c_ptr), value :: pstr
        end function jl_eval_string

        real(real64) function jl_unbox_float64(ptr) bind(c, name = 'jl_unbox_float64')
            import
            type(c_ptr), value :: ptr
        end function jl_unbox_float64

        type(c_ptr) function jl_box_float64(x) bind(c, name = 'jl_box_float64')
            import
            real(real64), value :: x
        end function jl_box_float64

        real(real32) function jl_unbox_float32(ptr) bind(c, name = 'jl_unbox_float32')
            import
            type(c_ptr), value :: ptr
        end function jl_unbox_float32

        type(c_ptr) function jl_box_float32(x) bind(c, name = 'jl_box_float32')
            import
            real(real32), value :: x
        end function jl_box_float32

        real(int64) function jl_unbox_int64(ptr) bind(c, name = 'jl_unbox_int64')
            import
            type(c_ptr), value :: ptr
        end function jl_unbox_int64

        type(c_ptr) function jl_box_int64(x) bind(c, name = 'jl_box_int64')
            import
            real(int64), value :: x
        end function jl_box_int64

        real(int32) function jl_unbox_int32(ptr) bind(c, name = 'jl_unbox_int32')
            import
            type(c_ptr), value :: ptr
        end function jl_unbox_int32

        type(c_ptr) function jl_box_int32(x) bind(c, name = 'jl_box_int32')
            import
            real(int32), value :: x
        end function jl_box_int32

        type(c_ptr) function jl_ptr_to_array_1d(typ, pdata, sz, own_buffer) bind(c, name = 'jl_ptr_to_array_1d')
            import
            type(c_ptr), value :: typ
            type(c_ptr), value :: pdata
            integer(c_size_t), value :: sz
            integer(c_int)   , value :: own_buffer
        end function jl_ptr_to_array_1d

        type(c_ptr) function jl_ptr_to_array(typ, pdata, dims, own_buffer) bind(c, name = 'jl_ptr_to_array')
            import
            type(c_ptr), value :: typ
            type(c_ptr), value :: pdata
            type(c_ptr), value :: dims
            integer(c_int)   , value :: own_buffer
        end function jl_ptr_to_array

        type(c_ptr) function jl_apply_array_type(typ, n) bind(c, name = 'jl_apply_array_type')
            import
            type(c_ptr), value :: typ
            integer(c_size_t), value :: n
        end function jl_apply_array_type

        type(c_ptr) function jl_alloc_array_1d(typ, n) bind(c, name = 'jl_alloc_array_1d')
            import
            type(c_ptr), value :: typ
            integer(c_size_t), value :: n
        end function jl_alloc_array_1d

        type(c_ptr) function jl_alloc_array_2d(typ, nr, nc) bind(c, name = 'jl_alloc_array_2d')
            import
            type(c_ptr), value :: typ
            integer(c_size_t), value :: nr, nc
        end function jl_alloc_array_2d

        type(c_ptr) function jl_alloc_array_3d(typ, nr, nc, nz) bind(c, name = 'jl_alloc_array_3d')
            import
            type(c_ptr), value :: typ
            integer(c_size_t), value :: nr, nc, nz
        end function jl_alloc_array_3d

        type(c_ptr) function jl_arrayref(iarr, n) bind(c, name = 'jl_arrayref')
            import
            type(c_ptr), value :: iarr
            integer(c_size_t), value :: n
        end function jl_arrayref

        type(c_funptr) function jl_get_global(modu, psymbol) bind(c, name = 'jl_get_global')
            import
            type(c_ptr), value :: modu
            type(c_ptr), value :: psymbol
        end function jl_get_global

        type(c_ptr) function jl_symbol(ptext) bind(c, name = 'jl_symbol')
            import
            type(c_ptr), value :: ptext
        end function jl_symbol

        type(c_ptr) function jl_call(pfunc, pargs, nargs) bind(c, name = 'jl_call')
            import
            type(c_funptr), value :: pfunc
            type(c_ptr), value :: pargs
            integer(c_int32_t), value :: nargs
        end function jl_call

        type(c_ptr) function jl_call0(pfunc) bind(c, name = 'jl_call0')
            import
            type(c_funptr), value :: pfunc
        end function jl_call0

        type(c_ptr) function jl_call1(pfunc, parg) bind(c, name = 'jl_call1')
            import
            type(c_funptr), value :: pfunc
            type(c_ptr), value :: parg
        end function jl_call1

        type(c_ptr) function jl_call2(pfunc, parg1, parg2) bind(c, name = 'jl_call2')
            import
            type(c_funptr), value :: pfunc
            type(c_ptr), value :: parg1, parg2
        end function jl_call2

        type(c_ptr) function jl_call3(pfunc, parg1, parg2, parg3) bind(c, name = 'jl_call3')
            import
            type(c_funptr), value :: pfunc
            type(c_ptr), value :: parg1, parg2, parg3
        end function jl_call3

    end interface

contains

    subroutine eval(text, iret)
        character(*), intent(in) :: text
        type(c_ptr), intent(out), optional :: iret
        character(:), allocatable, target :: buff
        type(c_ptr) :: kret
        buff = text // achar(0)
        kret = jl_eval_string(c_loc(buff))
       if (present(iret)) iret = kret
    end subroutine eval

    function jl_get_function(modu, namen)
        type(c_ptr), value :: modu
        character(*), intent(in) :: namen
        type(c_funptr) :: jl_get_function
        character(:), allocatable, target :: text
        text = namen // achar(0)
        jl_get_function = jl_get_global(modu, jl_symbol(c_loc(text)))
    end function jl_get_function

end module julia_mod

program array_test
    use julia_mod
    implicit none
    real(real64), parameter :: pi = 4 * atan(1.0_real64)
    character(:), allocatable, target :: text
    type(c_ptr) :: iret, argument, iarr, jarr, iatyp
    real(real64) :: x
    type(c_funptr) :: ifun
    integer, allocatable, target :: m(:)
    integer, pointer :: m2(:)
    integer :: i
    integer(int64) :: n
    type(jl_array_t), pointer :: parr
    character, pointer :: chars(:)
    call jl_init()

    call eval('println("Julia from Fortran")')


    text = 'sqrt(2.0)'//achar(0)
    iret = jl_eval_string(c_loc(text))
    x = jl_unbox_float64(iret)
    print *, 'SQRT(2.0)=', x, sqrt(2.0_real64)

    ifun = jl_get_function(jl_base_module, "sinpi")
    argument = jl_box_float64(1.0_real64 / 6.0_real64)
    iret = jl_call1(ifun, argument)
    x = jl_unbox_float64(iret)
    print *, 'Julia: sinpi(1/6)', x, ' Fortran:sin(pi/6)', sin(pi / 6)

    n = 20
    m = [(i, i = 1, n)]
    iatyp = jl_apply_array_type(jl_int32_type, 1_int64)
    iarr = jl_ptr_to_array_1d(iatyp, c_loc(m), n, 0)


    print '(a, *(i3))', 'array 1:n ', m
    ifun = jl_get_function(jl_base_module, "reverse!")
    iret = jl_call1(ifun, iarr)
    print '(a, *(i3))', 'reverse!  ', m


    ifun = jl_get_function(jl_base_module, "reverse")
    jarr = jl_call1(ifun, iarr)
    call c_f_pointer(jarr, parr)
    call c_f_pointer(parr%dat, m2, [n])
    print '(a, *(i3))', 'reverse   ', m2

    call jl_atexit_hook(0)
end program array_test

実行結果

  • sqrt(2) を Julia 側と Fortran 側の両方で計算しています。
  • Julia 側では π を単位とする sinpi 関数で sinpi(1/6) を、Fortran 側では sin(π/6.0) を計算しています。π を単位とする計算は多くなされますが、初めからπを基準としておくと精度よく計算できることが知られています。Fortran202X で Fortran にも同様な関数が導入される予定です。
  • Fortran 側の要素数 n(=20) 個の整数配列を、Julia 側に渡してINTENT(IN OUT) 的に値の並びを反転させました。
  • 反転した整数配列を intent(in) 的に Julia 側に渡して、もう一度反転して元に戻した配列を関数の戻り値として Fortran 側の配列で受け取りました。
hp8@HP8:~/forjulia$ gfortran -o test -fPIC -L$JULIA_DIR/lib -Wl,-rpath,$JULIA_DIR/lib j.f90 -ljulia
hp8@HP8:~/forjulia$ ./test
Julia from Fortran
 SQRT(2.0)=   1.4142135623730951        1.4142135623730951
 Julia: sinpi(1/6)  0.50000000000000000       Fortran:sin(pi/6)  0.49999999999999994
array 1:n   1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
reverse!   20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3  2  1
reverse     1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20

docs.julialang.org

1から始める Juliaプログラミング

1から始める Juliaプログラミング