[Stan]ロジスティック回帰の階層ベイズモデルとk-foldsクロスバリデーション

はじめに

stanは意思決定のための分析などでのパラメータ推定に使うことが多く、機械学習のために扱うことはありませんでした。ただ、過去にリク面などでお話したデータサイエンティストの方はstanで機械学習していて、クロスバリデーションもしているとの発言をされていました。
先日、記事を漁っていたらstanでクロスバリデーションを行うためのコードを書いている方を見つけたので、その方のコードをもとに大人のirisであるwineデータを用いて、質の高いワインかどうかを分類するために階層ベイズでのロジスティック回帰モデルを回してみたいと思います。

データについて

UCI Machine Learning Repositoryにある、赤ワインの評価と成分のデータです。データに関する説明はワインの味(美味しさのグレード)は予測できるか?(1)で丁寧になされていますので、ご確認ください。今回は6点以上であれば1を、そうでなければ0を取るものを教師データとしています。

分析方針

今回は階層ベイズモデルを扱うことから、グループごとにロジスティック回帰のパラメータが異なるという仮定を置きます。そのために、citric.acidというデータを3つのカテゴリデータに変換して、それをグループとして扱います。モデルでは、今回のデータセットの変数を全て回帰項として使います。もちろん、回帰用の式からはcitric.acidは除外します。
全体の80%を訓練データに、20%をテストデータとして、10foldsクロスバリデーションでのstanによる予測結果の平均AUCを評価指標とします。最後に、テストデータを用いた予測のAUCを確かめます。また、階層ベイズモデルではないモデルでの10foldsクロスバリデーションでのAUCとも比較します

分析概要

・データ整形
・訓練データとテストデータの分割
・クロスバリデーション用のデータの作成
・stanの実行
・クロスバリデーション結果の出力
・テストデータでの予測
・非階層モデルとの比較

全体のコード以下のリンクにあります。
kick_logistic_regression_allowing_k_hold_cross_validation_hierachical.R
logistic_regression_allowing_k_fold_cross_validation_hierachical.stan

データ整形

階層ベイズで扱うグループをcitric.acidから作っています。

訓練データとテストデータの分割

ワインの質に関するバイナリーデータをこちらで作成し、80%を訓練データに、20%をテストデータに分割しています。

クロスバリデーション用のデータの作成

こちらのコードでは任意の数でクロスバリデーション用のデータを作成し、stanで扱う訓練用データのlistに追加しています。
また、参考にしているブログより転用したstan_kfoldという関数を定義しています。k分割した際のstanの推定結果をリストに格納するための関数です。

stanの実行

こちらのstanのコードでは、M個のグループごとにパラメータが異なるというモデルを書いています。modelブロックの途中でholdoutを入れることで一部のデータを推定に使わないようにしています。

こちらはstanをキックするためのコードです。いつもと違い、先程定義したstan_kfoldを用いています。

クロスバリデーション結果の出力

以下は、k個ずつ手に入ったクロスバリデーションでの推定結果から、各パラメータの平均値を計算し、ロジスティック回帰モデルで2値の予測を行い、平均AUCを計算するコードです。

平均AUCは0.675となりました。すごくいいわけではないですが、手抜きモデルとしてはまずまずと言ったところでしょうか。

テストデータでの予測

以下のコードで最初に分けていたテストデータでの予測結果を返します。

実行の結果、AUCは0.665と、クロスバリデーションでの平均AUCと比べてあまり下がりませんでした。

非階層モデルとの比較

非階層モデルでも同様に10foldsクロスバリデーションの平均AUCを計算しました。非階層モデルよりもAUCが1%ポイントくらいは高いようです。

おわりに

現時点において、stanでの柔軟なモデリングを機械学習に活かす作法について紹介されている文献はあまりなく、選手人口はどれくらいいるのか気になるところですが、発見したブログのやり方でクロスバリデーションをカジュアルに行えるので、より多くの方がstanでの機械学習にチャレンジしうるものだなと思いました。ただ、このレベルの階層ベイズだとrstanarmで簡単にできてしまうので、より深く分析してモデルをカスタムしていきたいですね。

