Regurarized Greedy Forest
最近、決定木ベースの手法ではxgboost が主流となってきています。実際、xgboost やrandomForest は手軽に結構良い精度が出るので、まずはじめに試すとしたらこのあたりの手法かなと思います。
Regurarized Greedy Forest (以下、RGF と略す)は、C++ で書かれていることとトレーニングに時間がかかるため、あまり普及はしていないように感じます。ただ精度に関してはxgboost より良いことも多い印象があります。
Regularized greedy forest (RGF) in C++
こちらからダウンロードすることができます。RGF の使い方やアルゴリズムについては付属のpdf に詳しく書かれています。アルゴリズムについては時間があるときに追記しようと思います。
- コンペでの使用状況
- metric とloss の目安
- 使用上の注意点
- パラメータ
- python のwrapper
- R のwrapper
コンペでの使用状況
- 実際に、kaggle のコンペでもときどき使用されていて、2 値分類ではWest Nile Virus Prediction のコンペで2 位の方がRGF を用いていました。
- またHiggsBoson のコンペの2位の方も使用されていました。
GitHub - TimSalimans/HiggsML: My second place solution to the Higgs Boson Machine Learning Challenge
- 多クラス分類でも使用されている方がいました。otto のコンペで、商品を9 つのクラスに分類する問題で、stacked generalization の1st level に使用されていました。コードも公開されていますので、参考になると思います。
otto_2015/model.rgf.stack.R at master · diefimov/otto_2015 · GitHub
metric とloss の目安
以後、RGFの使い方を中心に見ていきます。まず評価関数に対してどのようにfit させるかについてです。
rgf のloss はLog, LS, Exp の3つから選択することになりますが、毎回3 つとも試すのは面倒なので第一候補の目安としては、経験的に次の表のように選べばよさそうです。
metirc | loss |
---|---|
logloss | LS (or Log) |
auc | Log(or LS) |
rmse | LS |
使用上の注意点:
クラス分類では、そのクラスに属するかどうかが +1 と-1 で表現されています。(ただし、wrapper を使う分には気にしなくてもOK。)
perl を使うため事前にインストールする必要があります。
パラメータ
reg_L2: 正則化パラメータ
1, 0.1, 0.01 を基本的に試す。
loss: 損失関数(LS, Log, Expo)
LS: square loss, (p-y)^2/2;
Expo: exponential loss, exp(-py);
Log: logistic loss, log(1 + exp(-py));
test_interval: 100
test_interval ごとにモデルをセーブします
max_leaf_forest: 500
葉の数がmax_leaf_forest に到達するまで実行します。
Verbose: 進捗
実装関連
train_x_fn: 訓練データのデータ点
train_y_fn: 訓練データのラベル
test_x_fn: 検証用データのデータ点
model_fn: 予測に用いるモデルのファイル
prediction_fn: 予測したデータを格納するファイル名
SaveLastModelOnly: 最後に実行したモデルだけがセーブされる
model_fn_for_warmstart: 指定したファイルから訓練の続きを行える。(early stopping の実装などに使える)
python のwrapper
python では、便利なwrapper が以下の2 つあります。
GitHub - MLWave/RGF-sklearn: Scikit-learn API toy wrapper for Regularized Greedy Forests
と
GitHub - fukatani/rgf_python: Python Wrapper of Regularized Greedy Forest.
です。後者の方を試してみます。
git clone https://github.com/fukatani/rgf_python.git python setup.py install
なのですが、自分の環境では失敗したため、rgf.py をそのまま読み込むことにしました。実行する前にrgf.py の
loc_exec = 'C:\\Users\\rf\\Documents\\python\\rgf1.2\\bin\\rgf.exe' loc_temp = 'temp/'
この部分を修正する必要があります。
from sklearn import datasets from sklearn.utils.validation import check_random_state from sklearn.cross_validation import StratifiedKFold from rgf import RGFClassifier iris = datasets.load_iris() rng = check_random_state(0) perm = rng.permutation(iris.target.size) iris.data = iris.data[perm] iris.target = iris.target[perm] rgf = RGFClassifier(max_leaf=400, algorithm="RGF_Sib", test_interval=100,) # cross validation rgf_score = 0 n_folds = 3 for train_idx, test_idx in StratifiedKFold(iris.target, n_folds): xs_train = iris.data[train_idx] y_train = iris.target[train_idx] xs_test = iris.data[test_idx] y_test = iris.target[test_idx] rgf.fit(xs_train, y_train) rgf_score += rgf.score(xs_test, y_test) rgf_score /= n_folds print('score: {0}'.format(rgf_score)) # score: 0.959967320261
となって動作していることが確認できます。こちらでは、多クラス分類にも対応しています。(original は2値分類のみ。)
R のwrapper
R でのwrapper はまだ見られない感じで、とりあえず自分のを貼っておきます。
GitHub - puyokw/RGF-R: R wrapper for Regularized Greedy Forest
これをrgf-src.R として使用しています。
path にはrgf1.2 があるディレクトリを指すようにしてください。今のところ、2値分類でMetircs がlogloss の場合のみになっています。(夏休みにもう少し更新していきたいと思っています。auc とrmse(regression) と多クラス分類への対応)
たとえばnumerai のデータを用いてみます。feature はfeature1 からfeature21 まででmetrics はlogloss の2 値分類となっています。
path <- 'C:/Users/KawaseYoshiaki/Desktop/tmp/' source(paste0(path,'rgf-src.R')) train <- read_csv(paste0(path,"numerai_training_data.csv")) test <- read_csv(paste0(path,"numerai_tournament_data.csv")) testId <- test$t_id target <- train$target test$t_id <- NULL train$target <- NULL (tmp <- RGFCV(train,target,nround=500,lambda=1,nfold=5) )# 200, 0.6916646 pred <- RGF(train,test,target,nround=tmp$bestNum*100,lambda=1) submission <- data.frame(t_id=testId, probability=pred$prediction) write_csv(submission,paste0(path,'rgf.csv'))
検証の間隔が100 ごとになっているため、今のところbestNum の値が1/100 で出力しているため、pred のときのnround は100 倍しています。
時間があるときにまた追記していきたいと思っています。