Hatena::ブログ(Diary)

mfumiの日記

2014-11-12

PyQt5でさくっとGUIを作る

普段GUIを作成する場合はC++withQtで作成してます.実のところQtを使えば簡単にGUIが作成できるんですが,わざわざQtプロジェクトを作るのもあれだなーという時

Python用のQtバインディングであるPyQt5を使うともっとお手軽にできるのではないかと思って試してみました.(ちなみに,PyQt5は全然ドキュメントがないです.)


最初に一つ言っておくと,Python用のQtバインディングは2種類あって,1つがPyQt,もう一つがPySideです.PySideの方はQt開発元のNokiaがサポートしてます.今回PyQtを利用した理由としたのは2014/11/12現在Qt5をサポートしているのはPyQtのみだったからです.(Python3は両方サポートしてます).重要な点ですが,PyQt5のライセンスGPLもしくは商用ライセンスです(Qt自体はLGPLもしくは商用ライセンス).PyQt4では,GPLに例外事項が設けられていて,PyQt4を用いて作成したスクリプトだけであればMITBSDライセンスで配布することが許可されていましたが,どうやらPyQt5ではこの例外事項が削除されたようです.PySideの方はLGPLで利用できます.機能的にPyQtとPySideでそれほど大きな差はないようですが,特にシグナルとスロットを自分で定義する場合に違いがあるようです.基本的にはQt関数をラップしてあるだけなので,GUIの機能の詳細についてはQtのドキュメントを参照することになります.


以下に載せてある今回作成したプログラムno titleにあります.

インストール

(python3.4, Qt5.3.1を使ってMac OS Xインストールしました).インストールは公式からSIPPyQt5のソースを落とします.

まずSIPインストールします.

$ python configure.py
$ make
$ make install (virtualenvを使っている場合そのままmake installでokです)

SIPというのはC++プログラムpythonで利用できるようにラップしてくれるもので,

これを利用してPyQt5はQtPythonから使えるようにします.

SIPインストールできたらPyQt5をインストールします.configure時にオプションでqmakeの場所を

教えて上げます.

$ python configure.py --qmake /usr/local/Cellar/qt5/5.3.1/bin/qmake
$ make (ちょっと時間がかかる)
$ make install

2014/11/12現在ではまだQt5.3.2はサポートされていないようです(makeに失敗する).


ちなみに,作業ディレクトリに日本語パスが含まれていると上手くいきません.windowsの場合はインストーラを使えばよさそうです.


例1: フロントエンドとしての利用

さて,それでは実際にPyQt5でプログラムを作ってみます.PyQt5で作るときには何よりも楽に速く作れることを目指すことにします(大きいの書くならC++使いますからね).

まず,よくありそうなパターンの1つである,フロントエンドとしてのPyQt5アプリケーションを作ってみました.階乗を計算するだけ(笑)のプログラムです.


f:id:mFumi:20141113001354p:image

no title

#!/usr/bin/env python

from PyQt5.QtWidgets import (QApplication, QWidget,
                             QGridLayout, QVBoxLayout, QHBoxLayout,
                             QLabel, QLineEdit, QPushButton)

def factorial(n):
    if n < 0:
        return -1
    elif n == 0:
        return 1
    else:
        return n * factorial(n-1)

class MainWindow(QWidget):
    def __init__(self, parent=None):
        super(MainWindow, self).__init__(parent)

        self.inputLine = QLineEdit()
        self.outputLine = QLineEdit()
        self.outputLine.setReadOnly(True)

        self.calcButton = QPushButton("&Calc")
        self.calcButton.clicked.connect(self.calc)

        lineLayout = QGridLayout()
        lineLayout.addWidget(QLabel("num"), 0, 0)
        lineLayout.addWidget(self.inputLine, 0, 1)
        lineLayout.addWidget(QLabel("result"), 1, 0)
        lineLayout.addWidget(self.outputLine, 1, 1)

        buttonLayout = QVBoxLayout()
        buttonLayout.addWidget(self.calcButton)

        mainLayout = QHBoxLayout()
        mainLayout.addLayout(lineLayout)
        mainLayout.addLayout(buttonLayout)

        self.setLayout(mainLayout)
        self.setWindowTitle("Factorial")

    def calc(self):
        n = int(self.inputLine.text())
        r = factorial(n)
        self.outputLine.setText(str(r))

if __name__ == '__main__':
    import sys
    app = QApplication(sys.argv)
    main_window = MainWindow()

    main_window.show()
    sys.exit(app.exec_())

この場合は,QWidgetをメインウィンドウとして利用します.コンストラクタでボタンや入力フィールドの配置とシグナルのセットアップをおこないます.Qtを使ったことがあればほぼ違和感なくソースが読めると思います.Qtで書いてるように書けばだいたいOKです.Qtを使ってなくてもやってることは分かりやすいと思います.レイアウトはQHBoxLayoutやQVBoxLayoutに任せるのが楽でいいです.アプリケーションとして重要なのがself.calcButton.clicked.connect(self.calc)の部分で,ここでcalcボタンを押したときにcalcが呼ばれます.

例2: 簡単に描画する

先ほどの例では描画機能がありませんでしたが,実際にGUIが欲しくなるのは何かを描画したいときが多いと思います.ということで今度は描画メインのプログラムを作ってみました.簡単なマルバツです.

f:id:mFumi:20141113001355p:image

no title

#!/usr/bin/env python

from PyQt5.QtCore import (QLineF, QPointF, QRectF, Qt)
from PyQt5.QtGui import (QBrush, QColor, QPainter)
from PyQt5.QtWidgets import (QApplication, QGraphicsView, QGraphicsScene, QGraphicsItem,
                             QGridLayout, QVBoxLayout, QHBoxLayout,
                             QLabel, QLineEdit, QPushButton)

class TicTacToe(QGraphicsItem):
    def __init__(self):
        super(TicTacToe, self).__init__()
        self.board = [[-1, -1, -1],[-1, -1, -1], [-1, -1, -1]]
        self.O = 0
        self.X = 1
        self.turn = self.O

    def reset(self):
        for y in range(3):
            for x in range(3):
                self.board[y][x] = -1
        self.turn = self.O
        self.update()

    def select(self, x, y):
        if x < 0 or y < 0 or x >= 3 or y >= 3:
            return
        if self.board[y][x] == -1:
            self.board[y][x] = self.turn
            self.turn = 1 - self.turn

    def paint(self, painter, option, widget):
        painter.setPen(Qt.black)
        painter.drawLine(0,100,300,100)
        painter.drawLine(0,200,300,200)
        painter.drawLine(100,0,100,300)
        painter.drawLine(200,0,200,300)

        for y in range(3):
            for x in range(3):
                if self.board[y][x] == self.O:
                    painter.setPen(Qt.red)
                    painter.drawEllipse(QPointF(50+x*100, 50+y*100), 30, 30)
                elif self.board[y][x] == self.X:
                    painter.setPen(Qt.blue)
                    painter.drawLine(20+x*100, 20+y*100, 80+x*100, 80+y*100)
                    painter.drawLine(20+x*100, 80+y*100, 80+x*100, 20+y*100)

    def boundingRect(self):
        return QRectF(0,0,300,300)

    def mousePressEvent(self, event):
        pos = event.pos()
        self.select(int(pos.x()/100), int(pos.y()/100))
        self.update()
        super(TicTacToe, self).mousePressEvent(event)

class MainWindow(QGraphicsView):
    def __init__(self):
        super(MainWindow, self).__init__()
        scene = QGraphicsScene(self)
        self.tic_tac_toe = TicTacToe()
        scene.addItem(self.tic_tac_toe)
        scene.setSceneRect(0, 0, 300, 300)
        self.setScene(scene)
        self.setCacheMode(QGraphicsView.CacheBackground)
        self.setWindowTitle("Tic Tac Toe")

    def keyPressEvent(self, event):
        key = event.key()
        if key == Qt.Key_R:
            self.tic_tac_toe.reset()
        super(MainWindow, self).keyPressEvent(event)

if __name__ == '__main__':
    import sys
    app = QApplication(sys.argv)
    mainWindow = MainWindow()

    mainWindow.show()
    sys.exit(app.exec_())

