K-means算法

释放双眼,带上耳机,听听看~!

偶然接触到了K-means,在理解之后写下博客记录。

首先,K-means是一种
无监督学习的
聚类算法。什么是聚类算法,聚类就是对大量
未标注的数据集,按数据存在的内部特征特征划分为多个不同的类别。

K-means

算法接受参数k,然后将事先输入的n个数据划分为k个聚类。其中满足条件:同一聚类对象相似度高,不同聚类对象相似度较小。

算法思想

k个点为中心聚类,对靠近的对象类归类,通过迭代,逐渐更新各聚类中心。

算法描述

(1)适当选择c个类的初始中心

(2)在第k个迭代中,对任意样本,求其到各中心的距离,将样本归到距离最短的中心所在的类

(3)利用均值等方法更新中心值

(4)对于所有的聚类中心,如果利用(2)(3)步骤迭代法更新后,值不变,则迭代结束

对于以上文字来说可能不太容易理解,接下来博主会放图结合文字来说明K-means

简单的准备几个点作为数据演示:x1(1,1),x2(3,2),x3(1,3),x4(4,7),x5(6,7),x6(5,10),x7(10,4),x8(11,6),x9(12,4)

将这些点绘制出来(左上角为原点):

K-means算法

将这些点用矩阵表示,每一列表示一个点,第0行表示X值,第一行表示Y值,第二行表示数据所属类别:

K-means算法

执行算法描述的第一步:

适当选择c个类的初始中心,我们打算将数据分为三个类,则需要
任意选取三个点作为聚类中心,我们取(1,1),(3,2),(1,3)作为初始聚类中心

执行算法描述第二步:

在第k个迭代中,对任意样本,求其到各中心的距离,将样本归到距离最短的中心所在的类,即需要求出事先准备好的9个点到每个聚类中心点的距离,我们构建D矩阵用于存储距离

K-means算法

D[i][j]表示第(i+1)个聚类点到第(j+1)个数据点的距离,例如:D[1][4] = 5.831 表示第2个聚类点(3,2)到第5个数据点(6,7)的距离

K-means算法

D矩阵每一行表示所有点到同一聚类中心的距离,每一列表示同一点到不同聚类中心的距离,我们将每一列进行比较,得到最小的距离,即可判断该点属于哪一类。我们构建G矩阵,将D矩阵按列进行比较,找到最小值位置,对应G矩阵位置 置为1,例如:D矩阵第一列0,2.2361,2,最小值为0,对应下标为D[0][0],则G[0][0] = 1, D矩阵第四列,6.7082,5.099,5,最小值为5,对应下标D[2][3],则G[2][3] = 1,得到G矩阵,

K-means算法

每一列表示一个数据,第几行为1表示该数据属于哪一类,例如x1的第一行为1,表示x1当前属于第一类,x2,x5,x7,x8,x9的第二行为1,则这些数据当前属于第二类,x3,x4,x6的第三行为1,则这些数据属于第三类。

执行算法描述第三步:

利用均值方法更新中心值,我们将同一类的数据取均值代替原来的聚类点。当前第一类只有一个数据点(1,1),所以聚类点1为(1/1,1/1)=(1,1),第二类数据有x2,x5,x7,x8,x9,则新的聚类点为((3+6+10+11+12)/5,2+7+4+6+4)/5)=(8.4,4.6),同理,新的第三个聚类点为((1+4+5)/3, (3+7+10)/3)=(3.333,6.666)

执行算法描述第四步:

对于所有的聚类中心,如果利用(2)(3)步骤迭代法更新后,值不变,则迭代结束

由于新的聚类点改变了,需要继续迭代

继续构建距离矩阵D,将所有点对新的聚类中心求距离

K-means算法

继续构建G矩阵,找到数据点所属类别

K-means算法

