跳至主要內容

SVM求解幼儿园问题

abstiger大约 11 分钟playingmachine learningsvm

这是前段时间在微博上闹的还蛮火热的这样一道幼儿园问题:

kindergarden-problem

很多ML大牛都发表了他们对这一题的看法,基本认为这是一个线性回归(linear regression)的问题,有些大牛还写了些小程序用以证明,结果很漂亮也很让人兴奋~参考博客open in new window

有粗略看了下@TreapDB大牛用python写的BP神经网络求解程序,结果也很给力!也了解到一些关于支持向量机与神经网络的对比,在样本需求和训练时间上SVM似乎还是要占优的, 于是我也忍不住想尝试下用SVM来求解,当然是借助工具了,具体用的是SVM-lightopen in new window,写这篇博文的目的也就是为了记录下我艰辛的求解过程,O(∩_∩)O哈哈~:

思考过程

刚好这段时间尝试看了些机器学习Machine Learning的东东,具体是支持向量机SVMopen in new window这块的内容,因为数学已经忘的差不多了,而且自己的高等代数当时就学的很不好,所以基本上遇到公式、推导什么的都直接跳过了……

关于svm的理论介绍我就不献丑了,在上面wiki链接里有一个External links的模块,里面有很多svm相关的学习指南和工具,其中我觉得Fletcher, Tristan写的这篇SVMExplainedopen in new window还是给像我这样的初学者很多直观的感受。当然这里面还给出了很多svm相关的工具箱链接,包括SVM-light、Libsvm、Shogun……

Step 1、Solve it as a numeric multiclass problem:

说来惭愧,其实一看到这题的时候,我并不知道这是个“数圈圈”的问题,而是直观的理解为一个四维空间里的多值分类问题:根据测试数据的坐标点分布多值分类:0-9,确定w,b以判断待检测坐标点的所属分类。

于是我的训练数据train1.dat为:

1 1:7 2:1 3:1 4:1 #7111
7 1:8 2:8 3:0 4:9 #8809
1 1:2 2:1 3:7 4:2 #2172
5 1:6 2:6 3:6 4:6 #6666
1 1:1 2:1 3:1 4:1 #1111
1 1:2 2:2 3:2 4:2 #2222
3 1:7 2:6 3:6 4:2 #7662
2 1:9 2:3 3:1 4:3 #9313
5 1:0 2:0 3:0 4:0 #0000
1 1:5 2:5 3:5 4:5 #5555
4 1:8 2:1 3:9 4:3 #8193
6 1:8 2:0 3:9 4:6 #8096
4 1:4 2:3 3:9 4:8 #4398
2 1:9 2:4 3:7 4:5 #9475
5 1:9 2:0 3:3 4:8 #9038
3 1:3 2:1 3:4 4:8 #3148

测试数据我加上了训练数据的内容和一条待预测的内容test1.dat:

1 1:7 2:1 3:1 4:1 #7111
7 1:8 2:8 3:0 4:9 #8809
1 1:2 2:1 3:7 4:2 #2172
5 1:6 2:6 3:6 4:6 #6666
1 1:1 2:1 3:1 4:1 #1111
1 1:2 2:2 3:2 4:2 #2222
3 1:7 2:6 3:6 4:2 #7662
2 1:9 2:3 3:1 4:3 #9313
5 1:0 2:0 3:0 4:0 #0000
1 1:5 2:5 3:5 4:5 #5555
4 1:8 2:1 3:9 4:3 #8193
6 1:8 2:0 3:9 4:6 #8096
4 1:4 2:3 3:9 4:8 #4398
2 1:9 2:4 3:7 4:5 #9475
5 1:9 2:0 3:3 4:8 #9038
3 1:3 2:1 3:4 4:8 #3148
6 1:2 2:8 3:8 4:9 #2889

注释:因为是多值分类,所以我用的是SVM-multiclass,又因为SVM-multiclass的“target”要求必须大于等于1,于是我便将所有的target值加了1。

执行命令:

svm_multiclass_learn -c 1 -t 0 train1.dat 1.model
svm_multiclass_classify test1.dat 1.model predict1.dat

然后我依次尝试了四个核,得到的预测结果依次为:

# -t 0(线性核)
Zero/one-error on test set: 52.94% (8 correct, 9 incorrect, 17 total)

#-t 1(多项式核)
Zero/one-error on test set: 11.76% (15 correct, 2 incorrect, 17 total)

#-t 2(径向基核)
Zero/one-error on test set: 5.88% (16 correct, 1 incorrect, 17 total)

#-t 3(sigmoid核)
Zero/one-error on test set: 94.12% (1 correct, 16 incorrect, 17 total)

结果发现采用径向基核对训练用的test数据能完全分类,可是对待测试的2889却分类错误了,期待是6,却归到了7。

