5.3.2 モジュールを作ってみる

さて、では作ってみよう。


/*
 * tridmodule.c --- 三重対角行列のLU分解をする Python 用 C モデュール
 */

/* 三重対角行列の LU 分解 (pivoting なし) */
void trilu(int n, double *al, double *ad, double *au)
{
  int i, nm1 = n - 1;
  /* 前進消去 (forward elimination) */
  for (i = 0; i < nm1; i++) {
    al[i + 1] /= ad[i];
    ad[i + 1] -= au[i] * al[i + 1];
  }
}

/* LU 分解済みの三重対角行列を係数に持つ3項方程式を解く */
void trisol(int n, double *al, double *ad, double *au, double *b)
{
  int i, nm1 = n - 1;
  /* 前進消去 (forward elimination) */
  for (i = 0; i < nm1; i++) b[i + 1] -= b[i] * al[i + 1];
  /* 後退代入 (backward substitution) */
  b[nm1] /= ad[nm1];
  for (i = n - 2; i >= 0; i--) b[i] = (b[i] - au[i] * b[i + 1]) / ad[i];
}

#include <Python.h>
#include <numpy/arrayobject.h>
#include <numpy/arrayscalars.h>
#include <stdlib.h>

static PyObject *trid_trilu(PyObject *self, PyObject *args)
{
  int n;
  PyArrayObject *al, *ad, *au;

  if (!PyArg_ParseTuple(args, "iOOO", &n, &al, &ad, &au))
    return NULL;

  if (al->nd != 1 || al->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg2 types does not much");
    return NULL;
  }
  if (ad->nd != 1 || ad->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg3 types does not much");
    return NULL;
  }
  if (au->nd != 1 || au->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg4 types does not much");
    return NULL;
  }
  trilu(n, (double*)al->data, (double*)ad->data, (double*)au->data);
  return Py_BuildValue(""); // return Py_RETURN_NONE; もOK?
}

static PyObject *trid_trisol(PyObject *self, PyObject *args)
{
  int n;
  PyArrayObject *al, *ad, *au, *b;

  if (!PyArg_ParseTuple(args, "iOOOO", &n, &al, &ad, &au, &b))
    return NULL;

  if (al->nd != 1 || al->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg2 types does not much");
    return NULL;
  }
  if (ad->nd != 1 || ad->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg3 types does not much");
    return NULL;
  }
  if (au->nd != 1 || au->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg4 types does not much");
    return NULL;
  }
  if (b->nd != 1 || b->descr->type_num != PyArray_DOUBLE) {
    PyErr_SetString(PyExc_ValueError, "arg5 types does not much");
    return NULL;
  }

  trisol(n,
         (double*)al->data, (double*)ad->data, (double*)au->data,
         (double*)b->data);
  return Py_BuildValue("");
}

static PyMethodDef tridMethods[] = {
  {"trilu", trid_trilu, METH_VARARGS, "LU factorize a tridiagonal matrix"},
  {"trisol", trid_trisol, METH_VARARGS, "solve linear equation"},
  {NULL,NULL,0,NULL} /* Sentinel */
};

static struct PyModuleDef tridmodule = {
  PyModuleDef_HEAD_INIT,
  "trid",   /* name of module */
  NULL, /* module  documentation, may be NULL --- "trid_doc" みたいの */
  -1,        /* size of per-interpreter state of the module,
                or -1 if the module keeps state in global variables. */
  tridMethods
};

// この辺は Python 2 とは全然違う
PyMODINIT_FUNC PyInit_trid(void)
{
  return PyModule_Create(&tridmodule);
}

桂田 祐史
2018-01-07