sort_includes.py revision 8225
1#!/usr/bin/env python
2
3import os
4import re
5import sys
6
7from file_types import *
8
9cpp_c_headers = {
10    'assert.h' : 'cassert',
11    'ctype.h'  : 'cctype',
12    'errno.h'  : 'cerrno',
13    'float.h'  : 'cfloat',
14    'limits.h' : 'climits',
15    'locale.h' : 'clocale',
16    'math.h'   : 'cmath',
17    'setjmp.h' : 'csetjmp',
18    'signal.h' : 'csignal',
19    'stdarg.h' : 'cstdarg',
20    'stddef.h' : 'cstddef',
21    'stdio.h'  : 'cstdio',
22    'stdlib.h' : 'cstdlib',
23    'string.h' : 'cstring',
24    'time.h'   : 'ctime',
25    'wchar.h'  : 'cwchar',
26    'wctype.h' : 'cwctype',
27}
28
29include_re = re.compile(r'([#%])(include|import).*[<"](.*)[">]')
30def include_key(line):
31    '''Mark directories with a leading space so directories
32    are sorted before files'''
33
34    match = include_re.match(line)
35    assert match, line
36    keyword = match.group(2)
37    include = match.group(3)
38
39    # Everything but the file part needs to have a space prepended
40    parts = include.split('/')
41    if len(parts) == 2 and parts[0] == 'dnet':
42        # Don't sort the dnet includes with respect to each other, but
43        # make them sorted with respect to non dnet includes.  Python
44        # guarantees that sorting is stable, so just clear the
45        # basename part of the filename.
46        parts[1] = ' '
47    parts[0:-1] = [ ' ' + s for s in parts[0:-1] ]
48    key = '/'.join(parts)
49
50    return key
51
52class SortIncludes(object):
53    # different types of includes for different sorting of headers
54    # <Python.h>         - Python header needs to be first if it exists
55    # <*.h>              - system headers (directories before files)
56    # <*>                - STL headers
57    # <*.(hh|hxx|hpp|H)> - C++ Headers (directories before files)
58    # "*"                - M5 headers (directories before files)
59    includes_re = (
60        ('python', '<>', r'^(#include)[ \t]+<(Python.*\.h)>(.*)'),
61        ('c', '<>', r'^(#include)[ \t]<(.+\.h)>(.*)'),
62        ('stl', '<>', r'^(#include)[ \t]+<([0-9A-z_]+)>(.*)'),
63        ('cc', '<>', r'^(#include)[ \t]+<([0-9A-z_]+\.(hh|hxx|hpp|H))>(.*)'),
64        ('m5cc', '""', r'^(#include)[ \t]"(.+\.h{1,2})"(.*)'),
65        ('swig0', '<>', r'^(%import)[ \t]<(.+)>(.*)'),
66        ('swig1', '<>', r'^(%include)[ \t]<(.+)>(.*)'),
67        ('swig2', '""', r'^(%import)[ \t]"(.+)"(.*)'),
68        ('swig3', '""', r'^(%include)[ \t]"(.+)"(.*)'),
69        )
70
71    # compile the regexes
72    includes_re = tuple((a, b, re.compile(c)) for a,b,c in includes_re)
73
74    def __init__(self):
75        self.reset()
76
77    def reset(self):
78        # clear all stored headers
79        self.includes = {}
80        for include_type,_,_ in self.includes_re:
81            self.includes[include_type] = []
82
83    def dump_block(self):
84        '''dump the includes'''
85        first = True
86        for include,_,_ in self.includes_re:
87            if not self.includes[include]:
88                continue
89
90            if not first:
91                # print a newline between groups of
92                # include types
93                yield ''
94            first = False
95
96            # print out the includes in the current group
97            # and sort them according to include_key()
98            prev = None
99            for l in sorted(self.includes[include],
100                            key=include_key):
101                if l != prev:
102                    yield l
103                prev = l
104
105    def __call__(self, lines, filename, language):
106        leading_blank = False
107        blanks = 0
108        block = False
109
110        for line in lines:
111            if not line:
112                blanks += 1
113                if not block:
114                    # if we're not in an include block, spit out the
115                    # newline otherwise, skip it since we're going to
116                    # control newlines withinin include block
117                    yield ''
118                continue
119
120            # Try to match each of the include types
121            for include_type,(ldelim,rdelim),include_re in self.includes_re:
122                match = include_re.match(line)
123                if not match:
124                    continue
125
126                # if we've got a match, clean up the #include line,
127                # fix up stl headers and store it in the proper category
128                groups = match.groups()
129                keyword = groups[0]
130                include = groups[1]
131                extra = groups[-1]
132                if include_type == 'c' and language == 'C++':
133                    stl_inc = cpp_c_headers.get(include, None)
134                    if stl_inc:
135                        include = stl_inc
136                        include_type = 'stl'
137
138                line = keyword + ' ' + ldelim + include + rdelim + extra
139
140                self.includes[include_type].append(line)
141
142                # We've entered a block, don't keep track of blank
143                # lines while in a block
144                block = True
145                blanks = 0
146                break
147            else:
148                # this line did not match a #include
149                assert not include_re.match(line)
150
151                # if we're not in a block and we didn't match an include
152                # to enter a block, just emit the line and continue
153                if not block:
154                    yield line
155                    continue
156
157                # We've exited an include block.
158                for block_line in self.dump_block():
159                    yield block_line
160
161                # if there are any newlines after the include block,
162                # emit a single newline (removing extras)
163                if blanks and block:
164                    yield ''
165
166                blanks = 0
167                block = False
168                self.reset()
169
170                # emit the line that ended the block
171                yield line
172
173        if block:
174            # We've exited an include block.
175            for block_line in self.dump_block():
176                yield block_line
177
178
179
180# default language types to try to apply our sorting rules to
181default_languages = frozenset(('C', 'C++', 'isa', 'python', 'scons', 'swig'))
182
183def options():
184    import optparse
185    options = optparse.OptionParser()
186    add_option = options.add_option
187    add_option('-d', '--dir_ignore', metavar="DIR[,DIR]", type='string',
188               default=','.join(default_dir_ignore),
189               help="ignore directories")
190    add_option('-f', '--file_ignore', metavar="FILE[,FILE]", type='string',
191               default=','.join(default_file_ignore),
192               help="ignore files")
193    add_option('-l', '--languages', metavar="LANG[,LANG]", type='string',
194               default=','.join(default_languages),
195               help="languages")
196    add_option('-n', '--dry-run', action='store_true',
197               help="don't overwrite files")
198
199    return options
200
201def parse_args(parser):
202    opts,args = parser.parse_args()
203
204    opts.dir_ignore = frozenset(opts.dir_ignore.split(','))
205    opts.file_ignore = frozenset(opts.file_ignore.split(','))
206    opts.languages = frozenset(opts.languages.split(','))
207
208    return opts,args
209
210if __name__ == '__main__':
211    parser = options()
212    opts, args = parse_args(parser)
213
214    for base in args:
215        for filename,language in find_files(base, languages=opts.languages,
216                file_ignore=opts.file_ignore, dir_ignore=opts.dir_ignore):
217            if opts.dry_run:
218                print "%s: %s" % (filename, language)
219            else:
220                update_file(filename, filename, language, SortIncludes())
221