参考情報

[1]Lionel Hertzog (2018), “K-fold cross-validation in Stan,datascienceplus.com”
[2]Alex Pavlakis (2018), “Making Predictions from Stan models in R”, Medium
[3]Richard McElreath (2016), “Statistical Rethinking: A Bayesian Course with Examples in R and Stan (Chapman & Hall/CRC Texts in Statistical Science)”, Chapman and Hall/CRC
[4]松浦 健太郎 (2016), 『StanとRでベイズ統計モデリング (Wonderful R)』, 共立出版
[5]馬場真哉 (2019), 『実践Data Scienceシリーズ RとStanではじめる ベイズ統計モデリングによるデータ分析入門』, 講談社

R Advent Calendar 2017 rvestを用いてポケモンデータをスクレイピング&分析してみた

R Advent Calendar 2017の11日目を担当するMr_Sakaueです。
今回はrvestパッケージを用いて、友人がハマっているポケモンの情報を集めてみようと思います。
もっとも、業務でWebスクレイピングする際はPythonでBeautifulSoupやSeleniumを使うことがほとんどなのですが、たまにはRでやってみようと思います。

目次
・やりたいこと
・rvestについて
・データの取得と集計と可視化と分析
・まとめ
・参考情報

やりたいこと

今回はポケモンたちのデータを集めた上で、以下の内容を行いたいと思います。

  • ポケモンのサイトから種族値を取得
  • ポケモンの種族値を標準化して再度ランキング
  • ポケモンのレア度や経験値に関する情報を取得
  • レア度や経験値と相関しそうな種族値を探る

今回扱った全てのコードはこちらに載せております。
https://github.com/KamonohashiPerry/r_advent_calendar_2017/tree/master

※種族値はゲームにおける隠しパラメータとして設定されている、ポケモンの能力値とされている。

rvestについて

rvestはRでWebスクレイピングを簡単に行えるパッケージです。ここでの説明は不要に思われますが、今回はread_html()、html_nodes()、html_text()、html_attr()の4つ関数を用いました。

基本的に以下の3ステップでWebの情報を取得することができます。

  • STEP1
    read_html()でHTMLからソースコードを取得する。(Pythonでいう、requestとBeautifulSoup)
  • STEP2
    html_nodes()でソースコードから指定した要素を抽出する。(PythonでいうところのfindAll)
  • STEP3
    html_text()やhtml_attr()で抽出した要素からテキストやリンクを抽出する。(Pythonでいうところのget(‘href’)など)

データの取得と集計と可視化

検索エンジンで検索してだいたい1位のサイトがあったので、そちらのWebサイトに載っているポケモンの種族値の一覧をスクレイピング対象とさせていただきます。

  • ポケモンのサイトから種族値を取得

以上のコードを実行すれば、こんな感じでポケモンの種族値一覧を得る事ができます。

とりあえず、種族値合計(Total Tribal Value 以下、TTV)のランキングの上位を確認してみます。知らないんですが、メガミュウツーとかいうイカつそうなポケモンが上位にいるようです。昭和の世代には縁のなさそうなポケモンばかりですねぇ。

■TTVランキング

取得した種族値を項目別に集計したり、Boxプロットを描いてみます。どうやら、攻撃の平均が高く、ヒットポイントや素早さの平均は低いようです。

■種族値のサマリー

■種族値のBoxプロット

  • ポケモンの種族値を標準化して再度ランキング

さて、攻撃の平均が高かったり、ヒットポイントと素早さの平均が低かったりしたので、各々の項目を標準化した上で、再度ランキングを作ってみたいと思います。

■標準化した種族値のサマリー

平均0、分散1にできているようです。

■標準化した種族値のBoxプロット

他よりも低かったヒットポイントと、高かった攻撃がならされていることが確認できます。

■標準化前後でのTTVランキングのギャップが大きかったものをピックアップ

