MFEM  v4.6.0
Finite element discretization library
kdtree.hpp
Go to the documentation of this file.
1 // Copyright (c) 2010-2023, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details
11 
12 #ifndef MFEM_KDTREE_HPP
13 #define MFEM_KDTREE_HPP
14 
15 #include <vector>
16 #include <algorithm>
17 #include <fstream>
18 #include <iostream>
19 #include <cmath>
20 #include <tuple>
21 
22 namespace mfem
23 {
24 
25 namespace KDTreeNorms
26 {
27 
28 /// Evaluates l1 norm of a vector.
29 template <typename Tfloat, int ndim>
30 struct Norm_l1
31 {
32  Tfloat operator() (const Tfloat* xx)
33  {
34  Tfloat tm=abs(xx[0]);
35  for (int i=1; i<ndim; i++)
36  {
37  tm=tm+abs(xx[i]);
38  }
39  return tm;
40  }
41 };
42 
43 /// Evaluates l2 norm of a vector.
44 template<typename Tfloat,int ndim>
45 struct Norm_l2
46 {
47  Tfloat operator() (const Tfloat* xx)
48  {
49  Tfloat tm;
50  tm=xx[0]*xx[0];
51  for (int i=1; i<ndim; i++)
52  {
53  tm=tm+xx[i]*xx[i];
54  }
55  return sqrt(tm);
56  }
57 };
58 
59 /// Finds the max absolute value of a vector.
60 template<typename Tfloat,int ndim>
61 struct Norm_li
62 {
63  Tfloat operator() (const Tfloat* xx)
64  {
65  Tfloat tm;
66  if (xx[0]<Tfloat(0.0)) { tm=-xx[0];}
67  else { tm=xx[0];}
68  for (int i=1; i<ndim; i++)
69  {
70  if (xx[i]<Tfloat(0.0))
71  {
72  if (tm<(-xx[i])) {tm=-xx[i];}
73  }
74  else
75  {
76  if (tm<xx[i]) {tm=xx[i];}
77  }
78  }
79  return tm;
80  }
81 };
82 
83 }
84 
85 /// Template class for build KDTree with template parameters Tindex
86 /// specifying the type utilized for indexing the points, Tfloat
87 /// specifying a float type for representing the coordinates of the
88 /// points, integer parameter ndim specifying the dimensionality of the
89 /// space and template function Tnorm for evaluating the distance
90 /// between two points. The KDTree class implements the standard k-d
91 /// tree data structure that can be used to transfer a ParGridFunction
92 /// defined on one MPI communicator to a ParGridFunction/GridFunction
93 /// defined on different MPI communicator. This can be useful when
94 /// comparing a solution computed on m ranks against a solution
95 /// computed with n or 1 rank(s).
96 template <typename Tindex, typename Tfloat, size_t ndim=3,
97  typename Tnorm=KDTreeNorms::Norm_l2<Tfloat,ndim> >
98 class KDTree
99 {
100 public:
101 
102  /// Structure defining a geometric point in the ndim-dimensional
103  /// space. The coordinate type (Tfloat) can be any floating or
104  /// integer type. It can be even a character if necessary. For
105  /// such types users should redefine the norms.
106  struct PointND
107  {
108  /// Geometric point constructor
109  PointND() { std::fill(xx,xx+ndim,Tfloat(0.0));}
110 
111  /// Coordinates of the point
112  Tfloat xx[ndim];
113  };
114 
115  /// Structure defining a node in the KDTree.
116  struct NodeND
117  {
118  /// Defines a point in the ndim-dimensional space
120 
121  /// Defines the attached index
122  Tindex ind;
123  };
124 
125  /// Default constructor
126  KDTree() = default;
127 
128  /// Returns the spatial dimension of the points
130  {
131  return ndim;
132  }
133 
134  /// Data iterator
135  typedef typename std::vector<NodeND>::iterator iterator;
136 
137  /// Returns iterator to beginning of the point cloud
139  {
140  return data.begin();
141  }
142 
143  /// Returns iterator to the end of the point cloud
145  {
146  return data.end();
147  }
148 
149  /// Returns the size of the point cloud
150  size_t size()
151  {
152  return data.size();
153  }
154 
155  /// Clears the point cloud
156  void clear()
157  {
158  data.clear();
159  }
160 
161  /// Builds the KDTree. If the point cloud is modified the tree
162  /// needs to be rebuild by a new call to Sort().
163  void Sort()
164  {
165  SortInPlace(data.begin(),data.end(),0);
166  }
167 
168  /// Adds a new node to the point cloud
169  void AddPoint(PointND& pt, Tindex ii)
170  {
171  NodeND nd;
172  nd.pt=pt;
173  nd.ind=ii;
174  data.push_back(nd);
175  }
176 
177  /// Adds a new node by coordinates and an associated index
178  void AddPoint(Tfloat* xx,Tindex ii)
179  {
180  NodeND nd;
181  for (size_t i=0; i<ndim; i++)
182  {
183  nd.pt.xx[i]=xx[i];
184  }
185  nd.ind=ii;
186  data.push_back(nd);
187  }
188 
189  /// Finds the nearest neighbour index
190  Tindex FindClosestPoint(PointND& pt)
191  {
192  PointS best_candidate;
193  best_candidate.sp=pt;
194  //initialize the best candidate
195  best_candidate.pos =0;
196  best_candidate.dist=Dist(data[0].pt, best_candidate.sp);
197  best_candidate.level=0;
198  PSearch(data.begin(), data.end(), 0, best_candidate);
199  return data[best_candidate.pos].ind;
200  }
201 
202  /// Finds the nearest neighbour index and return the clossest poitn in clp
203  Tindex FindClosestPoint(PointND& pt, PointND& clp)
204  {
205  PointS best_candidate;
206  best_candidate.sp=pt;
207  //initialize the best candidate
208  best_candidate.pos =0;
209  best_candidate.dist=Dist(data[0].pt, best_candidate.sp);
210  best_candidate.level=0;
211  PSearch(data.begin(), data.end(), 0, best_candidate);
212 
213  clp=data[best_candidate.pos].pt;
214  return data[best_candidate.pos].ind;
215  }
216 
217  /// Returns the closest point and the distance to the input point pt.
218  void FindClosestPoint(PointND& pt, Tindex& ind, Tfloat& dist)
219  {
220  PointND clp;
221  FindClosestPoint(pt,ind,dist,clp);
222 
223  }
224 
225  /// Returns the closest point and the distance to the input point pt.
226  void FindClosestPoint(PointND& pt, Tindex& ind, Tfloat& dist, PointND& clp)
227  {
228  PointS best_candidate;
229  best_candidate.sp=pt;
230  //initialize the best candidate
231  best_candidate.pos =0;
232  best_candidate.dist=Dist(data[0].pt, best_candidate.sp);
233  best_candidate.level=0;
234  PSearch(data.begin(), data.end(), 0, best_candidate);
235 
236  ind=data[best_candidate.pos].ind;
237  dist=best_candidate.dist;
238  clp=data[best_candidate.pos].pt;
239  }
240 
241 
242  /// Brute force search - please, use it only for debuging purposes
243  void FindClosestPointSlow(PointND& pt, Tindex& ind, Tfloat& dist)
244  {
245  PointS best_candidate;
246  best_candidate.sp=pt;
247  //initialize the best candidate
248  best_candidate.pos =0;
249  best_candidate.dist=Dist(data[0].pt, best_candidate.sp);
250  Tfloat dd;
251  for (auto iti=data.begin()+1; iti!=data.end(); iti++)
252  {
253  dd=Dist(iti->pt, best_candidate.sp);
254  if (dd<best_candidate.dist)
255  {
256  best_candidate.pos=iti-data.begin();
257  best_candidate.dist=dd;
258  }
259  }
260 
261  ind=data[best_candidate.pos].ind;
262  dist=best_candidate.dist;
263  }
264 
265  /// Finds all points within a distance R from point pt. The indices are
266  /// returned in the vector res and the correponding distances in vector dist.
267  void FindNeighborPoints(PointND& pt,Tfloat R, std::vector<Tindex> & res,
268  std::vector<Tfloat> & dist)
269  {
270  FindNeighborPoints(pt,R,data.begin(),data.end(),0,res,dist);
271  }
272 
273  /// Finds all points within a distance R from point pt. The indices are
274  /// returned in the vector res and the correponding distances in vector dist.
275  void FindNeighborPoints(PointND& pt,Tfloat R, std::vector<Tindex> & res)
276  {
277  FindNeighborPoints(pt,R,data.begin(),data.end(),0,res);
278  }
279 
280  /// Brute force search - please, use it only for debuging purposes
281  void FindNeighborPointsSlow(PointND& pt,Tfloat R, std::vector<Tindex> & res,
282  std::vector<Tfloat> & dist)
283  {
284  Tfloat dd;
285  for (auto iti=data.begin(); iti!=data.end(); iti++)
286  {
287  dd=Dist(iti->pt, pt);
288  if (dd<R)
289  {
290  res.push_back(iti->ind);
291  dist.push_back(dd);
292  }
293  }
294  }
295 
296  /// Brute force search - please, use it only for debuging purposes
297  void FindNeighborPointsSlow(PointND& pt,Tfloat R, std::vector<Tindex> & res)
298  {
299  Tfloat dd;
300  for (auto iti=data.begin(); iti!=data.end(); iti++)
301  {
302  dd=Dist(iti->pt, pt);
303  if (dd<R)
304  {
305  res.push_back(iti->ind);
306  }
307  }
308  }
309 
310 private:
311 
312  /// Functor utilized in the coordinate comparison
313  /// for building the KDTree
314  struct CompN
315  {
316  /// Current coordinate index
317  std::uint8_t dim;
318 
319  /// Constructor for the comparison
320  CompN(std::uint8_t dd):dim(dd) {}
321 
322  /// Compares two points p1 and p2
323  bool operator() (const PointND& p1, const PointND& p2)
324  {
325  return p1.xx[dim]<p2.xx[dim];
326  }
327 
328  /// Compares two nodes n1 and n2
329  bool operator() (const NodeND& n1, const NodeND& n2)
330  {
331  return n1.pt.xx[dim]<n2.pt.xx[dim];
332  }
333  };
334 
335  /// Point for storing tmp data
336  PointND tp;
337  Tnorm fnorm;
338 
339  /// Computes the distance between two nodes
340  Tfloat Dist(const PointND& pt1,const PointND& pt2)
341  {
342  for (size_t i=0; i<ndim; i++)
343  {
344  tp.xx[i]=pt1.xx[i]-pt2.xx[i];
345  }
346  return fnorm(tp.xx);
347  }
348 
349  /// The point cloud is stored in a vector.
350  std::vector<NodeND> data;
351 
352  /// Finds the median for a sequence of nodes starting with itb
353  /// and ending with ite. The current coordinate index is set by cdim.
354  Tfloat FindMedian(typename std::vector<NodeND>::iterator itb,
355  typename std::vector<NodeND>::iterator ite,
356  std::uint8_t cdim)
357  {
358  size_t siz=ite-itb;
359  std::nth_element(itb, itb+siz/2, ite, CompN(cdim));
360  return itb->pt.xx[cdim];
361  }
362 
363  /// Sorts the point cloud
364  void SortInPlace(typename std::vector<NodeND>::iterator itb,
365  typename std::vector<NodeND>::iterator ite,
366  size_t level)
367  {
368  std::uint8_t cdim=(std::uint8_t)(level%ndim);
369  size_t siz=ite-itb;
370  if (siz>2)
371  {
372  std::nth_element(itb, itb+siz/2, ite, CompN(cdim));
373  level=level+1;
374  SortInPlace(itb, itb+siz/2, level);
375  SortInPlace(itb+siz/2+1,ite, level);
376  }
377  }
378 
379  /// Structure utilized for nearest neighbor search (NNS)
380  struct PointS
381  {
382  Tfloat dist;
383  size_t pos;
384  size_t level;
385  PointND sp;
386  };
387 
388  /// Finds the closest point to bc.sp in the point cloud
389  /// bounded between [itb,ite).
390  void PSearch(typename std::vector<NodeND>::iterator itb,
391  typename std::vector<NodeND>::iterator ite,
392  size_t level, PointS& bc)
393  {
394  std::uint8_t dim=(std::uint8_t) (level%ndim);
395  size_t siz=ite-itb;
396  typename std::vector<NodeND>::iterator mtb=itb+siz/2;
397  if (siz>2)
398  {
399  // median is at itb+siz/2
400  level=level+1;
401  if ((bc.sp.xx[dim]-bc.dist)>mtb->pt.xx[dim]) // look on the right only
402  {
403  PSearch(itb+siz/2+1, ite, level, bc);
404  }
405  else if ((bc.sp.xx[dim]+bc.dist)<mtb->pt.xx[dim]) // look on the left only
406  {
407  PSearch(itb,itb+siz/2, level, bc);
408  }
409  else // check all
410  {
411  if (bc.sp.xx[dim]<mtb->pt.xx[dim])
412  {
413  // start with the left portion
414  PSearch(itb,itb+siz/2, level, bc);
415  // and continue to the right
416  if (!((bc.sp.xx[dim]+bc.dist)<mtb->pt.xx[dim]))
417  {
418  PSearch(itb+siz/2+1, ite, level, bc);
419  {
420  // check central one
421  Tfloat dd=Dist(mtb->pt, bc.sp);
422  if (dd<bc.dist)
423  {
424  bc.dist=dd; bc.pos=mtb-data.begin(); bc.level=level;
425  }
426  } // end central point check
427  }
428  }
429  else
430  {
431  // start with the right portion
432  PSearch(itb+siz/2+1, ite, level, bc);
433  // and continue with left
434  if (!((bc.sp.xx[dim]-bc.dist)>mtb->pt.xx[dim]))
435  {
436  PSearch(itb, itb+siz/2, level, bc);
437  {
438  // check central one
439  Tfloat dd=Dist(mtb->pt, bc.sp);
440  if (dd<bc.dist)
441  {
442  bc.dist=dd; bc.pos=mtb-data.begin(); bc.level=level;
443  }
444  } // end central point check
445  }
446  }
447  }
448  }
449  else
450  {
451  // check the nodes
452  Tfloat dd;
453  for (auto it=itb; it!=ite; it++)
454  {
455  dd=Dist(it->pt, bc.sp);
456  if (dd<bc.dist) // update bc
457  {
458  bc.pos=it-data.begin();
459  bc.dist=dd;
460  bc.level=level;
461  }
462  }
463  }
464  }
465 
466  /// Returns distances and indices of the n closest points to a point pt.
467  void NNS(PointND& pt,const int& npoints,
468  typename std::vector<NodeND>::iterator itb,
469  typename std::vector<NodeND>::iterator ite,
470  size_t level,
471  std::vector< std::tuple<Tfloat,Tindex> > & res)
472  {
473  std::uint8_t dim=(std::uint8_t) (level%ndim);
474  size_t siz=ite-itb;
475  typename std::vector<NodeND>::iterator mtb=itb+siz/2;
476  if (siz>2)
477  {
478  // median is at itb+siz/2
479  level=level+1;
480  Tfloat R=std::get<0>(res[npoints-1]);
481  // check central one
482  Tfloat dd=Dist(mtb->pt, pt);
483  if (dd<R)
484  {
485  res[npoints-1]=std::make_tuple(dd,mtb->ind);
486  std::nth_element(res.begin(), res.end()-1, res.end());
487  R=std::get<0>(res[npoints-1]);
488  }
489  if ((pt.xx[dim]-R)>mtb->pt.xx[dim]) // look to the right only
490  {
491  NNS(pt, npoints, itb+siz/2+1, ite, level, res);
492  }
493  else if ((pt.xx[dim]+R)<mtb->pt.xx[dim]) // look to the left only
494  {
495  NNS(pt, npoints, itb, itb+siz/2, level, res);
496  }
497  else // check all
498  {
499  NNS(pt,npoints, itb+siz/2+1, ite, level, res); // right
500  NNS(pt,npoints, itb, itb+siz/2, level, res); // left
501  }
502  }
503  else
504  {
505  Tfloat dd;
506  for (auto it=itb; it!=ite; it++)
507  {
508  dd=Dist(it->pt, pt);
509  if (dd< std::get<0>(res[npoints-1])) // update the list
510  {
511  res[npoints-1]=std::make_tuple(dd,it->ind);
512  std::nth_element(res.begin(), res.end()-1, res.end());
513  }
514  }
515  }
516  }
517 
518  /// Finds the set of indices of points within a distance R of a point pt.
519  void FindNeighborPoints(PointND& pt, Tfloat R,
520  typename std::vector<NodeND>::iterator itb,
521  typename std::vector<NodeND>::iterator ite,
522  size_t level,
523  std::vector<Tindex> & res)
524  {
525  std::uint8_t dim=(std::uint8_t) (level%ndim);
526  size_t siz=ite-itb;
527  typename std::vector<NodeND>::iterator mtb=itb+siz/2;
528  if (siz>2)
529  {
530  // median is at itb+siz/2
531  level=level+1;
532  if ((pt.xx[dim]-R)>mtb->pt.xx[dim]) // look to the right only
533  {
534  FindNeighborPoints(pt, R, itb+siz/2+1, ite, level, res);
535  }
536  else if ((pt.xx[dim]+R)<mtb->pt.xx[dim]) // look to the left only
537  {
538  FindNeighborPoints(pt,R, itb, itb+siz/2, level, res);
539  }
540  else //check all
541  {
542  FindNeighborPoints(pt,R, itb+siz/2+1, ite, level, res); // right
543  FindNeighborPoints(pt,R, itb, itb+siz/2, level, res); // left
544 
545  // check central one
546  Tfloat dd=Dist(mtb->pt, pt);
547  if (dd<R)
548  {
549  res.push_back(mtb->ind);
550  }
551  }
552  }
553  else
554  {
555  Tfloat dd;
556  for (auto it=itb; it!=ite; it++)
557  {
558  dd=Dist(it->pt, pt);
559  if (dd<R) // update bc
560  {
561  res.push_back(it->ind);
562  }
563  }
564  }
565  }
566 
567  /// Finds the set of indices of points within a distance R of a point pt.
568  void FindNeighborPoints(PointND& pt, Tfloat R,
569  typename std::vector<NodeND>::iterator itb,
570  typename std::vector<NodeND>::iterator ite,
571  size_t level,
572  std::vector<Tindex> & res, std::vector<Tfloat> & dist)
573  {
574  std::uint8_t dim=(std::uint8_t) (level%ndim);
575  size_t siz=ite-itb;
576  typename std::vector<NodeND>::iterator mtb=itb+siz/2;
577  if (siz>2)
578  {
579  // median is at itb+siz/2
580  level=level+1;
581  if ((pt.xx[dim]-R)>mtb->pt.xx[dim]) // look to the right only
582  {
583  FindNeighborPoints(pt, R, itb+siz/2+1, ite, level, res, dist);
584  }
585  else if ((pt.xx[dim]+R)<mtb->pt.xx[dim]) // look to the left only
586  {
587  FindNeighborPoints(pt,R, itb, itb+siz/2, level, res, dist);
588  }
589  else // check all
590  {
591  FindNeighborPoints(pt,R, itb+siz/2+1, ite, level, res, dist); // right
592  FindNeighborPoints(pt,R, itb, itb+siz/2, level, res, dist); // left
593 
594  // check central one
595  Tfloat dd=Dist(mtb->pt, pt);
596  if (dd<R)
597  {
598  res.push_back(mtb->ind);
599  dist.push_back(dd);
600  }
601  }
602  }
603  else
604  {
605  Tfloat dd;
606  for (auto it=itb; it!=ite; it++)
607  {
608  dd=Dist(it->pt, pt);
609  if (dd<R) // update bc
610  {
611  res.push_back(it->ind);
612  dist.push_back(dd);
613  }
614  }
615  }
616  }
617 };
618 
619 /// Defines KDTree in 3D
621 
622 /// Defines KDTree in 2D
624 
625 /// Defines KDTree in 1D
627 
628 } // namespace mfem
629 
630 #endif // MFEM_KDTREE_HPP
Tfloat xx[ndim]
Coordinates of the point.
Definition: kdtree.hpp:112
Finds the max absolute value of a vector.
Definition: kdtree.hpp:61
size_t size()
Returns the size of the point cloud.
Definition: kdtree.hpp:150
std::vector< NodeND >::iterator iterator
Data iterator.
Definition: kdtree.hpp:135
void AddPoint(Tfloat *xx, Tindex ii)
Adds a new node by coordinates and an associated index.
Definition: kdtree.hpp:178
KDTree< int, double, 3 > KDTree3D
Defines KDTree in 3D.
Definition: kdtree.hpp:620
void FindNeighborPointsSlow(PointND &pt, Tfloat R, std::vector< Tindex > &res)
Brute force search - please, use it only for debuging purposes.
Definition: kdtree.hpp:297
KDTree< int, double, 2 > KDTree2D
Defines KDTree in 2D.
Definition: kdtree.hpp:623
void FindClosestPoint(PointND &pt, Tindex &ind, Tfloat &dist, PointND &clp)
Returns the closest point and the distance to the input point pt.
Definition: kdtree.hpp:226
void FindNeighborPoints(PointND &pt, Tfloat R, std::vector< Tindex > &res)
Definition: kdtree.hpp:275
Evaluates l1 norm of a vector.
Definition: kdtree.hpp:30
void AddPoint(PointND &pt, Tindex ii)
Adds a new node to the point cloud.
Definition: kdtree.hpp:169
Tindex FindClosestPoint(PointND &pt, PointND &clp)
Finds the nearest neighbour index and return the clossest poitn in clp.
Definition: kdtree.hpp:203
Structure defining a node in the KDTree.
Definition: kdtree.hpp:116
KDTree()=default
Default constructor.
void Sort()
Definition: kdtree.hpp:163
void clear()
Clears the point cloud.
Definition: kdtree.hpp:156
Evaluates l2 norm of a vector.
Definition: kdtree.hpp:45
Tindex ind
Defines the attached index.
Definition: kdtree.hpp:122
iterator end()
Returns iterator to the end of the point cloud.
Definition: kdtree.hpp:144
PointND()
Geometric point constructor.
Definition: kdtree.hpp:109
void FindNeighborPointsSlow(PointND &pt, Tfloat R, std::vector< Tindex > &res, std::vector< Tfloat > &dist)
Brute force search - please, use it only for debuging purposes.
Definition: kdtree.hpp:281
int dim
Definition: ex24.cpp:53
void FindClosestPointSlow(PointND &pt, Tindex &ind, Tfloat &dist)
Brute force search - please, use it only for debuging purposes.
Definition: kdtree.hpp:243
void FindClosestPoint(PointND &pt, Tindex &ind, Tfloat &dist)
Returns the closest point and the distance to the input point pt.
Definition: kdtree.hpp:218
KDTree< int, double, 1 > KDTree1D
Defines KDTree in 1D.
Definition: kdtree.hpp:626
Tfloat operator()(const Tfloat *xx)
Definition: kdtree.hpp:63
PointND pt
Defines a point in the ndim-dimensional space.
Definition: kdtree.hpp:119
int SpaceDimension()
Returns the spatial dimension of the points.
Definition: kdtree.hpp:129
Tfloat operator()(const Tfloat *xx)
Definition: kdtree.hpp:47
void FindNeighborPoints(PointND &pt, Tfloat R, std::vector< Tindex > &res, std::vector< Tfloat > &dist)
Definition: kdtree.hpp:267
iterator begin()
Returns iterator to beginning of the point cloud.
Definition: kdtree.hpp:138
Tindex FindClosestPoint(PointND &pt)
Finds the nearest neighbour index.
Definition: kdtree.hpp:190
Tfloat operator()(const Tfloat *xx)
Definition: kdtree.hpp:32