1# Copyright (c) 2003-2004 The Regents of The University of Michigan
2# All rights reserved.
3#
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are
6# met: redistributions of source code must retain the above copyright
7# notice, this list of conditions and the following disclaimer;
8# redistributions in binary form must reproduce the above copyright
9# notice, this list of conditions and the following disclaimer in the
10# documentation and/or other materials provided with the distribution;
11# neither the name of the copyright holders nor the names of its
12# contributors may be used to endorse or promote products derived from
13# this software without specific prior written permission.
14#
15# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
19# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
21# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26#
27# Authors: Nathan Binkert
28
29import MySQLdb, re, string
30
31def statcmp(a, b):
32    v1 = a.split('.')
33    v2 = b.split('.')
34
35    last = min(len(v1), len(v2)) - 1
36    for i,j in zip(v1[0:last], v2[0:last]):
37        if i != j:
38            return cmp(i, j)
39
40    # Special compare for last element.
41    if len(v1) == len(v2):
42        return cmp(v1[last], v2[last])
43    else:
44        return cmp(len(v1), len(v2))
45
46class RunData:
47    def __init__(self, row):
48        self.run = int(row[0])
49        self.name = row[1]
50        self.user = row[2]
51        self.project = row[3]
52
53class SubData:
54    def __init__(self, row):
55        self.stat = int(row[0])
56        self.x = int(row[1])
57        self.y = int(row[2])
58        self.name = row[3]
59        self.descr = row[4]
60
61class Data:
62    def __init__(self, row):
63        if len(row) != 5:
64            raise 'stat db error'
65        self.stat = int(row[0])
66        self.run = int(row[1])
67        self.x = int(row[2])
68        self.y = int(row[3])
69        self.data = float(row[4])
70
71    def __repr__(self):
72        return '''Data(['%d', '%d', '%d', '%d', '%f'])''' % ( self.stat,
73            self.run, self.x, self.y, self.data)
74
75class StatData(object):
76    def __init__(self, row):
77        self.stat = int(row[0])
78        self.name = row[1]
79        self.desc = row[2]
80        self.type = row[3]
81        self.prereq = int(row[5])
82        self.precision = int(row[6])
83
84        import flags
85        self.flags = 0
86        if int(row[4]): self.flags |= flags.printable
87        if int(row[7]): self.flags |= flags.nozero
88        if int(row[8]): self.flags |= flags.nonan
89        if int(row[9]): self.flags |= flags.total
90        if int(row[10]): self.flags |= flags.pdf
91        if int(row[11]): self.flags |= flags.cdf
92
93        if self.type == 'DIST' or self.type == 'VECTORDIST':
94            self.min = float(row[12])
95            self.max = float(row[13])
96            self.bktsize = float(row[14])
97            self.size = int(row[15])
98
99        if self.type == 'FORMULA':
100            self.formula = self.db.allFormulas[self.stat]
101
102class Node(object):
103    def __init__(self, name):
104        self.name = name
105    def __str__(self):
106        return self.name
107
108class Result(object):
109    def __init__(self, x, y):
110        self.data = {}
111        self.x = x
112        self.y = y
113
114    def __contains__(self, run):
115        return run in self.data
116
117    def __getitem__(self, run):
118        if run not in self.data:
119            self.data[run] = [ [ 0.0 ] * self.y for i in xrange(self.x) ]
120        return self.data[run]
121
122class Database(object):
123    def __init__(self):
124        self.host = 'zizzer.pool'
125        self.user = ''
126        self.passwd = ''
127        self.db = 'm5stats'
128        self.cursor = None
129
130        self.allStats = []
131        self.allStatIds = {}
132        self.allStatNames = {}
133
134        self.allSubData = {}
135
136        self.allRuns = []
137        self.allRunIds = {}
138        self.allRunNames = {}
139
140        self.allFormulas = {}
141
142        self.stattop = {}
143        self.statdict = {}
144        self.statlist = []
145
146        self.mode = 'sum';
147        self.runs = None
148        self.ticks = None
149        self.method = 'sum'
150        self._method = type(self).sum
151
152    def get(self, job, stat, system=None):
153        run = self.allRunNames.get(str(job), None)
154        if run is None:
155            return None
156
157        from info import ProxyError, scalar, vector, value, values, total, len
158        if system is None and hasattr(job, 'system'):
159            system = job.system
160
161        if system is not None:
162            stat.system = self[system]
163        try:
164            if scalar(stat):
165                return value(stat, run.run)
166            if vector(stat):
167                return values(stat, run.run)
168        except ProxyError:
169            return None
170
171        return None
172
173    def query(self, sql):
174        self.cursor.execute(sql)
175
176    def update_dict(self, dict):
177        dict.update(self.stattop)
178
179    def append(self, stat):
180        statname = re.sub(':', '__', stat.name)
181        path = string.split(statname, '.')
182        pathtop = path[0]
183        fullname = ''
184
185        x = self
186        while len(path) > 1:
187            name = path.pop(0)
188            if not x.__dict__.has_key(name):
189                x.__dict__[name] = Node(fullname + name)
190            x = x.__dict__[name]
191            fullname = '%s%s.' % (fullname, name)
192
193        name = path.pop(0)
194        x.__dict__[name] = stat
195
196        self.stattop[pathtop] = self.__dict__[pathtop]
197        self.statdict[statname] = stat
198        self.statlist.append(statname)
199
200    def connect(self):
201        # connect
202        self.thedb = MySQLdb.connect(db=self.db,
203                                     host=self.host,
204                                     user=self.user,
205                                     passwd=self.passwd)
206
207        # create a cursor
208        self.cursor = self.thedb.cursor()
209
210        self.query('''select rn_id,rn_name,rn_sample,rn_user,rn_project
211                   from runs''')
212        for result in self.cursor.fetchall():
213            run = RunData(result);
214            self.allRuns.append(run)
215            self.allRunIds[run.run] = run
216            self.allRunNames[run.name] = run
217
218        self.query('select sd_stat,sd_x,sd_y,sd_name,sd_descr from subdata')
219        for result in self.cursor.fetchall():
220            subdata = SubData(result)
221            if self.allSubData.has_key(subdata.stat):
222                self.allSubData[subdata.stat].append(subdata)
223            else:
224                self.allSubData[subdata.stat] = [ subdata ]
225
226        self.query('select * from formulas')
227        for id,formula in self.cursor.fetchall():
228            self.allFormulas[int(id)] = formula.tostring()
229
230        StatData.db = self
231        self.query('select * from stats')
232        import info
233        for result in self.cursor.fetchall():
234            stat = info.NewStat(self, StatData(result))
235            self.append(stat)
236            self.allStats.append(stat)
237            self.allStatIds[stat.stat] = stat
238            self.allStatNames[stat.name] = stat
239
240    # Name: listruns
241    # Desc: Prints all runs matching a given user, if no argument
242    #       is given all runs are returned
243    def listRuns(self, user=None):
244        print '%-40s %-10s %-5s' % ('run name', 'user', 'id')
245        print '-' * 62
246        for run in self.allRuns:
247            if user == None or user == run.user:
248                print '%-40s %-10s %-10d' % (run.name, run.user, run.run)
249
250    # Name: listTicks
251    # Desc: Prints all samples for a given run
252    def listTicks(self, runs=None):
253        print "tick"
254        print "----------------------------------------"
255        sql = 'select distinct dt_tick from data where dt_stat=1180 and ('
256        if runs != None:
257            first = True
258            for run in runs:
259               if first:
260            #       sql += ' where'
261                   first = False
262               else:
263                   sql += ' or'
264               sql += ' dt_run=%s' % run.run
265            sql += ')'
266        self.query(sql)
267        for r in self.cursor.fetchall():
268            print r[0]
269
270    # Name: retTicks
271    # Desc: Prints all samples for a given run
272    def retTicks(self, runs=None):
273        sql = 'select distinct dt_tick from data where dt_stat=1180 and ('
274        if runs != None:
275            first = True
276            for run in runs:
277               if first:
278                   first = False
279               else:
280                   sql += ' or'
281               sql += ' dt_run=%s' % run.run
282            sql += ')'
283        self.query(sql)
284        ret = []
285        for r in self.cursor.fetchall():
286            ret.append(r[0])
287        return ret
288
289    # Name: liststats
290    # Desc: Prints all statistics that appear in the database,
291    #         the optional argument is a regular expression that can
292    #         be used to prune the result set
293    def listStats(self, regex=None):
294        print '%-60s %-8s %-10s' % ('stat name', 'id', 'type')
295        print '-' * 80
296
297        rx = None
298        if regex != None:
299            rx = re.compile(regex)
300
301        stats = [ stat.name for stat in self.allStats ]
302        stats.sort(statcmp)
303        for stat in stats:
304            stat = self.allStatNames[stat]
305            if rx == None or rx.match(stat.name):
306                print '%-60s %-8s %-10s' % (stat.name, stat.stat, stat.type)
307
308    # Name: liststats
309    # Desc: Prints all statistics that appear in the database,
310    #         the optional argument is a regular expression that can
311    #         be used to prune the result set
312    def listFormulas(self, regex=None):
313        print '%-60s %s' % ('formula name', 'formula')
314        print '-' * 80
315
316        rx = None
317        if regex != None:
318            rx = re.compile(regex)
319
320        stats = [ stat.name for stat in self.allStats ]
321        stats.sort(statcmp)
322        for stat in stats:
323            stat = self.allStatNames[stat]
324            if stat.type == 'FORMULA' and (rx == None or rx.match(stat.name)):
325                print '%-60s %s' % (stat.name, self.allFormulas[stat.stat])
326
327    def getStat(self, stats):
328        if type(stats) is not list:
329            stats = [ stats ]
330
331        ret = []
332        for stat in stats:
333            if type(stat) is int:
334                ret.append(self.allStatIds[stat])
335
336            if type(stat) is str:
337                rx = re.compile(stat)
338                for stat in self.allStats:
339                    if rx.match(stat.name):
340                        ret.append(stat)
341        return ret
342
343    #########################################
344    # get the data
345    #
346    def query(self, op, stat, ticks, group=False):
347        sql = 'select '
348        sql += 'dt_stat as stat, '
349        sql += 'dt_run as run, '
350        sql += 'dt_x as x, '
351        sql += 'dt_y as y, '
352        if group:
353            sql += 'dt_tick as tick, '
354        sql += '%s(dt_data) as data ' % op
355        sql += 'from data '
356        sql += 'where '
357
358        if isinstance(stat, list):
359            val = ' or '.join([ 'dt_stat=%d' % s.stat for s in stat ])
360            sql += ' (%s)' % val
361        else:
362            sql += ' dt_stat=%d' % stat.stat
363
364        if self.runs != None and len(self.runs):
365            val = ' or '.join([ 'dt_run=%d' % r for r in self.runs ])
366            sql += ' and (%s)' % val
367
368        if ticks != None and len(ticks):
369            val = ' or '.join([ 'dt_tick=%d' % s for s in ticks ])
370            sql += ' and (%s)' % val
371
372        sql += ' group by dt_stat,dt_run,dt_x,dt_y'
373        if group:
374            sql += ',dt_tick'
375        return sql
376
377    # Name: sum
378    # Desc: given a run, a stat and an array of samples, total the samples
379    def sum(self, *args, **kwargs):
380        return self.query('sum', *args, **kwargs)
381
382    # Name: avg
383    # Desc: given a run, a stat and an array of samples, average the samples
384    def avg(self, stat, ticks):
385        return self.query('avg', *args, **kwargs)
386
387    # Name: stdev
388    # Desc: given a run, a stat and an array of samples, get the standard
389    #       deviation
390    def stdev(self, stat, ticks):
391        return self.query('stddev', *args, **kwargs)
392
393    def __setattr__(self, attr, value):
394        super(Database, self).__setattr__(attr, value)
395        if attr != 'method':
396            return
397
398        if value == 'sum':
399            self._method = self.sum
400        elif value == 'avg':
401            self._method = self.avg
402        elif value == 'stdev':
403            self._method = self.stdev
404        else:
405            raise AttributeError, "can only set get to: sum | avg | stdev"
406
407    def data(self, stat, ticks=None):
408        if ticks is None:
409            ticks = self.ticks
410        sql = self._method(self, stat, ticks)
411        self.query(sql)
412
413        runs = {}
414        xmax = 0
415        ymax = 0
416        for x in self.cursor.fetchall():
417            data = Data(x)
418            if not runs.has_key(data.run):
419                runs[data.run] = {}
420            if not runs[data.run].has_key(data.x):
421                runs[data.run][data.x] = {}
422
423            xmax = max(xmax, data.x)
424            ymax = max(ymax, data.y)
425            runs[data.run][data.x][data.y] = data.data
426
427        results = Result(xmax + 1, ymax + 1)
428        for run,data in runs.iteritems():
429            result = results[run]
430            for x,ydata in data.iteritems():
431                for y,data in ydata.iteritems():
432                    result[x][y] = data
433        return results
434
435    def __getitem__(self, key):
436        return self.stattop[key]
437