服务器之家:专注于服务器技术及软件下载分享
分类导航

PHP教程|ASP.NET教程|Java教程|ASP教程|编程技术|正则表达式|C/C++|IOS|C#|Swift|Android|JavaScript|易语言|

服务器之家 - 编程语言 - Java教程 - Java实现的KNN算法示例

Java实现的KNN算法示例

2021-05-10 11:22带头大哥不是我 Java教程

这篇文章主要介绍了Java实现的KNN算法,结合实例形式分析了KNN算法的原理及Java定义与使用KNN算法流程、训练数据相关操作技巧,需要的朋友可以参考下

本文实例讲述了java实现的knn算法。分享给大家供大家参考,具体如下:

提起knn算法大家应该都不会陌生,对于数据挖掘来说算是十大经典算法之一。

算法的思想是:对于训练数据集中已经归类的分组,来对于未知的数据进行分组归类。其中是根据该未知点与其训练数据中的点计算距离,求出距离最短的点,并将其归入该点的那一类。

看看算法的工程吧:

1. 准备数据,对数据进行预处理
2. 选用合适的数据结构存储训练数据和测试元组
3. 设定参数,如k
4.维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
5. 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离l 与优先级队列中的最大距离lmax
6. 进行比较。若l>=lmax,则舍弃该元组,遍历下一个元组。若l < lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队                  列。
7. 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。
8. 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。

根据算法的过程我们进行java语言实现:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
package knn;
/**
 * 点的坐标 x 、y
 * @author administrator
 *
 */
public class pointbean {
int x;
int y;
public int getx() {
  return x;
}
public void setx(int x) {
  this.x = x;
}
public int gety() {
  return y;
}
public void sety(int y) {
  this.y = y;
}
public pointbean(int x, int y) {
  super();
  this.x = x;
  this.y = y;
}
public pointbean() {
  super();
}
@override
public string tostring() {
  return "pointbean [x=" + x + ", y=" + y + "]";
}
}

knn算法

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
package knn;
import java.util.arraylist;
/**
 * knn实现的方法
 * @author administrator
 *
 */
public class knnmain {
  public double getpointlength(arraylist<pointbean> list,pointbean bb){
    int b_x=bb.getx();
    int b_y=bb.gety();
    double temp=(b_x -list.get(0).getx())*(b_x -list.get(0).getx())+
        (b_y -list.get(0).gety())*(b_y -list.get(0).gety());
    // 找出最小的距离
    for(int i=1;i<list.size();i++){
      if(temp<((b_x -list.get(i).getx())*(b_x -list.get(i).getx())+
          (b_y -list.get(i).gety())*(b_y -list.get(i).gety()))){
        temp=(b_x -list.get(i).getx())*(b_x -list.get(i).getx())+
            (b_y -list.get(i).gety())*(b_y -list.get(i).gety());
      }
    }
    return math.sqrt(temp);
  }
  /**
   * 获取长度,找出最小的一个进行归类
   * @param list1
   * @param list2
   * @param list3
   * @param bb
   */
  public void getcontent(arraylist<pointbean> list1,arraylist<pointbean> list2,
      arraylist<pointbean> list3,pointbean bb){
    double a=getpointlength(list1,bb);
    double b=getpointlength(list2,bb);
    double c=getpointlength(list3,bb);
    //做出比较
    if(a>b){
      if(b>c){
        system.out.println("这个点:"+bb.getx()+" , "+bb.gety()+" " +"属于c");
      }else {
        system.out.println("这个点:"+bb.getx()+" , "+bb.gety()+" " +"属于b");
      }
    }else {
      if(a>c){
        system.out.println("这个点:"+bb.getx()+" , "+bb.gety()+" " +"属于c");
      }else {
        system.out.println("这个点:"+bb.getx()+" , "+bb.gety()+" " +"属于a");
      }
    }
  }
}

主函数

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package knn;
import java.util.arraylist;
/*
 * 主函数 knn
 */
public class testjava {
  static arraylist< pointbean> lista;
  static arraylist< pointbean> listb;
  static arraylist< pointbean> listc;
  static arraylist< pointbean> listd;
  public static void main(string[] args) {
    //创佳arraylist
    lista=new arraylist<pointbean>();
    listb=new arraylist<pointbean>();
    listc=new arraylist<pointbean>();
    listd=new arraylist<pointbean>();
    //写入数据
    setdate();
    gettestresult();
  }
  /**
   * 得到结果
   */
  private static void gettestresult() {
    //创建对象
    knnmain km=new knnmain();
    for(int i=0;i<listd.size();i++){
      km.getcontent(lista, listb, listc, listd.get(i));
    }
  }
  /**
   * 写入数据
   */
  private static void setdate() {
    //a的坐标点
    int a_x[]={1,1,2,2,1};
    int a_y[]={0,1,1,0,2};
    //b的坐标点
    int b_x[]={2,3,3,3,4};
    int b_y[]={4,4,3,2,3};
    //c的坐标点
    int c_x[]={4,5,5,6,6};
    int c_y[]={1,2,0,2,1};
    // 测试数据
    //b的坐标点
    int d_x[]={3,3,3,0,5};
    int d_y[]={0,1,5,0,1};
    //
    pointbean ba;
    for(int i=0;i<5;i++){
      ba=new pointbean(a_x[i], a_y[i]);
      lista.add(ba);
    }
    //
    pointbean bb ;
    for(int i=0;i<5;i++){
      bb=new pointbean(b_x[i], b_y[i]);
      listb.add(bb);
    }
    //
    pointbean bc ;
    for(int i=0;i<5;i++){
      bc=new pointbean(c_x[i], c_y[i]);
      listc.add(bc);
    }
    //
    pointbean bd ;
    for(int i=0;i<5;i++){
      bd=new pointbean(d_x[i], d_y[i]);
      listd.add(bd);
    }
  }
}

测试的结果:

这个点:3 , 1 属于a
这个点:3 , 5 属于b
这个点:0 , 0 属于a
这个点:5 , 1 属于c

到此简单的knn算法已经实现对于未知点的划分,有助于大家对于knn算法的理解。对于改进knn的一些算法java实现会在后面进行贴出。共同学习共同进步!

希望本文所述对大家java程序设计有所帮助。

原文链接:https://blog.csdn.net/u011015260/article/details/53392194

延伸 · 阅读

精彩推荐