下面代码是VP 树的 C++ 实现,递归search()
函数决定是搜索左孩子、右孩子还是两个孩子。为了有效地维护结果列表,我们使用优先级队列。
// A VP-Tree implementation, by Steve Hanov. (steve.hanov@gmail.com)
// Released to the Public Domain
// Based on "Data Structures and Algorithms for Nearest Neighbor Search" by Peter N. Yianilos
#include <stdlib.h>
#include <algorithm>
#include <vector>
#include <stdio.h>
#include <queue>
#include <limits>
template<typename T, double (*distance)( const T&, const T& )>
class VpTree
{
public:
VpTree() : _root(0) {}
~VpTree() {
delete _root;
}
void create( const std::vector& items ) {
delete _root;
_items = items;
_root = buildFromPoints(0, items.size());
}
void search( const T& target, int k, std::vector* results,
std::vector<double>* distances)
{
std::priority_queue<HeapItem> heap;
_tau = std::numeric_limits::max();
search( _root, target, k, heap );
results->clear(); distances->clear();
while( !heap.empty() ) {
results->push_back( _items[heap.top().index] );
distances->push_back( heap.top().dist );
heap.pop();
}
std::reverse( results->begin(), results->end() );
std::reverse( distances->begin(), distances->end() );
}
private:
std::vector<T> _items;
double _tau;
struct Node
{
int index;
double threshold;
Node* left;
Node* right;
Node() :
index(0), threshold(0.), left(0), right(0) {}
~Node() {
delete left;
delete right;
}
}* _root;
struct HeapItem {
HeapItem( int index, double dist) :
index(index), dist(dist) {}
int index;
double dist;
bool operator<( const HeapItem& o ) const {
return dist < o.dist;
}
};
struct DistanceComparator
{
const T& item;
DistanceComparator( const T& item ) : item(item) {}
bool operator()(const T& a, const T& b) {
return distance( item, a ) < distance( item, b );
}
};
Node* buildFromPoints( int lower, int upper )
{
if ( upper == lower ) {
return NULL;
}
Node* node = new Node();
node->index = lower;
if ( upper - lower > 1 ) {
// choose an arbitrary point and move it to the start
int i = (int)((double)rand() / RAND_MAX * (upper - lower - 1) ) + lower;
std::swap( _items[lower], _items[i] );
int median = ( upper + lower ) / 2;
// partitian around the median distance
std::nth_element(
_items.begin() + lower + 1,
_items.begin() + median,
_items.begin() + upper,
DistanceComparator( _items[lower] ));
// what was the median?
node->threshold = distance( _items[lower], _items[median] );
node->index = lower;
node->left = buildFromPoints( lower + 1, median );
node->right = buildFromPoints( median, upper );
}
return node;
下面是具体的调用,需要用到下面的数据
#include "VpTree.h"
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <stdint.h>
#include <string>
#include <string.h>
#include <math.h>
#define DIM 200
#define NUM 32000
void QueryPerformanceCounter( uint64_t* val )
{
timeval tv;
struct timezone tz = {0, 0};
gettimeofday( &tv, &tz );
*val = tv.tv_sec * 1000000 + tv.tv_usec;
}
struct Point {
std::string city;
double latitude;
double longitude;
};
double distance( const Point& p1, const Point& p2 )
{
double a = (p1.latitude-p2.latitude);
double b = (p1.longitude-p2.longitude);
return sqrt(a*a+b*b);
}
struct HeapItem {
HeapItem( int index, double dist) :
index(index), dist(dist) {}
int index;
double dist;
bool operator<( const HeapItem& o ) const {
return dist < o.dist;
}
};
void linear_search( const std::vector<Point>& items, const Point& target, int k, std::vector<Point>* results,
std::vector<double>* distances)
{
std::priority_queue<HeapItem> heap;
for ( int i = 0; i < items.size(); i++ ) {
double dist = distance( target, items[i] );
if ( heap.size() < k || dist < heap.top().dist ) {
if (heap.size() == k) heap.pop();
heap.push( HeapItem( i, dist ) );
}
}
results->clear();
distances->clear();
while( !heap.empty() ) {
results->push_back( items[heap.top().index] );
distances->push_back( heap.top().dist );
heap.pop();
}
std::reverse( results->begin(), results->end() );
std::reverse( distances->begin(), distances->end() );
}
int main( int argc, char* argv[] ) {
std::vector<Point> points;
printf("Reading cities database...\n");
FILE* file = fopen("cities.txt", "rt");
for(;;) {
char buffer[1000];
Point point;
if ( !fgets(buffer, 1000, file ) ) {
fclose( file );
break;
}
point.city = buffer;
size_t comma = point.city.rfind(",");
sscanf(buffer + comma + 1, "%lg", &point.longitude);
comma = point.city.rfind(",", comma-1);
sscanf(buffer + comma + 1, "%lg", &point.latitude);
//printf("%lg, %lg\n", point.latitude, point.longitude);
points.push_back(point);
//if(points.size()>50000)break;
}
VpTree<Point, distance> tree;
uint64_t start, end;
QueryPerformanceCounter( &start );
tree.create( points );
QueryPerformanceCounter( &end );
printf("Create took %d\n", (int)(end-start));
Point point;
point.latitude = 43.466438;
point.longitude = -80.519185;
std::vector<Point> results;
std::vector<double> distances;
QueryPerformanceCounter( &start );
tree.search( point, 8, &results, &distances );
QueryPerformanceCounter( &end );
printf("Search took %d\n", (int)(end-start));
for( int i = 0; i < results.size(); i++ ) {
printf("%s %lg\n", results[i].city.c_str(), distances[i]);
}
printf("---\n");
QueryPerformanceCounter( &start );
linear_search( points, point, 8, &results, &distances );
QueryPerformanceCounter( &end );
printf("Linear search took %d\n", (int)(end-start));
for( int i = 0; i < results.size(); i++ ) {
printf("%s %lg\n", results[i].city.c_str(), distances[i]);
}
return 0;
}
上面程序运行时候的
http://stevehanov.ca/blog/cities.txt.gzhttp://stevehanov.ca/blog/cities.txt.gz