無粋な日々に

頭の中のメモ。分からないことを整理する

Stanでガンマ回帰(動かす編)

回帰問題で目的変数が正の連続値をとる場合、ガンマ回帰は選択肢の一つになります。個人的に適用シーンは多いと思うのですが比較的情報が少ない気がします。本投稿ではトイデータとStanでサクッとモデルを動かしてみたいと思います。なおガンマ回帰はGLM(一般化線形モデル)の枠組みで推定できるためRのglm()でも簡単にfitできます。シンプルにモデリングしたいだけであれば、あえてStanでやる意味はないかもしれません

設定

今回の設定として目的変数\displaystyle{y_i}の期待値\displaystyle{\mu_i}が、説明変数\displaystyle{x_i}のべき乗の線形和 \displaystyle{b_{0} + b_{1}x_{i}^ {a}} で表現できるとします。

\displaystyle{


E[y_i] = \mu_{i} = b_{0} + b_{1}x_{i}^{a}

}

推定するパラメタは\displaystyle{b_0, b_1, a}です。あえて\displaystyle{a}を入れたのは、Stanなら冪乗の項の対数をとるなど工夫せずに直接推定できそうだと思ったからです。またGLMでいうリンク関数はidentity(恒等リンク関数)です。そのため\displaystyle{\mu_i \leq 0}とならないように説明変数の範囲を注意しないといけません。

さて期待値\displaystyle{\mu_i}でガンマ分布に従う\displaystyle{y_i}を考えたいのですが、ガンマ分布はshapeパラメタ:\displaystyle{\alpha}rateパラメタ:\displaystyle{\beta}によって次のように決まります。(\displaystyle{y, \alpha, \beta}は正の実数)

\displaystyle{


\operatorname{Gamma}(y|\alpha, \beta) = \frac{\beta^{\alpha}}{\Gamma(\alpha)}y^{\alpha-1}\operatorname{exp}^{-\beta y}

}

われわれに馴染み深い期待値\displaystyle{\mu}と分散\displaystyle{\phi}は、\displaystyle{\alpha, \beta}で次のように表現できます。正規分布と違って\displaystyle{\mu}\displaystyle{\phi}は互いに強く依存しています。

\displaystyle{


\begin{align}
\mu &= \frac{\alpha}{\beta} \tag{1} \\

\phi &= \frac{\alpha}{\beta^2} \tag{2}
\end{align}

}

ガンマ分布では\displaystyle{\alpha}\displaystyle{\beta}を決める必要があるため、\displaystyle{(1)(2)}より\displaystyle{\alpha}\displaystyle{ \beta}\displaystyle{\mu}\displaystyle{\phi}で表します。

\displaystyle{


\begin{align}
\alpha &= \frac{\mu^2}{\phi} \\
\beta &= \frac{\mu}{\phi} \\
\end{align}

}

今回の設定では\displaystyle{\mu}を説明変数の関数で表します。また\displaystyle{\phi}はGLMの枠組みでは一定と仮定するのが一般的なようです。今回も同様の仮定を置きます。 以上が今回のデータ生成過程の設定になります。モデリングではこれらの関係性と観測データをStanに直接指示すれば推定できそうです。まずはトイデータを生成しましょう。

トイデータ生成

Rを使って学習データを生成します。\displaystyle{a = 0.3, b_0 = 1.0, b_1 = 1.5, \phi = 0.1}としています。

set.seed(1234)

N <- 150

# 推定したいパラメタ
a <- 0.3
b0 <- 1.0
b1 <- 1.5
phi <- 0.1

# 一様分布からxを生成
x <- runif(n=N, 0.1, 20)

# 期待値と線形予測子
mu <- b0 + b1 * x ^ a

# パラメタの変換
shape <- mu ^ 2 / phi
rate <- mu / phi

# ガンマ分布に従う乱数生成
y <- rgamma(n = length(x), shape = shape, rate = rate)

生成したデータをプロットします。\displaystyle{x}\displaystyle{y}の散布図です。

library(tidyverse)

# x vs. y の散布図プロット
g <-list(x = x, y = y, mu = mu) %>% data.frame() %>% ggplot()
g <- g + geom_point(aes(x=x, y=y))
g <- g + geom_line(aes(x=x, y=mu, color='mu'))
g <- g + ggtitle('Toy data $mu = 0.2 + 1.5 * x ^ 0.3'))
g

f:id:messefor:20200905230011p:plain

Stanファイル

