ghmmにおける過学習の回避

ghmmで特徴ベクトルの次元数を大きくすると, BaumWelchアルゴリズムの結果がnanでいっぱいになってしまう. これは, 過学習が起こった際の数値的不安定さが原因と考えられる. (違うかも)

そこで, 出力分布の共分散行列を対角行列に制限することで, パラメーター数を減らしたい. この機能は存在するだろうと予想して, 「univariate ghmm」などでぐぐったが出ない・・・.
ひらめいたのが「ghmm diagonal covariance」というキーワードで, その結果たどりついたのが, このメーリングリストである.

要はconfigure時にオプションとして, 「--enable-gsl --enable-gsl-diagonal-hack」を追加して, makeすればよい. gslを使わなければ, オンに出来ないが, ちゃんとソースコードを読めばgslの使用も回避できそう.

これを使うと実際に過学習を回避することが出来た. ただし, 共分散行列の非対角成分は0にならず, その逆行列の非対角成分のみ0になる. まあ, 動くからいいけど, 気持ち悪い.

しかし, このgsl-diagonal-hackは本当にアドホックな解決っぽいなーw. プログラミングが面倒だから力技で解決してるように見えるw.

Pythonによるリアルタイムグラフ描画とマイクからの読み取り

なぜかマイクからの読み取りとリアルタイムグラフ描画を混ぜる.

マイクからの読み取り

ライブラリとしてpyaudioを使用した. aptで入る.
以下のプログラムはほとんど公式のサンプル通りだけど, 2秒マイクから音を受け取って, グラフに描画する.

import pyaudio
import sys
import pylab
import numpy

chunk = 1024
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
RECORD_SECONDS = 2
WAVE_OUTPUT_FILENAME = "output.wav"

p = pyaudio.PyAudio()

stream = p.open(format = FORMAT,
                channels = CHANNELS,
                rate = RATE,
                input = True,
                frames_per_buffer = chunk)

print "* recording"
all = []
for i in range(0, RATE / chunk * RECORD_SECONDS):
    data = stream.read(chunk)
    all.append(data)
print "* done recording"

stream.close()
p.terminate()

# write data to WAVE file
data = ''.join(all)
result = numpy.frombuffer(data,dtype="int16") / float(2**15)

pylab.plot(result)
pylab.ylim([-1,1])
pylab.show()

frames_per_bufferという概念がよく分からない. 公式リファレンスには

Specifies the number of frames per buffer.

と書いてある. うーん, どういうことなんだろう. 1024以外を何種類か試したけど, ことごとくエラーになった・・・.

あと, 最初は

bt_audio_service_open: connect() failed: Connection refused (111)

というエラーが出たけど, こちらのページのおかげで解決.

リアルタイムグラフ描画

次によくありがちな, 録音データをリアルタイムでグラフ描画するプログラムを作成する.
こんなの

ちょっと表示がずれてるのは, PrintScreen側の問題. グラフのx軸は録音開始からの経過時間. 実際のプログラムではグラフは左に流れていく.
ライブラリはGUI用にwx, グラフ描画にmatplotlibを用いている.

プログラムのグラフ描画部分はこちらのページを参考にした.

以下, ソースコード

#!/usr/bin/python
# -*- coding:utf-8 -*-

import pyaudio
import sys
import numpy

from time import time

"描画用ライブラリを利用"
import matplotlib
matplotlib.use('WXAgg')
from matplotlib.figure import Figure
from matplotlib.backends.backend_wxagg import \
    FigureCanvasWxAgg as FigCanvas, \
    NavigationToolbar2WxAgg as NavigationToolbar
import numpy as np
import pylab
import wx


"設定"
CHUNK = 1024
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 44100
RECORD_SECONDS = 20


class MyAudio:
    "データを生み出すクラス"
    def __init__(self):
        self.audio = pyaudio.PyAudio()
        
        self.stream = self.audio.open(
            format = FORMAT,
            channels = CHANNELS,
            rate = RATE,
            input = True,
            frames_per_buffer = CHUNK)

        self.chunk = CHUNK
        #横軸を秒にする
        self.sampling_rate = RATE
        self.record_seconds = RECORD_SECONDS

    def next(self):
        num_frame = self.stream.get_read_available()
        if num_frame == 0:
            return []

        data = []
        #1024ずつしかreadできない?
        for i in range(num_frame / self.chunk):
            data.append(self.stream.read(self.chunk))

        aft = self.stream.get_read_available()

        data = ''.join(data)
        #data = stream.read(chunk)
        signal = numpy.frombuffer(data,dtype="int16") / float(2**15)
        
        return signal


