其实,自动化机器学习另一个相关议题在于迁移学习。也就是说有很多已经训练好的通用模型可以直接调用或在其构架上定向训练。如果是自然语言处理任务,可以用Open AI GPT系列模型,虽然不开源(开源可以用Bert),但通过调用API一样可以用。图像识别可以拿 ResNet 或 VGG16 来用,这里可以省不少事。这些已经训练好的模型都算是大模型,如果想定向优化到一个小数据集或数据领域里可以特异性去训练其中一层或几层,也可以稍微调下构架在后面补上一两个层来收集数据集中信息,鉴于这些都属于深度学习领域,最好还是先去搞明白最近三五年的一些进展再动手。不过实际研究中数据可能既不是图像也不是文字,没有可直接迁移的通用模型,而且数据量也不大,所以不容易受益于机器学习领域的进展。不过鉴于现在自动化机器学习已经涉及了深度学习,所以自动化训练一个模型问题也不大。



# 读取一个代谢组学示例数据,x是所有的特征峰,y是肺癌与否
mzrt <- enviGCMS::getmzrtcsv('pos.csv')
lv <- ifelse(grepl('Control',mzrt$group$sample_group),'control','case')
trainIndex <- sample(1:1005,700)
train <- mzrt$data[, trainIndex]
train <- cbind.data.frame(Y=lv[trainIndex],t(train))
test  <- mzrt$data[,-trainIndex]
test  <- cbind.data.frame(Y=lv[-trainIndex],t(test))
y = 'Y'
pred = setdiff(names(train), y)
#convert variables to factors
train[,y] = as.factor(train[,y])
test[,y] = as.factor(test[,y])
train_h = as.h2o(train)
test_h = as.h2o(test)
# Run AutoML for 20 base models
aml = h2o.automl(x = pred, y = y,
                  training_frame = train_h,
                  max_models = 10,
                  seed = 42
# AutoML Leaderboard
lb <- h2o.get_leaderboard(object = aml, extra_columns = "ALL")
##                                                  model_id    auc logloss  aucpr
## 1 StackedEnsemble_BestOfFamily_1_AutoML_2_20220724_155544 0.8310  0.4942 0.8296
## 2    StackedEnsemble_AllModels_1_AutoML_2_20220724_155544 0.8264  0.5008 0.8267
## 3                          GBM_1_AutoML_2_20220724_155544 0.8182  0.5191 0.8182
## 4                          GBM_4_AutoML_2_20220724_155544 0.8135  0.5224 0.8197
## 5                          GLM_1_AutoML_2_20220724_155544 0.8106  0.5310 0.7991
## 6                          GBM_2_AutoML_2_20220724_155544 0.8063  0.5304 0.8081
##   mean_per_class_error   rmse    mse training_time_ms predict_time_per_row_ms
## 1               0.2450 0.4044 0.1635              547                 0.04390
## 2               0.2544 0.4069 0.1656              655                 0.03897
## 3               0.2843 0.4146 0.1719             1244                 0.01945
## 4               0.2677 0.4170 0.1739             2936                 0.01449
## 5               0.3011 0.4181 0.1748             5889                 0.01053
## 6               0.2782 0.4201 0.1765             2170                 0.01257
##              algo
## 1 StackedEnsemble
## 2 StackedEnsemble
## 3             GBM
## 4             GBM
## 5             GLM
## 6             GBM
## [12 rows x 10 columns]
# prediction
pred <- h2o.predict(aml,test_h[,-1])
caret::confusionMatrix(test$Y, as.data.frame(pred)$predict)
## Confusion Matrix and Statistics
##           Reference
## Prediction case control
##    case      99      51
##    control   18     137
##                Accuracy : 0.774        
##                  95% CI : (0.723, 0.82)
##     No Information Rate : 0.616        
##     P-Value [Acc > NIR] : 3.33e-09     
##                   Kappa : 0.546        
##  Mcnemar's Test P-Value : 0.000117     
##             Sensitivity : 0.846        
##             Specificity : 0.729        
##          Pos Pred Value : 0.660        
##          Neg Pred Value : 0.884        
##              Prevalence : 0.384        
##          Detection Rate : 0.325        
##    Detection Prevalence : 0.492        
##       Balanced Accuracy : 0.787        
##        'Positive' Class : case         
# explain the model
h2o.explain(aml, test_h)
## Leaderboard
## ===========
## > Leaderboard shows models with their metrics. When provided with H2OAutoML object, the leaderboard shows 5-fold cross-validated metrics by default (depending on the H2OAutoML settings), otherwise it shows metrics computed on the newdata. At most 20 models are shown by default.
## |  | model_id | auc | logloss | aucpr | mean_per_class_error | rmse | mse | training_time_ms | predict_time_per_row_ms | algo
## |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
## | **1** |StackedEnsemble_BestOfFamily_1_AutoML_2_20220724_155544 | 0.831017204354158 | 0.49423273504537 | 0.82957926642421 | 0.245028344811131 | 0.404395617724151 | 0.163535815634498 | 547 | 0.043897 | StackedEnsemble | 
## | **2** |StackedEnsemble_AllModels_1_AutoML_2_20220724_155544 | 0.826364376866685 | 0.500763642152662 | 0.826674410682095 | 0.254395708373444 | 0.406881363235919 | 0.16555244374872 | 655 | 0.038967 | StackedEnsemble | 
## | **3** |GBM_1_AutoML_2_20220724_155544 | 0.818198273805116 | 0.519098325610369 | 0.818161321019998 | 0.284320259340623 | 0.414569323179071 | 0.171867723721153 | 1244 | 0.019452 | GBM | 
## | **4** |GBM_4_AutoML_2_20220724_155544 | 0.813463168201153 | 0.522399583342484 | 0.819736816333827 | 0.267700079809773 | 0.416972629142305 | 0.173866173453847 | 2936 | 0.014492 | GBM | 
## | **5** |GLM_1_AutoML_2_20220724_155544 | 0.810628687088095 | 0.530963098927506 | 0.799140535379632 | 0.301088539481154 | 0.418120234490234 | 0.174824530490168 | 5889 | 0.010533 | GLM | 
## | **6** |GBM_2_AutoML_2_20220724_155544 | 0.806292630349106 | 0.530399779705598 | 0.808097284982117 | 0.278198767473815 | 0.420108424358518 | 0.176491088216996 | 2170 | 0.012573 | GBM | 
## | **7** |GBM_5_AutoML_2_20220724_155544 | 0.803375871119558 | 0.53609778763023 | 0.80905995260343 | 0.255708044331449 | 0.421280326474034 | 0.177477113474069 | 2286 | 0.012849 | GBM | 
## | **8** |GBM_3_AutoML_2_20220724_155544 | 0.801232526185011 | 0.539889449123235 | 0.808464403923565 | 0.266243757147911 | 0.424605908974098 | 0.18029017793572 | 2338 | 0.012583 | GBM | 
## | **9** |GBM_grid_1_AutoML_2_20220724_155544_model_1 | 0.78743859995557 | 0.553638512512293 | 0.794143332750029 | 0.296242358419931 | 0.431314041585845 | 0.186031802469116 | 2419 | 0.012655 | GBM | 
## | **10** |DRF_1_AutoML_2_20220724_155544 | 0.785599684052033 | 0.568249606521988 | 0.79898788403003 | 0.333711812669184 | 0.438422374043897 | 0.192214178062287 | 1056 | 0.013243 | DRF | 
## | **11** |XRT_1_AutoML_2_20220724_155544 | 0.75527608421988 | 0.590779332812731 | 0.756576018081061 | 0.30185783987033 | 0.449392172222291 | 0.20195332445467 | 862 | 0.01249 | DRF | 
## | **12** |DeepLearning_1_AutoML_2_20220724_155544 | 0.642542722911987 | 0.765522963151034 | 0.664741919283686 | 0.426007289841121 | 0.509052417909465 | 0.259134364179472 | 315 | 0.015305 | DeepLearning | 
## Confusion Matrix
## ================
## > Confusion matrix shows a predicted class vs an actual class.
## StackedEnsemble_BestOfFamily_1_AutoML_2_20220724_155544
## -------------------------------------------------------
## |  | case | control | Error | Rate
## |:---:|:---:|:---:|:---:|:---:|
## | **case** |113 | 37 | 0.246666666666667 |  =37/150 | 
## | **control** |23 | 132 | 0.148387096774194 |  =23/155 | 
## | **Totals** |136 | 169 | 0.19672131147541 |  =60/305 | 
## Variable Importance
## ===================
## > The variable importance plot shows the relative importance of the most important variables in the model.

## Variable Importance Heatmap
## ===========================
## > Variable importance heatmap shows variable importance across multiple models. Some models in H2O return variable importance for one-hot (binary indicator) encoded versions of categorical columns (e.g. Deep Learning, XGBoost). In order for the variable importance of categorical columns to be compared across all model types we compute a summarization of the the variable importance across all one-hot encoded features and return a single variable importance for the original categorical feature. By default, the models and variables are ordered by their similarity.

## Model Correlation
## =================
## > This plot shows the correlation between the predictions of the models. For classification, frequency of identical predictions is used. By default, models are ordered by their similarity (as computed by hierarchical clustering).

## Interpretable models: GLM_1_AutoML_2_20220724_155544 
## SHAP Summary
## ============
## > SHAP summary plot shows the contribution of the features for each instance (row of data). The sum of the feature contributions and the bias term is equal to the raw prediction of the model, i.e., prediction before applying inverse link function.

## Partial Dependence Plots
## ========================
## > Partial dependence plot (PDP) gives a graphical depiction of the marginal effect of a variable on the response. The effect of a variable is measured in change in the mean response. PDP assumes independence between the feature for which is the PDP computed and the rest.

# plot M264.1215T23.3

h2o.shutdown(prompt = F)


  • 读入数据并至少拆分为训练集与检验集
  • 在训练集上进行自动化机器学习
  • 检查预测模型的表现
  • 在检验集上观察预测效果
  • 解释模型


这里另一个相关主题是可解释性机器学习。我们做研究的预测率高几个百分点意义不大,但要是能说出哪个变量对预测的贡献高低并展示出来就很有意义了。这方面全局影响可以用Partial Dependence Plot(PDP) 来表示,或者给出变量重要性,或者给出 SHAP 贡献值,这里面 PDP 是针对全局预测的,还有些其他的例如 LIME 可以对单一预测给出变量重要性,具体怎么用需要结合你的科学问题来探索。