先程のデータ生成過程をそのままStanファイルに記載するだけです。これをexample01.stanとして保存します。

data {
  int N;  // 観測数
  vector[N] x;  // 説明変数x
  vector[N] y; // 目的変数y
}

parameters {
    // 推定したいパラメタ
  real<lower=0> a;
  real<lower=0> b0;
  real<lower=0> b1;
  real<lower=0> phi;  // variance of gamma
}

transformed parameters {
  vector<lower=0>[N] alpha;  // ガンマ分布のshapeパラメタ
  vector<lower=0>[N] beta;  // ガンマ分布のratioパラメタ
  vector<lower=0>[N] mu;  // 期待値

  for (n in 1:N) {
    mu[n] = b0 + b1 * pow(x[n], a);
  }

  for (n in 1:N) {
    alpha[n] = pow(mu[n], 2.) / phi;
  }
  
  beta = mu / phi;
}

model {
  y ~ gamma(alpha, beta); // likelihood
}

generated quantities {
  vector<lower=0>[N] y_new;
  for (n in 1:N) {
    y_new[n] = gamma_rng(alpha[n], beta[n]);
  }
}

MCMC実行

rstanを使って事後分布からのランダムサンプリングを実行します。

library(rstan)

options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

data <- list(N = N, x = x, y = y)
fit <- stan(file = 'example01.stan', data = data, seed = 1234, iter = 2000)

divergent transitionsの警告が出ますが、警告の遷移数が少ないのでここでは無視します。

Warning message: “There were 2 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup” Warning message: “Examine the pairs() plot to diagnose sampling problems

警告の詳細は以下を参考にしてください。 messefor.hatenablog.com

収束も問題なさそうです。

f:id:messefor:20200905230148p:plain

パラメタ推定結果

どのパラメタも真の値\displaystyle{a = 0.3, b_0 = 1.0, b_1 = 1.5, \phi = 0.1} を分布の25-75%20-80%の間に含んでいるので、大きく外していないようです。

mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
a 0.2643425 0.0018063278 0.04469031 0.20230515 0.22890738 0.2563850 0.2925427 0.3675151 612.1158 1.0053146
b0 0.6953202 0.0173214550 0.43094516 0.02972518 0.33524315 0.6733556 1.0266480 1.5374601 618.9781 1.0035314
b1 1.8591395 0.0164183020 0.40759706 1.06712580 1.54369900 1.8705854 2.1966878 2.5121268 616.3191 1.0038491
phi 0.1051311 0.0003396671 0.01246250 0.08323707 0.09635969 0.1042309 0.1122532 0.1334601 1346.1805 0.9992741

f:id:messefor:20200905230237p:plain

推定したパラメタからのサンプリング

推定したパラメタからのサンプリングデータを使って、学習データと乖離していないか確認します。

# 推定パラメタからのサンプリングを抽出
is.pred <- str_detect(rownames(result.summay), 'y_new.')
data <- data.frame(result.summay[is.pred,])

colnames(data) <-
  c('mean', 'se_mean', 'sd', 'p2.5', 'p25', 'p50',
      'p75', 'p97.5', 'n_eff', 'Rhat')
data$x <- x
data$y <- y
data$mu <- mu

# 図示
g <- ggplot(data=data)
g <- g + geom_line(aes(x=x, y=mean, color='post_mean')) +
      geom_line(aes(x=x, y=mu, color='mu')) +
      geom_point(aes(x=x, y=y))

g <- g + geom_ribbon(aes(x=x, ymin=p25,ymax=p75), fill="blue", alpha=0.2) +
      geom_ribbon(aes(x=x, ymin=p2.5,ymax=p97.5), fill="blue", alpha=0.2)
g <- g + ggtitle('True and Predicted') + labs(y='y')
g

水色の線が事後分布の期待値です。また濃い青が50%予測区間、薄い青が95%予測区間です。 大きな乖離はなさそうですし、推定できているとしましょう。

f:id:messefor:20200905230330p:plain

コード全体はここに置いてあります。RのコードStanのコード

まとめ

Stanを使ってガンマ回帰を行いました。正規分布を仮定した線形回帰での推定は、目的変数の分布の対称性が仮定されていることや負の値が出てきてしまうため、今回のようなケースではガンマ回帰の方がより自然な選択肢になります。GLMでも推定は行えますが、不確実性やモデルを柔軟に構築したい場合はStanを使うのも手かもしれません。


統計モデリングでは確率分布の知識が要求されますね。日々精進

参考文献