Stanでガンマ回帰(動かす編)
回帰問題で目的変数が正の連続値をとる場合、ガンマ回帰は選択肢の一つになります。個人的に適用シーンは多いと思うのですが比較的情報が少ない気がします。本投稿ではトイデータとStanでサクッとモデルを動かしてみたいと思います。なおガンマ回帰はGLM(一般化線形モデル)の枠組みで推定できるためRのglm()
でも簡単にfitできます。シンプルにモデリングしたいだけであれば、あえてStanでやる意味はないかもしれません
設定
今回の設定として目的変数の期待値が、説明変数のべき乗の線形和 で表現できるとします。
推定するパラメタはです。あえてを入れたのは、Stanなら冪乗の項の対数をとるなど工夫せずに直接推定できそうだと思ったからです。またGLMでいうリンク関数はidentity
(恒等リンク関数)です。そのためとならないように説明変数の範囲を注意しないといけません。
さて期待値でガンマ分布に従うを考えたいのですが、ガンマ分布はshape
パラメタ:とrate
パラメタ:によって次のように決まります。(は正の実数)
われわれに馴染み深い期待値と分散は、で次のように表現できます。正規分布と違ってとは互いに強く依存しています。
ガンマ分布ではとを決める必要があるため、よりとを とで表します。
今回の設定ではを説明変数の関数で表します。またはGLMの枠組みでは一定と仮定するのが一般的なようです。今回も同様の仮定を置きます。 以上が今回のデータ生成過程の設定になります。モデリングではこれらの関係性と観測データをStanに直接指示すれば推定できそうです。まずはトイデータを生成しましょう。
トイデータ生成
Rを使って学習データを生成します。としています。
set.seed(1234) N <- 150 # 推定したいパラメタ a <- 0.3 b0 <- 1.0 b1 <- 1.5 phi <- 0.1 # 一様分布からxを生成 x <- runif(n=N, 0.1, 20) # 期待値と線形予測子 mu <- b0 + b1 * x ^ a # パラメタの変換 shape <- mu ^ 2 / phi rate <- mu / phi # ガンマ分布に従う乱数生成 y <- rgamma(n = length(x), shape = shape, rate = rate)
生成したデータをプロットします。との散布図です。
library(tidyverse) # x vs. y の散布図プロット g <-list(x = x, y = y, mu = mu) %>% data.frame() %>% ggplot() g <- g + geom_point(aes(x=x, y=y)) g <- g + geom_line(aes(x=x, y=mu, color='mu')) g <- g + ggtitle('Toy data $mu = 0.2 + 1.5 * x ^ 0.3')) g
Stanファイル
先程のデータ生成過程をそのままStanファイルに記載するだけです。これをexample01.stan
として保存します。
data { int N; // 観測数 vector[N] x; // 説明変数x vector[N] y; // 目的変数y } parameters { // 推定したいパラメタ real<lower=0> a; real<lower=0> b0; real<lower=0> b1; real<lower=0> phi; // variance of gamma } transformed parameters { vector<lower=0>[N] alpha; // ガンマ分布のshapeパラメタ vector<lower=0>[N] beta; // ガンマ分布のratioパラメタ vector<lower=0>[N] mu; // 期待値 for (n in 1:N) { mu[n] = b0 + b1 * pow(x[n], a); } for (n in 1:N) { alpha[n] = pow(mu[n], 2.) / phi; } beta = mu / phi; } model { y ~ gamma(alpha, beta); // likelihood } generated quantities { vector<lower=0>[N] y_new; for (n in 1:N) { y_new[n] = gamma_rng(alpha[n], beta[n]); } }
MCMC実行
rstan
を使って事後分布からのランダムサンプリングを実行します。
library(rstan) options(mc.cores = parallel::detectCores()) rstan_options(auto_write = TRUE) data <- list(N = N, x = x, y = y) fit <- stan(file = 'example01.stan', data = data, seed = 1234, iter = 2000)
divergent transitions
の警告が出ますが、警告の遷移数が少ないのでここでは無視します。
Warning message: “There were 2 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup” Warning message: “Examine the pairs() plot to diagnose sampling problems
警告の詳細は以下を参考にしてください。 messefor.hatenablog.com
収束も問題なさそうです。
パラメタ推定結果
どのパラメタも真の値 を分布の25-75%20-80%の間に含んでいるので、大きく外していないようです。
mean | se_mean | sd | 2.5% | 25% | 50% | 75% | 97.5% | n_eff | Rhat | |
---|---|---|---|---|---|---|---|---|---|---|
a | 0.2643425 | 0.0018063278 | 0.04469031 | 0.20230515 | 0.22890738 | 0.2563850 | 0.2925427 | 0.3675151 | 612.1158 | 1.0053146 |
b0 | 0.6953202 | 0.0173214550 | 0.43094516 | 0.02972518 | 0.33524315 | 0.6733556 | 1.0266480 | 1.5374601 | 618.9781 | 1.0035314 |
b1 | 1.8591395 | 0.0164183020 | 0.40759706 | 1.06712580 | 1.54369900 | 1.8705854 | 2.1966878 | 2.5121268 | 616.3191 | 1.0038491 |
phi | 0.1051311 | 0.0003396671 | 0.01246250 | 0.08323707 | 0.09635969 | 0.1042309 | 0.1122532 | 0.1334601 | 1346.1805 | 0.9992741 |
推定したパラメタからのサンプリング
推定したパラメタからのサンプリングデータを使って、学習データと乖離していないか確認します。
# 推定パラメタからのサンプリングを抽出 is.pred <- str_detect(rownames(result.summay), 'y_new.') data <- data.frame(result.summay[is.pred,]) colnames(data) <- c('mean', 'se_mean', 'sd', 'p2.5', 'p25', 'p50', 'p75', 'p97.5', 'n_eff', 'Rhat') data$x <- x data$y <- y data$mu <- mu # 図示 g <- ggplot(data=data) g <- g + geom_line(aes(x=x, y=mean, color='post_mean')) + geom_line(aes(x=x, y=mu, color='mu')) + geom_point(aes(x=x, y=y)) g <- g + geom_ribbon(aes(x=x, ymin=p25,ymax=p75), fill="blue", alpha=0.2) + geom_ribbon(aes(x=x, ymin=p2.5,ymax=p97.5), fill="blue", alpha=0.2) g <- g + ggtitle('True and Predicted') + labs(y='y') g
水色の線が事後分布の期待値です。また濃い青が50%予測区間、薄い青が95%予測区間です。 大きな乖離はなさそうですし、推定できているとしましょう。
コード全体はここに置いてあります。Rのコード、Stanのコード
まとめ
Stanを使ってガンマ回帰を行いました。正規分布を仮定した線形回帰での推定は、目的変数の分布の対称性が仮定されていることや負の値が出てきてしまうため、今回のようなケースではガンマ回帰の方がより自然な選択肢になります。GLMでも推定は行えますが、不確実性やモデルを柔軟に構築したい場合はStanを使うのも手かもしれません。
統計モデリングでは確率分布の知識が要求されますね。日々精進