class GraphFrame(wx.Frame):
    """ The main frame of the application
    """
    title = 'Demo: dynamic matplotlib graph'
    
    def __init__(self):
        wx.Frame.__init__(self, None, -1, self.title,size=(600,400))
        self.audio = MyAudio()
        self.data = []
        #dataの先頭データが何秒時点のものか
        self.data_start_time = 0.0
        #self.data = self.audio.next()
        
        self.create_main_panel()
        
        self.redraw_timer = wx.Timer(self)
        self.Bind(wx.EVT_TIMER, self.on_redraw_timer, self.redraw_timer)
        #100msecごとに更新?
        self.redraw_timer.Start(100)

        #何秒分表示するか
        self.draw_sec = 3.0

        
        #描画する点のレート(秒間何点使うか. (負荷軽減のため))
        self.draw_sampling_rate = 1000
        #元データのサンプリングレート
        self.sampling_rate = self.audio.sampling_rate

    def create_main_panel(self):
        self.panel = wx.Panel(self)

        self.dpi = 100
        #恐ろしいことに第一引数はinchでの指定らしい
        self.fig = Figure((6.0, 4.0), dpi=self.dpi)

        self.axes = self.fig.add_subplot(111)
        self.axes.set_axis_bgcolor('black')
        self.axes.set_title('Very important random data', size=12)
        pylab.setp(self.axes.get_xticklabels(), fontsize=8)
        pylab.setp(self.axes.get_yticklabels(), fontsize=8)

        # plot the data as a line series, and save the reference 
        # to the plotted line series
        #
        self.plot_data = self.axes.plot(
            self.data, 
            linewidth=1,
            color=(1, 1, 0),
            )[0]

        self.canvas = FigCanvas(self.panel, -1, self.fig)

    def add_draw_data(self,data):
        """
        描画対象のデータを追加する
        data : 追加したいデータの配列
        仮定として前回のデータから時系列的に連続しているものとしている
        境界線でちょっとsampling rateがずれるけど, 気にしない
        """
        sampling_rate = self.sampling_rate
        sampling_sec = 1. / sampling_rate
        
        draw_sampling_rate = self.draw_sampling_rate
        draw_sampling_sec = 1. / draw_sampling_rate

        newdata = []
        time = 0.0

        for s in data:
            time += sampling_sec

            if time >= draw_sampling_sec:
                time -= draw_sampling_sec
                newdata.append(s)

        
        self.data += newdata
        remain_frame_length = int(self.draw_sec * draw_sampling_rate)

        #切り落とした分, 時間を進める
        self.data_start_time += max((len(self.data) - remain_frame_length),0) / float(draw_sampling_rate)

        self.data = self.data[-remain_frame_length:]


    def draw_plot(self):
        """ Redraws the plot
        """
        num_draw_frame = int(self.draw_sec * self.draw_sampling_rate)
        draw_sampling_rate = self.draw_sampling_rate

        xmin = self.data_start_time
        xmax = xmin + self.draw_sec 
        print xmin,xmax

        ymin = -1.0
        ymax = 1.0

        self.axes.set_xbound(lower=xmin,upper=xmax)
        self.axes.set_ybound(lower=ymin, upper=ymax)

        #gridを有効に
        self.axes.grid(True, color='gray')        

        # スタイルを設定
        pylab.setp(self.axes.get_xticklabels(), 
            visible=True)

        xaxis = [float(con) / self.draw_sampling_rate + self.data_start_time for con in range(len(self.data))]
        self.plot_data.set_xdata(xaxis)
        self.plot_data.set_ydata(np.array(self.data))
        
        self.canvas.draw()

    def on_redraw_timer(self, event):
        #タイマーから呼ばれる
        self.add_draw_data(self.audio.next().tolist())
        self.draw_plot()

    def on_exit(self, event):
        self.Destroy()
        

if __name__ == '__main__':
    app = wx.PySimpleApp()
    app.frame = GraphFrame()
    app.frame.Show()
    app.MainLoop()

重要な点として, グラフの秒間描画回数(FPS)が落ちないように,

  • 指定した秒数より昔のデータは破棄
  • 描画するサンプリング周期(秒間, 何点データを取るか)を元のサンプリング周期から落とす

この2つの処理を行なっている. これをやらないとみるみるうちにFPSが落ちていく.

GUIの勉強ももうちょっとした方がいいかもなー.