barchart.py revision 2180
1# Copyright (c) 2005-2006 The Regents of The University of Michigan 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: redistributions of source code must retain the above copyright 7# notice, this list of conditions and the following disclaimer; 8# redistributions in binary form must reproduce the above copyright 9# notice, this list of conditions and the following disclaimer in the 10# documentation and/or other materials provided with the distribution; 11# neither the name of the copyright holders nor the names of its 12# contributors may be used to endorse or promote products derived from 13# this software without specific prior written permission. 14# 15# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 19# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 20# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 21# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 22# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 23# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26# 27# Authors: Nathan Binkert 28# Lisa Hsu 29 30import matplotlib, pylab 31from matplotlib.font_manager import FontProperties 32from matplotlib.numerix import array, arange, reshape, shape, transpose, zeros 33from matplotlib.numerix import Float 34from matplotlib.ticker import NullLocator 35 36matplotlib.interactive(False) 37 38from chart import ChartOptions 39 40class BarChart(ChartOptions): 41 def __init__(self, default=None, **kwargs): 42 super(BarChart, self).__init__(default, **kwargs) 43 self.inputdata = None 44 self.chartdata = None 45 self.inputerr = None 46 self.charterr = None 47 48 def gen_colors(self, count): 49 cmap = matplotlib.cm.get_cmap(self.colormap) 50 if count == 1: 51 return cmap([ 0.5 ]) 52 53 if count < 5: 54 return cmap(arange(5) / float(4))[:count] 55 56 return cmap(arange(count) / float(count - 1)) 57 58 # The input data format does not match the data format that the 59 # graph function takes because it is intuitive. The conversion 60 # from input data format to chart data format depends on the 61 # dimensionality of the input data. Check here for the 62 # dimensionality and correctness of the input data 63 def set_data(self, data): 64 if data is None: 65 self.inputdata = None 66 self.chartdata = None 67 return 68 69 data = array(data) 70 dim = len(shape(data)) 71 if dim not in (1, 2, 3): 72 raise AttributeError, "Input data must be a 1, 2, or 3d matrix" 73 self.inputdata = data 74 75 # If the input data is a 1d matrix, then it describes a 76 # standard bar chart. 77 if dim == 1: 78 self.chartdata = array([[data]]) 79 80 # If the input data is a 2d matrix, then it describes a bar 81 # chart with groups. The matrix being an array of groups of 82 # bars. 83 if dim == 2: 84 self.chartdata = transpose([data], axes=(2,0,1)) 85 86 # If the input data is a 3d matrix, then it describes an array 87 # of groups of bars with each bar being an array of stacked 88 # values. 89 if dim == 3: 90 self.chartdata = transpose(data, axes=(1,2,0)) 91 92 def get_data(self): 93 return self.inputdata 94 95 data = property(get_data, set_data) 96 97 def set_err(self, err): 98 if err is None: 99 self.inputerr = None 100 self.charterr = None 101 return 102 103 err = array(err) 104 dim = len(shape(err)) 105 if dim not in (1, 2, 3): 106 raise AttributeError, "Input err must be a 1, 2, or 3d matrix" 107 self.inputerr = err 108 109 if dim == 1: 110 self.charterr = array([[err]]) 111 112 if dim == 2: 113 self.charterr = transpose([err], axes=(2,0,1)) 114 115 if dim == 3: 116 self.charterr = transpose(err, axes=(1,2,0)) 117 118 def get_err(self): 119 return self.inputerr 120 121 err = property(get_err, set_err) 122 123 # Graph the chart data. 124 # Input is a 3d matrix that describes a plot that has multiple 125 # groups, multiple bars in each group, and multiple values stacked 126 # in each bar. The underlying bar() function expects a sequence of 127 # bars in the same stack location and same group location, so the 128 # organization of the matrix is that the inner most sequence 129 # represents one of these bar groups, then those are grouped 130 # together to make one full stack of bars in each group, and then 131 # the outer most layer describes the groups. Here is an example 132 # data set and how it gets plotted as a result. 133 # 134 # e.g. data = [[[10,11,12], [13,14,15], [16,17,18], [19,20,21]], 135 # [[22,23,24], [25,26,27], [28,29,30], [31,32,33]]] 136 # 137 # will plot like this: 138 # 139 # 19 31 20 32 21 33 140 # 16 28 17 29 18 30 141 # 13 25 14 26 15 27 142 # 10 22 11 23 12 24 143 # 144 # Because this arrangement is rather conterintuitive, the rearrange 145 # function takes various matricies and arranges them to fit this 146 # profile. 147 # 148 # This code deals with one of the dimensions in the matrix being 149 # one wide. 150 # 151 def graph(self): 152 if self.chartdata is None: 153 raise AttributeError, "Data not set for bar chart!" 154 155 dim = len(shape(self.inputdata)) 156 cshape = shape(self.chartdata) 157 if self.charterr is not None and shape(self.charterr) != cshape: 158 raise AttributeError, 'Dimensions of error and data do not match' 159 160 if dim == 1: 161 colors = self.gen_colors(cshape[2]) 162 colors = [ [ colors ] * cshape[1] ] * cshape[0] 163 164 if dim == 2: 165 colors = self.gen_colors(cshape[0]) 166 colors = [ [ [ c ] * cshape[2] ] * cshape[1] for c in colors ] 167 168 if dim == 3: 169 colors = self.gen_colors(cshape[1]) 170 colors = [ [ [ c ] * cshape[2] for c in colors ] ] * cshape[0] 171 172 colors = array(colors) 173 174 self.figure = pylab.figure(figsize=self.chart_size) 175 176 outer_axes = None 177 inner_axes = None 178 if self.xsubticks is not None: 179 color = self.figure.get_facecolor() 180 self.metaaxes = self.figure.add_axes(self.figure_size, axisbg=color, frameon=False) 181 for tick in self.metaaxes.xaxis.majorTicks: 182 tick.tick1On = False 183 tick.tick2On = False 184 self.metaaxes.set_yticklabels([]) 185 self.metaaxes.set_yticks([]) 186 size = [0] * 4 187 size[0] = self.figure_size[0] 188 size[1] = self.figure_size[1] + .12 189 size[2] = self.figure_size[2] 190 size[3] = self.figure_size[3] - .12 191 self.axes = self.figure.add_axes(size) 192 outer_axes = self.metaaxes 193 inner_axes = self.axes 194 else: 195 self.axes = self.figure.add_axes(self.figure_size) 196 outer_axes = self.axes 197 inner_axes = self.axes 198 199 bars_in_group = len(self.chartdata) 200 201 width = 1.0 / ( bars_in_group + 1) 202 center = width / 2 203 204 bars = [] 205 for i,stackdata in enumerate(self.chartdata): 206 bottom = array([0.0] * len(stackdata[0]), Float) 207 stack = [] 208 for j,bardata in enumerate(stackdata): 209 bardata = array(bardata) 210 ind = arange(len(bardata)) + i * width + center 211 yerr = None 212 if self.charterr is not None: 213 yerr = self.charterr[i][j] 214 bar = self.axes.bar(ind, bardata, width, bottom=bottom, 215 color=colors[i][j], yerr=yerr) 216 if self.xsubticks is not None: 217 self.metaaxes.bar(ind, [0] * len(bardata), width) 218 stack.append(bar) 219 bottom += bardata 220 bars.append(stack) 221 222 if self.xlabel is not None: 223 outer_axes.set_xlabel(self.xlabel) 224 225 if self.ylabel is not None: 226 inner_axes.set_ylabel(self.ylabel) 227 228 if self.yticks is not None: 229 ymin, ymax = self.axes.get_ylim() 230 nticks = float(len(self.yticks)) 231 ticks = arange(nticks) / (nticks - 1) * (ymax - ymin) + ymin 232 inner_axes.set_yticks(ticks) 233 inner_axes.set_yticklabels(self.yticks) 234 elif self.ylim is not None: 235 self.inner_axes.set_ylim(self.ylim) 236 237 if self.xticks is not None: 238 outer_axes.set_xticks(arange(cshape[2]) + .5) 239 outer_axes.set_xticklabels(self.xticks) 240 241 if self.xsubticks is not None: 242 inner_axes.set_xticks(arange((cshape[0] + 1)*cshape[2])*width + 2*center) 243 self.xsubticks.append('') 244 inner_axes.set_xticklabels(self.xsubticks * cshape[2], fontsize=7, rotation=90) 245 246 if self.legend is not None: 247 if dim == 1: 248 lbars = bars[0][0] 249 if dim == 2: 250 lbars = [ bars[i][0][0] for i in xrange(len(bars))] 251 if dim == 3: 252 number = len(bars[0]) 253 lbars = [ bars[0][number - j - 1][0] for j in xrange(number)] 254 255 if self.fig_legend: 256 self.figure.legend(lbars, self.legend, self.legend_loc, 257 prop=FontProperties(size=self.legend_size)) 258 else: 259 self.axes.legend(lbars, self.legend, self.legend_loc, 260 prop=FontProperties(size=self.legend_size)) 261 262 if self.title is not None: 263 self.axes.set_title(self.title) 264 265 def savefig(self, name): 266 self.figure.savefig(name) 267 268 def savecsv(self, name): 269 f = file(name, 'w') 270 data = array(self.inputdata) 271 dim = len(data.shape) 272 273 if dim == 1: 274 #if self.xlabel: 275 # f.write(', '.join(list(self.xlabel)) + '\n') 276 f.write(', '.join([ '%f' % val for val in data]) + '\n') 277 if dim == 2: 278 #if self.xlabel: 279 # f.write(', '.join([''] + list(self.xlabel)) + '\n') 280 for i,row in enumerate(data): 281 ylabel = [] 282 #if self.ylabel: 283 # ylabel = [ self.ylabel[i] ] 284 f.write(', '.join(ylabel + [ '%f' % val for val in row]) + '\n') 285 if dim == 3: 286 f.write("don't do 3D csv files\n") 287 pass 288 289 f.close() 290 291if __name__ == '__main__': 292 from random import randrange 293 import random, sys 294 295 dim = 3 296 number = 5 297 298 args = sys.argv[1:] 299 if len(args) > 3: 300 sys.exit("invalid number of arguments") 301 elif len(args) > 0: 302 myshape = [ int(x) for x in args ] 303 else: 304 myshape = [ 3, 4, 8 ] 305 306 # generate a data matrix of the given shape 307 size = reduce(lambda x,y: x*y, myshape) 308 #data = [ random.randrange(size - i) + 10 for i in xrange(size) ] 309 data = [ float(i)/100.0 for i in xrange(size) ] 310 data = reshape(data, myshape) 311 312 # setup some test bar charts 313 if True: 314 chart1 = BarChart() 315 chart1.data = data 316 317 chart1.xlabel = 'Benchmark' 318 chart1.ylabel = 'Bandwidth (GBps)' 319 chart1.legend = [ 'x%d' % x for x in xrange(myshape[-1]) ] 320 chart1.xticks = [ 'xtick%d' % x for x in xrange(myshape[0]) ] 321 chart1.title = 'this is the title' 322 if len(myshape) > 2: 323 chart1.xsubticks = [ '%d' % x for x in xrange(myshape[1]) ] 324 chart1.graph() 325 chart1.savefig('/tmp/test1.png') 326 chart1.savefig('/tmp/test1.ps') 327 chart1.savefig('/tmp/test1.eps') 328 chart1.savecsv('/tmp/test1.csv') 329 330 if False: 331 chart2 = BarChart() 332 chart2.data = data 333 chart2.colormap = 'gray' 334 chart2.graph() 335 chart2.savefig('/tmp/test2.png') 336 chart2.savefig('/tmp/test2.ps') 337 338# pylab.show() 339