虽然这是一个完全错误的方向,但是却给了我两方面启示:

  1. 机器学习想要做到完全智能并不容易,数据的分布及其含义的理解,还是需要人工分析或者进一步的数据挖掘;
  2. 虽然无法通过绘制四维空间图观察到数据坐标点的直观分布,但是通过测试结果,应该可以判定训练数据的点是线性不可分的。 而由非线性到线性空间转换的RBF径向基核函数好给力,怪不得LIBsvm用它做默认核函数了。

Step 2、Solve it as counting “O” multiclass problem:

后来按照微博里说的将此问题转换为让机器去学习数字中带圈的个数后,便参考大牛的做法:“将0-9每个数字出现的次数,作为feature vector”,继续采用多值分类的方式:

于是我的训练数据train.dat为:

1 1:3 2:0 3:0 4:0 5:0 6:0 7:1 8:0 9:0 10:0 #7111
7 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:2 9:1 10:1 #8809
1 1:1 2:2 3:0 4:0 5:0 6:0 7:1 8:0 9:0 10:0 #2172
5 1:0 2:0 3:0 4:0 5:0 6:4 7:0 8:0 9:0 10:0 #6666
1 1:4 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 #1111
1 1:0 2:4 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 #2222
3 1:0 2:1 3:0 4:0 5:0 6:2 7:1 8:0 9:0 10:0 #7662
2 1:1 2:0 3:2 4:0 5:0 6:0 7:0 8:0 9:1 10:0 #9313
5 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:4 #0000
1 1:0 2:0 3:0 4:0 5:4 6:0 7:0 8:0 9:0 10:0 #5555
4 1:1 2:0 3:1 4:0 5:0 6:0 7:0 8:1 9:1 10:0 #8193
6 1:0 2:0 3:0 4:0 5:0 6:1 7:0 8:1 9:1 10:1 #8096
4 1:0 2:0 3:1 4:1 5:0 6:0 7:0 8:1 9:1 10:0 #4398
2 1:0 2:0 3:0 4:1 5:1 6:0 7:1 8:0 9:1 10:0 #9475
5 1:0 2:0 3:1 4:0 5:0 6:0 7:0 8:1 9:1 10:1 #9038
3 1:1 2:0 3:1 4:1 5:0 6:0 7:0 8:1 9:0 10:0 #3148

测试数据test.dat为:

1 1:3 2:0 3:0 4:0 5:0 6:0 7:1 8:0 9:0 10:0 #7111
7 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:2 9:1 10:1 #8809
1 1:1 2:2 3:0 4:0 5:0 6:0 7:1 8:0 9:0 10:0 #2172
5 1:0 2:0 3:0 4:0 5:0 6:4 7:0 8:0 9:0 10:0 #6666
1 1:4 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 #1111
1 1:0 2:4 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 #2222
3 1:0 2:1 3:0 4:0 5:0 6:2 7:1 8:0 9:0 10:0 #7662
2 1:1 2:0 3:2 4:0 5:0 6:0 7:0 8:0 9:1 10:0 #9313
5 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:4 #0000
1 1:0 2:0 3:0 4:0 5:4 6:0 7:0 8:0 9:0 10:0 #5555
4 1:1 2:0 3:1 4:0 5:0 6:0 7:0 8:1 9:1 10:0 #8193
6 1:0 2:0 3:0 4:0 5:0 6:1 7:0 8:1 9:1 10:1 #8096
4 1:0 2:0 3:1 4:1 5:0 6:0 7:0 8:1 9:1 10:0 #4398
2 1:0 2:0 3:0 4:1 5:1 6:0 7:1 8:0 9:1 10:0 #9475
5 1:0 2:0 3:1 4:0 5:0 6:0 7:0 8:1 9:1 10:1 #9038
3 1:1 2:0 3:1 4:1 5:0 6:0 7:0 8:1 9:0 10:0 #3148
6 1:0 2:1 3:0 4:0 5:0 6:0 7:0 8:2 9:1 10:0 #2889

执行命令:

svm_multiclass_learn -c 1000 -t 0 train.dat model
svm_multiclass_classify test.dat model predict.dat

然后我依次尝试了四个核:

#-t 0(线性核)
Zero/one-error on test set: 11.76% (15 correct, 2 incorrect, 17 total)

#-t 1(多项式核)
Zero/one-error on test set: 5.88% (16 correct, 1 incorrect, 17 total)

#-t 2(径向基核)
Zero/one-error on test set: 5.88% (16 correct, 1 incorrect, 17 total)

#-t 3(sigmoid核)
Zero/one-error on test set: 64.71% (6 correct, 11 incorrect, 17 total)

结果发现用多项式核和径向基核时,训练数据都能准确归类,而2889也都被错误的归到7类中了,而期待应该是6(加1的关系)。

这与@TreapDB大牛开始用BPNN把此问题当成分类问题得出的解是一样的……