今回はメインウィンドウとしてQGraphicsViewを利用しています(QGraphicsViewはQWidgetを継承してます).Qtの描画機能は,大雑把にいうとQGraphicsSceneを作って,その中にQGraphicsItemを配置して,最終的にQGraphicsSceneをQGraphicsViewにセットすることでおこないます.そこで,マルバツの本体はQGraphicsItemを継承する形で作成しています.QGraphicsItemで大事なのはpaint()関数とboudingRect()関数で,これを定義してあげることで描画がおこなわれます.また,mousePressEvent()を実装することでマウスイベントを,keyPressEvent()を実装することキーイベントがキャッチできます.今回はQGraphicsItemは一つしか使ってませんが,もちろん複数のQGraphicsItemを使う事ができます.QGraphicsItemを使うと,マウスイベントのディスパッチや衝突判定をQt側でおこなってくれるので,例えばグラフを可視化するプログラムを作って,グラフの頂点をマウスドラッグで動かせるようにするといったことが簡単にできます.ボタンや入力フィールドが無い場合はこの構成で十分だと思います.

例3: QWidgetとQGraphicsViewの組み合わせ

個人的にこれが一番使うかもしれません.メインウィンドウをQWidgetで作成して,その中にQGraphicsViewと操作用のボタンや入力フィールドを配置します.例として,1次元のセルオートマトンシミュレートするプログラムを作ってみました.


f:id:mFumi:20141113001353p:image

no title


このプログラムでは,QTimerを利用することでアニメーションをしています.QTimerの使い方は,Qtimerのインスタンスを作成してtimeoutシグナルに対し呼び出したい関数をconnectし,timerを起動させるという流れになります.このとき,QGraphicsItem自体はQObjectを継承していないためシグナルを扱う事が出来ないので*1,メインウィンドウ側でQTimerを実行し,定期的に描画をおこなう構成にしています.いろいろやり方はありますがこれがシンプルなんじゃないかなと思います.

例4: matplotlibtとの連携

PyQt5を使ってうれしい機能の一つに,matplotlibと簡単に連携できることがあげられます.matplotlibは1.4.0からPyQt5をサポートしてます.以下がmatplotlibをPyQt5から使用する例です.

f:id:mFumi:20141113001352p:image

no title

#!/usr/bin/env python

import random
import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

from PyQt5.QtCore import (QLineF, QPointF, QRectF, Qt, QTimer)
from PyQt5.QtGui import (QBrush, QColor, QPainter)
from PyQt5.QtWidgets import (QApplication, QGraphicsView, QGraphicsScene, QGraphicsItem,
                             QGridLayout, QVBoxLayout, QHBoxLayout, QSizePolicy,
                             QLabel, QLineEdit, QPushButton)

