服务器之家:专注于服务器技术及软件下载分享
分类导航

PHP教程|ASP.NET教程|JAVA教程|ASP教程|编程技术|正则表达式|C/C++|IOS|C#|Swift|Android|JavaScript|易语言|

服务器之家 - 编程语言 - JAVA教程 - 详解Java实现的k-means聚类算法

详解Java实现的k-means聚类算法

2021-03-19 12:02tianshl JAVA教程

这篇文章主要介绍了详解Java实现的k-means聚类算法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

需求

对MySQL数据库中某个表的某个字段执行k-means算法,将处理后的数据写入新表中。

源码及驱动

kmeans.rar

源码

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import java.sql.*;
import java.util.*;
 
/**
 * @author tianshl
 * @version 2018/1/13 上午11:13
 */
public class Kmeans {
  // 源数据
  private List<Integer> origins = new ArrayList<>();
 
  // 分组数据
  private Map<Double, List<Integer>> grouped;
 
  // 初始质心列表
  private List<Double> cores;
 
  // 数据源
  private String tableName;
  private String colName;
 
  /**
   * 构造方法
   *
   * @param tableName 源数据表名称
   * @param colName  源数据列名称
   * @param cores   质心列表
   */
  private Kmeans(String tableName, String colName,List<Double> cores){
    this.cores = cores;
    this.tableName = tableName;
    this.colName = colName;
  }
 
  /**
   * 重新计算质心
   *
   * @return 新的质心列表
   */
  private List<Double> newCores(){
    List<Double> newCores = new ArrayList<>();
 
    for(List<Integer> v: grouped.values()){
      newCores.add(v.stream().reduce(0, (sum, num) -> sum + num) / (v.size() + 0.0));
    }
 
    Collections.sort(newCores);
    return newCores;
  }
 
  /**
   * 判断是否结束
   *
   * @return bool
   */
  private Boolean isOver(){
    List<Double> _cores = newCores();
    for(int i=0, len=cores.size(); i<len; i++){
      if(!cores.get(i).toString().equals(_cores.get(i).toString())){
        // 使用新质心
        cores = _cores;
        return false;
      }
    }
    return true;
  }
 
  /**
   * 数据分组
   */
  private void setGrouped(){
    grouped = new HashMap<>();
 
    Double core;
    for (Integer origin: origins) {
      core = getCore(origin);
 
      if (!grouped.containsKey(core)) {
        grouped.put(core, new ArrayList<>());
      }
 
      grouped.get(core).add(origin);
    }
  }
 
  /**
   * 选择质心
   *
   * @param num  要分组的数据
   * @return   质心
   */
  private Double getCore(Integer num){
 
    // 差 列表
    List<Double> diffs = new ArrayList<>();
 
    // 计算差
    for(Double core: cores){
      diffs.add(Math.abs(num - core));
    }
 
    // 最小差 -> 索引 -> 对应的质心
    return cores.get(diffs.indexOf(Collections.min(diffs)));
  }
 
  /**
   * 建立数据库连接
   * @return connection
   */
  private Connection getConn(){
    try {
      // URL指向要访问的数据库名mydata
      String url = "jdbc:mysql://localhost:3306/data_analysis_dev";
      // MySQL配置时的用户名
      String user = "root";
      // MySQL配置时的密码
      String password = "root";
 
      // 加载驱动
      Class.forName("com.mysql.jdbc.Driver");
 
      //声明Connection对象
      Connection conn = DriverManager.getConnection(url, user, password);
 
      if(conn.isClosed()){
        System.out.println("连接数据库失败!");
        return null;
      }
      System.out.println("连接数据库成功!");
 
      return conn;
 
    } catch (Exception e) {
      System.out.println("连接数据库失败!");
      e.printStackTrace();
    }
 
    return null;
  }
 
  /**
   * 关闭数据库连接
   *
   * @param conn 连接
   */
  private void close(Connection conn){
    try {
      if(conn != null && !conn.isClosed()) conn.close();
    } catch (Exception e){
      e.printStackTrace();
    }
  }
 
  /**
   * 获取源数据
   */
  private void getOrigins(){
 
    Connection conn = null;
    try {
      conn = getConn();
      if(conn == null) return;
 
      Statement statement = conn.createStatement();
 
      ResultSet rs = statement.executeQuery(String.format("select %s from %s", colName, tableName));
 
      while(rs.next()){
        origins.add(rs.getInt(1));
      }
      conn.close();
    } catch (Exception e){
      e.printStackTrace();
    } finally {
     close(conn);
    }
  }
 
  /**
   * 向新表中写数据
   */
  private void write(){
 
    Connection conn = null;
    try {
      conn = getConn();
      if(conn == null) return;
      
      // 创建表
      Statement statement = conn.createStatement();
 
      // 删除旧数据表
      statement.execute("DROP TABLE IF EXISTS k_means; ");
      // 创建新表
      statement.execute("CREATE TABLE IF NOT EXISTS k_means(`core` DECIMAL(11, 7), `col` INTEGER(11));");
 
      // 禁止自动提交
      conn.setAutoCommit(false);
 
      PreparedStatement ps = conn.prepareStatement("INSERT INTO k_means VALUES (?, ?)");
 
      for(Map.Entry<Double, List<Integer>> entry: grouped.entrySet()){
        Double core = entry.getKey();
        for(Integer value: entry.getValue()){
          ps.setDouble(1, core);
          ps.setInt(2, value);
          ps.addBatch();
        }
      }
 
      // 批量执行
      ps.executeBatch();
 
      // 提交事务
      conn.commit();
 
      // 关闭连接
      conn.close();
    } catch (Exception e){
      e.printStackTrace();
    } finally {
      close(conn);
    }
  }
 
  /**
   * 处理数据
   */
  private void run(){
    System.out.println("获取源数据");
    // 获取源数据
    getOrigins();
 
    // 停止分组
    Boolean isOver = false;
 
    System.out.println("数据分组处理");
    while(!isOver) {
      // 数据分组
      setGrouped();
      // 判断是否停止分组
      isOver = isOver();
    }
 
    System.out.println("将处理好的数据写入数据库");
    // 将分组数据写入新表
    write();
 
    System.out.println("写数据完毕");
  }
 
  public static void main(String[] args){
    List<Double> cores = new ArrayList<>();
    cores.add(260.0);
    cores.add(600.0);
    // 表名, 列名, 质心列表
    new Kmeans("attributes", "attr_length", cores).run();
  }
}

源文件

?
1
Kmeans.java

编译

?
1
javac Kmeans.java

运行

?
1
2
# 指定依赖库
java -Djava.ext.dirs=./lib Kmeans

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。

原文链接:https://my.oschina.net/tianshl/blog/1606526

延伸 · 阅读

精彩推荐