sort_includes.py revision 10275
1#!/usr/bin/env python
2
3import os
4import re
5import sys
6
7from file_types import *
8
9cpp_c_headers = {
10    'assert.h' : 'cassert',
11    'ctype.h'  : 'cctype',
12    'errno.h'  : 'cerrno',
13    'float.h'  : 'cfloat',
14    'limits.h' : 'climits',
15    'locale.h' : 'clocale',
16    'math.h'   : 'cmath',
17    'setjmp.h' : 'csetjmp',
18    'signal.h' : 'csignal',
19    'stdarg.h' : 'cstdarg',
20    'stddef.h' : 'cstddef',
21    'stdio.h'  : 'cstdio',
22    'stdlib.h' : 'cstdlib',
23    'string.h' : 'cstring',
24    'time.h'   : 'ctime',
25    'wchar.h'  : 'cwchar',
26    'wctype.h' : 'cwctype',
27}
28
29include_re = re.compile(r'([#%])(include|import).*[<"](.*)[">]')
30def include_key(line):
31    '''Mark directories with a leading space so directories
32    are sorted before files'''
33
34    match = include_re.match(line)
35    assert match, line
36    keyword = match.group(2)
37    include = match.group(3)
38
39    # Everything but the file part needs to have a space prepended
40    parts = include.split('/')
41    if len(parts) == 2 and parts[0] == 'dnet':
42        # Don't sort the dnet includes with respect to each other, but
43        # make them sorted with respect to non dnet includes.  Python
44        # guarantees that sorting is stable, so just clear the
45        # basename part of the filename.
46        parts[1] = ' '
47    parts[0:-1] = [ ' ' + s for s in parts[0:-1] ]
48    key = '/'.join(parts)
49
50    return key
51
52class SortIncludes(object):
53    # different types of includes for different sorting of headers
54    # <Python.h>         - Python header needs to be first if it exists
55    # <*.h>              - system headers (directories before files)
56    # <*>                - STL headers
57    # <*.(hh|hxx|hpp|H)> - C++ Headers (directories before files)
58    # "*"                - M5 headers (directories before files)
59    includes_re = (
60        ('python', '<>', r'^(#include)[ \t]+<(Python.*\.h)>(.*)'),
61        ('c', '<>', r'^(#include)[ \t]<(.+\.h)>(.*)'),
62        ('stl', '<>', r'^(#include)[ \t]+<([0-9A-z_]+)>(.*)'),
63        ('cc', '<>', r'^(#include)[ \t]+<([0-9A-z_]+\.(hh|hxx|hpp|H))>(.*)'),
64        ('m5cc', '""', r'^(#include)[ \t]"(.+\.h{1,2})"(.*)'),
65        ('swig0', '<>', r'^(%import)[ \t]<(.+)>(.*)'),
66        ('swig1', '<>', r'^(%include)[ \t]<(.+)>(.*)'),
67        ('swig2', '""', r'^(%import)[ \t]"(.+)"(.*)'),
68        ('swig3', '""', r'^(%include)[ \t]"(.+)"(.*)'),
69        )
70
71    # compile the regexes
72    includes_re = tuple((a, b, re.compile(c)) for a,b,c in includes_re)
73
74    def __init__(self):
75        pass
76
77    def reset(self):
78        # clear all stored headers
79        self.includes = {}
80        for include_type,_,_ in self.includes_re:
81            self.includes[include_type] = []
82
83    def dump_block(self):
84        '''dump the includes'''
85        first = True
86        for include,_,_ in self.includes_re:
87            if not self.includes[include]:
88                continue
89
90            if not first:
91                # print a newline between groups of
92                # include types
93                yield ''
94            first = False
95
96            # print out the includes in the current group
97            # and sort them according to include_key()
98            prev = None
99            for l in sorted(self.includes[include],
100                            key=include_key):
101                if l != prev:
102                    yield l
103                prev = l
104
105    def __call__(self, lines, filename, language):
106        self.reset()
107        leading_blank = False
108        blanks = 0
109        block = False
110
111        for line in lines:
112            if not line:
113                blanks += 1
114                if not block:
115                    # if we're not in an include block, spit out the
116                    # newline otherwise, skip it since we're going to
117                    # control newlines withinin include block
118                    yield ''
119                continue
120
121            # Try to match each of the include types
122            for include_type,(ldelim,rdelim),include_re in self.includes_re:
123                match = include_re.match(line)
124                if not match:
125                    continue
126
127                # if we've got a match, clean up the #include line,
128                # fix up stl headers and store it in the proper category
129                groups = match.groups()
130                keyword = groups[0]
131                include = groups[1]
132                extra = groups[-1]
133                if include_type == 'c' and language == 'C++':
134                    stl_inc = cpp_c_headers.get(include, None)
135                    if stl_inc:
136                        include = stl_inc
137                        include_type = 'stl'
138
139                line = keyword + ' ' + ldelim + include + rdelim + extra
140
141                self.includes[include_type].append(line)
142
143                # We've entered a block, don't keep track of blank
144                # lines while in a block
145                block = True
146                blanks = 0
147                break
148            else:
149                # this line did not match a #include
150                assert not include_re.match(line)
151
152                # if we're not in a block and we didn't match an include
153                # to enter a block, just emit the line and continue
154                if not block:
155                    yield line
156                    continue
157
158                # We've exited an include block.
159                for block_line in self.dump_block():
160                    yield block_line
161
162                # if there are any newlines after the include block,
163                # emit a single newline (removing extras)
164                if blanks and block:
165                    yield ''
166
167                blanks = 0
168                block = False
169                self.reset()
170
171                # emit the line that ended the block
172                yield line
173
174        if block:
175            # We've exited an include block.
176            for block_line in self.dump_block():
177                yield block_line
178
179
180
181# default language types to try to apply our sorting rules to
182default_languages = frozenset(('C', 'C++', 'isa', 'python', 'scons', 'swig'))
183
184def options():
185    import optparse
186    options = optparse.OptionParser()
187    add_option = options.add_option
188    add_option('-d', '--dir_ignore', metavar="DIR[,DIR]", type='string',
189               default=','.join(default_dir_ignore),
190               help="ignore directories")
191    add_option('-f', '--file_ignore', metavar="FILE[,FILE]", type='string',
192               default=','.join(default_file_ignore),
193               help="ignore files")
194    add_option('-l', '--languages', metavar="LANG[,LANG]", type='string',
195               default=','.join(default_languages),
196               help="languages")
197    add_option('-n', '--dry-run', action='store_true',
198               help="don't overwrite files")
199
200    return options
201
202def parse_args(parser):
203    opts,args = parser.parse_args()
204
205    opts.dir_ignore = frozenset(opts.dir_ignore.split(','))
206    opts.file_ignore = frozenset(opts.file_ignore.split(','))
207    opts.languages = frozenset(opts.languages.split(','))
208
209    return opts,args
210
211if __name__ == '__main__':
212    parser = options()
213    opts, args = parse_args(parser)
214
215    for base in args:
216        for filename,language in find_files(base, languages=opts.languages,
217                file_ignore=opts.file_ignore, dir_ignore=opts.dir_ignore):
218            if opts.dry_run:
219                print "%s: %s" % (filename, language)
220            else:
221                update_file(filename, filename, language, SortIncludes())
222