This repository has been archived by the owner on Feb 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathKD_Tree.cpp
134 lines (134 loc) · 4.68 KB
/
KD_Tree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <iostream>
#include <vector>
#include <algorithm>
template <typename Point, typename Value>
struct KDNode {
Point point;
Value data;
KDNode* left;
KDNode* right;
KDNode(const Point& p, const Value& d) : point(p), data(d), left(nullptr), right(nullptr) {}
};
template <typename Point, typename Value>
class KDTree {
public:
KDTree() : root(nullptr), dimensions(0) {}
void insert(const Point& point, const Value& data) {
root = insertNode(root, point, data, 0);
}
Value nearestNeighbor(const Point& query) const {
if (root == nullptr) {
throw std::runtime_error("KD-Tree is empty");
}
KDNode<Point, Value>* nearest = nullptr;
double minDistance = std::numeric_limits<double>::max();
nearestNeighbor(root, query, 0, nearest, minDistance);
return nearest->data;
}
void inOrderTraversal() const {
inOrderTraversal(root);
std::cout << std::endl;
}
void preOrderTraversal() const {
preOrderTraversal(root);
std::cout << std::endl;
}
void postOrderTraversal() const {
postOrderTraversal(root);
std::cout << std::endl;
}
private:
KDNode<Point, Value>* root;
std::size_t dimensions;
KDNode<Point, Value>* insertNode(KDNode<Point, Value>* node, const Point& point, const Value& data, std::size_t depth) {
if (node == nullptr) {
dimensions = point.size();
return new KDNode<Point, Value>(point, data);
}
std::size_t axis = depth % dimensions;
if (point[axis] < node->point[axis]) {
node->left = insertNode(node->left, point, data, depth + 1);
} else {
node->right = insertNode(node->right, point, data, depth + 1);
}
return node;
}
void nearestNeighbor(KDNode<Point, Value>* node, const Point& query, std::size_t depth,
KDNode<Point, Value>*& nearest, double& minDistance) const {
if (node == nullptr) {
return;
}
std::size_t axis = depth % dimensions;
double distance = calculateDistance(query, node->point);
if (distance < minDistance) {
nearest = node;
minDistance = distance;
}
if (query[axis] < node->point[axis]) {
nearestNeighbor(node->left, query, depth + 1, nearest, minDistance);
if (query[axis] + minDistance >= node->point[axis]) {
nearestNeighbor(node->right, query, depth + 1, nearest, minDistance);
}
} else {
nearestNeighbor(node->right, query, depth + 1, nearest, minDistance);
if (query[axis] - minDistance <= node->point[axis]) {
nearestNeighbor(node->left, query, depth + 1, nearest, minDistance);
}
}
}
double calculateDistance(const Point& p1, const Point& p2) const {
double distance = 0.0;
for (std::size_t i = 0; i < dimensions; ++i) {
distance += std::pow(p1[i] - p2[i], 2);
}
return std::sqrt(distance);
}
void inOrderTraversal(const KDNode<Point, Value>* node) const {
if (node != nullptr) {
inOrderTraversal(node->left);
printNode(node);
inOrderTraversal(node->right);
}
}
void preOrderTraversal(const KDNode<Point, Value>* node) const {
if (node != nullptr) {
printNode(node);
preOrderTraversal(node->left);
preOrderTraversal(node->right);
}
}
void postOrderTraversal(const KDNode<Point, Value>* node) const {
if (node != nullptr) {
postOrderTraversal(node->left);
postOrderTraversal(node->right);
printNode(node);
}
}
void printNode(const KDNode<Point, Value>* node) const {
std::cout << "(";
for (std::size_t i = 0; i < dimensions; ++i) {
std::cout << node->point[i];
if (i < dimensions - 1) {
std::cout << ", ";
}
}
std::cout << ") ";
}
};
int main() {
KDTree<std::vector<double>, std::string> kdTree;
kdTree.insert({2.0, 3.0}, "A");
kdTree.insert({5.0, 4.0}, "B");
kdTree.insert({9.0, 6.0}, "C");
kdTree.insert({4.0, 7.0}, "D");
kdTree.insert({8.0, 1.0}, "E");
std::cout << "In-Order Traversal: ";
kdTree.inOrderTraversal();
std::cout << "Pre-Order Traversal: ";
kdTree.preOrderTraversal();
std::cout << "Post-Order Traversal: ";
kdTree.postOrderTraversal();
std::vector<double> queryPoint = {6.0, 3.0};
std::cout << "Nearest neighbor to query point (6.0, 3.0): " << kdTree.nearestNeighbor(queryPoint) << std::endl;
return 0;
}