trie.hh revision 8959:24b06cbf2d67
1/*
2 * Copyright (c) 2012 Google
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are
7 * met: redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer;
9 * redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution;
12 * neither the name of the copyright holders nor the names of its
13 * contributors may be used to endorse or promote products derived from
14 * this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
20 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 *
28 * Authors: Gabe Black
29 */
30
31#ifndef __BASE_TRIE_HH__
32#define __BASE_TRIE_HH__
33
34#include <cassert>
35
36#include "base/cprintf.hh"
37#include "base/misc.hh"
38#include "base/types.hh"
39
40// Key has to be an integral type.
41template <class Key, class Value>
42class Trie
43{
44  protected:
45    struct Node
46    {
47        Key key;
48        Key mask;
49
50        bool
51        matches(Key test)
52        {
53            return (test & mask) == key;
54        }
55
56        Value *value;
57
58        Node *parent;
59        Node *kids[2];
60
61        Node(Key _key, Key _mask, Value *_val) :
62            key(_key & _mask), mask(_mask), value(_val),
63            parent(NULL)
64        {
65            kids[0] = NULL;
66            kids[1] = NULL;
67        }
68
69        void
70        clear()
71        {
72            if (kids[1]) {
73                kids[1]->clear();
74                delete kids[1];
75                kids[1] = NULL;
76            }
77            if (kids[0]) {
78                kids[0]->clear();
79                delete kids[0];
80                kids[0] = NULL;
81            }
82        }
83
84        void
85        dump(int level)
86        {
87            for (int i = 1; i < level; i++) {
88                cprintf("|");
89            }
90            if (level == 0)
91                cprintf("Root ");
92            else
93                cprintf("+ ");
94            cprintf("(%p, %p, %#X, %#X, %p)\n", parent, this, key, mask, value);
95            if (kids[0])
96                kids[0]->dump(level + 1);
97            if (kids[1])
98                kids[1]->dump(level + 1);
99        }
100    };
101
102  protected:
103    Node head;
104
105  public:
106    typedef Node *Handle;
107
108    Trie() : head(0, 0, NULL)
109    {}
110
111    static const unsigned MaxBits = sizeof(Key) * 8;
112
113  private:
114    /**
115     * A utility method which checks whether the key being looked up lies
116     * beyond the Node being examined. If so, it returns true and advances the
117     * node being examined.
118     * @param parent The node we're currently "at", which can be updated.
119     * @param kid The node we may want to move to.
120     * @param key The key we're looking for.
121     * @param new_mask The mask to use when matching against the key.
122     * @return Whether the current Node was advanced.
123     */
124    bool
125    goesAfter(Node **parent, Node *kid, Key key, Key new_mask)
126    {
127        if (kid && kid->matches(key) && (kid->mask & new_mask) == kid->mask) {
128            *parent = kid;
129            return true;
130        } else {
131            return false;
132        }
133    }
134
135    /**
136     * A utility method which extends a mask value one more bit towards the
137     * lsb. This is almost just a signed right shift, except that the shifted
138     * in bits are technically undefined. This is also slightly complicated by
139     * the zero case.
140     * @param orig The original mask to extend.
141     * @return The extended mask.
142     */
143    Key
144    extendMask(Key orig)
145    {
146        // Just in case orig was 0.
147        const Key msb = ULL(1) << (MaxBits - 1);
148        return orig | (orig >> 1) | msb;
149    }
150
151    /**
152     * Method which looks up the Handle corresponding to a particular key. This
153     * is useful if you want to delete the Handle corresponding to a key since
154     * the "remove" function takes a Handle as its argument.
155     * @param key The key to look up.
156     * @return The first Handle matching this key, or NULL if none was found.
157     */
158    Handle
159    lookupHandle(Key key)
160    {
161        Node *node = &head;
162        while (node) {
163            if (node->value)
164                return node;
165
166            if (node->kids[0] && node->kids[0]->matches(key))
167                node = node->kids[0];
168            else if (node->kids[1] && node->kids[1]->matches(key))
169                node = node->kids[1];
170            else
171                node = NULL;
172        }
173
174        return NULL;
175    }
176
177  public:
178    /**
179     * Method which inserts a key/value pair into the trie.
180     * @param key The key which can later be used to look up this value.
181     * @param width How many bits of the key (from msb) should be used.
182     * @param val A pointer to the value to store in the trie.
183     * @return A Handle corresponding to this value.
184     */
185    Handle
186    insert(Key key, unsigned width, Value *val)
187    {
188        // Build a mask which masks off all the bits we don't care about.
189        Key new_mask = ~(Key)0;
190        if (width < MaxBits)
191            new_mask <<= (MaxBits - width);
192        // Use it to tidy up the key.
193        key &= new_mask;
194
195        // Walk past all the nodes this new node will be inserted after. They
196        // can be ignored for the purposes of this function.
197        Node *node = &head;
198        while (goesAfter(&node, node->kids[0], key, new_mask) ||
199               goesAfter(&node, node->kids[1], key, new_mask))
200        {}
201        assert(node);
202
203        Key cur_mask = node->mask;
204        // If we're already where the value needs to be...
205        if (cur_mask == new_mask) {
206            assert(!node->value);
207            node->value = val;
208            return node;
209        }
210
211        for (unsigned int i = 0; i < 2; i++) {
212            Node *&kid = node->kids[i];
213            Node *new_node;
214            if (!kid) {
215                // No kid. Add a new one.
216                new_node = new Node(key, new_mask, val);
217                new_node->parent = node;
218                kid = new_node;
219                return new_node;
220            }
221
222            // Walk down the leg until something doesn't match or we run out
223            // of bits.
224            Key last_mask;
225            bool done;
226            do {
227                last_mask = cur_mask;
228                cur_mask = extendMask(cur_mask);
229                done = ((key & cur_mask) != (kid->key & cur_mask)) ||
230                    last_mask == new_mask;
231            } while (!done);
232            cur_mask = last_mask;
233
234            // If this isn't the right leg to go down at all, skip it.
235            if (cur_mask == node->mask)
236                continue;
237
238            // At the point we walked to above, add a new node.
239            new_node = new Node(key, cur_mask, NULL);
240            new_node->parent = node;
241            kid->parent = new_node;
242            new_node->kids[0] = kid;
243            kid = new_node;
244
245            // If we ran out of bits, the value goes right here.
246            if (cur_mask == new_mask) {
247                new_node->value = val;
248                return new_node;
249            }
250
251            // Still more bits to deal with, so add a new node for that path.
252            new_node = new Node(key, new_mask, val);
253            new_node->parent = kid;
254            kid->kids[1] = new_node;
255            return new_node;
256        }
257
258        panic("Reached the end of the Trie insert function!\n");
259        return NULL;
260    }
261
262    /**
263     * Method which looks up the Value corresponding to a particular key.
264     * @param key The key to look up.
265     * @return The first Value matching this key, or NULL if none was found.
266     */
267    Value *
268    lookup(Key key)
269    {
270        Node *node = lookupHandle(key);
271        if (node)
272            return node->value;
273        else
274            return NULL;
275    }
276
277    /**
278     * Method to delete a value from the trie.
279     * @param node A Handle to remove.
280     * @return The Value pointer from the removed entry.
281     */
282    Value *
283    remove(Handle handle)
284    {
285        Node *node = handle;
286        Value *val = node->value;
287        if (node->kids[1]) {
288            assert(node->value);
289            node->value = NULL;
290            return val;
291        }
292        if (!node->parent)
293            panic("Trie: Can't remove root node.\n");
294
295        Node *parent = node->parent;
296
297        // If there's a kid, fix up it's parent pointer.
298        if (node->kids[0])
299            node->kids[0]->parent = parent;
300        // Figure out which kid we are, and update our parent's pointers.
301        if (parent->kids[0] == node)
302            parent->kids[0] = node->kids[0];
303        else if (parent->kids[1] == node)
304            parent->kids[1] = node->kids[0];
305        else
306            panic("Trie: Inconsistent parent/kid relationship.\n");
307        // Make sure if the parent only has one kid, it's kid[0].
308        if (parent->kids[1] && !parent->kids[0]) {
309            parent->kids[0] = parent->kids[1];
310            parent->kids[1] = NULL;
311        }
312
313        // If the parent has less than two kids and no cargo and isn't the
314        // root, delete it too.
315        if (!parent->kids[1] && !parent->value && parent->parent)
316            remove(parent);
317        delete node;
318        return val;
319    }
320
321    /**
322     * Method to lookup a value from the trie and then delete it.
323     * @param key The key to look up and then remove.
324     * @return The Value pointer from the removed entry, if any.
325     */
326    Value *
327    remove(Key key)
328    {
329        Handle handle = lookupHandle(key);
330        if (!handle)
331            return NULL;
332        return remove(handle);
333    }
334
335    /**
336     * A method which removes all key/value pairs from the trie. This is more
337     * efficient than trying to remove elements individually.
338     */
339    void
340    clear()
341    {
342        head.clear();
343    }
344
345    /**
346     * A debugging method which prints the contents of this trie.
347     * @param title An identifying title to put in the dump header.
348     */
349    void
350    dump(const char *title)
351    {
352        cprintf("**************************************************\n");
353        cprintf("*** Start of Trie: %s\n", title);
354        cprintf("*** (parent, me, key, mask, value pointer)\n");
355        cprintf("**************************************************\n");
356        head.dump(0);
357    }
358};
359
360#endif
361