ラッキーが144位ほど出世しています。攻撃が低く、ヒットポイントの高いラッキーが標準化により優遇されるようになったと考える事ができます。ポケモン大会の上位ランカーである後輩社員もラッキーは手強いですと言っていたのでまんざらでもないのでしょう。

  • ポケモンのレア度や経験値に関する情報を取得

今回のサイトには、個別にポケモン別のページが用意されており、そちらから、ゲットしやすさや経験値に関する情報を抽出します。

以上のコードを実行すれば、やや時間がかかりますが、全ポケモンのゲットしやすさや経験値のデータを抽出する事ができます。それらの情報がゲットできたら、まずは可視化します。

■ゲットしやすさのヒストグラム

ゲットのしやすさは、小さいほど捕まえる難易度が高くなっています。難易度の高いポケモンである0が多過ぎるので、このデータは欠損値が0になっているのではないかと疑われます。

■経験値のヒストグラム

経験値は、レベル100になるまでに要する経験値をさしています。ほとんどが100万程度となっているようです。

■ゲットしやすさと標準化TTVの散布図

やはり、ゲットしやすさに関してはデータに不備があるようで、コラッタ(アローラの姿)のような雑魚ポケのゲットのしやすさが0だったり、伝説のポケモンであるネクロズマが255だったりします。ただ、上限と下限のデータを間引けば右下がりの傾向が見られそうです。

■経験値と標準化TTVの散布図

経験値が多く必要にも関わらず、TTVが低い集団があります。どうやらこの集団に属するのは、「キノガッサ」・「マクノシタ」・「イルミーゼ」・「ゴクリン」・「シザリガー」などで、一回しか進化しないポケモンのようです。これらのポケモンは育てにくく、TTVの低い、コスパの悪そうなポケモンと考えることができるのではないでしょうか。(技や特性によってはバリューあるかもしれませんが。)

  • レア度や経験値と相関しそうな種族値を探る

先ほどのレア度に関しては、データがおかしそうだったので、レア度0と255に関しては除外してみます。

■ゲットしやすさと標準化TTVの散布図

やはり除外する事で、理想的な右下がりの傾向を示す散布図が得られたと思います。
さて、各種族値がレア度にどれだけ相関しているのかを分析したいのですが、その前にレア度を表す二項変数を作成します。

■ゲットしやすさが50以下であれば1、それ以外を0にする変数を作成

続いて、各種族値を説明変数として、レア度を目的変数としたロジスティック回帰モデルの推定をrstanで実行させます。

■stanコード

■rstanでロジスティック回帰を行い、推定結果を可視化するコード

■MCMCのシミュレーション結果のトレースプロット

どうやら収束してそうです。

■ロジスティック回帰の推定結果

見にくいので、推定結果を松浦さんの「StanとRでベイズ統計モデリング」にあるコードを用いて可視化します。

■推定結果の可視化

どうやら、0を含まない係数について見てみると、b3(攻撃)、b5(特殊攻撃)、b6(特殊防御)が高いほど、レア度が増す傾向があるようです。珍しいポケモンは攻撃が強いという傾向があると言えるのではないでしょうか。

まとめ

  • rvestは簡単にスクレイピングできて便利。
  • ポケモンデータは色々整備されてそうで今後も分析したら面白そう。
  • 珍しいポケモンは「攻撃」、「特殊攻撃」、「特殊防御」が高い傾向がある。
  • 経験値が必要なのにTTVの低い、コスパの悪そうなポケモンたちがいる。

それでは、どうか良い年末をお過ごし下さい!
メリークリスマス!

参考情報

データサイエンティストのための最新知識と実践 Rではじめよう! [モダン]なデータ分析
StanとRでベイズ統計モデリング (Wonderful R)
【R言語】rvestパッケージによるウェブスクレイピング その2
Receiving NAs when scraping links (href) with rvest

顧客生涯価値(CLV)の計算 with R

顧客生涯価値(CLV:Customer Lifetime Value)を計算してくれるRのコード(Calculating Customer Lifetime Value with Recency, Frequency, and Monetary (RFM))があったので、今更感がありますが取り上げたいと思います。

