00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00127 #ifndef __LSHKIT_PROBE__
00128 #define __LSHKIT_PROBE__
00129
00130 #include <lshkit/common.h>
00131 #include <lshkit/lsh.h>
00132 #include <lshkit/composite.h>
00133 #include <lshkit/metric.h>
00134 #include <lshkit/lsh-index.h>
00135 #include <lshkit/mplsh-model.h>
00136 #include <lshkit/topk.h>
00137
00138 namespace lshkit
00139 {
00140
00142 struct Probe
00143 {
00144 unsigned mask;
00145 unsigned shift;
00146 float score;
00147 unsigned reserve;
00148 bool operator < (const Probe &p) const { return score < p.score; }
00149 Probe operator + (const Probe &m) const
00150 {
00151 Probe ret;
00152 ret.mask = mask | m.mask;
00153 ret.shift = shift | m.shift;
00154 ret.score = score + m.score;
00155 return ret;
00156 }
00157 bool conflict (const Probe &m)
00158 {
00159 return (mask & m.mask) != 0;
00160 }
00161 static const unsigned MAX_M = 20;
00162 static const unsigned MAX_T = 200;
00163 };
00164
00166 typedef std::vector<Probe> ProbeSequence;
00167
00169 void GenProbeSequenceTemplate (ProbeSequence &seq, unsigned M, unsigned T);
00170
00172 class ProbeSequenceTemplates: public std::vector<ProbeSequence>
00173 {
00174 public:
00175 ProbeSequenceTemplates(unsigned max_M, unsigned max_T)
00176 : std::vector<ProbeSequence>(max_M + 1)
00177 {
00178 for (unsigned i = 1; i <= max_M; ++i)
00179 {
00180 GenProbeSequenceTemplate(at(i), i, max_T);
00181 }
00182 }
00183 };
00184
00185 extern ProbeSequenceTemplates __probeSequenceTemplates;
00186
00188 class MultiProbeLsh: public RepeatHash<GaussianLsh>
00189 {
00190 unsigned H_;
00191 public:
00192 typedef RepeatHash<GaussianLsh> Super;
00193 typedef Super::Domain Domain;
00194
00205 struct Parameter : public Super::Parameter {
00206
00207 unsigned range;
00208
00209 template<class Archive>
00210 void serialize(Archive & ar, const unsigned int version)
00211 {
00212 ar & range;
00213 ar & repeat;
00214 ar & dim;
00215 ar & W;
00216 }
00217 };
00218
00219 MultiProbeLsh () {}
00220
00221 template <typename RNG>
00222 void reset(const Parameter ¶m, RNG &rng)
00223 {
00224 H_ = param.range;
00225 Super::reset(param, rng);
00226 }
00227
00228 template <typename RNG>
00229 MultiProbeLsh(const Parameter ¶m, RNG &rng)
00230 {
00231 H_ = param.range;
00232 Super::reset(param, rng);
00233 }
00234
00235 unsigned getRange () const
00236 {
00237 return H_;
00238 }
00239
00240 unsigned operator () (Domain obj) const
00241 {
00242 return Super::operator ()(obj) % H_;
00243 }
00244
00245 template<class Archive>
00246 void serialize(Archive & ar, const unsigned int version)
00247 {
00248 Super::serialize(ar, version);
00249 ar & H_;
00250 }
00251
00252 void genProbeSequence (Domain obj, std::vector<unsigned> &seq, unsigned T) const;
00253 };
00254
00255
00257 template <typename KEY>
00258 class MultiProbeLshIndex: public LshIndex<MultiProbeLsh, KEY>
00259 {
00260 public:
00261 typedef LshIndex<MultiProbeLsh, KEY> Super;
00265 typedef typename Super::Parameter Parameter;
00266
00267 private:
00268
00269 Parameter param_;
00270 MultiProbeLshRecallTable recall_;
00271
00272 public:
00273 typedef typename Super::Domain Domain;
00274 typedef KEY Key;
00275
00277 MultiProbeLshIndex() {
00278 }
00279
00281
00288 template <typename Engine>
00289 void init (const Parameter ¶m, Engine &engine, unsigned L) {
00290 Super::init(param, engine, L);
00291 param_ = param;
00292
00293
00294 recall_.reset(MultiProbeLshModel(Super::lshs_.size(), 1.0, param_.repeat, Probe::MAX_T), 200, 0.0001, 20.0);
00295 }
00296
00298 void load (std::istream &ar)
00299 {
00300 Super::load(ar);
00301 param_.serialize(ar, 0);
00302 recall_.load(ar);
00303 verify(ar);
00304 }
00305
00307 void save (std::ostream &ar)
00308 {
00309 Super::save(ar);
00310 param_.serialize(ar, 0);
00311 recall_.save(ar);
00312 verify(ar);
00313 }
00314
00316
00320 template <typename SCANNER>
00321 void query (Domain obj, unsigned T, SCANNER &scanner)
00322 {
00323 std::vector<unsigned> seq;
00324 for (unsigned i = 0; i < Super::lshs_.size(); ++i) {
00325 Super::lshs_[i].genProbeSequence(obj, seq, T);
00326 for (unsigned j = 0; j < seq.size(); ++j) {
00327 typename Super::Bin &bin = Super::tables_[i][seq[j]];
00328 BOOST_FOREACH(Key key, bin) {
00329 scanner(key);
00330 }
00331 }
00332 }
00333 }
00334
00336
00342 template <typename SCANNER>
00343 void query_recall (Domain obj, float recall, SCANNER &scanner) const
00344 {
00345 unsigned K = scanner.topk().getK();
00346 if (K == 0) throw std::logic_error("CANNOT ACCEPT R-NN QUERY");
00347 if (scanner.topk().size() < K) throw std::logic_error("ERROR");
00348 unsigned L = Super::lshs_.size();
00349 std::vector<std::vector<unsigned> > seqs(L);
00350 for (unsigned i = 0; i < L; ++i) {
00351 Super::lshs_[i].genProbeSequence(obj, seqs[i], Probe::MAX_T);
00352 }
00353
00354 for (unsigned j = 0; j < Probe::MAX_T; ++j) {
00355 if (j >= seqs[0].size()) break;
00356 for (unsigned i = 0; i < L; ++i) {
00357 BOOST_FOREACH(Key key, Super::tables_[i][seqs[i][j]]) {
00358 scanner(key);
00359 }
00360 }
00361 float r = 0.0;
00362 for (unsigned i = 0; i < K; ++i) {
00363 r += recall_.lookup(scanner.topk()[i].dist / param_.W, j + 1);
00364 }
00365 r /= K;
00366 if (r >= recall) break;
00367 }
00368 }
00369 };
00370
00371 }
00372
00373
00374 #endif
00375