Step 3、Solve it as Linear regression problem:

后来又听说这好像是个线性回归的问题……,便尝试用SVM-light的Regression来做(通过-z r参数指定):

-z {c,r,p}  - select between classification (c), regression (r), and preference ranking (p) (see [Joachims, 2002c]) (default classification)

于是便将训练数据train.regression.dat改为:

0 1:3 2:0 3:0 4:0 5:0 6:0 7:1 8:0 9:0 10:0 #7111
6 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:2 9:1 10:1 #8809
0 1:1 2:2 3:0 4:0 5:0 6:0 7:1 8:0 9:0 10:0 #2172
4 1:0 2:0 3:0 4:0 5:0 6:4 7:0 8:0 9:0 10:0 #6666
0 1:4 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 #1111
0 1:0 2:4 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 #2222
2 1:0 2:1 3:0 4:0 5:0 6:2 7:1 8:0 9:0 10:0 #7662
1 1:1 2:0 3:2 4:0 5:0 6:0 7:0 8:0 9:1 10:0 #9313
4 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:4 #0000
0 1:0 2:0 3:0 4:0 5:4 6:0 7:0 8:0 9:0 10:0 #5555
3 1:1 2:0 3:1 4:0 5:0 6:0 7:0 8:1 9:1 10:0 #8193
5 1:0 2:0 3:0 4:0 5:0 6:1 7:0 8:1 9:1 10:1 #8096
3 1:0 2:0 3:1 4:1 5:0 6:0 7:0 8:1 9:1 10:0 #4398
1 1:0 2:0 3:0 4:1 5:1 6:0 7:1 8:0 9:1 10:0 #9475
4 1:0 2:0 3:1 4:0 5:0 6:0 7:0 8:1 9:1 10:1 #9038
2 1:1 2:0 3:1 4:1 5:0 6:0 7:0 8:1 9:0 10:0 #3148

测试数据test.regression.dat又增加了两个2467,1893:

0 1:3 2:0 3:0 4:0 5:0 6:0 7:1 8:0 9:0 10:0 #7111
6 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:2 9:1 10:1 #8809
0 1:1 2:2 3:0 4:0 5:0 6:0 7:1 8:0 9:0 10:0 #2172
4 1:0 2:0 3:0 4:0 5:0 6:4 7:0 8:0 9:0 10:0 #6666
0 1:4 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 #1111
0 1:0 2:4 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 #2222
2 1:0 2:1 3:0 4:0 5:0 6:2 7:1 8:0 9:0 10:0 #7662
1 1:1 2:0 3:2 4:0 5:0 6:0 7:0 8:0 9:1 10:0 #9313
4 1:0 2:0 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:4 #0000
0 1:0 2:0 3:0 4:0 5:4 6:0 7:0 8:0 9:0 10:0 #5555
3 1:1 2:0 3:1 4:0 5:0 6:0 7:0 8:1 9:1 10:0 #8193
5 1:0 2:0 3:0 4:0 5:0 6:1 7:0 8:1 9:1 10:1 #8096
3 1:0 2:0 3:1 4:1 5:0 6:0 7:0 8:1 9:1 10:0 #4398
1 1:0 2:0 3:0 4:1 5:1 6:0 7:1 8:0 9:1 10:0 #9475
4 1:0 2:0 3:1 4:0 5:0 6:0 7:0 8:1 9:1 10:1 #9038
2 1:1 2:0 3:1 4:1 5:0 6:0 7:0 8:1 9:0 10:0 #3148
5 1:0 2:1 3:0 4:0 5:0 6:0 7:0 8:2 9:1 10:0 #2889
1 1:0 2:1 3:0 4:1 5:0 6:1 7:1 8:0 9:0 10:0 #2467
3 1:1 2:0 3:1 4:0 5:0 6:0 7:0 8:1 9:1 10:0 #1893

注释: 因为是回归,所以我用的是SVM-light,而回归的target便没有了必须大于等于1限制,所以便将target值都改回正确了。

执行命令:

svm_learn -c 2.0 -t 0 -z r train.regression.dat regression.model
svm_classify test.regression.dat regression.model predict.regression.dat

得到预测结果predict.regression.dat为:

0.041134824
5.9
0.10000005
3.9
-0.017730418
0.10000004
2.0294327
1.1000001
4.1
0.10000004
2.9852837
4.9139628
3.0573582
1.1
4.0147164
2.1000001
4.9
1.1220745
2.9852837

与上面test.regression.dat第一列的target值对比便可发现,四舍五入后结果是完全吻合的。给力!

问题总结

最大的感受就是:机器学习或者说数据挖掘的首要事情是给问题定性,确定数据的分布情况和求解方式!

由于本人的机器学习及SVM理论知识还有对SVM-light工具还都不是很了解,整个实验的过程基本靠蒙和猜测,错误或者不合理之处,还请大家指正!