目次

・顧客生涯価値の数式
・データセット
・関数
・データセットの読み込みと加工
・再購買率とRFMとの関係
・再購買率の推定
・顧客生涯価値の計算
・参考情報

顧客生涯価値の数式

まず、顧客生涯価値の数式は以下の通りです。
customer_lifetime_value

t:年や月などの期間
n:顧客が解約するまでの期間合計
r:保持率(1-解約率)
P(t):t期に顧客から得られる収益
d:割引率

rは数式上では固定ですが、実際にはデモグラ属性(年齢、地理情報、職種など)や行動(RFMなど)や在職中かどうかなどの要因により変わりうるものだと考えられます。参考文献のブログでは、このrのロジスティック回帰による推定がなされています。

データセット

データ名:CDNow
概要:1997年の第一四半期をスタート時点とした顧客の購買行動データ
期間:1997年1月〜1998年6月
顧客数:23570
取引レコード数:69659
変数:顧客ID、購入日、購入金額
入手方法:DatasetsでCDNOW dataset (full dataset)をダウンロード

関数

参考文献にはgetDataFrame関数、getPercentages関数、getCLV関数の三つの関数が出てきますが、CLVの計算に必要なのはgetDataFrame関数、getCLV関数の二つです。getPercentages関数はRecencyなどに応じて細かく分析する際に用います。

getDataFrame関数・・・生のデータセットから、指定した期間に応じたRecencyのデータを作成する関数です。

getPercentages関数・・・Recencyなどの回数に応じて、購入した顧客の割合などを計算するための関数です。

getCLV関数・・・Recency、Frequency、Monetary、購入者の数(1人と置く)、コスト(0としている)、期間、割引率、推定したモデルをもとにCLVを計算する関数です。

データセットの読み込みと加工

再購買率とRFMとの関係

まず初めにデータセットを加工します。ロジットの推定における説明変数用のデータとして19970101〜19980228のデータを用い、被説明変数にあたる購入したかどうかのデータを19980301〜19980430のデータを用いて作ります。

データを確認します。

60日以内に購入した顧客(Recency=0)のうち、45%が再び購入しているようです。

100ドル購入した顧客(Monetary=10)のうち、19%が再び購入しているようです。

10回購入したことのある顧客(Frequency=10)は49%が再び購入しているようです。

再購買率の推定

RFM(Recency、Frequency、Monetary)のデータに基づいて、再購買率をロジスティック回帰によって推定し、予測確率を用いて顧客生涯価値を計算します。

rfm_analysis_est

RecencyとFrequencyによるロジスティック回帰

顧客生涯価値の計算

推定したロジットを用いて、生涯価値を計算します。

それではさっそく、1998年5月〜6月のデータを用いて、今回推定した顧客生涯価値が妥当なのかどうかを確かめたいと思います。

予測したCLVと実際の取引金額データを散布図で描き、回帰線を引く。

life_time_value_estimation

CLVが上がれば、1998年5月1日〜6月30日の間(未来)の取引金額が増すような傾向が出ています。

参考情報

RFM Customer Analysis with R Language
Calculating Customer Lifetime Value with Recency, Frequency, and Monetary (RFM)

ロジスティック回帰分析に関する参考文献

ロジスティック回帰分析に関する参考文献を載せています。
限界効果についてや、多項ロジットなどについての文献もあります。

  • Rのパッケージ
  • Package ‘mfx’
    http://cran.r-project.org/web/packages/mfx/mfx.pdf
    限界効果を計算できるmfxパッケージ。

    Package ‘mlogit’
    http://cran.r-project.org/web/packages/mlogit/mlogit.pdf
    多項ロジットを計算できる。

  • 実行例
  • R Data Analysis Examples: Logit Regression
    http://www.ats.ucla.edu/stat/r/dae/logit.htm

    R Data Analysis Examples: Multinomial Logistic Regression
    http://www.ats.ucla.edu/stat/r/dae/mlogit.htm