# FigureCanvas inherits QWidget
class MainWindow(FigureCanvas):
    def __init__(self, parent=None, width=4, height=3, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = fig.add_subplot(111)
        self.axes.hold(False)

        super(MainWindow, self).__init__(fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        timer = QTimer(self)
        timer.timeout.connect(self.update_figure)
        timer.start(50)

        self.x  = np.arange(0, 4*np.pi, 0.1)
        self.y  = np.sin(self.x)

        self.setWindowTitle("Sin Curve")

    def update_figure(self):
        self.axes.plot(self.x, self.y)
        self.y = np.roll(self.y,-1)
        self.draw()

if __name__ == '__main__':
    import sys
    app = QApplication(sys.argv)
    mainWindow = MainWindow()

    mainWindow.show()
    sys.exit(app.exec_())

matplotlibのFigureCanvasQTAggを継承してクラスを作って,その中でdraw()を呼ぶとグラフが描画されます.アニメーションをしたければQTimerを使ってdraw()を定期的に呼ぶようにすればokです(簡単!).こんなことができるのはFigureCanvasQTAggが(QGraphicsItemではなく)QWidgetを継承しているからです.QWidgetを継承しているのでQTimerのtimeoutシグナルも使えます.もちろんこれは一番単純な例で,セルオートマトンの例のように自分でQWidgetを使ってメインウィンドウを作り,その中にmatplotlibを埋め込むことができます.

例4: uiファイル(Qt Designer)を使う

さて,これまではQtuiファイルを活用してきませんでしたが,ボタンや入力フィールドがたくさんあるような場合だと やはりQt Designerでuiファイルを作成した方が楽です. uiファイルをPyQt5で使う方法は単純で,まずQt Designerを使ってuiファイルを作成した後,pyuic5コマンドを利用してuiファイルをpythonスクリプトに変換します.以下に例を示します.

f:id:mFumi:20141113001351p:image

https://github.com/mfumi/pyqt5-example/tree/master/echo

echo_ui.py

# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'echo.ui'
#
# Created: Wed Nov 12 18:33:05 2014
#      by: PyQt5 UI code generator 5.3.2
#
# WARNING! All changes made in this file will be lost!

from PyQt5 import QtCore, QtGui, QtWidgets

class Ui_Echo(object):
    def setupUi(self, Echo):
        Echo.setObjectName("Echo")
        Echo.resize(250, 200)
        self.textBrowser = QtWidgets.QTextBrowser(Echo)
        self.textBrowser.setGeometry(QtCore.QRect(40, 90, 180, 79))
        self.textBrowser.setObjectName("textBrowser")
        self.horizontalLayoutWidget = QtWidgets.QWidget(Echo)
        self.horizontalLayoutWidget.setGeometry(QtCore.QRect(40, 30, 180, 41))
        self.horizontalLayoutWidget.setObjectName("horizontalLayoutWidget")
        self.horizontalLayout = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget)
        self.horizontalLayout.setContentsMargins(0, 0, 0, 0)
        self.horizontalLayout.setObjectName("horizontalLayout")
        self.inputEdit = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
        self.inputEdit.setObjectName("inputEdit")
        self.horizontalLayout.addWidget(self.inputEdit)
        self.submitButton = QtWidgets.QPushButton(self.horizontalLayoutWidget)
        self.submitButton.setObjectName("submitButton")
        self.horizontalLayout.addWidget(self.submitButton)

        self.retranslateUi(Echo)
        QtCore.QMetaObject.connectSlotsByName(Echo)

    def retranslateUi(self, Echo):
        _translate = QtCore.QCoreApplication.translate
        Echo.setWindowTitle(_translate("Echo", "Echo"))
        self.submitButton.setText(_translate("Echo", "Submit"))

echo.py

#!/usr/bin/env python

from PyQt5.QtWidgets import (QApplication, QWidget,
                             QGridLayout, QVBoxLayout, QHBoxLayout,
                             QLabel, QLineEdit, QPushButton)
import echo_ui

class MainWindow(QWidget):
    def __init__(self, parent=None):
        super(MainWindow, self).__init__(parent)
        self.ui = echo_ui.Ui_Echo()
        self.ui.setupUi(self)
        self.ui.submitButton.clicked.connect(self.submit)

    def submit(self):
        text = self.ui.inputEdit.text()
        self.ui.textBrowser.append(text)

if __name__ == '__main__':
    import sys
    app = QApplication(sys.argv)
    main_window = MainWindow()

    main_window.show()
    sys.exit(app.exec_())

pyuic5で生成されたファイルは見ればわかるようにUIセットアップする関数を含んでいるので,あとはこれを呼ぶだけです.便利!!!

まとめ

ということで実際にPyQt5を使っていくつかアプリケーションを作ってみました.ここにあるのはどれも慣れれば数十分で作成できます.まぁ,これはPyQt5というよりそもそもQt5が使いやすいからというのが大部分を占めてると思いますが.Qtはqmakeあるいはcmakeを前提にしているとはいえ,Qt Creatorを使えばかなり簡単に作成できるので,PyQt5を使うのと速度的にそれほど大きな差はないかもしれません.でも小さいものではやはりコンパイル無しの1ファイルで作れるPyQt5は楽だと思いますし,pythonで何か描画したいというときに重宝すると思いました*2.今回作ったサンプルもベースとして結構活用できるんじゃないかなと思います.

*1:QGraphicsObjectを使えばできますが

*2:勿論,もっと複雑なものも作れます

2014-04-29

CPythonのgrennlet(グリーンスレッド)の実装について

OSが管理するスレッドと違って,ユーザのプログラムによって管理されるスレッドのことをグリーンスレッドといいます.他にもマイクロスレッドとか軽量スレッドとかいったりすることもあるようです.ネイティブスレッド(OSが管理するスレッド)と比較したときのグリーンスレッドの特徴として.

スケジューリングをユーザプログラム自身がおこなう

・一般にスレッドの切り替えコストはネイティブスレッドより低い

・一つのプログラム内において実行されるグリーンスレッドは一つ (マルチコアでも,OSから見るとあくまでユーザプログラムは一つなので,グリーンスレッドが並列実行されることがない)などが上げられます.



greenletとは,PythonC言語実装であるCPythonにおけるグリーンスレッドの実装の一つです.(2014/4/29現在,greenletはpython2.xのみをサポートしています.また,以下での説明でgreenletのバージョンは0.4.2です.)

greenletを使用することでプログラムの処理を一旦停止して別の処理をおこなうコルーチンが簡単に実現できます.公式のサンプルコードそのままですが,

    from greenlet import greenlet

    def test1():
        print 12
        gr2.switch()
        print 34

    def test2():
        print 56
        gr1.switch()
        print 78

    gr1 = greenlet(test1)
    gr2 = greenlet(test2)
    gr1.switch()

このようにして実行すると,

12
56
34

が表示されます.greenlet()で新たなグリーンスレッド(greenletと呼びます)を作成し,switch()でそのgreenletに処理を切り替えます.greenletには親子関係が存在し,子が終了したら親の処理が再開されます.特に親を指定しない場合はメインスレッドが親になります.この例ではtest1()が終了したあと,メインスレッドに戻ってプログラムが終了します(test2()の残りの部分は実行されずに終わります).

プログラムを一旦停止するだけならpythonにはジェネレータがあるのでyieldを使えばできますが,greenletを使うことで関数の呼び出し階層にかかわらず同一スレッド内なら任意のgreenletへ処理を切り替える事が出来ます.


さて,このgreenletの中身ですが,pythonで実装されているのではなく,CPythonのモジュール(C言語拡張)として実装されています.CPythonの拡張方法の詳細はここに書いてあります.

greenletのオブジェクト(PyGrrenlet)は以下のように定義されています.


greenlet.h

typedef struct _greenlet {
    PyObject_HEAD
    char* stack_start;
    char* stack_stop;
    char* stack_copy;
    intptr_t stack_saved;
    struct _greenlet* stack_prev;
    struct _greenlet* parent;
    PyObject* run_info;
    struct _frame* top_frame;
    int recursion_depth;
    PyObject* weakreflist;
    PyObject* exc_type;
    PyObject* exc_value;
    PyObject* exc_traceback;
    PyObject* dict;
} PyGreenlet;

ここで特に重要なのが,stack_start, stack_stop, stack_copy, stack_prev, parentあたりです.一般にプログラムの処理を切り替える(コンテキストスイッチ)ために何が必要かというと,各種レジスタを保存することと,スタックを切り替えることです.greenletでもスタック上にレジスタを退避させ,そしてそのスタックを切り替えることで,コンテキストスイッチを実現します.greenletにおけるスタックレイアウトは以下のようになっています(ソースコードのコメントより)

Stack layout for a greenlet:

               |     ^^^       |
               |  older data   |
               |               |
  stack_stop . |_______________|
        .      |               |
        .      | greenlet data |
        .      |   in stack    |
        .    * |_______________| . .  _____________  stack_copy + stack_saved
        .      |               |     |             |
        .      |     data      |     |greenlet data|
        .      |   unrelated   |     |    saved    |
        .      |      to       |     |   in heap   |
 stack_start . |     this      | . . |_____________| stack_copy
               |   greenlet    |
               |               |
               |  newer data   |
               |     vvv       |

PyGreenletのstack_stopがそのgreenletのスタックの一番下,stack_startがスタックの一番上を指します.またスタック領域の一部は他のgreenletを実行する際にはヒープに退避されます.ヒープ領域を覚えるために必要なのがstack_copyとstack_savedです.全ての領域をヒープに退避せずに一部だけ退避させるのはメモリ効率のためでしょうか.


greenletのsiwtch()を呼んだとき,一般に以下のような順番で関数が呼ばれます.


g_switch() => g_switchstack() => g_slp_switch() => SLP_SAVESTATE => SLP_RESTORE_STATE


まずはg_switch()から見てきます.

static PyObject *
g_switch(PyGreenlet* target, PyObject* args, PyObject* kwargs)
{
    ...
    /* find the real target by ignoring dead greenlets,
       and if necessary starting a greenlet. */
    while (target) {
        if (PyGreenlet_ACTIVE(target)) {
            ts_target = target;
            err = g_switchstack();  # スタックの切り替え
            break;
        }
        if (!PyGreenlet_STARTED(target)) {
            void* dummymarker;
            ts_target = target;
            err = g_initialstub(&dummymarker);
            if (err == 1) {
                continue; /* retry the switch */
            }
            break;
        }
        target = target->parent;
    }
    ...
}

PyGreenlet_ACTIVE(target)が真のとき,つまり切り替え対象のgreenletが既に一度は実行されているとき,g_switchstack()を呼びます.

static int g_switchstack(void)
{
    ...
    int err;
    {   /* save state */
        PyGreenlet* current = ts_current;   // ts_currentが現在のgreenlet
        PyThreadState* tstate = PyThreadState_GET();
        current->recursion_depth = tstate->recursion_depth;
        current->top_frame = tstate->frame;
        current->exc_type = tstate->exc_type;
        current->exc_value = tstate->exc_value;
        current->exc_traceback = tstate->exc_traceback;
    }
    err = slp_switch();
    if (err < 0) {   /* error */
        ...
    }
    else {
        PyGreenlet* target = ts_target;
        PyGreenlet* origin = ts_current;
        ...
    }
    return err;
}

g_switchstack()では現在の状態をcurrentに保存したあと,slp_switch()を呼びます.このslp_switch()が実際のレジスタの退避とスタックの切り替えを行いますが,そういった処理はアセンブラを使わなければできないため,CPUごとにplatform/switch_amd64_unix.c のように定義されています.


switch_amd64_unix.c


#define REGS_TO_SAVE "r12", "r13", "r14", "r15"

static int
slp_switch(void)
{
    int err = 0;
    void* rbp;
    void* rbx;
    unsigned int csr;
    unsigned short cw;
    register long *stackref, stsizediff;
    __asm__ volatile ("" : : : REGS_TO_SAVE);
    __asm__ volatile ("fstcw %0" : "=m" (cw));
    __asm__ volatile ("stmxcsr %0" : "=m" (csr));
    __asm__ volatile ("movq %%rbp, %0" : "=m" (rbp));
    __asm__ volatile ("movq %%rbx, %0" : "=m" (rbx));
    __asm__ ("movq %%rsp, %0" : "=g" (stackref));
    {
        SLP_SAVE_STATE(stackref, stsizediff);
        __asm__ volatile (
            "addq %0, %%rsp\n"
            "addq %0, %%rbp\n"
            :
            : "r" (stsizediff)
            );
        SLP_RESTORE_STATE();
        err = fancy_return_zero();
    }
    __asm__ volatile ("movq %0, %%rbx" : : "m" (rbx));
    __asm__ volatile ("movq %0, %%rbp" : : "m" (rbp));
    __asm__ volatile ("ldmxcsr %0" : : "m" (csr));
    __asm__ volatile ("fldcw %0" : : "m" (cw));
    __asm__ volatile ("" : : : REGS_TO_SAVE);
    return err;
}

前半部分でスタック上に必要なレジスタを退避し,SLP_SAVE_STATE()とその次の部分でスタックの切り替えをおこないます.

    __asm__ volatile ("" : : : REGS_TO_SAVE);

とするとgccインラインアセンブラの記述によってREGS_TO_SAVEで指定したレジスタがワーカレジスタであることがコンパイラに伝わるので,これらのレジスタを退避するコードをコンパイラが生成してくれます.具体的には,

 0x100000f54:  pushq  %r15
 0x100000f56:  pushq  %r14
 0x100000f58:  pushq  %r13
 0x100000f5a:  pushq  %r12
 ... (処理)
 0x100000f72:  popq   %r12
 0x100000f74:  popq   %r13
 0x100000f76:  popq   %r14
 0x100000f78:  popq   %r15

のようになります.r12-r15を退避させているのはx64の呼び出し規約によります.fstcwでFPUの制御レジスタの退避,stmxcsrでSSE関連の制御レジスタの退避をします."=m"(cw)のようにしているので,スタック上のローカル変数に値が退避されていることが分かります.スタックが切り替わった後半部分ではレジスタの復元をおこないます.スタックが切り替わっているのでこの関数から戻った時点で既に切り替え対象だったgreenletに処理が移っています.分かりにくいですが前半部分で処理を停止したgreenletが次に再開されるときはSLP_RESTORE_STATE()からということです.(ところで一番最後のREGS_TO_SAVEの文はいらないような..?)


SLP_SAVE_STATE()とSLP_RESTORE_STATE()で一体何をしているかというと,この2つはマクロとしてgreenlet.cに定義されています.

#define SLP_SAVE_STATE(stackref, stsizediff)        \
  stackref += STACK_MAGIC;              \
  if (slp_save_state((char*)stackref)) return -1;   \
  if (!PyGreenlet_ACTIVE(ts_target)) return 1;      \
  stsizediff = ts_target->stack_start - (char*)stackref

#define SLP_RESTORE_STATE()         \
  slp_restore_state()

既にgreenletが存在している場合(switch()の呼び出しが2回目以降),SLP_SAVE_STATE()はslp_save_state()関数を呼び,さらにstsizediffに切り替え先のgreenletと現在のスタックの先頭との差を格納します.SLP_SAVE_STATE()の次にあるインラインアセンブラ部分で%rspと%rbpにその差を加える事でスタックの切り替えを実現している訳です.

slp_save_state()では,ヒープに必要なだけスタックの退避をおこないます.つまり,切り替え先のstack_stop(target_stop)が,現在のstack_stopよりも上にあるなら,その分をヒープに格納します.

static int GREENLET_NOINLINE(slp_save_state)(char* stackref)
{
    /* must free all the C stack up to target_stop */
    char* target_stop = ts_target->stack_stop; # 切り替え先のスタックの下限
    PyGreenlet* owner = ts_current;
    assert(owner->stack_saved == 0);
    if (owner->stack_start == NULL)
        owner = owner->stack_prev;  /* not saved if dying */
    else
        owner->stack_start = stackref;
    
    while (owner->stack_stop < target_stop) {
        /* ts_current is entierely within the area to free */
        if (g_save(owner, owner->stack_stop))
            return -1;  /* XXX */
        owner = owner->stack_prev;
    }
    if (owner != ts_target) {
        if (g_save(owner, target_stop))
            return -1;  /* XXX */
    }
    return 0;
}

g_save()で実際にヒープへのコピーをおこないます.

static int g_save(PyGreenlet* g, char* stop)
{
    /* Save more of g's stack into the heap -- at least up to 'stop'

       g->stack_stop |________|
                     |        |
                     |    __ stop       . . . . .
                     |        |    ==>  .       .
                     |________|          _______
                     |        |         |       |
                     |        |         |       |
      g->stack_start |        |         |_______| g->stack_copy

     */
    intptr_t sz1 = g->stack_saved;
    intptr_t sz2 = stop - g->stack_start;
    assert(g->stack_start != NULL);
    if (sz2 > sz1) {
        char* c = (char*)PyMem_Realloc(g->stack_copy, sz2);
        if (!c) {
            PyErr_NoMemory();
            return -1;
        }
        memcpy(c+sz1, g->stack_start+sz1, sz2-sz1);
        g->stack_copy = c;
        g->stack_saved = sz2;
    }
    return 0;
}

slp_restore_state()では逆にヒープに退避していたスタック領域をスタックに復元します.実行するgreenletのスタック領域は必ず全てスタックになければいけないため,ヒープに退避していたもの全てを復元します(復元しても良いように事前にslp_save_state()で重複する部分はヒープに退避されています).

static void GREENLET_NOINLINE(slp_restore_state)(void)
{
    PyGreenlet* g = ts_target;
    PyGreenlet* owner = ts_current;
    
    /* Restore the heap copy back into the C stack */
    if (g->stack_saved != 0) {
        memcpy(g->stack_start, g->stack_copy, g->stack_saved);
        PyMem_Free(g->stack_copy);
        g->stack_copy = NULL;
        g->stack_saved = 0;
    }
    if (owner->stack_start == NULL)
        owner = owner->stack_prev; /* greenlet is dying, skip it */
    while (owner && owner->stack_stop <= g->stack_stop)
        owner = owner->stack_prev; /* find greenlet with more stack */
    g->stack_prev = owner;
}

ということで以上がgreenletにおけるコンテキストスイッチの概要ですが, 一番最初に greenletを作って実行するときはどのようになっているでしょうか.


まず,最初にgreeletオブジェクトを作成するときにはgreen_new()が呼ばれます.

static PyObject* green_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
    PyObject* o;
    if (!STATE_OK)
        return NULL;
    
    o = type->tp_alloc(type, 0);
    if (o != NULL) {
        Py_INCREF(ts_current);
        ((PyGreenlet*) o)->parent = ts_current;
    }
    return o;
}

