霞と側杖を食らう

ほしいものです。なにかいただけるとしあわせです。[https://www.amazon.jp/hz/wishlist/ls/2EIEFTV4IKSIJ?ref_=wl_share]

トピックモデル(混合ユニグラムモデル・EMアルゴリズム)の学習記録

 

【学習動機】

前回、ユニグラムモデルを学習した。

moratoriamuo.hatenablog.com

これからは、混合ユニグラムモデルを学習する。混合ユニグラムモデルをEMアルゴリズムで推定する。

【学習内容】

混合ユニグラムモデル(mixture of unigram models)は各文書がそれぞれ1つのトピックをもっていてトピックごとに異なった単語分布をもち、そこから単語が生成されるモデル。文書のトピックはトピック分布から生成され、割り当てられる。Rで、このモデルの人工データを生成し、EMアルゴリズムでパラメータを最尤推定する。

人工データ生成

ライブラリの読み込み。

library(tidyverse)
library(ggplot2)
library(MCMCpack)    #rdirichletでディリクレ乱数を生成する
## Warning: package 'MCMCpack' was built under R version 3.5.2
set.seed(271)

トピックとトピックごとに頻出する語彙を設定する。トピックとして、しりとりとRとアジカンを選択した。しりとりを選んでしまったが、分布から独立に単語が生成されることを考えると例としてふさわしくないことに後で気付いたがそのままにする。語彙の種類が多すぎると大変なので、この程度に絞った。

topics <- c("shiritori","r","akg")
allword <- c("りんご","ゴリラ","ラッパ",
             "宇宙","神","パイプ",
             "スタンダード","センスレス","リライト","ループ" )
nvoc <- length(allword)    #語彙の総数

各トピックに対応する単語分布の設定。トピックを想定しながら手動で設定した。ここは、パラメータを上手く決めてディリクレ分布から乱数発生させてもよい。各トピックの単語分布を決めたら、リストにしてまとめておく。

#トピック: しりとり
df_word_dist_shiritori <- data.frame(
  words = allword,
  prob = c(0.25,0.2,0.2,
           rep(0.05,7))
)

#トピック: R
df_word_dist_r <- data.frame(
  words = allword,
  prob = c(0.01,0.01,0.02,
           0.3,0.2,0.31,
           0.01,0.01,0.015,0.115)
)

#トピック: AKG
df_word_dist_akg <- data.frame(
  words = allword,
  prob = c(rep(0.01,3),
           0.06,0.1,0.01,
           0.2,0.2,0.2,0.2)
)

word_dists <- list(df_word_dist_shiritori, df_word_dist_r, df_word_dist_akg)
names(word_dists) <- c("shiritori","r","akg")

文書の数と各文書の単語数(決めるのが面倒なので一様乱数で生成)を決定。

numD <- 100
vnumW <- sample(140:271,numD,replace = TRUE)
numW <- sum(vnumW)    #総単語数

トピック分布をディリクレ分布によって生成し、トピック分布から文書へのトピックの割り当てを生成。

theta <- rdirichlet(1,c(1,2,7))    #alpha=c(1,2,7)
vtopic <- sample(topics, numD, replace = TRUE, prob = theta)

あとで推定されたパラメータと比較するために真のパラメータ(thetaとphi)を保存しておく。

truepara <- list(
  truetheta = theta %>% as.vector(),
  truephi = word_dists
)

以上で指定してものを使えば混合ユニグラムモデルからのデータを生成する関数が作れる。引数は文書の数と各文書の単語数のベクトル、単語分布のデータフレームのリストをとる。

sim_mixunigram <- function(numD, vnumW, word_dists){
  list_word <- vector("list", numD)
  for(d in 1:numD){
    list_word[[d]] <- sample(word_dists[[vtopic[d]]]$words, vnumW[d], replace=TRUE,
                             prob = word_dists[[vtopic[d]]]$prob)
  }
  return(list_word)
}

人口データを発生させてみる。前回同様、頻度テーブルにして文書1の頻度テーブルを見てみる。

DW <- sim_mixunigram(numD=numD, vnumW = vnumW, word_dists = word_dists)    #DocumentWordsList
DWtab <- DW %>% purrr::map(factor,levels = allword) %>% purrr::map(table)
DWtab[[1]]
## 
##       りんご       ゴリラ       ラッパ         宇宙           神 
##            0            2            2           53           29 
##       パイプ スタンダード   センスレス     リライト       ループ 
##           51            1            1            1           33

1つ目の文書のトピックはRだったに違いない。

後の計算をしやすいように、文書ターム行列(Document Term Matrix)に変換しておく。

DTM <- matrix(DWtab %>% unlist() %>% as.numeric(), ncol = nvoc, byrow=T)    #DocumentTermMatrix
colnames(DTM) <- allword

EMアルゴリズム最尤推定

トピックの数は決めておかないといけないので、トピックは3つと天下り的に仮定しておく。

numT <- 3    ##ofTopics

対数尤度を計算する関数を定義しておく。引数には、文書ターム行列とトピック分布のパラメータthetaと各トピックの単語分布のデータフレームのリストphiをとる。

loglikelihood_mug <- function(DTM, theta, phi){
  ##mixture of unigramの対数尤度を計算する
  oneL <- numeric(nrow(DTM))    #文書ごとの尤度ベクトル
  pr <- numeric(nrow(DTM))    #計算用のproductで計算したもの
  
  for(k in 1:length(theta)){
    for(d in 1:nrow(DTM)){
      pr[d] <- prod(purrr::map2_dbl(phi[[k]]$prob, DTM[d,], `^`))    #p(wd|phik)
    }
    oneL <- oneL + theta[k]*pr
  }
  LL <- oneL %>% log() %>% sum()
  return(LL)
}

岩田『トピックモデル』のEMアルゴリズム擬似コードを見ながら、EMアルゴリズム関数を作成。引数には、文書ターム行列とトピック数、全ての語彙、繰り返し回数(対数尤度を見ている限り、想像以上に局所最適への収束が速いので、デフォで10としている)をとる。返り値として、推定されたthetaとphi、対数尤度の履歴のリストを返す。

fEM <- function(DTM, numT, allword, numRep=10){

  #対数尤度履歴
  LLhistory <- numeric(numRep)
  
  #基準化関数
  fScale <- function(x){
    y <- x/sum(x)
    return(y)
  }
  
  #初期値: 全て同じ値で入れてしまうと上手く動かなかった
  theta <- runif(3) %>% fScale()
  phi <- vector("list",numT)
  for(i in 1:numT){
    phi[[i]] <- data.frame(
      words = allword,
      prob = runif(10) %>% fScale()
    )
  }
  
  for(nrep in 1:numRep){
    #次ステップのパラメータを0に初期化
    theta_new <- rep(0,numT); phi_new <- vector("list",numT);
    for(k in 1:numT){
      phi_new[[k]] <- data.frame(
        words = allword,
        prob = rep(0,length(allword))
      )
    }

    for(d in 1:nrow(DTM)){
      Qd <- numeric(numT)
      for(k in 1:numT){
        Qd[k] <- prod(purrr::map2_dbl(phi[[k]]$prob, DTM[d,], `^`))*theta[k]
      }

      Qd <- fScale(Qd)
      for(k in 1:numT){
        theta_new[k] <- theta_new[k] + Qd[k]
        phi_new[[k]]$prob <- phi_new[[k]]$prob + Qd[k]*as.vector(DTM[d,])
      }
    }
    
    theta <- fScale(theta_new)
    for(k in 1:numT){
      phi[[k]]$prob <- phi_new[[k]]$prob %>% fScale()
    }
    
    LLhistory[nrep] <- loglikelihood_mug(DTM, theta, phi)
  }
  
  
  ret <- list(
    EMtheta = theta,
    EMphi = phi,
    LLhistory = LLhistory
  )
}

実際に推定してみて、対数尤度の動きを見てみる。

result <- fEM(DTM = DTM, numT = 3, allword = allword)
result$LLhistory
##  [1] -37824.38 -37803.02 -37803.02 -37803.02 -37803.02 -37803.02 -37803.02
##  [8] -37803.02 -37803.02 -37803.02

二回目には収束している感じだろうか。

推定されたphiをみていく。

result$EMphi[[1]]
##           words       prob
## 1        りんご 0.01059916
## 2        ゴリラ 0.01078621
## 3        ラッパ 0.00885342
## 4          宇宙 0.06004115
## 5            神 0.09626535
## 6        パイプ 0.01028742
## 7  スタンダード 0.19963838
## 8    センスレス 0.20163352
## 9      リライト 0.19970073
## 10       ループ 0.20219465

どうやらアジカンのトピックの単語分布を推定しているようなので、真の分布と比較してみると

truepara[["truephi"]][["akg"]]
##           words prob
## 1        りんご 0.01
## 2        ゴリラ 0.01
## 3        ラッパ 0.01
## 4          宇宙 0.06
## 5            神 0.10
## 6        パイプ 0.01
## 7  スタンダード 0.20
## 8    センスレス 0.20
## 9      リライト 0.20
## 10       ループ 0.20

次は

result$EMphi[[2]]
##           words        prob
## 1        りんご 0.011339662
## 2        ゴリラ 0.008175105
## 3        ラッパ 0.021360759
## 4          宇宙 0.294831224
## 5            神 0.209651899
## 6        パイプ 0.300105485
## 7  スタンダード 0.010812236
## 8    センスレス 0.010284810
## 9      リライト 0.017405063
## 10       ループ 0.116033755

Rのトピックの単語分布を推定しているようなので、真の分布と比較してみると

truepara[["truephi"]][["r"]]
##           words  prob
## 1        りんご 0.010
## 2        ゴリラ 0.010
## 3        ラッパ 0.020
## 4          宇宙 0.300
## 5            神 0.200
## 6        パイプ 0.310
## 7  スタンダード 0.010
## 8    センスレス 0.010
## 9      リライト 0.015
## 10       ループ 0.115

最後に

result$EMphi[[3]]
##           words       prob
## 1        りんご 0.24165554
## 2        ゴリラ 0.19225634
## 3        ラッパ 0.19759680
## 4          宇宙 0.05740988
## 5            神 0.04405874
## 6        パイプ 0.04806409
## 7  スタンダード 0.04806409
## 8    センスレス 0.06809079
## 9      リライト 0.04138852
## 10       ループ 0.06141522

しりとりの真の分布と比較する。

truepara[["truephi"]][["shiritori"]]
##           words prob
## 1        りんご 0.25
## 2        ゴリラ 0.20
## 3        ラッパ 0.20
## 4          宇宙 0.05
## 5            神 0.05
## 6        パイプ 0.05
## 7  スタンダード 0.05
## 8    センスレス 0.05
## 9      リライト 0.05
## 10       ループ 0.05

推定されたtheta(ここではakg,r,shiritoriの順)と真のtheta(shiritori,r,akgの順)を比較してみる。

result$EMtheta %>% round(digits = 3)
## [1] 0.78 0.18 0.04
truepara[["truetheta"]]%>% round(digits = 3)
## [1] 0.031 0.217 0.752

そこそこ近くまでは来ているみたい。

もう少し精度を上げようと単語数や文書数を増やしてみようとしたのだが、fEM関数で小数の積をとりまくることになることがおそらくの原因で、NaNが発生してしまった。こういうときはどうやって対処するのか分からなかったので、数値計算・数値解析のお話しが気になった。

【学習予定】

次は混合ユニグラムモデルを変分ベイズで推定していく。