霞と側杖を食らう

ほしいものです。なにかいただけるとしあわせです。[https://www.amazon.jp/hz/wishlist/ls/2EIEFTV4IKSIJ?ref_=wl_share]

ガンマ分布の乱数からディリクレ分布の乱数を生成する方法に関する学習記録

【学習動機】

ディリクレ分布からの乱数を生成したいが, そんなものはないというケースがある. たとえば, Excel周りで実装しないといけない場合など. そういう場合にどうすれば良いのか調べてみると, ディリクレ乱数はガンマ分布からの乱数で生成できることを知った. その生成方法について書き, Rで実際にガンマ乱数から生成してみる, また, 既にあるパッケージの{MCMCpack}のrdirichlet関数と比較してみる.

【学習記録】

ディリクレ乱数の生成方法

以下の文献に詳しいので, 参考にした. Bela A Frigyik, Amol Kapila, and Maya R Gupta. 2010
Introduction to the Dirichlet distribution and related processes.
https://vannevar.ece.uw.edu/techsite/papers/documents/UWEETR-2010-0006.pdf [pdf注意]

ディリクレ分布の乱数生成の方法
(i) the urn-drawing method
ポリアの壺の方法. ちなみにポリアの壺は『いかにして問題をとくか』の 著者である数学者ジョージ・ポリアに因んで名付けられている.
(ii) The Stick-breaking Approach 棒折りのアプローチ. ちなみにノンパラベイズで棒折り過程の話は出てくる. 次の章にディリクレ過程の話も書かれている.
(iii) Generating the Dirichlet from Gamma RVs ガンマ確率変数による生成方法. ガンマ確率変数による生成方法は, (i)ポリアの壺の方法と(ii)棒折りのアプローチのいずれよりも計算効率が良いとのこと.

ここでは(iii)の方法のみ扱う. 数理やその他の方法などは上の文献を参照されたい.
(iii)のガンマ確率変数によるディリクレ乱数は以下の2ステップで生成される.

Step 1: ガンマ乱数生成
Generate gamma realizations: for \(i=1,...,k\), draw a number \(z_i\) from \(\Gamma\) (\(\alpha_i\), 1).

Step 2: 正規化
Normalize them to form a pmf: for \(i=1,...,k\), set


. Then q is a realization of Dir(\(\alpha\)).

Rで生成

# パッケージ読み込み
library(tidyverse)
# 乱数固定
set.seed(20230228)

乱数生成のためのサンプルサイズ設定とパラメータ設定.

# サンプルサイズ
N <- 10000
# ディリクレ分布のパラメータ設定
alpha <- c(1, 2, 3)

ガンマ分布から乱数を生成
仕組みを分かりやすくするために, 無駄にループを回す, Rとしては非効率な書き方になっているので注意(後述)

k <- length(alpha)
z <- matrix(numeric(k*N), ncol=k, byrow=TRUE)
q_gamma <- matrix(numeric(k*N), ncol=k, byrow=TRUE)
# Step 1 : ガンマ乱数生成
for(i in 1:k){
  z[,i] <- rgamma(n=N, alpha[i])
}
# Step 2 : 正規化
for(j in 1:N){
  q_gamma[j,] <- z[j,]/sum(z[j,])
}

{MCMCpack}パッケージの関数を使用して生成

q_mcmcpack <- MCMCpack::rdirichlet(n = N, alpha = alpha)

生成したディリクレ乱数を可視化したいが, 自分では作画できないので, 以下の記事のコードをほとんどそのまま使用させていただきました. 感謝. ディリクレ分布の乱数生成

www.anarchive-beta.com

自分でガンマ乱数から生成したもののヒートマップを描く.

# 軸目盛の位置を指定
axis_vals <- seq(from = 0, to = 1, by = 0.1)

# 枠線用の値を作成
ternary_axis_df <- tibble::tibble(
  y_1_start = c(0.5, 0, 1),         # 始点のx軸の値
  y_2_start = c(0.5*sqrt(3), 0, 0), # 始点のy軸の値
  y_1_end = c(0, 1, 0.5),           # 終点のx軸の値
  y_2_end = c(0, 0, 0.5*sqrt(3)),   # 終点のy軸の値
  axis = c("x_1", "x_2", "x_3")     # 元の軸
)

# グリッド線用の値を作成
ternary_grid_df <- tibble::tibble(
  y_1_start = c(
    0.5 * axis_vals, 
    axis_vals, 
    0.5 * axis_vals + 0.5
  ), # 始点のx軸の値
  y_2_start = c(
    sqrt(3) * 0.5 * axis_vals, 
    rep(0, times = length(axis_vals)), 
    sqrt(3) * 0.5 * (1 - axis_vals)
  ), # 始点のy軸の値
  y_1_end = c(
    axis_vals, 
    0.5 * axis_vals + 0.5, 
    0.5 * rev(axis_vals)
  ), # 終点のx軸の値
  y_2_end = c(
    rep(0, times = length(axis_vals)), 
    sqrt(3) * 0.5 * (1 - axis_vals), 
    sqrt(3) * 0.5 * rev(axis_vals)
  ), # 終点のy軸の値
  axis = c("x_1", "x_2", "x_3") |> 
    rep(each = length(axis_vals)) # 元の軸
)

# 軸ラベル用の値を作成
ternary_axislabel_df <- tibble::tibble(
  y_1 = c(0.25, 0.5, 0.75),               # x軸の値
  y_2 = c(0.25*sqrt(3), 0, 0.25*sqrt(3)), # y軸の値
  label = c("alpha[1]", "alpha[2]", "alpha[3]"),      # 軸ラベル
  h = c(3, 0.5, -2),  # 水平方向の調整用の値
  v = c(0.5, 3, 0.5), # 垂直方向の調整用の値
  axis = c("x_1", "x_2", "x_3") # 元の軸
)

# 軸目盛ラベル用の値を作成
ternary_ticklabel_df <- tibble::tibble(
  y_1 = c(
    0.5 * axis_vals, 
    axis_vals, 
    0.5 * axis_vals + 0.5
  ), # x軸の値
  y_2 = c(
    sqrt(3) * 0.5 * axis_vals, 
    rep(0, times = length(axis_vals)), 
    sqrt(3) * 0.5 * (1 - axis_vals)
  ), # y軸の値
  label = c(
    rev(axis_vals), 
    axis_vals, 
    rev(axis_vals)
  ), # 軸目盛ラベル
  h = c(
    rep(1.5, times = length(axis_vals)), 
    rep(1.5, times = length(axis_vals)), 
    rep(-0.5, times = length(axis_vals))
  ), # 水平方向の調整用の値
  v = c(
    rep(0.5, times = length(axis_vals)), 
    rep(0.5, times = length(axis_vals)), 
    rep(0.5, times = length(axis_vals))
  ), # 垂直方向の調整用の値
  angle = c(
    rep(-60, times = length(axis_vals)), 
    rep(60, times = length(axis_vals)), 
    rep(0, times = length(axis_vals))
  ), # ラベルの表示角度
  axis = c("x_1", "x_2", "x_3") |> 
    rep(each = length(axis_vals)) # 元の軸
)
  
# サンプルを三角座標に変換して格納(q_gamma)
data_df_q_gamma <- tibble::tibble(
  y_1 = q_gamma[, 2] + 0.5 * q_gamma[, 3], # 三角座標のx軸の値
  y_2 = sqrt(3) * 0.5 * q_gamma[, 3] # 三角座標のy軸の値
)


# 三角座標の値を作成
y_1_vals <- seq(from = 0, to = 1, length.out = 301)
y_2_vals <- seq(from = 0, to = 0.5*sqrt(3), length.out = 300)

# 格子点を作成
y_mat <- tidyr::expand_grid(
  y_1 = y_1_vals, 
  y_2 = y_2_vals
) |> # 格子点を作成
  as.matrix() # マトリクスに変換

# 3次元変数に変換
phi_mat <- tibble::tibble(
  phi_2 = y_mat[, 1] - y_mat[, 2] / sqrt(3), 
  phi_3 = 2 * y_mat[, 2] / sqrt(3)
) |> # 元の座標に変換
  dplyr::mutate(
    phi_2 = dplyr::if_else(phi_2 >= 0 & phi_2 <= 1, true = phi_2, false = as.numeric(NA)), 
    phi_3 = dplyr::if_else(phi_3 >= 0 & phi_3 <= 1 & !is.na(phi_2), true = phi_3, false = as.numeric(NA)), 
    phi_1 = 1 - phi_2 - phi_3, 
    phi_1 = dplyr::if_else(phi_1 >= 0 & phi_1 <= 1, true = phi_1, false = as.numeric(NA))
  ) |> # 範囲外の値をNAに置換
  dplyr::select(phi_1, phi_2, phi_3) |> # 順番を変更
  as.matrix() # マトリクスに変換


# パラメータラベル用の文字列を作成
param_text <- paste0(
  "list(", 
  "alpha==(list(", paste0(alpha, collapse = ", "), "))", 
  ", N==", N, 
  ")"
)

  
# サンプルの度数のヒートマップを作成
ggplot() + 
  geom_segment(data = ternary_axis_df, 
               mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), 
               color = "gray50") + # 三角図の枠線
  geom_segment(data = ternary_grid_df, 
               mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), 
               color = "gray50", linetype = "dashed") + # 三角図のグリッド線
  geom_text(data = ternary_ticklabel_df, 
            mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v, angle = angle)) + # 三角図の軸目盛ラベル
  geom_text(data = ternary_axislabel_df, 
            mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v), 
            parse = TRUE, size = 6) + # 三角図の軸ラベル
  geom_bin_2d(data = data_df_q_gamma, 
              mapping = aes(x = y_1, y = y_2, fill = ..count..), 
              alpha = 0.8) + # サンプル
  scale_fill_distiller(palette = "Spectral") + # 塗りつぶしの色
  scale_x_continuous(breaks = c(0, 0.5, 1), labels = NULL) + # x軸
  scale_y_continuous(breaks = c(0, 0.25*sqrt(3), 0.5*sqrt(3)), labels = NULL) + # y軸
  coord_fixed(ratio = 1, clip = "off") + # アスペクト比
  theme(axis.ticks = element_blank(), 
        panel.grid.minor = element_blank()) + # 図の体裁
  labs(title = "Dirichlet Distribution (Gamma RVs method)", 
       subtitle = parse(text = param_text), 
       fill = "frequency", 
       x = "", y = "") +
  theme_bw()