ここで,STATE_OKは以下のマクロで,

#define STATE_OK    (ts_current->run_info == PyThreadState_GET()->dict \
            || !green_updatecurrent())

最初はgreen_updatecurrent()が呼ばれます.

static int green_updatecurrent(void)
{
    ...
    /* first time we see this tstate */
    current = green_create_main();
    ...
    ts_current = current;
    ...
    return 0;
}

この中でgreen_create_main()が呼ばれ,メインスレッドのgreenletが作成され,それがts_currentになります.

static PyGreenlet* green_create_main(void)
{
    PyGreenlet* gmain;
    PyObject* dict = PyThreadState_GetDict();
    if (dict == NULL) {
        if (!PyErr_Occurred())
            PyErr_NoMemory();
        return NULL;
    }

    /* create the main greenlet for this thread */
    gmain = (PyGreenlet*) PyType_GenericAlloc(&PyGreenlet_Type, 0);
    if (gmain == NULL)
        return NULL;
    gmain->stack_start = (char*) 1;
    gmain->stack_stop = (char*) -1;
    gmain->run_info = dict;
    Py_INCREF(dict);
    return gmain;
}

したがって,最初に作成したgreenletの親はメインスレッドのgreenletになります.(その後green_init()が呼ばれ,明示的に親を指定した場合はそれが親になります).さて,新規のgreenletに対してswitch()をした場合,g_switch()内からg_initialstub()が呼ばれます.

static PyObject *
g_switch(PyGreenlet* target, PyObject* args, PyObject* kwargs)
{
    ...
    /* find the real target by ignoring dead greenlets,
       and if necessary starting a greenlet. */
    while (target) {
        if (PyGreenlet_ACTIVE(target)) {
            ts_target = target;
            err = g_switchstack()
            break;
        }
        if (!PyGreenlet_STARTED(target)) {
            void* dummymarker;
            ts_target = target;
            err = g_initialstub(&dummymarker); # 初めてgreenletを実行する場合
            if (err == 1) {
                continue; /* retry the switch */
            }
            break;
        }
        target = target->parent;
    }
    ...
}

g_initialstub()では,初期設定ののち,g_switchstack()をしてスタックを切り替えます.

static int GREENLET_NOINLINE(g_initialstub)(void* mark)
{
    ...
    /* start the greenlet */
    self->stack_start = NULL;
    self->stack_stop = (char*) mark;
    if (ts_current->stack_start == NULL) {
        /* ts_current is dying */
        self->stack_prev = ts_current->stack_prev;
    }
    else {
        self->stack_prev = ts_current;
    }
    ...

    /* restore arguments in case they are clobbered */
    ts_target = self;
    ts_passaround_args = args;
    ts_passaround_kwargs = kwargs;

    /* perform the initial switch */
    err = g_switchstack();   # スタックの切り替え

    /* returns twice!
       The 1st time with err=1: we are in the new greenlet
       The 2nd time with err=0: back in the caller's greenlet
    */
    if (err == 1) {
        ...

        if (args == NULL) {
            /* pending exception */
            result = NULL;
        } else {
            /* call g.run(*args, **kwargs) */
            result = PyEval_CallObjectWithKeywords(   # 実際の処理の開始
                run, args, kwargs);
            ...
        }
        ...
        # greenletの処理が終了したので,親を再開させる
        /* jump back to parent */
        self->stack_start = NULL;  /* dead */
        for (parent = self->parent; parent != NULL; parent = parent->parent) {
            result = g_switch(parent, result, NULL);  # ここで親に処理を切り替える.成功すればもう戻ってこない
            /* Return here means switch to parent failed,
             * in which case we throw *current* exception
             * to the next parent in chain.
             */
            assert(result == NULL);
        }
        /* We ran out of parents, cannot continue */
        PyErr_WriteUnraisable((PyObject *) self);
        Py_FatalError("greenlets cannot continue");
    }
    ...
    return err;
}

