include/lshkit/mplsh.h

Go to the documentation of this file.
00001 /* 
00002     Copyright (C) 2008 Wei Dong <wdong@princeton.edu>. All Rights Reserved.
00003   
00004     This file is part of LSHKIT.
00005   
00006     LSHKIT is free software: you can redistribute it and/or modify
00007     it under the terms of the GNU General Public License as published by
00008     the Free Software Foundation, either version 3 of the License, or
00009     (at your option) any later version.
00010 
00011     LSHKIT is distributed in the hope that it will be useful,
00012     but WITHOUT ANY WARRANTY; without even the implied warranty of
00013     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014     GNU General Public License for more details.
00015 
00016     You should have received a copy of the GNU General Public License
00017     along with LSHKIT.  If not, see <http://www.gnu.org/licenses/>.
00018 */
00019 
00127 #ifndef __LSHKIT_PROBE__
00128 #define __LSHKIT_PROBE__
00129 
00130 #include <lshkit/common.h>
00131 #include <lshkit/lsh.h>
00132 #include <lshkit/composite.h>
00133 #include <lshkit/metric.h>
00134 #include <lshkit/lsh-index.h>
00135 #include <lshkit/mplsh-model.h>
00136 #include <lshkit/topk.h>
00137 
00138 namespace lshkit
00139 {
00140 
00142 struct Probe
00143 {
00144     unsigned mask;
00145     unsigned shift;
00146     float score;
00147     unsigned reserve;
00148     bool operator < (const Probe &p) const { return score < p.score; }
00149     Probe operator + (const Probe &m) const
00150     {
00151         Probe ret;
00152         ret.mask = mask | m.mask;
00153         ret.shift = shift | m.shift;
00154         ret.score = score + m.score;
00155         return ret;
00156     }
00157     bool conflict (const Probe &m)
00158     {
00159         return (mask & m.mask) != 0;
00160     }
00161     static const unsigned MAX_M = 20;
00162     static const unsigned MAX_T = 200;
00163 }; 
00164 
00166 typedef std::vector<Probe> ProbeSequence;
00167 
00169 void GenProbeSequenceTemplate (ProbeSequence &seq, unsigned M, unsigned T);
00170 
00172 class ProbeSequenceTemplates: public std::vector<ProbeSequence>
00173 {
00174 public:
00175     ProbeSequenceTemplates(unsigned max_M, unsigned max_T)
00176         : std::vector<ProbeSequence>(max_M + 1)
00177     {
00178         for (unsigned i = 1; i <= max_M; ++i)
00179         {
00180             GenProbeSequenceTemplate(at(i), i, max_T);
00181         }
00182     }
00183 };
00184 
00185 extern ProbeSequenceTemplates __probeSequenceTemplates;
00186 
00188 class MultiProbeLsh: public RepeatHash<GaussianLsh> 
00189 {
00190     unsigned H_;
00191 public:
00192     typedef RepeatHash<GaussianLsh> Super;
00193     typedef Super::Domain Domain;
00194 
00205     struct Parameter : public Super::Parameter {
00206 
00207         unsigned range;
00208 
00209         template<class Archive>
00210         void serialize(Archive & ar, const unsigned int version)
00211         {
00212             ar & range;
00213             ar & repeat;
00214             ar & dim;
00215             ar & W;
00216         }
00217     };
00218 
00219     MultiProbeLsh () {}
00220 
00221     template <typename RNG>
00222     void reset(const Parameter &param, RNG &rng)
00223     {
00224         H_ = param.range;
00225         Super::reset(param, rng);
00226     }
00227 
00228     template <typename RNG>
00229     MultiProbeLsh(const Parameter &param, RNG &rng)
00230     {
00231         H_ = param.range;
00232         Super::reset(param, rng);
00233     }
00234 
00235     unsigned getRange () const
00236     {
00237         return H_;
00238     }
00239 
00240     unsigned operator () (Domain obj) const
00241     {
00242         return Super::operator ()(obj) % H_;
00243     }
00244 
00245     template<class Archive>
00246     void serialize(Archive & ar, const unsigned int version)
00247     {
00248         Super::serialize(ar, version);
00249         ar & H_;
00250     }
00251 
00252     void genProbeSequence (Domain obj, std::vector<unsigned> &seq, unsigned T) const;
00253 };
00254 
00255 
00257 template <typename KEY>
00258 class MultiProbeLshIndex: public LshIndex<MultiProbeLsh, KEY>
00259 {
00260 public:
00261     typedef LshIndex<MultiProbeLsh, KEY> Super;
00265     typedef typename Super::Parameter Parameter;
00266 
00267 private:
00268 
00269     Parameter param_;
00270     MultiProbeLshRecallTable recall_;
00271 
00272 public: 
00273     typedef typename Super::Domain Domain;
00274     typedef KEY Key;
00275 
00277     MultiProbeLshIndex() {
00278     } 
00279 
00281 
00288     template <typename Engine>
00289     void init (const Parameter &param, Engine &engine, unsigned L) {
00290         Super::init(param, engine, L);
00291         param_ = param;
00292         // we are going to normalize the distance by window size, so here we pass W = 1.0.
00293         // We tune adaptive probing for KNN distance range [0.0001W, 20W].
00294         recall_.reset(MultiProbeLshModel(Super::lshs_.size(), 1.0, param_.repeat, Probe::MAX_T), 200, 0.0001, 20.0);
00295     }
00296 
00298     void load (std::istream &ar) 
00299     {
00300         Super::load(ar);
00301         param_.serialize(ar, 0);
00302         recall_.load(ar);
00303         verify(ar);
00304     }
00305 
00307     void save (std::ostream &ar)
00308     {
00309         Super::save(ar);
00310         param_.serialize(ar, 0);
00311         recall_.save(ar);
00312         verify(ar);
00313     }
00314 
00316 
00320     template <typename SCANNER>
00321     void query (Domain obj, unsigned T, SCANNER &scanner)
00322     {
00323         std::vector<unsigned> seq;
00324         for (unsigned i = 0; i < Super::lshs_.size(); ++i) {
00325             Super::lshs_[i].genProbeSequence(obj, seq, T);
00326             for (unsigned j = 0; j < seq.size(); ++j) {
00327                 typename Super::Bin &bin = Super::tables_[i][seq[j]];
00328                 BOOST_FOREACH(Key key, bin) {
00329                     scanner(key);
00330                 }
00331             }
00332         }
00333     }
00334 
00336 
00342     template <typename SCANNER>
00343     void query_recall (Domain obj, float recall, SCANNER &scanner) const
00344     {
00345         unsigned K = scanner.topk().getK();
00346         if (K == 0) throw std::logic_error("CANNOT ACCEPT R-NN QUERY");
00347         if (scanner.topk().size() < K) throw std::logic_error("ERROR");
00348         unsigned L = Super::lshs_.size();
00349         std::vector<std::vector<unsigned> > seqs(L);
00350         for (unsigned i = 0; i < L; ++i) {
00351             Super::lshs_[i].genProbeSequence(obj, seqs[i], Probe::MAX_T);
00352         }
00353 
00354         for (unsigned j = 0; j < Probe::MAX_T; ++j) {
00355             if (j >= seqs[0].size()) break;
00356             for (unsigned i = 0; i < L; ++i) {
00357                 BOOST_FOREACH(Key key, Super::tables_[i][seqs[i][j]]) {
00358                     scanner(key);
00359                 }
00360             }
00361             float r = 0.0;
00362             for (unsigned i = 0; i < K; ++i) {
00363                 r += recall_.lookup(scanner.topk()[i].dist / param_.W, j + 1);
00364             }
00365             r /= K;
00366             if (r >= recall) break;
00367         }
00368     }
00369 };
00370 
00371 }
00372 
00373 
00374 #endif
00375 

Get LSHKIT at SourceForge.net. Fast, secure and Free Open Source software downloads doxygen