# サンプルを三角座標に変換して格納(q_mcmcpack) data_df_q_mcmcpack <- tibble::tibble( y_1 = q_mcmcpack[, 2] + 0.5 * q_mcmcpack[, 3], # 三角座標のx軸の値 y_2 = sqrt(3) * 0.5 * q_mcmcpack[, 3] # 三角座標のy軸の値 ) # サンプルの度数のヒートマップを作成 ggplot() + geom_segment(data = ternary_axis_df, mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), color = "gray50") + # 三角図の枠線 geom_segment(data = ternary_grid_df, mapping = aes(x = y_1_start, y = y_2_start, xend = y_1_end, yend = y_2_end), color = "gray50", linetype = "dashed") + # 三角図のグリッド線 geom_text(data = ternary_ticklabel_df, mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v, angle = angle)) + # 三角図の軸目盛ラベル geom_text(data = ternary_axislabel_df, mapping = aes(x = y_1, y = y_2, label = label, hjust = h, vjust = v), parse = TRUE, size = 6) + # 三角図の軸ラベル geom_bin_2d(data = data_df_q_mcmcpack, mapping = aes(x = y_1, y = y_2, fill = ..count..), alpha = 0.8) + # サンプル scale_fill_distiller(palette = "Spectral") + # 塗りつぶしの色 scale_x_continuous(breaks = c(0, 0.5, 1), labels = NULL) + # x軸 scale_y_continuous(breaks = c(0, 0.25*sqrt(3), 0.5*sqrt(3)), labels = NULL) + # y軸 coord_fixed(ratio = 1, clip = "off") + # アスペクト比 theme(axis.ticks = element_blank(), panel.grid.minor = element_blank()) + # 図の体裁 labs(title = "Dirichlet Distribution (MCMCpack)", subtitle = parse(text = param_text), fill = "frequency", x = "", y = "") + theme_bw()


