00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #ifndef __LSHKIT_FOREST__
00022 #define __LSHKIT_FOREST__
00023
00055 #include <algorithm>
00056 #include <lshkit/common.h>
00057 #include <lshkit/topk.h>
00058
00059 namespace lshkit {
00060
00062
00066 template <typename LSH, typename KEY>
00067 class ForestIndex
00068 {
00069 BOOST_CONCEPT_ASSERT((LshConcept<LSH>));
00070 public:
00071 typedef typename LSH::Parameter Parameter;
00072 typedef typename LSH::Domain Domain;
00073 typedef KEY Key;
00074
00075 private:
00076
00077 struct Tree
00078 {
00079 std::vector<LSH> lsh;
00080
00081 struct Node
00082 {
00083 size_t size;
00084 std::vector<Node *> children;
00085 std::vector<Key> data;
00086
00087 Node () : size(0) {
00088 }
00089
00090 ~Node () {
00091 BOOST_FOREACH(Node *n, children) {
00092 if (n != 0) delete n;
00093 }
00094 }
00095
00096 bool empty () const {
00097 return size == 0;
00098 }
00099
00100 template <typename ACCESSOR>
00101 void insert (Tree *tree, unsigned depth, Key key, ACCESSOR &acc) {
00102 ++size;
00103 if (children.empty()) {
00104 data.push_back(key);
00105 if (depth < tree->lsh.size() && data.size() > 1) {
00106
00107 LSH &lsh = tree->lsh[depth];
00108 if (lsh.getRange() == 0) throw std::logic_error("LSH WITH UNLIMITED HASH VALUE CANNOT BE USED IN LSH FOREST.");
00109 children.resize(lsh.getRange());
00110 BOOST_FOREACH(Key key, data) {
00111 unsigned h = lsh(acc(key));
00112 if (children[h] == 0) {
00113 children[h] = new Node();
00114 }
00115 children[h]->insert(tree, depth+1, key, acc);
00116 }
00117 data.clear();
00118 }
00119 }
00120 else {
00121 unsigned h = tree->lsh[depth](acc(key));
00122 if (children[h] == 0) {
00123 children[h] = new Node();
00124 }
00125 children[h]->insert(tree, depth+1, key, acc);
00126 }
00127 }
00128
00129 template <typename SCANNER>
00130 void scan (Domain val, SCANNER &scanner) const {
00131 if (!children.empty()) {
00132 BOOST_FOREACH(const Node *n, children) {
00133 if (n != 0) {
00134 n->scan(val, scanner);
00135 }
00136 }
00137 }
00138 if (!data.empty()) {
00139 BOOST_FOREACH(Key key, data) {
00140 scanner(key);
00141 }
00142 }
00143 }
00144 } *root;
00145
00146 public:
00147
00148 Tree (): root(0) {
00149 }
00150
00151 template <typename ENGINE>
00152 void reset (const Parameter ¶m, ENGINE &engine, unsigned depth)
00153 {
00154 lsh.resize(depth);
00155 BOOST_FOREACH(LSH &h, lsh) {
00156 h.reset(param, engine);
00157 }
00158 root = new Node();
00159 }
00160
00161 ~Tree ()
00162 {
00163 if (root != 0) delete root;
00164 }
00165
00166 template <typename ACCESSOR>
00167 void insert (Key key, ACCESSOR &acc)
00168 {
00169 root->insert(this, 0, key, acc);
00170 }
00171
00172 void lookup (Domain val, std::vector<const Node *> *nodes) const {
00173 const Node *cur = root;
00174 unsigned depth = 0;
00175 nodes->clear();
00176 for (;;) {
00177 nodes->push_back(cur);
00178 if (cur->children.empty()) break;
00179 unsigned h = lsh[depth](val);
00180 cur = cur->children[h];
00181 if (cur == 0) break;
00182 ++depth;
00183 }
00184 }
00185 };
00186
00187 friend struct Tree;
00188
00189 std::vector<Tree> trees;
00190
00191
00192 public:
00193 ForestIndex()
00194 {
00195 }
00196
00198
00204 template <typename Engine>
00205 void init(const Parameter ¶m, Engine &engine, unsigned L, unsigned depth)
00206 {
00207 trees.resize(L);
00208 BOOST_FOREACH(Tree &t, trees) {
00209 t.reset(param, engine, depth);
00210 }
00211 }
00212
00214
00221 template <typename ACCESSOR>
00222 void insert (Key key, ACCESSOR &acc)
00223 {
00224 BOOST_FOREACH(Tree &t, trees) {
00225 t.insert(key, acc);
00226 }
00227 }
00228
00230
00235 template <typename SCANNER>
00236 void query (Domain val, unsigned M, SCANNER &scanner) const
00237 {
00238 std::vector<std::vector<const typename Tree::Node *> > list(trees.size());
00239 for (unsigned i = 0; i < trees.size(); ++i) {
00240 trees[i].lookup(val, &list[i]);
00241 }
00242
00243 unsigned d = 0;
00244 for (;;) {
00245 unsigned s = 0;
00246 for (unsigned i = 0; i < list.size(); ++i) {
00247 if (d < list[i].size()) {
00248 s += list[i][d]->size;
00249 }
00250 }
00251 if (s < M) break;
00252 ++d;
00253 }
00254
00255 if (d > 0) --d;
00256
00257
00258 for (unsigned i = 0; i < list.size(); ++i) {
00259 if (d < list[i].size()) {
00260 list[i][d]->scan(val, scanner);
00261 }
00262 }
00263 }
00264 };
00265
00266
00267 }
00268
00269 #endif
00270