ここで,重要なのが,ソースコメントにも書いてあるとおり,g_switchstack()からは2回リターンするということです(...!!!).

どういう事かというと,g_switchstack()を実行してスタックを切り替えると,スタックが切り替わった時点で既に実行は新しいgreenletになっていますが,この新しいgreenletの状態でg_switchstack()がリターンします.そしてしばらくした後,その新しいgreenletが終了(あるいは明示的にswitch()した場合),切り替え元のgreenletの処理が再開されるわけですが,その切り替え元のgreenletの状態でg_switchstack()が再びリターンするということです.g_switchstack()から戻ってきたとき,新しいgreenletなのか,それとも切り替えもとのgreenletなのかでg_switchstack()の戻り値は異なります.

これはslp_switch()をよく見ればわかります.slp_switch()で通常の切り替えが発生した場合,その戻り値は0です.しかし,切り替え先のgreenletが新しい場合,つまりヒープやスタックから情報を復元する必要がない場合,SLP_SAVE_STATE()でreturn 1が返されます.

#define SLP_SAVE_STATE(stackref, stsizediff)		\
  stackref += STACK_MAGIC;				\
  if (slp_save_state((char*)stackref)) return -1;	\
  if (!PyGreenlet_ACTIVE(ts_target)) return 1;		\  # ここが新しいgreenletの場合実行される
  stsizediff = ts_target->stack_start - (char*)stackref

g_switchstack()の戻り値が1の場合,新しいgreenletなのでそのgreenletを実行します.そして,そのgreenletが終了した場合親の処理を再開させます.勿論.そのgreenletが新たに別のswitch()を呼んで処理を切り替えることもあるでしょう.


