sorteddict.py revision 8223:394cb2dc3f7c
1# Copyright (c) 2006-2009 Nathan Binkert <nate@binkert.org>
2# All rights reserved.
3#
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are
6# met: redistributions of source code must retain the above copyright
7# notice, this list of conditions and the following disclaimer;
8# redistributions in binary form must reproduce the above copyright
9# notice, this list of conditions and the following disclaimer in the
10# documentation and/or other materials provided with the distribution;
11# neither the name of the copyright holders nor the names of its
12# contributors may be used to endorse or promote products derived from
13# this software without specific prior written permission.
14#
15# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
19# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
21# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27from bisect import bisect_left, bisect_right
28
29class SortedDict(dict):
30    def _get_sorted(self):
31        return getattr(self, '_sorted', sorted)
32    def _set_sorted(self, val):
33        self._sorted = val
34        self._del_keys()
35    sorted = property(_get_sorted, _set_sorted)
36
37    @property
38    def _keys(self):
39        try:
40            return self._sorted_keys
41        except AttributeError:
42            _sorted_keys = self.sorted(dict.iterkeys(self))
43            self._sorted_keys = _sorted_keys
44            return _sorted_keys
45
46    def _left_eq(self, key):
47        index = self._left_ge(self, key)
48        if self._keys[index] != key:
49            raise KeyError(key)
50        return index
51
52    def _right_eq(self, key):
53        index = self._right_le(self, key)
54        if self._keys[index] != key:
55            raise KeyError(key)
56        return index
57
58    def _right_lt(self, key):
59        index = bisect_left(self._keys, key)
60        if index:
61            return index - 1
62        raise KeyError(key)
63
64    def _right_le(self, key):
65        index = bisect_right(self._keys, key)
66        if index:
67            return index - 1
68        raise KeyError(key)
69
70    def _left_gt(self, key):
71        index = bisect_right(self._keys, key)
72        if index != len(self._keys):
73            return index
74        raise KeyError(key)
75
76    def _left_ge(self, key):
77        index = bisect_left(self._keys, key)
78        if index != len(self._keys):
79            return index
80        raise KeyError(key)
81
82    def _del_keys(self):
83        try:
84            del self._sorted_keys
85        except AttributeError:
86            pass
87
88    def __repr__(self):
89        return 'SortedDict({%s})' % ', '.join('%r: %r' % item
90                                              for item in self.iteritems())
91    def __setitem__(self, key, item):
92        dict.__setitem__(self, key, item)
93        self._del_keys()
94
95    def __delitem__(self, key):
96        dict.__delitem__(self, key)
97        self._del_keys()
98
99    def clear(self):
100        self.data.clear()
101        self._del_keys()
102
103    def copy(self):
104        t = type(self)
105        return t(self)
106
107    def keys(self):
108        return self._keys[:]
109
110    def values(self):
111        return list(self.itervalues())
112
113    def items(self):
114        return list(self.iteritems())
115
116    def iterkeys(self):
117        return iter(self._keys)
118
119    def itervalues(self):
120        for k in self._keys:
121            yield self[k]
122
123    def iteritems(self):
124        for k in self._keys:
125            yield k, self[k]
126
127    def keyrange(self, start=None, end=None, inclusive=False):
128        if start is not None:
129            start = self._left_ge(start)
130
131        if end is not None:
132            if inclusive:
133                end = self._right_le(end)
134            else:
135                end = self._right_lt(end)
136
137        return iter(self._keys[start:end+1])
138
139    def valuerange(self, *args, **kwargs):
140        for k in self.keyrange(*args, **kwargs):
141            yield self[k]
142
143    def itemrange(self, *args, **kwargs):
144        for k in self.keyrange(*args, **kwargs):
145            yield k, self[k]
146
147    def update(self, *args, **kwargs):
148        dict.update(self, *args, **kwargs)
149        self._del_keys()
150
151    def setdefault(self, key, _failobj=None):
152        try:
153            return self[key]
154        except KeyError:
155            self[key] = _failobj
156
157    def pop(self, key, *args):
158        try:
159            dict.pop(self, key)
160            self._del_keys()
161        except KeyError:
162            if not args:
163                raise
164            return args[0]
165
166    def popitem(self):
167        try:
168            key = self._keys[0]
169            self._del_keys()
170        except IndexError:
171            raise KeyError('popitem(): dictionary is empty')
172        else:
173            return key, dict.pop(self, key)
174
175    @classmethod
176    def fromkeys(cls, seq, value=None):
177        d = cls()
178        for key in seq:
179            d[key] = value
180        return d
181
182if __name__ == '__main__':
183    def display(d):
184        print d
185        print d.keys()
186        print list(d.iterkeys())
187        print d.values()
188        print list(d.itervalues())
189        print d.items()
190        print list(d.iteritems())
191
192    d = SortedDict(x=24,e=5,j=4,b=2,z=26,d=4)
193    display(d)
194
195    print 'popitem', d.popitem()
196    display(d)
197
198    print 'pop j'
199    d.pop('j')
200    display(d)
201
202    d.setdefault('a', 1)
203    d.setdefault('g', 7)
204    d.setdefault('_')
205    display(d)
206
207    d.update({'b' : 2, 'h' : 8})
208    display(d)
209
210    del d['x']
211    display(d)
212    d['y'] = 26
213    display(d)
214
215    print `d`
216
217    print d.copy()
218
219    for k,v in d.itemrange('d', 'z', inclusive=True):
220        print k,v
221