1# Copyright (c) 2005-2006 The Regents of The University of Michigan 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: redistributions of source code must retain the above copyright 7# notice, this list of conditions and the following disclaimer; 8# redistributions in binary form must reproduce the above copyright 9# notice, this list of conditions and the following disclaimer in the 10# documentation and/or other materials provided with the distribution; 11# neither the name of the copyright holders nor the names of its 12# contributors may be used to endorse or promote products derived from 13# this software without specific prior written permission. 14# 15# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 19# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 20# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 21# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 22# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 23# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26# 27# Authors: Nathan Binkert 28# Lisa Hsu 29 30import matplotlib, pylab 31from matplotlib.font_manager import FontProperties 32from matplotlib.numerix import array, arange, reshape, shape, transpose, zeros 33from matplotlib.numerix import Float 34from matplotlib.ticker import NullLocator 35 36matplotlib.interactive(False) 37 38from chart import ChartOptions 39 40class BarChart(ChartOptions): 41 def __init__(self, default=None, **kwargs): 42 super(BarChart, self).__init__(default, **kwargs) 43 self.inputdata = None 44 self.chartdata = None 45 self.inputerr = None 46 self.charterr = None 47 48 def gen_colors(self, count): 49 cmap = matplotlib.cm.get_cmap(self.colormap) 50 if count == 1: 51 return cmap([ 0.5 ]) 52 53 if count < 5: 54 return cmap(arange(5) / float(4))[:count] 55 56 return cmap(arange(count) / float(count - 1)) 57 58 # The input data format does not match the data format that the 59 # graph function takes because it is intuitive. The conversion 60 # from input data format to chart data format depends on the 61 # dimensionality of the input data. Check here for the 62 # dimensionality and correctness of the input data 63 def set_data(self, data): 64 if data is None: 65 self.inputdata = None 66 self.chartdata = None 67 return 68 69 data = array(data) 70 dim = len(shape(data)) 71 if dim not in (1, 2, 3): 72 raise AttributeError, "Input data must be a 1, 2, or 3d matrix" 73 self.inputdata = data 74 75 # If the input data is a 1d matrix, then it describes a 76 # standard bar chart. 77 if dim == 1: 78 self.chartdata = array([[data]]) 79 80 # If the input data is a 2d matrix, then it describes a bar 81 # chart with groups. The matrix being an array of groups of 82 # bars. 83 if dim == 2: 84 self.chartdata = transpose([data], axes=(2,0,1)) 85 86 # If the input data is a 3d matrix, then it describes an array 87 # of groups of bars with each bar being an array of stacked 88 # values. 89 if dim == 3: 90 self.chartdata = transpose(data, axes=(1,2,0)) 91 92 def get_data(self): 93 return self.inputdata 94 95 data = property(get_data, set_data) 96 97 def set_err(self, err): 98 if err is None: 99 self.inputerr = None 100 self.charterr = None 101 return 102 103 err = array(err) 104 dim = len(shape(err)) 105 if dim not in (1, 2, 3): 106 raise AttributeError, "Input err must be a 1, 2, or 3d matrix" 107 self.inputerr = err 108 109 if dim == 1: 110 self.charterr = array([[err]]) 111 112 if dim == 2: 113 self.charterr = transpose([err], axes=(2,0,1)) 114 115 if dim == 3: 116 self.charterr = transpose(err, axes=(1,2,0)) 117 118 def get_err(self): 119 return self.inputerr 120 121 err = property(get_err, set_err) 122 123 # Graph the chart data. 124 # Input is a 3d matrix that describes a plot that has multiple 125 # groups, multiple bars in each group, and multiple values stacked 126 # in each bar. The underlying bar() function expects a sequence of 127 # bars in the same stack location and same group location, so the 128 # organization of the matrix is that the inner most sequence 129 # represents one of these bar groups, then those are grouped 130 # together to make one full stack of bars in each group, and then 131 # the outer most layer describes the groups. Here is an example 132 # data set and how it gets plotted as a result. 133 # 134 # e.g. data = [[[10,11,12], [13,14,15], [16,17,18], [19,20,21]], 135 # [[22,23,24], [25,26,27], [28,29,30], [31,32,33]]] 136 # 137 # will plot like this: 138 # 139 # 19 31 20 32 21 33 140 # 16 28 17 29 18 30 141 # 13 25 14 26 15 27 142 # 10 22 11 23 12 24 143 # 144 # Because this arrangement is rather conterintuitive, the rearrange 145 # function takes various matricies and arranges them to fit this 146 # profile. 147 # 148 # This code deals with one of the dimensions in the matrix being 149 # one wide. 150 # 151 def graph(self): 152 if self.chartdata is None: 153 raise AttributeError, "Data not set for bar chart!" 154 155 dim = len(shape(self.inputdata)) 156 cshape = shape(self.chartdata) 157 if self.charterr is not None and shape(self.charterr) != cshape: 158 raise AttributeError, 'Dimensions of error and data do not match' 159 160 if dim == 1: 161 colors = self.gen_colors(cshape[2]) 162 colors = [ [ colors ] * cshape[1] ] * cshape[0] 163 164 if dim == 2: 165 colors = self.gen_colors(cshape[0]) 166 colors = [ [ [ c ] * cshape[2] ] * cshape[1] for c in colors ] 167 168 if dim == 3: 169 colors = self.gen_colors(cshape[1]) 170 colors = [ [ [ c ] * cshape[2] for c in colors ] ] * cshape[0] 171 172 colors = array(colors) 173 174 self.figure = pylab.figure(figsize=self.chart_size) 175 176 outer_axes = None 177 inner_axes = None 178 if self.xsubticks is not None: 179 color = self.figure.get_facecolor() 180 self.metaaxes = self.figure.add_axes(self.figure_size, 181 axisbg=color, frameon=False) 182 for tick in self.metaaxes.xaxis.majorTicks: 183 tick.tick1On = False 184 tick.tick2On = False 185 self.metaaxes.set_yticklabels([]) 186 self.metaaxes.set_yticks([]) 187 size = [0] * 4 188 size[0] = self.figure_size[0] 189 size[1] = self.figure_size[1] + .12 190 size[2] = self.figure_size[2] 191 size[3] = self.figure_size[3] - .12 192 self.axes = self.figure.add_axes(size) 193 outer_axes = self.metaaxes 194 inner_axes = self.axes 195 else: 196 self.axes = self.figure.add_axes(self.figure_size) 197 outer_axes = self.axes 198 inner_axes = self.axes 199 200 bars_in_group = len(self.chartdata) 201 202 width = 1.0 / ( bars_in_group + 1) 203 center = width / 2 204 205 bars = [] 206 for i,stackdata in enumerate(self.chartdata): 207 bottom = array([0.0] * len(stackdata[0]), Float) 208 stack = [] 209 for j,bardata in enumerate(stackdata): 210 bardata = array(bardata) 211 ind = arange(len(bardata)) + i * width + center 212 yerr = None 213 if self.charterr is not None: 214 yerr = self.charterr[i][j] 215 bar = self.axes.bar(ind, bardata, width, bottom=bottom, 216 color=colors[i][j], yerr=yerr) 217 if self.xsubticks is not None: 218 self.metaaxes.bar(ind, [0] * len(bardata), width) 219 stack.append(bar) 220 bottom += bardata 221 bars.append(stack) 222 223 if self.xlabel is not None: 224 outer_axes.set_xlabel(self.xlabel) 225 226 if self.ylabel is not None: 227 inner_axes.set_ylabel(self.ylabel) 228 229 if self.yticks is not None: 230 ymin, ymax = self.axes.get_ylim() 231 nticks = float(len(self.yticks)) 232 ticks = arange(nticks) / (nticks - 1) * (ymax - ymin) + ymin 233 inner_axes.set_yticks(ticks) 234 inner_axes.set_yticklabels(self.yticks) 235 elif self.ylim is not None: 236 inner_axes.set_ylim(self.ylim) 237 238 if self.xticks is not None: 239 outer_axes.set_xticks(arange(cshape[2]) + .5) 240 outer_axes.set_xticklabels(self.xticks) 241 242 if self.xsubticks is not None: 243 numticks = (cshape[0] + 1) * cshape[2] 244 inner_axes.set_xticks(arange(numticks) * width + 2 * center) 245 xsubticks = list(self.xsubticks) + [ '' ] 246 inner_axes.set_xticklabels(xsubticks * cshape[2], fontsize=7, 247 rotation=30) 248 249 if self.legend is not None: 250 if dim == 1: 251 lbars = bars[0][0] 252 if dim == 2: 253 lbars = [ bars[i][0][0] for i in xrange(len(bars))] 254 if dim == 3: 255 number = len(bars[0]) 256 lbars = [ bars[0][number - j - 1][0] for j in xrange(number)] 257 258 if self.fig_legend: 259 self.figure.legend(lbars, self.legend, self.legend_loc, 260 prop=FontProperties(size=self.legend_size)) 261 else: 262 self.axes.legend(lbars, self.legend, self.legend_loc, 263 prop=FontProperties(size=self.legend_size)) 264 265 if self.title is not None: 266 self.axes.set_title(self.title) 267 268 def savefig(self, name): 269 self.figure.savefig(name) 270 271 def savecsv(self, name): 272 f = file(name, 'w') 273 data = array(self.inputdata) 274 dim = len(data.shape) 275 276 if dim == 1: 277 #if self.xlabel: 278 # f.write(', '.join(list(self.xlabel)) + '\n') 279 f.write(', '.join([ '%f' % val for val in data]) + '\n') 280 if dim == 2: 281 #if self.xlabel: 282 # f.write(', '.join([''] + list(self.xlabel)) + '\n') 283 for i,row in enumerate(data): 284 ylabel = [] 285 #if self.ylabel: 286 # ylabel = [ self.ylabel[i] ] 287 f.write(', '.join(ylabel + [ '%f' % v for v in row]) + '\n') 288 if dim == 3: 289 f.write("don't do 3D csv files\n") 290 pass 291 292 f.close() 293 294if __name__ == '__main__': 295 from random import randrange 296 import random, sys 297 298 dim = 3 299 number = 5 300 301 args = sys.argv[1:] 302 if len(args) > 3: 303 sys.exit("invalid number of arguments") 304 elif len(args) > 0: 305 myshape = [ int(x) for x in args ] 306 else: 307 myshape = [ 3, 4, 8 ] 308 309 # generate a data matrix of the given shape 310 size = reduce(lambda x,y: x*y, myshape) 311 #data = [ random.randrange(size - i) + 10 for i in xrange(size) ] 312 data = [ float(i)/100.0 for i in xrange(size) ] 313 data = reshape(data, myshape) 314 315 # setup some test bar charts 316 if True: 317 chart1 = BarChart() 318 chart1.data = data 319 320 chart1.xlabel = 'Benchmark' 321 chart1.ylabel = 'Bandwidth (GBps)' 322 chart1.legend = [ 'x%d' % x for x in xrange(myshape[-1]) ] 323 chart1.xticks = [ 'xtick%d' % x for x in xrange(myshape[0]) ] 324 chart1.title = 'this is the title' 325 if len(myshape) > 2: 326 chart1.xsubticks = [ '%d' % x for x in xrange(myshape[1]) ] 327 chart1.graph() 328 chart1.savefig('/tmp/test1.png') 329 chart1.savefig('/tmp/test1.ps') 330 chart1.savefig('/tmp/test1.eps') 331 chart1.savecsv('/tmp/test1.csv') 332 333 if False: 334 chart2 = BarChart() 335 chart2.data = data 336 chart2.colormap = 'gray' 337 chart2.graph() 338 chart2.savefig('/tmp/test2.png') 339 chart2.savefig('/tmp/test2.ps') 340 341# pylab.show() 342