設定したパラメータから計算される平均

alpha/sum(alpha)
## [1] 0.1666667 0.3333333 0.5000000

自分でガンマ乱数から生成したものの平均値

colSums(q_gamma)/N
## [1] 0.1645176 0.3340670 0.5014154

{MCMCpack}パッケージの関数を使用して生成したものの平均値

colSums(q_mcmcpack)/N
## [1] 0.1665876 0.3357620 0.4976504

グラフ, 平均値の比較をみてみたところ, 合っていそう.

MCMCpack::rdirichletの中身

さて, MCMCpack::rdirichletの中身を確認する.

MCMCpack::rdirichlet
## function (n, alpha) 
## {
##     l <- length(alpha)
##     x <- matrix(rgamma(l * n, alpha), ncol = l, byrow = TRUE)
##     sm <- x %*% rep(1, l)
##     return(x/as.vector(sm))
## }
## <bytecode: 0x0000022e10de5e00>
## <environment: namespace:MCMCpack>

実はMCMCpack::rdirichletは, ガンマ分布の乱数から同じように生成していることが分かる.
xで一括でガンマ乱数を生成して, matrixに折りたたむ.
smは行列の積で行和を計算.
smをベクトル化して正規化して終了.
Rで実装するなら, 無駄にループするよりも当然こちらの方が効率が良い.

【学習予定】

いつかディリクレ過程周りも理解してまとめておきたい. 

関係ないけど, ディリクレ分布は単体(simplex)上の話で, ゲーム理論を少し思い出したりもした.