5.3.1 trid.c

「公開プログラムのページ」 にある trid-lu.c が叩き台。 Cモジュールを実現するための方法は, 「numpyの配列を受け取るCモジュールを作る」 を読んで理解しました。


/*
 * trid.c --- 三重対角行列のLU分解をする Python 用 C モデュール
 *
 *  written by mk, on 5 Januaray 2013.
 *  http://www.math.meiji.ac.jp/~mk/program/linear/trid-lu.c
 *  http://d.hatena.ne.jp/ousttrue/20091205/1260035679
 *  http://codeit.blog.fc2.com/blog-entry-9.html
 *
 */

/* 三重対角行列の LU 分解 (pivoting なし) */
void trilu(int n, double *al, double *ad, double *au)
{
  int i, nm1 = n - 1;
  /* 前進消去 (forward elimination) */
  for (i = 0; i < nm1; i++) {
    al[i + 1] /= ad[i];
    ad[i + 1] -= au[i] * al[i + 1];
  }
}

/* LU 分解済みの三重対角行列を係数に持つ3項方程式を解く */
void trisol(int n, double *al, double *ad, double *au, double *b)
{
  int i, nm1 = n - 1;
  /* 前進消去 (forward elimination) */
  for (i = 0; i < nm1; i++) b[i + 1] -= b[i] * al[i + 1];
  /* 後退代入 (backward substitution) */
  b[nm1] /= ad[nm1];
  for (i = n - 2; i >= 0; i--) b[i] = (b[i] - au[i] * b[i + 1]) / ad[i];
}

void trid(int n, double *al, double *ad, double *au, double *b)
{
  trilu(n,al,ad,au);
  trisol(n,al,ad,au,b);
}

#include <Python.h>
#include <numpy/arrayobject.h>
#include <numpy/arrayscalars.h>
#include <stdlib.h>

PyObject *trid_trilu(PyObject *self, PyObject *args)
{
  int n;
  PyArrayObject *al, *ad, *au;

  if (!PyArg_ParseTuple(args, "iOOO", &n, &al, &ad, &au))
    return NULL;

  if (al->nd != 1 || al->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg2 types does not much");
    return NULL;
  }
  if (ad->nd != 1 || ad->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg3 types does not much");
    return NULL;
  }
  if (au->nd != 1 || au->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg4 types does not much");
    return NULL;
  }
  trilu(n, (double*)al->data, (double*)ad->data, (double*)au->data);
  return Py_BuildValue(""); // return Py_RETURN_NONE; もOK?
}

PyObject *trid_trisol(PyObject *self, PyObject *args)
{
  int n;
  PyArrayObject *al, *ad, *au, *b;

  if (!PyArg_ParseTuple(args, "iOOOO", &n, &al, &ad, &au, &b))
    return NULL;

  if (al->nd != 1 || al->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg2 types does not much");
    return NULL;
  }
  if (ad->nd != 1 || ad->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg3 types does not much");
    return NULL;
  }
  if (au->nd != 1 || au->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg4 types does not much");
    return NULL;
  }
  if (b->nd != 1 || b->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg5 types does not much");
    return NULL;
  }

  trisol(n,
         (double*)al->data, (double*)ad->data, (double*)au->data,
         (double*)b->data);
  return Py_BuildValue("");
}

static PyMethodDef trid_methods[] = {
  {"trilu",  trid_trilu,  METH_VARARGS, "LU factorize a tridiagonal matrix"},
  {"trisol", trid_trisol, METH_VARARGS, "solve linear eq."},
  {NULL,NULL,0,NULL}
};

// tridモデュールなので、inittrid という名前でないといけない
PyMODINIT_FUNC inittrid()
{
  (void)Py_InitModule("trid", trid_methods);
  import_array();
}

桂田 祐史
2017-12-10