1/* 2 * Copyright (c) 2012 Google 3 * All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are 7 * met: redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer; 9 * redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution; 12 * neither the name of the copyright holders nor the names of its 13 * contributors may be used to endorse or promote products derived from 14 * this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 20 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 21 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 22 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 24 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 * 28 * Authors: Gabe Black 29 */ 30 31#ifndef __BASE_TRIE_HH__ 32#define __BASE_TRIE_HH__ 33 34#include <cassert> 35#include <iostream> 36 37#include "base/cprintf.hh" 38#include "base/logging.hh" 39#include "base/types.hh" 40 41// Key has to be an integral type. 42template <class Key, class Value> 43class Trie 44{ 45 protected: 46 struct Node 47 { 48 Key key; 49 Key mask; 50 51 bool 52 matches(Key test) 53 { 54 return (test & mask) == key; 55 } 56 57 Value *value; 58 59 Node *parent; 60 Node *kids[2]; 61 62 Node(Key _key, Key _mask, Value *_val) : 63 key(_key & _mask), mask(_mask), value(_val), 64 parent(NULL) 65 { 66 kids[0] = NULL; 67 kids[1] = NULL; 68 } 69 70 void 71 clear() 72 { 73 if (kids[1]) { 74 kids[1]->clear(); 75 delete kids[1]; 76 kids[1] = NULL; 77 } 78 if (kids[0]) { 79 kids[0]->clear(); 80 delete kids[0]; 81 kids[0] = NULL; 82 } 83 } 84 85 void 86 dump(std::ostream &os, int level) 87 { 88 for (int i = 1; i < level; i++) { 89 ccprintf(os, "|"); 90 } 91 if (level == 0) 92 ccprintf(os, "Root "); 93 else 94 ccprintf(os, "+ "); 95 ccprintf(os, "(%p, %p, %#X, %#X, %p)\n", 96 parent, this, key, mask, value); 97 if (kids[0]) 98 kids[0]->dump(os, level + 1); 99 if (kids[1]) 100 kids[1]->dump(os, level + 1); 101 } 102 }; 103 104 protected: 105 Node head; 106 107 public: 108 typedef Node *Handle; 109 110 Trie() : head(0, 0, NULL) 111 {} 112 113 static const unsigned MaxBits = sizeof(Key) * 8; 114 115 private: 116 /** 117 * A utility method which checks whether the key being looked up lies 118 * beyond the Node being examined. If so, it returns true and advances the 119 * node being examined. 120 * @param parent The node we're currently "at", which can be updated. 121 * @param kid The node we may want to move to. 122 * @param key The key we're looking for. 123 * @param new_mask The mask to use when matching against the key. 124 * @return Whether the current Node was advanced. 125 */ 126 bool 127 goesAfter(Node **parent, Node *kid, Key key, Key new_mask) 128 { 129 if (kid && kid->matches(key) && (kid->mask & new_mask) == kid->mask) { 130 *parent = kid; 131 return true; 132 } else { 133 return false; 134 } 135 } 136 137 /** 138 * A utility method which extends a mask value one more bit towards the 139 * lsb. This is almost just a signed right shift, except that the shifted 140 * in bits are technically undefined. This is also slightly complicated by 141 * the zero case. 142 * @param orig The original mask to extend. 143 * @return The extended mask. 144 */ 145 Key 146 extendMask(Key orig) 147 { 148 // Just in case orig was 0. 149 const Key msb = ULL(1) << (MaxBits - 1); 150 return orig | (orig >> 1) | msb; 151 } 152 153 /** 154 * Method which looks up the Handle corresponding to a particular key. This 155 * is useful if you want to delete the Handle corresponding to a key since 156 * the "remove" function takes a Handle as its argument. 157 * @param key The key to look up. 158 * @return The first Handle matching this key, or NULL if none was found. 159 */ 160 Handle 161 lookupHandle(Key key) 162 { 163 Node *node = &head; 164 while (node) { 165 if (node->value) 166 return node; 167 168 if (node->kids[0] && node->kids[0]->matches(key)) 169 node = node->kids[0]; 170 else if (node->kids[1] && node->kids[1]->matches(key)) 171 node = node->kids[1]; 172 else 173 node = NULL; 174 } 175 176 return NULL; 177 } 178 179 public: 180 /** 181 * Method which inserts a key/value pair into the trie. 182 * @param key The key which can later be used to look up this value. 183 * @param width How many bits of the key (from msb) should be used. 184 * @param val A pointer to the value to store in the trie. 185 * @return A Handle corresponding to this value. 186 */ 187 Handle 188 insert(Key key, unsigned width, Value *val) 189 { 190 // We use NULL value pointers to mark internal nodes of the trie, so 191 // we don't allow inserting them as real values. 192 assert(val); 193 194 // Build a mask which masks off all the bits we don't care about. 195 Key new_mask = ~(Key)0; 196 if (width < MaxBits) 197 new_mask <<= (MaxBits - width); 198 // Use it to tidy up the key. 199 key &= new_mask; 200 201 // Walk past all the nodes this new node will be inserted after. They 202 // can be ignored for the purposes of this function. 203 Node *node = &head; 204 while (goesAfter(&node, node->kids[0], key, new_mask) || 205 goesAfter(&node, node->kids[1], key, new_mask)) 206 {} 207 assert(node); 208 209 Key cur_mask = node->mask; 210 // If we're already where the value needs to be... 211 if (cur_mask == new_mask) { 212 assert(!node->value); 213 node->value = val; 214 return node; 215 } 216 217 for (unsigned int i = 0; i < 2; i++) { 218 Node *&kid = node->kids[i]; 219 Node *new_node; 220 if (!kid) { 221 // No kid. Add a new one. 222 new_node = new Node(key, new_mask, val); 223 new_node->parent = node; 224 kid = new_node; 225 return new_node; 226 } 227 228 // Walk down the leg until something doesn't match or we run out 229 // of bits. 230 Key last_mask; 231 bool done; 232 do { 233 last_mask = cur_mask; 234 cur_mask = extendMask(cur_mask); 235 done = ((key & cur_mask) != (kid->key & cur_mask)) || 236 last_mask == new_mask; 237 } while (!done); 238 cur_mask = last_mask; 239 240 // If this isn't the right leg to go down at all, skip it. 241 if (cur_mask == node->mask) 242 continue; 243 244 // At the point we walked to above, add a new node. 245 new_node = new Node(key, cur_mask, NULL); 246 new_node->parent = node; 247 kid->parent = new_node; 248 new_node->kids[0] = kid; 249 kid = new_node; 250 251 // If we ran out of bits, the value goes right here. 252 if (cur_mask == new_mask) { 253 new_node->value = val; 254 return new_node; 255 } 256 257 // Still more bits to deal with, so add a new node for that path. 258 new_node = new Node(key, new_mask, val); 259 new_node->parent = kid; 260 kid->kids[1] = new_node; 261 return new_node; 262 } 263 264 panic("Reached the end of the Trie insert function!\n"); 265 return NULL; 266 } 267 268 /** 269 * Method which looks up the Value corresponding to a particular key. 270 * @param key The key to look up. 271 * @return The first Value matching this key, or NULL if none was found. 272 */ 273 Value * 274 lookup(Key key) 275 { 276 Node *node = lookupHandle(key); 277 if (node) 278 return node->value; 279 else 280 return NULL; 281 } 282 283 /** 284 * Method to delete a value from the trie. 285 * @param node A Handle to remove. 286 * @return The Value pointer from the removed entry. 287 */ 288 Value * 289 remove(Handle handle) 290 { 291 Node *node = handle; 292 Value *val = node->value; 293 if (node->kids[1]) { 294 assert(node->value); 295 node->value = NULL; 296 return val; 297 } 298 if (!node->parent) 299 panic("Trie: Can't remove root node.\n"); 300 301 Node *parent = node->parent; 302 303 // If there's a kid, fix up it's parent pointer. 304 if (node->kids[0]) 305 node->kids[0]->parent = parent; 306 // Figure out which kid we are, and update our parent's pointers. 307 if (parent->kids[0] == node) 308 parent->kids[0] = node->kids[0]; 309 else if (parent->kids[1] == node) 310 parent->kids[1] = node->kids[0]; 311 else 312 panic("Trie: Inconsistent parent/kid relationship.\n"); 313 // Make sure if the parent only has one kid, it's kid[0]. 314 if (parent->kids[1] && !parent->kids[0]) { 315 parent->kids[0] = parent->kids[1]; 316 parent->kids[1] = NULL; 317 } 318 319 // If the parent has less than two kids and no cargo and isn't the 320 // root, delete it too. 321 if (!parent->kids[1] && !parent->value && parent->parent) 322 remove(parent); 323 delete node; 324 return val; 325 } 326 327 /** 328 * Method to lookup a value from the trie and then delete it. 329 * @param key The key to look up and then remove. 330 * @return The Value pointer from the removed entry, if any. 331 */ 332 Value * 333 remove(Key key) 334 { 335 Handle handle = lookupHandle(key); 336 if (!handle) 337 return NULL; 338 return remove(handle); 339 } 340 341 /** 342 * A method which removes all key/value pairs from the trie. This is more 343 * efficient than trying to remove elements individually. 344 */ 345 void 346 clear() 347 { 348 head.clear(); 349 } 350 351 /** 352 * A debugging method which prints the contents of this trie. 353 * @param title An identifying title to put in the dump header. 354 */ 355 void 356 dump(const char *title, std::ostream &os=std::cout) 357 { 358 ccprintf(os, "**************************************************\n"); 359 ccprintf(os, "*** Start of Trie: %s\n", title); 360 ccprintf(os, "*** (parent, me, key, mask, value pointer)\n"); 361 ccprintf(os, "**************************************************\n"); 362 head.dump(os, 0); 363 } 364}; 365 366#endif 367