KNN–K近邻

1、KNN的步骤

  • (1)计算输入数据与训练数据的距离(一般欧几里得距离);
  • (2)从训练集中,选取距离输入数据点最近的k个数据;
  • (3)对于分类任务【常见】,取这k个训练数据类别的众数;对于回归任务,取这k个训练数据值的平均数。
img
特点
  • (1)如上步骤,KNN没有模型训练的过程。需要预测数据时,直接与训练数据集进行计算即可。
  • (2)KNN算法中最重要的超参数就是K的选择,会在下面具体操作中介绍。
  • (3)因为需要计算距离,所以需要进行数值变量标准化,以及类别变量转化(如果有分类变量的话)。
  • (4)KNN在数据量小或者维度较小的情况下效果很好,但不适用于大规模的数据(计算量大)。

关于距离,欧几里得距离,归一化(中心化)

KNN在训练阶段不进行任何计算,直到进入预测阶段之后才进行具体的计算。这在机器学习中是比较少见的,又称为“懒惰”的学习。

2、mlr3建模

1
2
library(mlr3verse)
library(tidyverse)

2.1 糖尿病数据

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20

data(diabetes, package = "mclust")

head(diabetes)
#    class glucose insulin sspg
# 1 Normal      80     356  124
# 2 Normal      97     289  117
# 3 Normal     105     319  143
# 4 Normal      90     356  199
# 5 Normal      90     323  240
# 6 Normal      86     381  157

summary(diabetes)
#   class       glucose       insulin            sspg      
# Chemical:36   Min.   : 70   Min.   :  45.0   Min.   : 10.0  
# Normal  :76   1st Qu.: 90   1st Qu.: 352.0   1st Qu.:118.0  
# Overt   :33   Median : 97   Median : 403.0   Median :156.0  
# 				Mean   :122   Mean   : 540.8   Mean   :186.1  
# 				3rd Qu.:112   3rd Qu.: 558.0   3rd Qu.:221.0  
# 				Max.   :353   Max.   :1568.0   Max.   :748.0

2.2 确定预测目标与训练方法

  • (1)确定预测目的:根据三个指标insulin, sspg, glucose 对糖尿病状态class进行诊断
1
2
3
4
5
6
7
8
task_classif = as_task_classif(diabetes, target = "class")
task_classif$col_roles$stratum = "class"
# <TaskClassif:diabetes> (145 x 4)
# * Target: class
# * Properties: multiclass, strata
# * Features (3):
#   - dbl (3): glucose, insulin, sspg
# * Strata: class
  • (2)确定预测方法:使用KNN分类算法,超参数设为4
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
learner = lrn("classif.kknn", predict_type="prob")
learner$param_set
# <ParamSet>
#             id    class lower upper nlevels default value
# 1:           k ParamInt     1   Inf     Inf       7     7
# 2:    distance ParamDbl     0   Inf     Inf       2      
# 3:      kernel ParamFct    NA    NA      10 optimal      
# 4:       scale ParamLgl    NA    NA       2    TRUE      
# 5:     ykernel ParamUty    NA    NA     Inf              
# 6: store_model ParamLgl    NA    NA       2   FALSE    

##如上默认会对数据进行归一化、默认k=7

2.3 模型训练、预测

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
## 如下使用60%数据训练、40%数据验证
split = partition(task_classif, ratio = 0.6, stratify = T)
learner$train(task_classif, row_ids = split$train)
prediction = learner$predict(task_classif, row_ids = split$test)
prediction$confusion
#           truth
# response   Chemical Normal Overt
#   Chemical       12      1     2
#   Normal          2     29     0
#   Overt           0      0    11

as.data.table(prediction) %>% head
#    row_ids  truth response prob.Chemical prob.Normal prob.Overt
# 1:       1 Normal   Normal    0.00000000   1.0000000          0
# 2:       6 Normal   Normal    0.00000000   1.0000000          0
# 3:      11 Normal   Normal    0.00000000   1.0000000          0
# 4:      14 Normal   Normal    0.01729054   0.9827095          0
# 5:      15 Normal   Normal    0.00000000   1.0000000          0
# 6:      16 Normal Chemical    0.52837210   0.4716279          0

prediction$score(msr("classif.acc"))
#classif.acc 
#  0.9122807 



##确定最终模型及模型预测
learner$train(task_classif)
learner$model
new_data = data.frame(glucose=c(200,300),
					  insulin=c(500,1000),
					  sspg=c(100,50))
learner$predict_newdata(new_data)

2.4 交叉验证优化超参数

  • 交叉验证是将数据分为两部分:训练集+测试集。在训练集中训练模型,在测试集中评估模型的性能,从而避免过拟合的情况。
  • 如果对交叉验证的结果满意,最后就可以使用所有数据(训练集+测试集)来训练模型。
  • 有3种常见的交叉验证方法:(1)留出法(如上2.3);(2)K-折;(3)留一法。K折交叉验证更常用,演示如下。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
search_space = ps(                        #候选超参数
  k = p_int(lower = 3, upper = 20)
)
resampling = rsmp("cv")                   #交叉验证方式
measure = msr("classif.acc")              #评价比较指标,可以多个指标
terminator = trm("none")                  #是否设置提前终止

instance = TuningInstanceSingleCrit$new(
  task = task_classif,
  learner = learner,
  resampling = resampling,
  measure = measure,
  terminator = terminator,
  search_space = search_space
)
#遍历超参数方式
tuner = tnr("grid_search", resolution=search_space$nlevels)
tuner$optimize(instance)
as.data.table(instance$archive)[,c(1,2)]
#      k classif.acc
#  1: 14   0.8858974
#  2:  6   0.9117308
#  3: 12   0.9060897
#  4: 13   0.8983974
#  5: 17   0.8715385
#  6:  7   0.9183974
#  7: 16   0.8792308
#  8: 19   0.8715385
#  9: 20   0.8715385
# 10:  4   0.9040385
# 11:  3   0.9107051
# 12:  8   0.9183974
# 13: 18   0.8715385
# 14:  5   0.9117308
# 15: 10   0.9060897
# 16: 11   0.9060897
# 17: 15   0.8792308
# 18:  9   0.9260897

instance$result_learner_param_vals      #最佳超参数
# $k
# [1] 9
instance$result_y                       #最佳超参数的CV结果
# classif.acc 
# 0.9260897



#使用最佳超参数训练模型
learner$param_set$values$k = instance$result_learner_param_vals$k
learner$train(task_classif)
learner$model