1/*
2 * Copyright (c) 2012 Google
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are
7 * met: redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer;
9 * redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution;
12 * neither the name of the copyright holders nor the names of its
13 * contributors may be used to endorse or promote products derived from
14 * this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
20 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 *
28 * Authors: Gabe Black
29 */
30
31#ifndef __BASE_TRIE_HH__
32#define __BASE_TRIE_HH__
33
34#include <cassert>
35#include <iostream>
36
37#include "base/cprintf.hh"
38#include "base/logging.hh"
39#include "base/types.hh"
40
41// Key has to be an integral type.
42template <class Key, class Value>
43class Trie
44{
45  protected:
46    struct Node
47    {
48        Key key;
49        Key mask;
50
51        bool
52        matches(Key test)
53        {
54            return (test & mask) == key;
55        }
56
57        Value *value;
58
59        Node *parent;
60        Node *kids[2];
61
62        Node(Key _key, Key _mask, Value *_val) :
63            key(_key & _mask), mask(_mask), value(_val),
64            parent(NULL)
65        {
66            kids[0] = NULL;
67            kids[1] = NULL;
68        }
69
70        void
71        clear()
72        {
73            if (kids[1]) {
74                kids[1]->clear();
75                delete kids[1];
76                kids[1] = NULL;
77            }
78            if (kids[0]) {
79                kids[0]->clear();
80                delete kids[0];
81                kids[0] = NULL;
82            }
83        }
84
85        void
86        dump(std::ostream &os, int level)
87        {
88            for (int i = 1; i < level; i++) {
89                ccprintf(os, "|");
90            }
91            if (level == 0)
92                ccprintf(os, "Root ");
93            else
94                ccprintf(os, "+ ");
95            ccprintf(os, "(%p, %p, %#X, %#X, %p)\n",
96                     parent, this, key, mask, value);
97            if (kids[0])
98                kids[0]->dump(os, level + 1);
99            if (kids[1])
100                kids[1]->dump(os, level + 1);
101        }
102    };
103
104  protected:
105    Node head;
106
107  public:
108    typedef Node *Handle;
109
110    Trie() : head(0, 0, NULL)
111    {}
112
113    static const unsigned MaxBits = sizeof(Key) * 8;
114
115  private:
116    /**
117     * A utility method which checks whether the key being looked up lies
118     * beyond the Node being examined. If so, it returns true and advances the
119     * node being examined.
120     * @param parent The node we're currently "at", which can be updated.
121     * @param kid The node we may want to move to.
122     * @param key The key we're looking for.
123     * @param new_mask The mask to use when matching against the key.
124     * @return Whether the current Node was advanced.
125     */
126    bool
127    goesAfter(Node **parent, Node *kid, Key key, Key new_mask)
128    {
129        if (kid && kid->matches(key) && (kid->mask & new_mask) == kid->mask) {
130            *parent = kid;
131            return true;
132        } else {
133            return false;
134        }
135    }
136
137    /**
138     * A utility method which extends a mask value one more bit towards the
139     * lsb. This is almost just a signed right shift, except that the shifted
140     * in bits are technically undefined. This is also slightly complicated by
141     * the zero case.
142     * @param orig The original mask to extend.
143     * @return The extended mask.
144     */
145    Key
146    extendMask(Key orig)
147    {
148        // Just in case orig was 0.
149        const Key msb = ULL(1) << (MaxBits - 1);
150        return orig | (orig >> 1) | msb;
151    }
152
153    /**
154     * Method which looks up the Handle corresponding to a particular key. This
155     * is useful if you want to delete the Handle corresponding to a key since
156     * the "remove" function takes a Handle as its argument.
157     * @param key The key to look up.
158     * @return The first Handle matching this key, or NULL if none was found.
159     */
160    Handle
161    lookupHandle(Key key)
162    {
163        Node *node = &head;
164        while (node) {
165            if (node->value)
166                return node;
167
168            if (node->kids[0] && node->kids[0]->matches(key))
169                node = node->kids[0];
170            else if (node->kids[1] && node->kids[1]->matches(key))
171                node = node->kids[1];
172            else
173                node = NULL;
174        }
175
176        return NULL;
177    }
178
179  public:
180    /**
181     * Method which inserts a key/value pair into the trie.
182     * @param key The key which can later be used to look up this value.
183     * @param width How many bits of the key (from msb) should be used.
184     * @param val A pointer to the value to store in the trie.
185     * @return A Handle corresponding to this value.
186     */
187    Handle
188    insert(Key key, unsigned width, Value *val)
189    {
190        // We use NULL value pointers to mark internal nodes of the trie, so
191        // we don't allow inserting them as real values.
192        assert(val);
193
194        // Build a mask which masks off all the bits we don't care about.
195        Key new_mask = ~(Key)0;
196        if (width < MaxBits)
197            new_mask <<= (MaxBits - width);
198        // Use it to tidy up the key.
199        key &= new_mask;
200
201        // Walk past all the nodes this new node will be inserted after. They
202        // can be ignored for the purposes of this function.
203        Node *node = &head;
204        while (goesAfter(&node, node->kids[0], key, new_mask) ||
205               goesAfter(&node, node->kids[1], key, new_mask))
206        {}
207        assert(node);
208
209        Key cur_mask = node->mask;
210        // If we're already where the value needs to be...
211        if (cur_mask == new_mask) {
212            assert(!node->value);
213            node->value = val;
214            return node;
215        }
216
217        for (unsigned int i = 0; i < 2; i++) {
218            Node *&kid = node->kids[i];
219            Node *new_node;
220            if (!kid) {
221                // No kid. Add a new one.
222                new_node = new Node(key, new_mask, val);
223                new_node->parent = node;
224                kid = new_node;
225                return new_node;
226            }
227
228            // Walk down the leg until something doesn't match or we run out
229            // of bits.
230            Key last_mask;
231            bool done;
232            do {
233                last_mask = cur_mask;
234                cur_mask = extendMask(cur_mask);
235                done = ((key & cur_mask) != (kid->key & cur_mask)) ||
236                    last_mask == new_mask;
237            } while (!done);
238            cur_mask = last_mask;
239
240            // If this isn't the right leg to go down at all, skip it.
241            if (cur_mask == node->mask)
242                continue;
243
244            // At the point we walked to above, add a new node.
245            new_node = new Node(key, cur_mask, NULL);
246            new_node->parent = node;
247            kid->parent = new_node;
248            new_node->kids[0] = kid;
249            kid = new_node;
250
251            // If we ran out of bits, the value goes right here.
252            if (cur_mask == new_mask) {
253                new_node->value = val;
254                return new_node;
255            }
256
257            // Still more bits to deal with, so add a new node for that path.
258            new_node = new Node(key, new_mask, val);
259            new_node->parent = kid;
260            kid->kids[1] = new_node;
261            return new_node;
262        }
263
264        panic("Reached the end of the Trie insert function!\n");
265        return NULL;
266    }
267
268    /**
269     * Method which looks up the Value corresponding to a particular key.
270     * @param key The key to look up.
271     * @return The first Value matching this key, or NULL if none was found.
272     */
273    Value *
274    lookup(Key key)
275    {
276        Node *node = lookupHandle(key);
277        if (node)
278            return node->value;
279        else
280            return NULL;
281    }
282
283    /**
284     * Method to delete a value from the trie.
285     * @param node A Handle to remove.
286     * @return The Value pointer from the removed entry.
287     */
288    Value *
289    remove(Handle handle)
290    {
291        Node *node = handle;
292        Value *val = node->value;
293        if (node->kids[1]) {
294            assert(node->value);
295            node->value = NULL;
296            return val;
297        }
298        if (!node->parent)
299            panic("Trie: Can't remove root node.\n");
300
301        Node *parent = node->parent;
302
303        // If there's a kid, fix up it's parent pointer.
304        if (node->kids[0])
305            node->kids[0]->parent = parent;
306        // Figure out which kid we are, and update our parent's pointers.
307        if (parent->kids[0] == node)
308            parent->kids[0] = node->kids[0];
309        else if (parent->kids[1] == node)
310            parent->kids[1] = node->kids[0];
311        else
312            panic("Trie: Inconsistent parent/kid relationship.\n");
313        // Make sure if the parent only has one kid, it's kid[0].
314        if (parent->kids[1] && !parent->kids[0]) {
315            parent->kids[0] = parent->kids[1];
316            parent->kids[1] = NULL;
317        }
318
319        // If the parent has less than two kids and no cargo and isn't the
320        // root, delete it too.
321        if (!parent->kids[1] && !parent->value && parent->parent)
322            remove(parent);
323        delete node;
324        return val;
325    }
326
327    /**
328     * Method to lookup a value from the trie and then delete it.
329     * @param key The key to look up and then remove.
330     * @return The Value pointer from the removed entry, if any.
331     */
332    Value *
333    remove(Key key)
334    {
335        Handle handle = lookupHandle(key);
336        if (!handle)
337            return NULL;
338        return remove(handle);
339    }
340
341    /**
342     * A method which removes all key/value pairs from the trie. This is more
343     * efficient than trying to remove elements individually.
344     */
345    void
346    clear()
347    {
348        head.clear();
349    }
350
351    /**
352     * A debugging method which prints the contents of this trie.
353     * @param title An identifying title to put in the dump header.
354     */
355    void
356    dump(const char *title, std::ostream &os=std::cout)
357    {
358        ccprintf(os, "**************************************************\n");
359        ccprintf(os, "*** Start of Trie: %s\n", title);
360        ccprintf(os, "*** (parent, me, key, mask, value pointer)\n");
361        ccprintf(os, "**************************************************\n");
362        head.dump(os, 0);
363    }
364};
365
366#endif
367