trie.hh revision 8952:6188362beee1
1/*
2 * Copyright (c) 2012 Google
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are
7 * met: redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer;
9 * redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution;
12 * neither the name of the copyright holders nor the names of its
13 * contributors may be used to endorse or promote products derived from
14 * this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
20 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 *
28 * Authors: Gabe Black
29 */
30
31#ifndef __BASE_TRIE_HH__
32#define __BASE_TRIE_HH__
33
34#include "base/cprintf.hh"
35#include "base/misc.hh"
36#include "base/types.hh"
37
38// Key has to be an integral type.
39template <class Key, class Value>
40class Trie
41{
42  protected:
43    struct Node
44    {
45        Key key;
46        Key mask;
47
48        bool
49        matches(Key test)
50        {
51            return (test & mask) == key;
52        }
53
54        Value *value;
55
56        Node *parent;
57        Node *kids[2];
58
59        Node(Key _key, Key _mask, Value *_val) :
60            key(_key & _mask), mask(_mask), value(_val),
61            parent(NULL)
62        {
63            kids[0] = NULL;
64            kids[1] = NULL;
65        }
66
67        void
68        clear()
69        {
70            if (kids[1]) {
71                kids[1]->clear();
72                delete kids[1];
73                kids[1] = NULL;
74            }
75            if (kids[0]) {
76                kids[0]->clear();
77                delete kids[0];
78                kids[0] = NULL;
79            }
80        }
81
82        void
83        dump(int level)
84        {
85            for (int i = 1; i < level; i++) {
86                cprintf("|");
87            }
88            if (level == 0)
89                cprintf("Root ");
90            else
91                cprintf("+ ");
92            cprintf("(%p, %p, %#X, %#X, %p)\n", parent, this, key, mask, value);
93            if (kids[0])
94                kids[0]->dump(level + 1);
95            if (kids[1])
96                kids[1]->dump(level + 1);
97        }
98    };
99
100  protected:
101    Node head;
102
103  public:
104    typedef Node *Handle;
105
106    Trie() : head(0, 0, NULL)
107    {}
108
109    static const unsigned MaxBits = sizeof(Key) * 8;
110
111  private:
112    /**
113     * A utility method which checks whether the key being looked up lies
114     * beyond the Node being examined. If so, it returns true and advances the
115     * node being examined.
116     * @param parent The node we're currently "at", which can be updated.
117     * @param kid The node we may want to move to.
118     * @param key The key we're looking for.
119     * @param new_mask The mask to use when matching against the key.
120     * @return Whether the current Node was advanced.
121     */
122    bool
123    goesAfter(Node **parent, Node *kid, Key key, Key new_mask)
124    {
125        if (kid && kid->matches(key) && (kid->mask & new_mask) == kid->mask) {
126            *parent = kid;
127            return true;
128        } else {
129            return false;
130        }
131    }
132
133    /**
134     * A utility method which extends a mask value one more bit towards the
135     * lsb. This is almost just a signed right shift, except that the shifted
136     * in bits are technically undefined. This is also slightly complicated by
137     * the zero case.
138     * @param orig The original mask to extend.
139     * @return The extended mask.
140     */
141    Key
142    extendMask(Key orig)
143    {
144        // Just in case orig was 0.
145        const Key msb = ULL(1) << (MaxBits - 1);
146        return orig | (orig >> 1) | msb;
147    }
148
149    /**
150     * Method which looks up the Handle corresponding to a particular key. This
151     * is useful if you want to delete the Handle corresponding to a key since
152     * the "remove" function takes a Handle as its argument.
153     * @param key The key to look up.
154     * @return The first Handle matching this key, or NULL if none was found.
155     */
156    Handle
157    lookupHandle(Key key)
158    {
159        Node *node = &head;
160        while (node) {
161            if (node->value)
162                return node;
163
164            if (node->kids[0] && node->kids[0]->matches(key))
165                node = node->kids[0];
166            else if (node->kids[1] && node->kids[1]->matches(key))
167                node = node->kids[1];
168            else
169                node = NULL;
170        }
171
172        return NULL;
173    }
174
175  public:
176    /**
177     * Method which inserts a key/value pair into the trie.
178     * @param key The key which can later be used to look up this value.
179     * @param width How many bits of the key (from msb) should be used.
180     * @param val A pointer to the value to store in the trie.
181     * @return A Handle corresponding to this value.
182     */
183    Handle
184    insert(Key key, unsigned width, Value *val)
185    {
186        // Build a mask which masks off all the bits we don't care about.
187        Key new_mask = ~(Key)0;
188        if (width < MaxBits)
189            new_mask <<= (MaxBits - width);
190        // Use it to tidy up the key.
191        key &= new_mask;
192
193        // Walk past all the nodes this new node will be inserted after. They
194        // can be ignored for the purposes of this function.
195        Node *node = &head;
196        while (goesAfter(&node, node->kids[0], key, new_mask) ||
197               goesAfter(&node, node->kids[1], key, new_mask))
198        {}
199        assert(node);
200
201        Key cur_mask = node->mask;
202        // If we're already where the value needs to be...
203        if (cur_mask == new_mask) {
204            assert(!node->value);
205            node->value = val;
206            return node;
207        }
208
209        for (unsigned int i = 0; i < 2; i++) {
210            Node *&kid = node->kids[i];
211            Node *new_node;
212            if (!kid) {
213                // No kid. Add a new one.
214                new_node = new Node(key, new_mask, val);
215                new_node->parent = node;
216                kid = new_node;
217                return new_node;
218            }
219
220            // Walk down the leg until something doesn't match or we run out
221            // of bits.
222            Key last_mask;
223            bool done;
224            do {
225                last_mask = cur_mask;
226                cur_mask = extendMask(cur_mask);
227                done = ((key & cur_mask) != (kid->key & cur_mask)) ||
228                    last_mask == new_mask;
229            } while (!done);
230            cur_mask = last_mask;
231
232            // If this isn't the right leg to go down at all, skip it.
233            if (cur_mask == node->mask)
234                continue;
235
236            // At the point we walked to above, add a new node.
237            new_node = new Node(key, cur_mask, NULL);
238            new_node->parent = node;
239            kid->parent = new_node;
240            new_node->kids[0] = kid;
241            kid = new_node;
242
243            // If we ran out of bits, the value goes right here.
244            if (cur_mask == new_mask) {
245                new_node->value = val;
246                return new_node;
247            }
248
249            // Still more bits to deal with, so add a new node for that path.
250            new_node = new Node(key, new_mask, val);
251            new_node->parent = kid;
252            kid->kids[1] = new_node;
253            return new_node;
254        }
255
256        panic("Reached the end of the Trie insert function!\n");
257        return NULL;
258    }
259
260    /**
261     * Method which looks up the Value corresponding to a particular key.
262     * @param key The key to look up.
263     * @return The first Value matching this key, or NULL if none was found.
264     */
265    Value *
266    lookup(Key key)
267    {
268        Node *node = lookupHandle(key);
269        if (node)
270            return node->value;
271        else
272            return NULL;
273    }
274
275    /**
276     * Method to delete a value from the trie.
277     * @param node A Handle to remove.
278     * @return The Value pointer from the removed entry.
279     */
280    Value *
281    remove(Handle handle)
282    {
283        Node *node = handle;
284        Value *val = node->value;
285        if (node->kids[1]) {
286            assert(node->value);
287            node->value = NULL;
288            return val;
289        }
290        if (!node->parent)
291            panic("Trie: Can't remove root node.\n");
292
293        Node *parent = node->parent;
294
295        // If there's a kid, fix up it's parent pointer.
296        if (node->kids[0])
297            node->kids[0]->parent = parent;
298        // Figure out which kid we are, and update our parent's pointers.
299        if (parent->kids[0] == node)
300            parent->kids[0] = node->kids[0];
301        else if (parent->kids[1] == node)
302            parent->kids[1] = node->kids[0];
303        else
304            panic("Trie: Inconsistent parent/kid relationship.\n");
305        // Make sure if the parent only has one kid, it's kid[0].
306        if (parent->kids[1] && !parent->kids[0]) {
307            parent->kids[0] = parent->kids[1];
308            parent->kids[1] = NULL;
309        }
310
311        // If the parent has less than two kids and no cargo and isn't the
312        // root, delete it too.
313        if (!parent->kids[1] && !parent->value && parent->parent)
314            remove(parent);
315        delete node;
316        return val;
317    }
318
319    /**
320     * Method to lookup a value from the trie and then delete it.
321     * @param key The key to look up and then remove.
322     * @return The Value pointer from the removed entry, if any.
323     */
324    Value *
325    remove(Key key)
326    {
327        Handle handle = lookupHandle(key);
328        if (!handle)
329            return NULL;
330        return remove(handle);
331    }
332
333    /**
334     * A method which removes all key/value pairs from the trie. This is more
335     * efficient than trying to remove elements individually.
336     */
337    void
338    clear()
339    {
340        head.clear();
341    }
342
343    /**
344     * A debugging method which prints the contents of this trie.
345     * @param title An identifying title to put in the dump header.
346     */
347    void
348    dump(const char *title)
349    {
350        cprintf("**************************************************\n");
351        cprintf("*** Start of Trie: %s\n", title);
352        cprintf("*** (parent, me, key, mask, value pointer)\n");
353        cprintf("**************************************************\n");
354        head.dump(0);
355    }
356};
357
358#endif
359