KNN算法JAVA实现

合集下载
  1. 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
  2. 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
  3. 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。
super(); this.index = index; this.distance = distance; this.c = c; }
public int getIndex() { return index; } public void setIndex(int index) { this.index = index;
} catch (Exception e) { e.printStackTrace(); } } }
datafile.txt: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
} public double getDistance() { return distance; } public void setDistance(double distance) { this.distance = distance; } public String getC() { return c; } public void setC(String c) { this.c = c; } }
classCount.put(c, classCount.get(c) + 1); } else { classCount.put(c, 1); } } int maxIndex = -1; int maxCount = 0; Object[] classes = classCount.keySet().toArray(); for (int i = 0; i < classes.length; i++) { if (classCount.get(classes[i]) > maxCount) { maxIndex = i; maxCount = classCount.get(classes[i]); } } return classes[maxIndex].toString();
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
testfile.txt: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
public List getRandKNum(int k, int max) { List rand = new ArrayList(k); for (int i = 0; i < k; i++) { int temp = (int) (Math.random() * max); if (!rand.contains(temp)) {
public String knn(List> datas, List testData, int k) { PriorityQueue pq = new PriorityQueue(k, comparator); List randNum = getRandKNum(k, datas.size()); for (int i = 0; i < k; i++) { int index = randNum.get(i); List currData = datas.get(index); String c = currData.get(currData.size() - 1).toString(); KNNNode node = new KNNNode(index, calDistance(testData, currData), c); pq.add(node); } for (int i = 0; i < datas.size(); i++) { List t = datas.get(i); double distance = calDistance(testData, t); KNNNode top = pq.peek(); if (top.getDistance() > distance) {
}
public static void main(String[] args) { TestKNN t = new TestKNN(); String datafile = new File("").getAbsolutePath() + File.separator + "datafile"; String testfile = new File("").getAbsolutePath() + File.separator + "testfile"; try { List> datas = new ArrayList>(); List> testDatas = new ArrayList>(); t.read(datas, datafile); t.read(testDatas, testfile); KNN knn = new KNN(); for (int i = 0; i < testDatas.size(); i++) {
private Comparator comparator = new Comparator() { public int compare(KNNNode o1, KNNNode o2) { if (o1.getDistance() >= o2.getDistance()) {
return 1; } else { return 0; } } };
pq.remove(); pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString())); } }
return getMostClass(pq); }
private String getMostClass(PriorityQueue pq) { Map classCount = new HashMap(); for (int i = 0; i < pq.size(); i++) { KNNNode node = pq.remove(); String c = node.getC(); if (classCount.containsKey(c)) {
String t[] = data.split(" "); l = new ArrayList(); for (int i = 0; i < t.length; i++) { l.add(Double.parseDouble(t[i])); } datas.add(l); data = br.readLine(); } } catch (Exception e) { e.printStackTrace(); }
public class TestKNN {
public void read(List> datas, String path){ try { BufferedReader br = new BufferedReader(new FileReader(new File("f:/KNNdata/datafile.txt"))); String data = br.readLine(); List l = null; while (data != null) {
} }
KNNNode 类: package KNN; public class KNNNode { private int index; // 元组标号 private double distance; // 与测试元组的距离 private String c; // 所属类别 public KNNNode(int index, double distance, String c) {
rand.add(temp); } else { i--; } } return rand; }
public double calDistance(List d1, List d2) { double distance = 0.00; for (int i = 0; i < d1.size(); i++) { distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i)); } return distance; }
List test = testDatas.get(i); System.out.print("测试元组: "); for (int j = 0; j < test.size(); j++) { System.out.print(test.get(j) + " "); } System.out.print("类别为: "); System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3))))); }
TestKNN 类: package KNN; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.ArrayList; import java.util.List;
knn 算法 java 实现
(2源自文库15-08-25 17:52:03)
转载▼ 标签: 分类: 机器学习 股票
KNN 类: package KNN; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; public class KNN {
相关文档
最新文档