00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00025 #ifndef __LSHKIT_TOPK__
00026 #define __LSHKIT_TOPK__
00027
00028 #include <vector>
00029 #include <limits>
00030 #include <algorithm>
00031 #include <fstream>
00032
00057 namespace lshkit {
00058
00060
00064 template <typename KEY>
00065 struct TopkEntry
00066 {
00067 KEY key;
00068 float dist;
00069 bool match (const TopkEntry &e) const { return key == e.key; }
00070 bool match (KEY e) const { return key == e; }
00071
00072 TopkEntry (KEY key_, float dist_) : key(key_), dist(dist_) {}
00073 TopkEntry () : dist(std::numeric_limits<float>::max()) { }
00074 void reset () { dist = std::numeric_limits<float>::max(); }
00075
00076 friend bool operator < (const TopkEntry &e1, const TopkEntry &e2)
00077 {
00078 return e1.dist < e2.dist;
00079 }
00080 };
00081
00083
00095 template <class KEY>
00096 class Topk: public std::vector<TopkEntry<KEY> >
00097 {
00098 unsigned K;
00099 float R;
00100 float th;
00101 public:
00102 typedef TopkEntry<KEY> Element;
00103 typedef typename std::vector<TopkEntry<KEY> > Base;
00104
00105 Topk () {}
00106
00107 ~Topk () {}
00108
00110 void reset (unsigned k, float r = std::numeric_limits<float>::max()) {
00111 if (k == 0) throw std::invalid_argument("K MUST BE POSITIVE");
00112 R = th = r;
00113 K = k;
00114 this->resize(k);
00115 for (typename Base::iterator it = this->begin(); it != this->end(); ++it) it->reset();
00116 }
00117
00118 void reset (unsigned k, KEY key, float r = std::numeric_limits<float>::max()) {
00119 if (k == 0) throw std::invalid_argument("K MUST BE POSITIVE");
00120 R = th = r;
00121 K = k;
00122 this->resize(k); for (typename
00123 Base::iterator it = this->begin(); it != this->end(); ++it) {
00124 it->reset(); it->key = key; }
00125 }
00126
00127 void reset (float r) {
00128 K = 0;
00129 R = th = r;
00130 this->clear();
00131 }
00132
00133 float threshold () const {
00134 return th;
00135 }
00136
00138 Topk &operator << (Element t)
00139 {
00140 if (!(t.dist < th)) return *this;
00141 if (K == 0) {
00142 this->push_back(t);
00143 return *this;
00144 }
00145
00146 unsigned i = this->size() - 1;
00147 unsigned j;
00148 for (;;)
00149 {
00150 if (i == 0) break;
00151 j = i - 1;
00152 if (this->at(j).match(t)) return *this;
00153 if (this->at(j) < t) break;
00154 i = j;
00155 }
00156
00157
00158 j = this->size() - 1;
00159 for (;;)
00160 {
00161 if (j == i) break;
00162 this->at(j) = this->at(j-1);
00163 --j;
00164 }
00165 this->at(i) = t;
00166 th = this->back().dist;
00167 return *this;
00168 }
00169
00171
00172 float recall (const Topk<KEY> &topk ) const
00173 {
00174 unsigned matched = 0;
00175 for (typename Base::const_iterator ii = this->begin(); ii != this->end(); ++ii)
00176 {
00177 for (typename Base::const_iterator jj = topk.begin(); jj != topk.end(); ++jj)
00178 {
00179 if (ii->match(*jj))
00180 {
00181 matched++;
00182 break;
00183 }
00184 }
00185 }
00186 return float(matched)/float(this->size());
00187 }
00188
00189 unsigned getK () const {
00190 return K;
00191 }
00192 };
00193
00195
00198 template <typename ACCESSOR, typename METRIC>
00199 class TopkScanner {
00200 public:
00202 typedef typename ACCESSOR::Key Key;
00204 typedef typename ACCESSOR::Value Value;
00205
00207
00222 TopkScanner(const ACCESSOR &accessor, const METRIC &metric, unsigned K, float R = std::numeric_limits<float>::max())
00223 : accessor_(accessor), metric_(metric), K_(K), R_(R) {
00224 }
00225
00227
00230 void reset (Value query) {
00231 query_ = query;
00232 accessor_.reset();
00233 topk_.reset(K_, R_);
00234 cnt_ = 0;
00235 }
00236
00238 unsigned cnt () const {
00239 return cnt_;
00240 }
00241
00243 const Topk<Key> &topk () const {
00244 return topk_;
00245 }
00246
00248
00251 void operator () (Key key) {
00252 if (accessor_.mark(key)) {
00253 ++cnt_;
00254 topk_ << typename Topk<Key>::Element(key, metric_(query_, accessor_(key)));
00255 }
00256 }
00257 private:
00258 ACCESSOR accessor_;
00259 METRIC metric_;
00260 unsigned K_;
00261 float R_;
00262 Topk<Key> topk_;
00263 Value query_;
00264 unsigned cnt_;
00265 };
00266
00267 }
00268
00269 #endif
00270