这次x1,x2,x3属于第一类,x4,x5,x6属于第三类,x7,x8,x9属于第二类,利用均值方法更新中心值,新的聚类点为(1.666 ,2.000)、 ( 11.000,  4.666)、(  5.000,  8.000),发现聚类点再次变化,继续迭代

继续构建距离矩阵D,将所有点对新的聚类中心求距离

K-means算法

继续构建G矩阵,找到数据点所属类别

K-means算法

新的聚类点为(1.666 ,2.000)、 ( 11.000,  4.666)、(  5.000,  8.000),发现聚类点停止改变,停止迭代。

所以我们将x1,x2,x3分为一类,x4,x5,x6分为一类,x7,x8,x9分为一类,将其绘制出来:

K-means算法

以上就是本人对K-means的理解

接下来为C++、opencv对K-means测试的代码


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
1#include<iostream>
2#include<opencv2\opencv.hpp>
3#include<math.h>
4
5using namespace std;
6using namespace cv;
7#define K 3
8//找到两点之间距离
9float getDistance(Point2f A, Point2f B)
10{
11  float distance = 0.0;
12  distance = sqrt(pow(A.x - B.x, 2) + pow(A.y - B.y,2));
13  return distance;
14}
15//找出vector中的最小值
16int getMinIndex(vector<float>data)
17{
18  float index = 0;
19  float min = 10000;
20  for (int i = 0; i < K; i++)
21  {
22      if (data[i] < min)
23      {
24          min = data[i];
25          index = i;
26      }
27  }
28  return index;
29}
30//找出vector中值为1的下标
31vector<int> getIndexIsOne(vector<float>data)
32{
33  vector<int>index;
34  for (int i = 0; i < data.size(); i++)
35  {
36      if (data[i] == 1)
37          index.push_back(i);
38  }
39  return index;
40}
41//是否停止迭代
42bool shouldStop(vector<Point2f>oldCentroids, vector<Point2f>centroids, int iterations, int maxIt)
43{
44  if (iterations > maxIt)
45      return true;
46  return oldCentroids == centroids;
47
48}
49//更新数据类别
50void updateLabels(Mat &dataset, vector<Point2f>points, vector<Point2f>&centroids)
51{
52  //构建D0矩阵  K行N列,用与记录每个点与聚类点的距离
53  Mat D0 = Mat::zeros(Size(points.size(), K), CV_32F);
54  //计算每个点与聚类点之间的距离,D0[i][j]表示第i个数据点与第j个聚类点的距离
55  for (int i = 0; i < K; i++)
56  {
57      for (int j = 0; j < points.size(); j++)
58      {
59          D0.at<float>(i, j) = getDistance(points[j], centroids[i]);
60      }
61  }
62  //构建G矩阵  K行N列,按列进行比较,找到最小值位置,对应G矩阵位置 置为1
63  Mat G0 = Mat::zeros(Size(points.size(), K), CV_32F);
64  for (int i = 0; i < points.size(); i++)
65  {
66      Mat col;
67      //获取每一列数据后,使用reshape转换成vector便于计算
68      D0.colRange(i, i + 1).copyTo(col);
69      //reshape(cn,row)
70      vector<float>colsVec(col.reshape(1, 1));
71      G0.at<float>(getMinIndex(colsVec), i) = 1;
72  }
73  for (int i = 0; i < K; i++)
74  {
75      Mat row;
76      //获取每一行数据后,使用reshape转换成vector便于计算
77      G0.rowRange(i, i + 1).copyTo(row);
78      vector<float>rowsVec(row.reshape(1, 1));
79      vector<int>indexVec;
80      indexVec = getIndexIsOne(rowsVec);
81      int xSum = 0.0;
82      int ySum = 0.0;
83      //利用均值更新聚类点
84      for (int j = 0; j < indexVec.size(); j++)
85      {
86          dataset.at<float>(2, indexVec[j]) = i*1.0;//bug
87          xSum += points[indexVec[j]].x;
88          ySum += points[indexVec[j]].y;
89      }
90      centroids[i].x = xSum*1.0 / indexVec.size();
91      centroids[i].y = ySum*1.0 / indexVec.size();
92  }
93}
94Mat Kmeans(vector<Point2f>points, int classification, int maxIt)
95{
96  //创建一个 3行,N列的数据,多出来的一行用于表示数据类别
97  Mat dataset = Mat::zeros(Size(points.size(), 3), CV_32FC1);
98  for (int i = 0; i < points.size(); i++)
99  {
100     dataset.at<float>(0, i) = points[i].x;
101     dataset.at<float>(1, i) = points[i].y;
102 }
103 vector<Point2f>centroids(3);
104 //初始化聚类点,任意取K个数据作为初始数据
105 for (int i = 0; i < K; i++)
106 {
107     centroids[i] = points[i];
108 }
109 int iterations = 0;
110 //用于比较聚类点是否发生变化
111 vector<Point2f>oldCentroids(3,Point(0,0));
112
113 while (!shouldStop(oldCentroids, centroids, iterations, maxIt))
114 {
115     iterations++;
116     oldCentroids.assign(centroids.begin(), centroids.end());
117
118     updateLabels(dataset, points, centroids);
119 }
120 return dataset;
121}
122//绘图
123void DrawMat(Mat dataset,Mat &drawingBoard)
124{
125
126 for (int i = 0; i < dataset.cols; i++)
127 {
128     if (dataset.at<float>(2, i) == 0)
129         //circle(drawingBoard,Point(dataset.at<float>(1, i),)
130         drawingBoard.at<Vec3b>(dataset.at<float>(0, i), dataset.at<float>(1, i)) = Vec3b(0, 0, 255);
131     else if (dataset.at<float>(2, i) == 1)
132         drawingBoard.at<Vec3b>(dataset.at<float>(0, i), dataset.at<float>(1, i)) = Vec3b(0, 255, 0);
133     else if (dataset.at<float>(2, i) == 2)
134         drawingBoard.at<Vec3b>(dataset.at<float>(0, i), dataset.at<float>(1, i)) = Vec3b(255, 0, 0);
135 }
136 imshow("散点分类图",drawingBoard);
137}
138int main()
139{
140
141 Point2f x1 = Point2f(1, 1);
142 Point2f x2 = Point2f(3, 2);
143 Point2f x3 = Point2f(1, 3);
144 Point2f x4 = Point2f(4, 7);
145 Point2f x5 = Point2f(6, 7);
146 Point2f x6 = Point2f(5, 10);
147 Point2f x7 = Point2f(10, 4);
148 Point2f x8 = Point2f(11, 6);
149 Point2f x9 = Point2f(12, 4);
150 //将数据放入容器中,便于计算
151 vector<Point2f>points;
152 points.push_back(x1);
153 points.push_back(x2);
154 points.push_back(x3);
155 points.push_back(x4);
156 points.push_back(x5);
157 points.push_back(x6);
158 points.push_back(x7);
159 points.push_back(x8);
160 points.push_back(x9);
161
162
163 Mat dataset1 = Mat::zeros(20, 20, CV_8UC3);
164 for (int i = 0; i < points.size(); i++)
165 {
166     dataset1.at<Vec3b>(points[i]) = Vec3b(255, 255, 255);
167 }
168
169 Mat dataset = Kmeans(points, K, 200);
170 Mat drawingBoard = Mat::zeros(20, 20, CV_8UC3);
171 DrawMat(dataset, drawingBoard);
172 waitKey(0);
173 return 0;
174}
175

 

给TA打赏
共{{data.count}}人
人已打赏
安全运维

基于spring boot和mongodb打造一套完整的权限架构(四)【完全集成security】

2021-12-11 11:36:11

安全运维

Ubuntu上NFS的安装配置

2021-12-19 17:36:11

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索