ということで以上がgreenletの処理の概要になります.つくづくよくできてますね (´ρ`)

2014-02-13

グラフ(ネットワーク)を奇麗に描画するアルゴリズム

グラフはノードと辺の集合から構成されているだけなので,その描画方法は任意です.例として,以下の3つのグラフはどれも同じです.


f:id:mFumi:20140213192203p:image:w360


f:id:mFumi:20140213192202p:image:w360


f:id:mFumi:20140213192201p:image:w360


グラフをどうやって奇麗に描画するかという研究は昔からおこなわれていて,そのうちの一つに力学モデルがあります(Wikipedia).このモデルではノード間にある力が作用すると考えてそれが安定するような場所を探します.簡単な例ではノードを電荷だと仮定してそれぞれに働くクーロン力を計算し,それをもとにノード位置の修正をおこないます.


力学モデルのアルゴリズムで代表的なものにFruchterman-Reingold アルゴリズムがあります.このアルゴリズムは多くのグラフを扱うライブラリでサポートされているようです.

このアルゴリズムの考え方はシンプルで,ノードに影響する力が2つあります.一つは自分以外の全てのノードからの斥力,もう一つが隣接ノードからの引力です.斥力f_rと引力f_aは以下の式で定義されます.




ここで k は以下の式で定義されます.



areaは描画領域の面積(横幅W,縦幅Lだとするとarea = W*L),|V|は頂点数,cはある定数です.f_rとf_aの関係を図にすると以下のようになります(k=10の場合).


f:id:mFumi:20140213192405p:image:w360


図に示したように,d=kのときf_a + f_rは0となります.アルゴリズムではdをノード間の距離だとして,それぞれのノードがなるべくkだけ離れているような場所を探します.つまり,kはノード間の理想距離として定義されています.もしあるノード間の距離がkよりも大きくはずれていたらそれを修正するように力が働きます.

斥力と引力の式は実験的に求めたもので,どちらもフックの法則に似ています.論文では斥力・引力の候補として




を考えてみたものの,この場合だと極所解に陥ることが多かったということが書かれています.


また,もう一つアルゴリズムにおける重要な考えとして,温度パラメータtがあります.この温度パラメータは,変位の最大値を決定します.つまり,引力と斥力の影響を考慮してノードを動かすわけですが,あまり大きく動いてしまっても困るので,t以下に制限するわけです.いまいちこの温度パラメータに関しては論文に載っていないような気がしますが,一つの例としてはtの初期値を病が領域の横幅の10分の1とし,それを1回の繰り返しごとに徐々に減少させていく方法があります.


実際のアルゴリズムの処理の流れは以下のようになります.


1. 各頂点間に働く引力/斥力を計算

2. 1. で計算した値をもとに頂点を動かす.ただし,このとき枠外に出ないようにする.また,動かす変位の大きさは温度パラメータt以下とする.

3. 温度パラメータを下げる

4. 以上の処理をある一定回数繰り返す.


基本的な計算量はO(|V|^2 + |E|)になります.計算を高速化する手法として論文では領域をいくつかに分割し,ある領域内にいるノードはその領域内の他のノードからのみ力を受けると仮定して計算する方法が書かれています(この辺りちゃんと読んでない..).計算の終了条件がただのループのみなので,一定時間での計算が保障されています.斥力は自分以外の全てのノードから,引力は自分が繋がっているノードからのみ受けるという点も割とポイントだと思います.


さて,このアルゴリズムを実際にpythonで実装をしてみます.グラフ自体の表現にはNetworkXを使います (なお,NetworkXではspring_layoutがFruchterman-Reingold アルゴリズムの実装です.デフォルトでNetowkXで描画するとこのアルゴリズムが使用されます).論文に擬似コードが載っていますので,そのまま素直に実装します.



fruchterman_reingold()が実際のアルゴリズム部分です.この実装では領域は中心(0.5,0.5),半径0.5の円としています(領域を四角にしてしまうとグラフが四角くなりやすくなるため).

実際にグラフの構造が変化する流れは以下のようになります(この例ではわざとtを小さくとって1回の変化を小さくしています).

f:id:mFumi:20140213192411g:image

before

f:id:mFumi:20140213192409p:image


after

f:id:mFumi:20140213192410p:image

NetworkXにもともとあるfruchterman-reingoldを使って描画すると以下のようになります.

f:id:mFumi:20140213192412p:image


こんなものですかね.NetworkXの方が若干良さそうな感じがしなくもないですが.


別の例

f:id:mFumi:20140213192922g:image


ということで以上のような感じで比較的シンプルにグラフを奇麗に描画することができます.ちなみに上のpythonの例ではほぼ擬似コードそのままで最適化もなにもしていないしそもそも計算方法自体が最も基本的なものなので遅いです.まぁ参考ということで.NetworkXなら素直にspring_layout使いましょう.

2014-02-12

NetworkXによるスモールワールドネットワークの生成

NetworkXpython製の複雑ネットワークのためのライブラリです.NetworkXを使うと,お手軽にグラフ構造が作成できます.平均経路長(2点のノード間の平均距離)やクラスタリング係数(隣接ノード同士が接続している割合)といったネットワークの特徴量を求める関数が用意されているのは勿論のこと,経路探索や,さらにはPageRankやHITSといったものも簡単に計算できます.また,matplotlibを使うことで描画もできます.

インストール

$ pip install networkx matplotlib

でできます(ちなみにMac OS 10.9ではln -s /usr/local/opt/freetype/include/freetype2 /usr/local/include/freetypeしないと駄目でした).

NetworkXのサンプルは以下のようになります(以下python3です).

#!/usr/bin/env python
import networkx as nx
import matplotlib.pyplot as plt

G = nx.Graph()
G.add_node(1)
G.add_node(2)
G.add_node(3)
G.add_edge(1,2)
G.add_edge(2,3)
nx.draw(G)
plt.savefig("a.png")

これで以下のグラフが生成されます.

f:id:mFumi:20140212222847p:image

見たままですね.ちなみにadd_node()では基本的にどんなオブジェクトでも追加できます.今回はこのNetworkXを使ってスモールワールドネットワークの検証をしてみることにしました.


スモールワールドネットワークとは複雑系ネットワークに分類されるネットワークの一種で,簡単にいってしまうと

(1)全体のノード数に対して各ノードが持つリンク数(辺の数)が小さいが,平均経路長は短い,

(2)クラスタリング係数が高い

という性質を持つネットワークです.これは現実世界のネットワークがよく持つ性質です.例えば,人間関係のネットワークを考えてみると,全人口に対して一人の人間が持つリンク数(知人数)は小さいです.また,ある人の友人同士が友人である確率はそれ以外の場合に比べると高い(クラスタリング係数は高い)と考えられます.平均経路長が短いというのは直感には合わないかもしれませんが,有名なミルグラムの実験の結果で示されるように,どんな人でも友人を辿っていけば6人以下でほぼ到達できることが分かっています.

1998年にWatssらがこのスモールワールドネットワークを定式化し,生成する方法を発表しました(論文はこれ).この論文の登場によって一気に複雑系ネットワークの研究が進んだそうで,引用数2万overとよく分からないことになってます.

この論文でのスモールワールドネットワーク生成方法は発表者の氏名をとってWatts & Strogtzモデル(WSモデル)と呼ばれます.WSモデルの生成方法は以下のようになります.


1. グラフ内にN個のノードを追加する(それぞれ1~Nと番号をつける)

2. 各ノードを隣接k個のノードと接続する (例えばk=1ならノード1はノード2とノードN と接続)

3. グラフ内の各辺それぞれについて,確率pでリンクを適当なノードに繋ぎ変える(重複は禁止)


ここでp=0の場合は何も繋ぎ変えが起こらず,この場合平均経路長は長く,クラスタリング係数は大きくなります.またp=1の場合はいわゆるランダムネットワークが生成され,平均経路長が短く,クラスタリング係数は小さくなります.pが中間ぐらいのとき,先ほどの言った平均経路長が短くクラスタリング係数が高いというスモールワールドネットワークが生成されます.

NetworkXにはこのWSモデルでネットワークを生成する関数が当然ながら用意されているのですが,まぁ折角なので自分で論文の方法に忠実に作ってみました.



make_small_world_netork()がWSモデルによるスモールワールドネットワーク生成関数です.ただ作るだけでは意味がないので,論文と同じように確率pを変化させ平均経路長L(p)とクラスタリング係数C(p)のp=0に対する比のグラフを生成してみます.ノード数と各ノード辺の数はそれぞれ1000と10です.


f:id:mFumi:20140212222853p:image


(データはこちら.左からp,L(p)/L0,C(p)/C0です)

論文のFigure2と同じような結果が得られました(ならないと困りますが..).x軸は対数スケールです.このグラフから(1)平均経路長はp=0.01でも十分小さい(p=1のランダムネットワークの平均経路長に近い),(2)一方pが1付近でなければクラスタリング係数はランダムネットワークよりも十分大きい(p=0のクラスタリング係数に近い),ということが分かります.これがスモールワールドネットワークの持つ性質です.


さて,論文の後半ではこのスモールワールドネットワークを使って簡単な病気の伝染のシミュレーションをおこなっているので,それもやってみることにします.シミュレーションの流れは以下のようになります.


1. スモールワールドネットワークを生成する.(各ノードが人を表す)

2. ノード0 が伝染病に感染しているとする(初期条件)

3. 確率rで感染しているノードの隣接ノードに感染する

4. 感染したノードは一定ステップ後に取り除かれる(死亡)

5. 全てのノードが感染 or 感染しているノードが存在しなくなったらシミュレーションを終了


論文にはノードの大きさや辺の数が書いていないようなので,先ほどと同じノード数1000,各ノードのリンク数10,また感染してから死亡するまでの時間を2としてシミュレーションをしてみました.コードは以下のようになります.


実行結果

f:id:mFumi:20140212222851p:image


f:id:mFumi:20140212222852p:image


最初の図は各pに対してrを変化させ,半分以上のノードが感染する確率r_halfをグラフにしたものです.また,2番目の図はr=1(常に隣接ノードに感染)と固定した場合に全てのノードが感染するまでにかかる時間T(p)をp=0のときの値T0で割ったグラフです(データはこちら.左からp,r_half,T(p)/T0,L(p)/L0です).x軸は先ほど同様に対数スケールです.これらの図は論文のFigure3に相当します.この図から(1)r_halfはpが小さいところで急激に減少する(p > 0.1 ではあまり変化がない),(2)T(p)/T(0)のグラフはL(p)/L(0)のグラフに近く,つまりpが0.01程度でも急速に感染が進む,ということが分かります.


また,r=1のときにp=0,0.05,1のそれぞれについてどのように伝染していくか図にしてみると以下のようになります(ソース)

(1) p = 0

必要ステップ数: 101

平均経路長: 50.450450450450454

クラスタリング係数: 0.6666666666666636

f:id:mFumi:20140212222849g:image


(2) p = 0.05

必要ステップ数: 9

平均経路長: 5.318636636636636

クラスタリング係数: 0.5792653735153729

f:id:mFumi:20140212222850g:image


(3) p = 1

必要ステップ数: 5

平均経路長: 3.2708968968968968

クラスタリング係数: 0.009309777477424539

f:id:mFumi:20140212222848g:image



このようにp=0.05でも十分平均経路長は短く,あっという間に伝染していく様子が観察できます.平均経路長が短くなる理由は,辺の繋ぎ変えをおこなうことで遠くのノード同士を繋ぐショートカットが生成されるからです.


さて,このような感じでとても簡単な方法でスモールワールドネットワークが生成できて,現実世界のネットワークをよく近似しますが,実は現実世界のネットワークはスモールワールドネットワークが持っていない重要なある性質を持っています.それはスケールフリー性です.スケールフリー性を持つネットワークは次数分布(次数はあるノードにおける辺の数のこと)がベキ分布に従い,高次元のノード数が少なく,低次元のノードが多くなります(WSモデルでは全てのノードの辺の数は等しいです).このスケールフリー性を持つネットワーク(スケールフリーネットワーク)を生成する方法がWSモデルの1年後の1999年にBarabasiらによって発表されました.例えば,Twitterなんかではごく一部のユーザのみが膨大なフォロワー数を持っていますね.勿論NetworkXにはスケールフリーネットワークを生成する関数が用意されています.


スケールフリーネットッワークについては次回?

2013-08-12

pythonのitertoolsのrecipeのメモ

公式にitertoolsのrecipe集があるのでそれのメモ.

バージョン: python3.3

参考: no title


・take(n,iterable)

最初のn個の要素をリストとして返します

from itertools import islice 
def take(n, iterable):
    return list(islice(iterable, n))

In [76]: take(3,[1,2,3,4,5])
Out[76]: [1, 2, 3]

・tabulate(function,start=0)

function(0),function(1),... を返すイテレータを返します.

from itertools import count 
def tabulate(function, start=0):
    return map(function, count(start))

In [82]: i = tabulate(lambda x: x*x)

In [83]: next(i)
Out[83]: 0

In [84]: next(i)
Out[84]: 1

In [85]: next(i)
Out[85]: 4

In [86]: next(i)
Out[86]: 9

・consume(iterator,n)

イテレータをn進める,もしnがNoneならイテレータを末尾まで進める

import collections
from itertools import islice
def consume(iterator, n):
    if n is None:
        collections.deque(iterator, maxlen=0)
    else:
        next(islice(iterator, n, n), None)

In [105]: it = iter([1,2,3,4,5,6,7,8,9,10])

In [106]: consume(it,n=5)

In [107]: for i in it:
   .....:     print(i)
   .....:
6
7
8
9
10

・nth(iterable,n,default=None)

n番目の要素を返します.もしnが範囲外ならdefaultを返します

from itertools import islice
def nth(iterable, n, default=None):
    return next(islice(iterable, n, None), default)

In [110]: nth([1,2,3],2)
Out[110]: 3

・quantify(iterable,pred=bool)

predがTrueになる要素の数を返します

def quantify(iterable, pred=bool):
    return sum(map(pred, iterable))

In [115]: quantify([1,2,3,4,5],pred=lambda x: x<3)
Out[115]: 2  # 3より小さい数の数

・padnone(iterable):

iterableな要素の後にNoneを付け加えたイテレータを返す

from itertools imprt chain,repeat
def padnone(iterable):
    return chain(iterable, repeat(None))

In [125]: a = padnone([1,2,3])

In [126]: next(a)
Out[126]: 1

In [127]: next(a)
Out[127]: 2

In [128]: next(a)
Out[128]: 3

In [129]: next(a)

・ncycles(iterable,n):

iterableをn回繰り返すイテレータを返す

frmo itertools import chain,repeat
def ncycles(iterable, n):
    return chain.from_iterable(repeat(tuple(iterable), n))

In [136]: it = ncycles([1,2],3)

In [137]: for i in it:
   .....:     print(i)
   .....:
1
2
1
2
1
2

・dotproduct(vec1,vec2)

vec1,vec2の内積を返します.

import operator
def dotproduct(vec1, vec2):
    return sum(map(operator.mul, vec1, vec2))

In [141]: dotproduct([1,2],[3,4])
Out[141]: 11

・flatten(listOfLists)

リストのリストを一つのリストにまとめます

from itertools import chain
def flatten(listOfLists):
    return chain.from_iterable(listOfLists)

In [144]: it = flatten([[1,2],['a','b']])

In [145]: for i in it:
   .....:     print(i)
   .....:
1
2
a
b

・repeatfunc(func,times=None,*args)

funcをtimes回リピートした結果を返すイテレータを返します.

timesを指定しない場合は無限リストになります.

from itertools import starmap,repeat
def repeatfunc(func, times=None, *args):
    if times is None:
        return starmap(func, repeat(args))
    return starmap(func, repeat(args, times))

In [150]: r = repeatfunc(random.random)

In [151]: next(r)
Out[151]: 0.5425125491886116

In [152]: next(r)
Out[152]: 0.37101369831757924

・pairwise(iterable)

1つだけ指す場所が違う2つのイテレータを返します.連続した2つの要素を操作するのに使えます.

from itertools import tee
def pairwise(iterable):
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

In [163]: xs = [1,4,9,16,25]

# xs[i] - xs[i-1] を計算したい
In [164]: for a,b in pairwise(xs):
   .....:     print(b-a)
   .....:
3
5
7
9

・grouper(n,iterable,fillvalue=None)

要素をn個ずつにまとめます.要素数がnで割り切れなかったらfillvalueで埋められます.

from itertools import zip_longest
def grouper(n, iterable, fillvalue=None):
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

In [168]: it = grouper(3,'ABCDEFG','x')

In [169]: for i in it:
   .....:     print(i)
   .....:
('A', 'B', 'C')
('D', 'E', 'F')
('G', 'x', 'x')

なんでこういう動作をするかというと,args = [iter(iterable)]*n とイテレータを複製していますが,このイテレータは全て同一であるため,いずれか一つのイテレータを操作すると他のイテレータも操作されるためです.


・roundrobin(*iterables)

それぞれのiterableなものから順番に要素をとってきます.

from itertools import cycle,islice
def roundrobin(*iterables):
    # Recipe credited to George Sakkis
    pending = len(iterables)
    nexts = cycle(iter(it).__next__ for it in iterables)
    while pending:
        try:
            for next in nexts:
                yield next()
        except StopIteration:
            pending -= 1
            nexts = cycle(islice(nexts, pending))

In [171]: it = roundrobin([1,2],'AB','xy')

In [172]: for i in it:
   .....:     print(i)
   .....:
1
A
x
2
B
y

・partition(pred,iterable)

predがTrueかどうかに応じてiterableを2つに分割します

from itertools import filterfalse
def partition(pred, iterable):
    t1, t2 = tee(iterable)
    return filterfalse(pred, t1), filter(pred, t2)

In [174]: it1,it2 = partition(lambda x: x<3, [1,2,3,4,5])

In [175]: for i in it1:
    print(i)
   .....:
3
4
5

In [176]: for i in it2:
    print(i)
   .....:
1
2

・powerset(iterable):

iterableの冪集合を返します

from itertools import chain,combinations
def powerset(iterable):
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

In [178]: it = powerset([1,2,3])

In [179]: for i in it:
   .....:     print(i)
   .....:
()
(1,)
(2,)
(3,)
(1, 2)
(1, 3)
(2, 3)
(1, 2, 3)

・unique_everseen(iterable,key=None)

iterableな要素のうちユニークなもののみを返すイテレータを返します

from itertools import filterfalse
def unique_everseen(iterable, key=None):
    seen = set()
    seen_add = seen.add
    if key is None:
        for element in filterfalse(seen.__contains__, iterable):
            seen_add(element)
            yield element
    else:
        for element in iterable:
            k = key(element)
            if k not in seen:
                seen_add(k)
                yield element

In [189]: it = unique_everseen([1,2,1,1,2,4,2,5])

In [190]: for i in it:
    print(i)
   .....:
1
2
4
5

・unique_justseen(iterable,key=None)

連続する要素を1つにまとめます.

from operator import itemgetter
from itertools import groupby
def unique_justseen(iterable, key=None):
    return map(next, map(itemgetter(1), groupby(iterable, key)))

In [200]: it = unique_justseen([1,1,2,3,3,1,2,2])

In [201]: for i in it:
   .....:     print(i)
   .....:
1
2
3
1
2

下の例でも結果は同じだけど上のようにやると遅延評価できる

In [206]: def unique_justseen(iterable,key=None):
    return map(itemgetter(0),itertools.groupby(iterable,key))
   .....:

In [207]: for i in it:
    print(i)
   .....:

In [208]: it = unique_justseen([1,1,2,3,3,1,2,2])

In [209]: for i in it:
    print(i)
   .....:
1
2
3
1
2

・iter_except(func,exception,first=None)

funcを例外が発生するまで呼び続けます.firstは最初に1回だけ実行される(dbの初期化とかに使う).

def iter_except(func, exception, first=None):
    try:
        if first is not None:
            yield first()
        while 1:
            yield func()
    except exception:
        pass

In [219]: it = iter_except(d.popitem,KeyError)

In [220]: for i in it:       # KeyErrorで中断されることなく処理できる
   .....:     print(i)
   .....:
('a', 1)
('b', 2)
('c', 3)

・random_product(*args,repeat=1)

直積をランダムに1つ返します

import random
def random_product(*args, repeat=1):
    pools = [tuple(pool) for pool in args] * repeat
    return tuple(random.choice(pool) for pool in pools)

In [277]: random_product([1,2],['x','y'])
Out[277]: (1, 'x')

In [278]: random_product([1,2],['x','y'])
Out[278]: (1, 'y')

・random_permutation(iterable,r=None)

順列をランダムに1つ返します

import random
def random_permutation(iterable, r=None):
    pool = tuple(iterable)
    r = len(pool) if r is None else r
    return tuple(random.sample(pool, r))

In [267]: random_permutation([1,2,3])
Out[267]: (2, 1, 3)

In [268]: random_permutation([1,2,3])
Out[268]: (3, 2, 1)

・random_combination(iterable,r)

組み合わせをランダムに一つ返します

import random
def random_combination(iterable, r):
    pool = tuple(iterable)
    n = len(pool)
    indices = sorted(random.sample(range(n), r))
    return tuple(pool[i] for i in indices)

In [280]: random_combination([1,2,3],2)
Out[280]: (1, 3)

In [281]: random_combination([1,2,3],2)
Out[281]: (1, 2)

・random_combination_with_replacement(iterable,r)

重複組み合わせをランダムに一つ返します

def random_combination_with_replacement(iterable, r):
    pool = tuple(iterable)
    n = len(pool)
    indices = sorted(random.randrange(n) for i in range(r))
    return tuple(pool[i] for i in indices)

In [286]: random_combination([1,2,3],2)
Out[286]: (3, 3)

In [287]: random_combination([1,2,3],2)
Out[287]: (1, 3)

pythonのitertoolsメモ

バージョン: python3.3

参考: http://docs.python.org/3/library/itertools.html

itertoolsのpythonによる実装が書いてあるので勉強になります.


・itertools.accumulate(iterable[, func])

In [2]: it = itertools.accumulate([1,2,3])

In [3]: for i in it:
   ...:     print(i)
   ...:
1
3   # 1+2
6   # 3+3

In [6]: it = itertools.accumulate([1,2,3],lambda x,y: x*y)

In [7]: for i in it:
   ...:     print(i)
   ...:
1
2    # 1*2
6    # 2*3

python3.2で追加されてfuncは3.3から指定できるようになったらしい.


・itertools.chain(*iterables)

iterableなアイテムをつなげたイテレータを返す

In [10]: it = itertools.chain([1,2,3],['a','b','c'])

In [11]: for i in it:
   ....:     print(i)
   ....:
1
2
3
a
b
c

・itertools.chain.from_iterable(iterable)

iteratools.chain(*iterables)と似てるけど引数が一つのリスト

In [13]: it = itertools.chain.from_iterable([[1,2,3],['a','b','c']])

In [14]: for i in it:
   ....:     print(i)
   ....:
1
2
3
a
b
c

・itertools.combinations(iterable,r)

長さrの組み合わせを返します

In [15]: it = itertools.combinations([1,2,3],2)

In [16]: for i in it:
   ....:     print(i)
   ....:
(1, 2)
(1, 3)
(2, 3)

・itertools.combinations_with_replacement(iterable,r)

重複を含めた長さrの組み合わせを返します

In [17]: it = itertools.combinations_with_replacement([1,2,3],2)

in [18]: for i in it:
   ....:     print(i)
   ....:
(1, 1)
(1, 2)
(1, 3)
(2, 2)
(2, 3)
(3, 3)

・itertools.compress(data,selectors)

selectorsがTrueの要素のdataについてのイテレータを返します

例を見た方が早い.

In [21]: it = itertools.compress(['A','B','C'],[True,False,True])

In [22]: for i in it:
   ....:     print(i)
   ....:
A
C

python3.1で追加


・itertools.count(start=0,step=1)

startから初めてstepずつ増えてくiteratorを返します

In [27]: def generator():
   ....:     for i in range(10):
   ....:         yield i

In [30]: for i,v in zip(generator(),itertools.count()):
    print("{0}:{1}".format(i,v))
   ....:
0:0
1:1
2:2
3:3
4:4
5:5
6:6
7:7
8:8
9:9

ちなみにpython3ではzip()がpython2のitertools.izip()に相当します.


・itetools.cycle(iterable)

iterableなものを巡回するイテレータを返します.

In [2]: it = itertools.cycle([1,2,3])

In [3]: next(it)
Out[3]: 1

In [4]: next(it)
Out[4]: 2

In [5]: next(it)
Out[5]: 3

In [6]: next(it)
Out[6]: 1

In [7]: next(it)
Out[7]: 2

In [8]: next(it)
Out[8]: 3
||<k


・itertools.dropwhile(predicate,iterable)
先頭から見ていって初めてpredicateがFalseになる要素の前までの要素を除いたイテレータを返します
>|python|
In [12]: it = itertools.dropwhile(lambda x : x < 5, [1,2,4,2,5,3,6,2])

In [13]: for i in it:
   ...:     print(i)
   ...:
5
3
6
2

・itertools.filterfalse(predicate,iterable)

iterableをpredicateでフィルターした結果を返すイテレータを返します

In [14]: it = itertools.filterfalse(lambda x : x < 5, [1,2,4,2,5,3,6,2])

In [15]: for i in it:
   ....:     print(i)
   ....:
5
6

・itetools.groupby(iterable,key=None)

連続する要素をkeyで比較して一致した場合には一つにまとめる.keyを指定しない場合は同じ要素が連続した場合一つにまとめます

uniqコマンド的な.

戻り値は2つあって,1つがキー,もう一つが連続する要素をまとめたもの(これもイテレータ).


In [24]: for k,g in itertools.groupby('AAABDDDCCC'):
    print(k)
    print(list(g))
   ....:
A
['A', 'A', 'A']
B
['B']
D
['D', 'D', 'D']
C
['C', 'C', 'C']

・itertools.islice(iterable,stop), itertools.islice(iterable,start,stoip[,step])

iterableな要素の中で選択した範囲を返すイテレータを返します

In [25]: it = itertools.islice('ABCDEFG',1,5,2)

In [26]: for i in it:
   ....:     print(i)
   ....:
B
D

・itertools.permutation(iterable,r=None)

長さrの順列を返します.rを指定しない場合はiterableの長さがrになります.

In [27]: it = itertools.permutations('ABC')

In [28]: for i in it:
   ....:     print(i)
   ....:
('A', 'B', 'C')
('A', 'C', 'B')
('B', 'A', 'C')
('B', 'C', 'A')
('C', 'A', 'B')
('C', 'B', 'A')

・itertools.product(*iterables,repeat=1)

直積を返します.product([0,1],repeat=3)はproduct([0,1],[0,1],[0,1])と同じ意味.

In [34]: it = itertools.product([0,1],['x','y'])

In [35]: for i in it:
   ....:     print(i)
   ....:
(0, 'x')
(0, 'y')
(1, 'x')
(1, 'y')

・itertools.repeat(object[, times])

objectをtimes回繰り返すイテレータを返します.timesを指定しない場合無限リストになります.

主な使い方としてはmap()やzip()に一定値を与えるのに使うみたい

In [37]: list(map(pow,range(10),itertools.repeat(2)))
Out[37]: [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

・itertools.starmap(function,iterable)

iterableな各要素についてfunctionで計算した結果のイテレータを返します.

map()に似てるけど要はmapの引数が既にzip化されてるときに使います.

In [46]: [i for i in itertools.starmap(pow,[(2,5),(3,2),(10,3)])]
Out[46]: [32, 9, 1000]

これと等価

In [50]: x = zip([2,3,10],[5,2,3])

In [51]: for i in itertools.starmap(pow,x):
    print(i)
    ....:
32
9
1000

ちなみにmap()を使うなら

In [52]: [i for i in map(pow,[2,3,10],[5,2,3])]
Out[53]: [32, 9, 1000]

・itertools.takewhile(predicate,iterable)

先頭から見ていってpredicateが初めてFalseになるまでの要素を返すイテレータを返す.

In [64]: it = itertools.dropwhile(lambda x:x<5, [1,2,3,5,2,1])

In [65]: for i in it:
   ....:     print(i)
   ....:
5
2
1

・itertools.tee(iterable,n = 2)

1つのiterableからn個のイテレータを返します.nのデフォルトは2です.

In [66]: it1,it2 = itertools.tee([1,2,3])

In [67]: for i in it1:
   ....:     print(i)
   ....:
1
2
3

In [68]: for i in it2:
   ....:     print(i)
   ....:
1
2
3

teeで複数イテレータを作ったあと元のiterableを操作することはイテレータが勝手に進んだりするのやめましょう.イテレータを複数作るのに比べてteeのメリットは何かというと,teeを使った方がバッファリングをするためI/Oが少なくなり,動作が速くなる可能性があります.逆に言うとメモリをよく消費します.このことはteeの実装を見ると分かります.

def tee(iterable, n=2):
    it = iter(iterable)
    deques = [collections.deque() for i in range(n)]
    def gen(mydeque):
        while True:
            if not mydeque:             # when the local deque is empty
                newval = next(it)       # fetch a new value and
                for d in deques:        # load it to all the deques
                    d.append(newval)
            yield mydeque.popleft()
    return tuple(gen(d) for d in deques)

この用に,teeで何をしているかというと,iterableから一つイテレータを作って,それからn個のdequeを用意して,いずれかのteeの戻り値のイテレータがが新しい値を取り出す度に他のdequeに値を追加していきます.


・itertools.zip_longest(*iterables,fillvalue=None)

zip()に似ていますが,zipと違うのはiterablesの長さが異なる場合,短いiterableが

終わるとそれ以降対応する場所はfillvalueで埋められます.


In [69]: it = itertools.zip_longest(['A','B','C','D'],[1,2],fillvalue='-')

In [70]: for i in it:
   ....:     print(i)
   ....:
('A', 1)
('B', 2)
('C', '-')
('D', '-')