StateMachine.py revision 6793:bc8c8617c4f0
1# Copyright (c) 1999-2008 Mark D. Hill and David A. Wood
2# Copyright (c) 2009 The Hewlett-Packard Development Company
3# All rights reserved.
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met: redistributions of source code must retain the above copyright
8# notice, this list of conditions and the following disclaimer;
9# redistributions in binary form must reproduce the above copyright
10# notice, this list of conditions and the following disclaimer in the
11# documentation and/or other materials provided with the distribution;
12# neither the name of the copyright holders nor the names of its
13# contributors may be used to endorse or promote products derived from
14# this software without specific prior written permission.
15#
16# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
20# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
28from m5.util import code_formatter, orderdict
29
30from slicc.symbols.Symbol import Symbol
31from slicc.symbols.Var import Var
32import slicc.generate.html as html
33
34class StateMachine(Symbol):
35    def __init__(self, symtab, ident, location, pairs, config_parameters):
36        super(StateMachine, self).__init__(symtab, ident, location, pairs)
37        self.table = None
38        self.config_parameters = config_parameters
39        for param in config_parameters:
40            var = Var(symtab, param.name, location, param.type_ast.type,
41                      "m_%s" % param.name, {}, self)
42            self.symtab.registerSym(param.name, var)
43
44        self.states = orderdict()
45        self.events = orderdict()
46        self.actions = orderdict()
47        self.transitions = []
48        self.in_ports = []
49        self.functions = []
50        self.objects = []
51
52        self.message_buffer_names = []
53
54    def __repr__(self):
55        return "[StateMachine: %s]" % self.ident
56
57    def addState(self, state):
58        assert self.table is None
59        self.states[state.ident] = state
60
61    def addEvent(self, event):
62        assert self.table is None
63        self.events[event.ident] = event
64
65    def addAction(self, action):
66        assert self.table is None
67
68        # Check for duplicate action
69        for other in self.actions.itervalues():
70            if action.ident == other.ident:
71                action.warning("Duplicate action definition: %s" % action.ident)
72                action.error("Duplicate action definition: %s" % action.ident)
73            if action.short == other.short:
74                other.warning("Duplicate action shorthand: %s" % other.ident)
75                other.warning("    shorthand = %s" % other.short)
76                action.warning("Duplicate action shorthand: %s" % action.ident)
77                action.error("    shorthand = %s" % action.short)
78
79        self.actions[action.ident] = action
80
81    def addTransition(self, trans):
82        assert self.table is None
83        self.transitions.append(trans)
84
85    def addInPort(self, var):
86        self.in_ports.append(var)
87
88    def addFunc(self, func):
89        # register func in the symbol table
90        self.symtab.registerSym(str(func), func)
91        self.functions.append(func)
92
93    def addObject(self, obj):
94        self.objects.append(obj)
95
96    # Needs to be called before accessing the table
97    def buildTable(self):
98        assert self.table is None
99
100        table = {}
101
102        for trans in self.transitions:
103            # Track which actions we touch so we know if we use them
104            # all -- really this should be done for all symbols as
105            # part of the symbol table, then only trigger it for
106            # Actions, States, Events, etc.
107
108            for action in trans.actions:
109                action.used = True
110
111            index = (trans.state, trans.event)
112            if index in table:
113                table[index].warning("Duplicate transition: %s" % table[index])
114                trans.error("Duplicate transition: %s" % trans)
115            table[index] = trans
116
117        # Look at all actions to make sure we used them all
118        for action in self.actions.itervalues():
119            if not action.used:
120                error_msg = "Unused action: %s" % action.ident
121                if "desc" in action:
122                    error_msg += ", "  + action.desc
123                action.warning(error_msg)
124        self.table = table
125
126    def writeCodeFiles(self, path):
127        self.printControllerHH(path)
128        self.printControllerCC(path)
129        self.printCSwitch(path)
130        self.printCWakeup(path)
131        self.printProfilerCC(path)
132        self.printProfilerHH(path)
133
134        for func in self.functions:
135            func.writeCodeFiles(path)
136
137    def printControllerHH(self, path):
138        '''Output the method declarations for the class declaration'''
139        code = code_formatter()
140        ident = self.ident
141        c_ident = "%s_Controller" % self.ident
142
143        self.message_buffer_names = []
144
145        code('''
146/** \\file $ident.hh
147 *
148 * Auto generated C++ code started by $__file__:$__line__
149 * Created by slicc definition of Module "${{self.short}}"
150 */
151
152#ifndef ${ident}_CONTROLLER_H
153#define ${ident}_CONTROLLER_H
154
155#include "mem/ruby/common/Global.hh"
156#include "mem/ruby/common/Consumer.hh"
157#include "mem/ruby/slicc_interface/AbstractController.hh"
158#include "mem/protocol/TransitionResult.hh"
159#include "mem/protocol/Types.hh"
160#include "mem/protocol/${ident}_Profiler.hh"
161''')
162
163        seen_types = set()
164        for var in self.objects:
165            if var.type.ident not in seen_types and not var.type.isPrimitive:
166                code('#include "mem/protocol/${{var.type.c_ident}}.hh"')
167            seen_types.add(var.type.ident)
168
169        # for adding information to the protocol debug trace
170        code('''
171extern stringstream ${ident}_transitionComment;
172
173class $c_ident : public AbstractController {
174#ifdef CHECK_COHERENCE
175#endif /* CHECK_COHERENCE */
176public:
177    $c_ident(const string & name);
178    static int getNumControllers();
179    void init(Network* net_ptr, const vector<string> & argv);
180    MessageBuffer* getMandatoryQueue() const;
181    const int & getVersion() const;
182    const string toString() const;
183    const string getName() const;
184    const MachineType getMachineType() const;
185    void print(ostream& out) const;
186    void printConfig(ostream& out) const;
187    void wakeup();
188    void set_atomic(Address addr);
189    void started_writes();
190    void clear_atomic();
191    void printStats(ostream& out) const { s_profiler.dumpStats(out); }
192    void clearStats() { s_profiler.clearStats(); }
193private:
194''')
195
196        code.indent()
197        # added by SS
198        for param in self.config_parameters:
199            code('int m_${{param.ident}};')
200
201        if self.ident == "L1Cache":
202            code('''
203int servicing_atomic;
204bool started_receiving_writes;
205Address locked_read_request1;
206Address locked_read_request2;
207Address locked_read_request3;
208Address locked_read_request4;
209int read_counter;
210''')
211
212        code('''
213int m_number_of_TBEs;
214
215TransitionResult doTransition(${ident}_Event event, ${ident}_State state, const Address& addr); // in ${ident}_Transitions.cc
216TransitionResult doTransitionWorker(${ident}_Event event, ${ident}_State state, ${ident}_State& next_state, const Address& addr); // in ${ident}_Transitions.cc
217string m_name;
218int m_transitions_per_cycle;
219int m_buffer_size;
220int m_recycle_latency;
221map< string, string > m_cfg;
222NodeID m_version;
223Network* m_net_ptr;
224MachineID m_machineID;
225${ident}_Profiler s_profiler;
226static int m_num_controllers;
227// Internal functions
228''')
229
230        for func in self.functions:
231            proto = func.prototype
232            if proto:
233                code('$proto')
234
235        code('''
236
237// Actions
238''')
239        for action in self.actions.itervalues():
240            code('/** \\brief ${{action.desc}} */')
241            code('void ${{action.ident}}(const Address& addr);')
242
243        # the controller internal variables
244        code('''
245
246// Object
247''')
248        for var in self.objects:
249            th = var.get("template_hack", "")
250            code('${{var.type.c_ident}}$th* m_${{var.c_ident}}_ptr;')
251
252            if var.type.ident == "MessageBuffer":
253                self.message_buffer_names.append("m_%s_ptr" % var.c_ident)
254
255        code.dedent()
256        code('};')
257        code('#endif // ${ident}_CONTROLLER_H')
258        code.write(path, '%s.hh' % c_ident)
259
260    def printControllerCC(self, path):
261        '''Output the actions for performing the actions'''
262
263        code = code_formatter()
264        ident = self.ident
265        c_ident = "%s_Controller" % self.ident
266
267        code('''
268/** \\file $ident.cc
269 *
270 * Auto generated C++ code started by $__file__:$__line__
271 * Created by slicc definition of Module "${{self.short}}"
272 */
273
274#include "mem/ruby/common/Global.hh"
275#include "mem/ruby/slicc_interface/RubySlicc_includes.hh"
276#include "mem/protocol/${ident}_Controller.hh"
277#include "mem/protocol/${ident}_State.hh"
278#include "mem/protocol/${ident}_Event.hh"
279#include "mem/protocol/Types.hh"
280#include "mem/ruby/system/System.hh"
281''')
282
283        # include object classes
284        seen_types = set()
285        for var in self.objects:
286            if var.type.ident not in seen_types and not var.type.isPrimitive:
287                code('#include "mem/protocol/${{var.type.c_ident}}.hh"')
288            seen_types.add(var.type.ident)
289
290        code('''
291int $c_ident::m_num_controllers = 0;
292
293stringstream ${ident}_transitionComment;
294#define APPEND_TRANSITION_COMMENT(str) (${ident}_transitionComment << str)
295/** \\brief constructor */
296$c_ident::$c_ident(const string &name)
297    : m_name(name)
298{
299''')
300        code.indent()
301        if self.ident == "L1Cache":
302            code('''
303servicing_atomic = 0;
304started_receiving_writes = false;
305locked_read_request1 = Address(-1);
306locked_read_request2 = Address(-1);
307locked_read_request3 = Address(-1);
308locked_read_request4 = Address(-1);
309read_counter = 0;
310''')
311
312        code('m_num_controllers++;')
313        for var in self.objects:
314            if var.ident.find("mandatoryQueue") >= 0:
315                code('m_${{var.c_ident}}_ptr = new ${{var.type.c_ident}}();')
316
317        code.dedent()
318        code('''
319}
320
321void $c_ident::init(Network *net_ptr, const vector<string> &argv)
322{
323    for (size_t i = 0; i < argv.size(); i += 2) {
324        if (argv[i] == "version")
325            m_version = atoi(argv[i+1].c_str());
326        else if (argv[i] == "transitions_per_cycle")
327            m_transitions_per_cycle = atoi(argv[i+1].c_str());
328        else if (argv[i] == "buffer_size")
329            m_buffer_size = atoi(argv[i+1].c_str());
330        else if (argv[i] == "recycle_latency")
331            m_recycle_latency = atoi(argv[i+1].c_str());
332        else if (argv[i] == "number_of_TBEs")
333            m_number_of_TBEs = atoi(argv[i+1].c_str());
334''')
335
336        code.indent()
337        code.indent()
338        for param in self.config_parameters:
339            code('else if (argv[i] == "${{param.name}}")')
340            if param.type_ast.type.ident == "int":
341                code('    m_${{param.name}} = atoi(argv[i+1].c_str());')
342            elif param.type_ast.type.ident == "bool":
343                code('    m_${{param.name}} = string_to_bool(argv[i+1]);')
344            else:
345                self.error("only int and bool parameters are "\
346                           "currently supported")
347        code.dedent()
348        code.dedent()
349        code('''
350    }
351
352    m_net_ptr = net_ptr;
353    m_machineID.type = MachineType_${ident};
354    m_machineID.num = m_version;
355    for (size_t i = 0; i < argv.size(); i += 2) {
356        if (argv[i] != "version")
357            m_cfg[argv[i]] = argv[i+1];
358    }
359
360    // Objects
361    s_profiler.setVersion(m_version);
362''')
363
364        code.indent()
365        for var in self.objects:
366            vtype = var.type
367            vid = "m_%s_ptr" % var.c_ident
368            if "network" not in var:
369                # Not a network port object
370                if "primitive" in vtype:
371                    code('$vid = new ${{vtype.c_ident}};')
372                    if "default" in var:
373                        code('(*$vid) = ${{var["default"]}};')
374                else:
375                    # Normal Object
376                    # added by SS
377                    if "factory" in var:
378                        code('$vid = ${{var["factory"]}};')
379                    elif var.ident.find("mandatoryQueue") < 0:
380                        th = var.get("template_hack", "")
381                        expr = "%s  = new %s%s" % (vid, vtype.c_ident, th)
382
383                        args = ""
384                        if "non_obj" not in vtype and not vtype.isEnumeration:
385                            if expr.find("TBETable") >= 0:
386                                args = "m_number_of_TBEs"
387                            else:
388                                args = var.get("constructor_hack", "")
389                            args = "(%s)" % args
390
391                        code('$expr$args;')
392                    else:
393                        code(';')
394
395                    code('assert($vid != NULL);')
396
397                    if "default" in var:
398                        code('(*$vid) = ${{var["default"]}}; // Object default')
399                    elif "default" in vtype:
400                        code('(*$vid) = ${{vtype["default"]}}; // Type ${{vtype.ident}} default')
401
402                    # Set ordering
403                    if "ordered" in var and "trigger_queue" not in var:
404                        # A buffer
405                        code('$vid->setOrdering(${{var["ordered"]}});')
406
407                    # Set randomization
408                    if "random" in var:
409                        # A buffer
410                        code('$vid->setRandomization(${{var["random"]}});')
411
412                    # Set Priority
413                    if vtype.isBuffer and \
414                           "rank" in var and "trigger_queue" not in var:
415                        code('$vid->setPriority(${{var["rank"]}});')
416            else:
417                # Network port object
418                network = var["network"]
419                ordered =  var["ordered"]
420                vnet = var["virtual_network"]
421
422                assert var.machine is not None
423                code('''
424$vid = m_net_ptr->get${network}NetQueue(m_version+MachineType_base_number(string_to_MachineType("${{var.machine.ident}}")), $ordered, $vnet);
425''')
426
427                code('assert($vid != NULL);')
428
429                # Set ordering
430                if "ordered" in var:
431                    # A buffer
432                    code('$vid->setOrdering(${{var["ordered"]}});')
433
434                # Set randomization
435                if "random" in var:
436                    # A buffer
437                    code('$vid->setRandomization(${{var["random"]}})')
438
439                # Set Priority
440                if "rank" in var:
441                    code('$vid->setPriority(${{var["rank"]}})')
442
443                # Set buffer size
444                if vtype.isBuffer:
445                    code('''
446if (m_buffer_size > 0) {
447    $vid->setSize(m_buffer_size);
448}
449''')
450
451                # set description (may be overriden later by port def)
452                code('$vid->setDescription("[Version " + int_to_string(m_version) + ", ${ident}, name=${{var.c_ident}}]");')
453
454        # Set the queue consumers
455        code.insert_newline()
456        for port in self.in_ports:
457            code('${{port.code}}.setConsumer(this);')
458
459        # Set the queue descriptions
460        code.insert_newline()
461        for port in self.in_ports:
462            code('${{port.code}}.setDescription("[Version " + int_to_string(m_version) + ", $ident, $port]");')
463
464        # Initialize the transition profiling
465        code.insert_newline()
466        for trans in self.transitions:
467            # Figure out if we stall
468            stall = False
469            for action in trans.actions:
470                if action.ident == "z_stall":
471                    stall = True
472
473            # Only possible if it is not a 'z' case
474            if not stall:
475                state = "%s_State_%s" % (self.ident, trans.state.ident)
476                event = "%s_Event_%s" % (self.ident, trans.event.ident)
477                code('s_profiler.possibleTransition($state, $event);')
478
479        # added by SS to initialize recycle_latency of message buffers
480        for buf in self.message_buffer_names:
481            code("$buf->setRecycleLatency(m_recycle_latency);")
482
483        code.dedent()
484        code('}')
485
486        has_mandatory_q = False
487        for port in self.in_ports:
488            if port.code.find("mandatoryQueue_ptr") >= 0:
489                has_mandatory_q = True
490
491        if has_mandatory_q:
492            mq_ident = "m_%s_mandatoryQueue_ptr" % self.ident
493        else:
494            mq_ident = "NULL"
495
496        code('''
497int $c_ident::getNumControllers() {
498    return m_num_controllers;
499}
500
501MessageBuffer* $c_ident::getMandatoryQueue() const {
502    return $mq_ident;
503}
504
505const int & $c_ident::getVersion() const{
506    return m_version;
507}
508
509const string $c_ident::toString() const{
510    return "$c_ident";
511}
512
513const string $c_ident::getName() const{
514    return m_name;
515}
516const MachineType $c_ident::getMachineType() const{
517    return MachineType_${ident};
518}
519
520void $c_ident::print(ostream& out) const { out << "[$c_ident " << m_version << "]"; }
521
522void $c_ident::printConfig(ostream& out) const {
523    out << "$c_ident config: " << m_name << endl;
524    out << "  version: " << m_version << endl;
525    for (map<string, string>::const_iterator it = m_cfg.begin(); it != m_cfg.end(); it++) {
526        out << "  " << (*it).first << ": " << (*it).second << endl;
527    }
528}
529
530// Actions
531''')
532
533        for action in self.actions.itervalues():
534            if "c_code" not in action:
535                continue
536
537            code('''
538/** \\brief ${{action.desc}} */
539void $c_ident::${{action.ident}}(const Address& addr)
540{
541    DEBUG_MSG(GENERATED_COMP, HighPrio, "executing");
542    ${{action["c_code"]}}
543}
544
545''')
546        code.write(path, "%s.cc" % c_ident)
547
548    def printCWakeup(self, path):
549        '''Output the wakeup loop for the events'''
550
551        code = code_formatter()
552        ident = self.ident
553
554        code('''
555// Auto generated C++ code started by $__file__:$__line__
556// ${ident}: ${{self.short}}
557
558#include "mem/ruby/common/Global.hh"
559#include "mem/ruby/slicc_interface/RubySlicc_includes.hh"
560#include "mem/protocol/${ident}_Controller.hh"
561#include "mem/protocol/${ident}_State.hh"
562#include "mem/protocol/${ident}_Event.hh"
563#include "mem/protocol/Types.hh"
564#include "mem/ruby/system/System.hh"
565
566void ${ident}_Controller::wakeup()
567{
568
569    int counter = 0;
570    while (true) {
571        // Some cases will put us into an infinite loop without this limit
572        assert(counter <= m_transitions_per_cycle);
573        if (counter == m_transitions_per_cycle) {
574            g_system_ptr->getProfiler()->controllerBusy(m_machineID); // Count how often we\'re fully utilized
575            g_eventQueue_ptr->scheduleEvent(this, 1); // Wakeup in another cycle and try again
576            break;
577        }
578''')
579
580        code.indent()
581        code.indent()
582
583        # InPorts
584        #
585        # Find the position of the mandatory queue in the vector so
586        # that we can print it out first
587
588        mandatory_q = None
589        if self.ident == "L1Cache":
590            for i,port in enumerate(self.in_ports):
591                assert "c_code_in_port" in port
592                if str(port).find("mandatoryQueue_in") >= 0:
593                    assert mandatory_q is None
594                    mandatory_q = port
595
596            assert mandatory_q is not None
597
598            # print out the mandatory queue here
599            port = mandatory_q
600            code('// ${ident}InPort $port')
601            output = port["c_code_in_port"]
602
603            pos = output.find("TransitionResult result = doTransition((L1Cache_mandatory_request_type_to_event(((*in_msg_ptr)).m_Type)), L1Cache_getState(addr), addr);")
604            assert pos >= 0
605            atomics_string = '''
606if ((((*in_msg_ptr)).m_Type) == CacheRequestType_ATOMIC) {
607    if (servicing_atomic == 0) {
608        if (locked_read_request1 == Address(-1)) {
609            assert(read_counter == 0);
610            locked_read_request1 = addr;
611            assert(read_counter == 0);
612            read_counter++;
613        }
614        else if (addr == locked_read_request1) {
615            ; // do nothing
616        }
617        else {
618            assert(0); // should never be here if servicing one request at a time
619        }
620    }
621    else if (!started_receiving_writes) {
622        if (servicing_atomic == 1) {
623            if (locked_read_request2 == Address(-1)) {
624                assert(locked_read_request1 != Address(-1));
625                assert(read_counter == 1);
626                locked_read_request2 = addr;
627                assert(read_counter == 1);
628                read_counter++;
629            }
630            else if (addr == locked_read_request2) {
631                ; // do nothing
632            }
633            else {
634                assert(0); // should never be here if servicing one request at a time
635            }
636        }
637        else if (servicing_atomic == 2) {
638            if (locked_read_request3 == Address(-1)) {
639                assert(locked_read_request1 != Address(-1));
640                assert(locked_read_request2 != Address(-1));
641                assert(read_counter == 1);
642                locked_read_request3 = addr;
643                assert(read_counter == 2);
644                read_counter++;
645            }
646            else if (addr == locked_read_request3) {
647                ; // do nothing
648            }
649            else {
650                assert(0); // should never be here if servicing one request at a time
651            }
652        }
653        else if (servicing_atomic == 3) {
654            if (locked_read_request4 == Address(-1)) {
655                assert(locked_read_request1 != Address(-1));
656                assert(locked_read_request2 != Address(-1));
657                assert(locked_read_request3 != Address(-1));
658                assert(read_counter == 1);
659                locked_read_request4 = addr;
660                assert(read_counter == 3);
661                read_counter++;
662            }
663            else if (addr == locked_read_request4) {
664                ; // do nothing
665            }
666            else {
667                assert(0); // should never be here if servicing one request at a time
668            }
669        }
670        else {
671            assert(0);
672        }
673    }
674}
675else {
676    if (servicing_atomic > 0) {
677        // reset
678        servicing_atomic = 0;
679        read_counter = 0;
680        started_receiving_writes = false;
681        locked_read_request1 = Address(-1);
682        locked_read_request2 = Address(-1);
683        locked_read_request3 = Address(-1);
684        locked_read_request4 = Address(-1);
685    }
686}
687'''
688
689            output = output[:pos] + atomics_string + output[pos:]
690            code('$output')
691
692        for port in self.in_ports:
693            # don't print out mandatory queue twice
694            if port == mandatory_q:
695                continue
696
697            if ident == "L1Cache":
698                if str(port).find("forwardRequestNetwork_in") >= 0:
699                    code('''
700bool postpone = false;
701if ((((*m_L1Cache_forwardToCache_ptr)).isReady())) {
702    const RequestMsg* in_msg_ptr;
703    in_msg_ptr = dynamic_cast<const RequestMsg*>(((*m_L1Cache_forwardToCache_ptr)).peek());
704    if ((((servicing_atomic == 1)  && (locked_read_request1 == ((*in_msg_ptr)).m_Address)) ||
705         ((servicing_atomic == 2)  && (locked_read_request1 == ((*in_msg_ptr)).m_Address || locked_read_request2 == ((*in_msg_ptr)).m_Address)) ||
706         ((servicing_atomic == 3)  && (locked_read_request1 == ((*in_msg_ptr)).m_Address || locked_read_request2 == ((*in_msg_ptr)).m_Address || locked_read_request3 == ((*in_msg_ptr)).m_Address)) ||
707         ((servicing_atomic == 4)  && (locked_read_request1 == ((*in_msg_ptr)).m_Address || locked_read_request2 == ((*in_msg_ptr)).m_Address || locked_read_request3 == ((*in_msg_ptr)).m_Address || locked_read_request1 == ((*in_msg_ptr)).m_Address)))) {
708    postpone = true;
709    }
710}
711if (!postpone) {
712''')
713            code.indent()
714            code('// ${ident}InPort $port')
715            code('${{port["c_code_in_port"]}}')
716            code.dedent()
717
718            if ident == "L1Cache":
719                if str(port).find("forwardRequestNetwork_in") >= 0:
720                    code.dedent()
721                    code('}')
722                    code.indent()
723            code('')
724
725        code.dedent()
726        code.dedent()
727        code('''
728        break;  // If we got this far, we have nothing left todo
729    }
730}
731''')
732
733        if self.ident == "L1Cache":
734            code('''
735void ${ident}_Controller::set_atomic(Address addr)
736{
737    servicing_atomic++;
738}
739
740void ${ident}_Controller::started_writes()
741{
742    started_receiving_writes = true;
743}
744
745void ${ident}_Controller::clear_atomic()
746{
747    assert(servicing_atomic > 0);
748    read_counter--;
749    servicing_atomic--;
750    if (read_counter == 0) {
751        servicing_atomic = 0;
752        started_receiving_writes = false;
753        locked_read_request1 = Address(-1);
754        locked_read_request2 = Address(-1);
755        locked_read_request3 = Address(-1);
756        locked_read_request4 = Address(-1);
757    }
758}
759''')
760        else:
761            code('''
762void ${ident}_Controller::started_writes()
763{
764    assert(0);
765}
766
767void ${ident}_Controller::set_atomic(Address addr)
768{
769    assert(0);
770}
771
772void ${ident}_Controller::clear_atomic()
773{
774    assert(0);
775}
776''')
777
778
779        code.write(path, "%s_Wakeup.cc" % self.ident)
780
781    def printCSwitch(self, path):
782        '''Output switch statement for transition table'''
783
784        code = code_formatter()
785        ident = self.ident
786
787        code('''
788// Auto generated C++ code started by $__file__:$__line__
789// ${ident}: ${{self.short}}
790
791#include "mem/ruby/common/Global.hh"
792#include "mem/protocol/${ident}_Controller.hh"
793#include "mem/protocol/${ident}_State.hh"
794#include "mem/protocol/${ident}_Event.hh"
795#include "mem/protocol/Types.hh"
796#include "mem/ruby/system/System.hh"
797
798#define HASH_FUN(state, event)  ((int(state)*${ident}_Event_NUM)+int(event))
799
800#define GET_TRANSITION_COMMENT() (${ident}_transitionComment.str())
801#define CLEAR_TRANSITION_COMMENT() (${ident}_transitionComment.str(""))
802
803TransitionResult ${ident}_Controller::doTransition(${ident}_Event event, ${ident}_State state, const Address& addr
804)
805{
806    ${ident}_State next_state = state;
807
808    DEBUG_NEWLINE(GENERATED_COMP, MedPrio);
809    DEBUG_MSG(GENERATED_COMP, MedPrio, *this);
810    DEBUG_EXPR(GENERATED_COMP, MedPrio, g_eventQueue_ptr->getTime());
811    DEBUG_EXPR(GENERATED_COMP, MedPrio,state);
812    DEBUG_EXPR(GENERATED_COMP, MedPrio,event);
813    DEBUG_EXPR(GENERATED_COMP, MedPrio,addr);
814
815    TransitionResult result = doTransitionWorker(event, state, next_state, addr);
816
817    if (result == TransitionResult_Valid) {
818        DEBUG_EXPR(GENERATED_COMP, MedPrio, next_state);
819        DEBUG_NEWLINE(GENERATED_COMP, MedPrio);
820        s_profiler.countTransition(state, event);
821        if (Debug::getProtocolTrace()) {
822            g_system_ptr->getProfiler()->profileTransition("${ident}", m_version, addr,
823                    ${ident}_State_to_string(state),
824                    ${ident}_Event_to_string(event),
825                    ${ident}_State_to_string(next_state), GET_TRANSITION_COMMENT());
826        }
827    CLEAR_TRANSITION_COMMENT();
828    ${ident}_setState(addr, next_state);
829
830    } else if (result == TransitionResult_ResourceStall) {
831        if (Debug::getProtocolTrace()) {
832            g_system_ptr->getProfiler()->profileTransition("${ident}", m_version, addr,
833                   ${ident}_State_to_string(state),
834                   ${ident}_Event_to_string(event),
835                   ${ident}_State_to_string(next_state),
836                   "Resource Stall");
837        }
838    } else if (result == TransitionResult_ProtocolStall) {
839        DEBUG_MSG(GENERATED_COMP, HighPrio, "stalling");
840        DEBUG_NEWLINE(GENERATED_COMP, MedPrio);
841        if (Debug::getProtocolTrace()) {
842            g_system_ptr->getProfiler()->profileTransition("${ident}", m_version, addr,
843                   ${ident}_State_to_string(state),
844                   ${ident}_Event_to_string(event),
845                   ${ident}_State_to_string(next_state),
846                   "Protocol Stall");
847        }
848    }
849
850    return result;
851}
852
853TransitionResult ${ident}_Controller::doTransitionWorker(${ident}_Event event, ${ident}_State state, ${ident}_State& next_state, const Address& addr
854)
855{
856    switch(HASH_FUN(state, event)) {
857''')
858
859        # This map will allow suppress generating duplicate code
860        cases = orderdict()
861
862        for trans in self.transitions:
863            case_string = "%s_State_%s, %s_Event_%s" % \
864                (self.ident, trans.state.ident, self.ident, trans.event.ident)
865
866            case = code_formatter()
867            # Only set next_state if it changes
868            if trans.state != trans.nextState:
869                ns_ident = trans.nextState.ident
870                case('next_state = ${ident}_State_${ns_ident};')
871
872            actions = trans.actions
873
874            # Check for resources
875            case_sorter = []
876            res = trans.resources
877            for key,val in res.iteritems():
878                if key.type.ident != "DNUCAStopTable":
879                    val = '''
880if (!%s.areNSlotsAvailable(%s)) {
881    return TransitionResult_ResourceStall;
882}
883''' % (key.code, val)
884                case_sorter.append(val)
885
886
887            # Emit the code sequences in a sorted order.  This makes the
888            # output deterministic (without this the output order can vary
889            # since Map's keys() on a vector of pointers is not deterministic
890            for c in sorted(case_sorter):
891                case("$c")
892
893            # Figure out if we stall
894            stall = False
895            for action in actions:
896                if action.ident == "z_stall":
897                    stall = True
898                    break
899
900            if stall:
901                case('return TransitionResult_ProtocolStall;')
902            else:
903                for action in actions:
904                    case('${{action.ident}}(addr);')
905                case('return TransitionResult_Valid;')
906
907            case = str(case)
908
909            # Look to see if this transition code is unique.
910            if case not in cases:
911                cases[case] = []
912
913            cases[case].append(case_string)
914
915        # Walk through all of the unique code blocks and spit out the
916        # corresponding case statement elements
917        for case,transitions in cases.iteritems():
918            # Iterative over all the multiple transitions that share
919            # the same code
920            for trans in transitions:
921                code('  case HASH_FUN($trans):')
922            code('  {')
923            code('    $case')
924            code('  }')
925
926        code('''
927      default:
928        WARN_EXPR(m_version);
929        WARN_EXPR(g_eventQueue_ptr->getTime());
930        WARN_EXPR(addr);
931        WARN_EXPR(event);
932        WARN_EXPR(state);
933        ERROR_MSG(\"Invalid transition\");
934    }
935    return TransitionResult_Valid;
936}
937''')
938        code.write(path, "%s_Transitions.cc" % self.ident)
939
940    def printProfilerHH(self, path):
941        code = code_formatter()
942        ident = self.ident
943
944        code('''
945// Auto generated C++ code started by $__file__:$__line__
946// ${ident}: ${{self.short}}
947
948#ifndef ${ident}_PROFILER_H
949#define ${ident}_PROFILER_H
950
951#include "mem/ruby/common/Global.hh"
952#include "mem/protocol/${ident}_State.hh"
953#include "mem/protocol/${ident}_Event.hh"
954
955class ${ident}_Profiler {
956  public:
957    ${ident}_Profiler();
958    void setVersion(int version);
959    void countTransition(${ident}_State state, ${ident}_Event event);
960    void possibleTransition(${ident}_State state, ${ident}_Event event);
961    void dumpStats(ostream& out) const;
962    void clearStats();
963
964  private:
965    int m_counters[${ident}_State_NUM][${ident}_Event_NUM];
966    int m_event_counters[${ident}_Event_NUM];
967    bool m_possible[${ident}_State_NUM][${ident}_Event_NUM];
968    int m_version;
969};
970
971#endif // ${ident}_PROFILER_H
972''')
973        code.write(path, "%s_Profiler.hh" % self.ident)
974
975    def printProfilerCC(self, path):
976        code = code_formatter()
977        ident = self.ident
978
979        code('''
980// Auto generated C++ code started by $__file__:$__line__
981// ${ident}: ${{self.short}}
982
983#include "mem/protocol/${ident}_Profiler.hh"
984
985${ident}_Profiler::${ident}_Profiler()
986{
987    for (int state = 0; state < ${ident}_State_NUM; state++) {
988        for (int event = 0; event < ${ident}_Event_NUM; event++) {
989            m_possible[state][event] = false;
990            m_counters[state][event] = 0;
991        }
992    }
993    for (int event = 0; event < ${ident}_Event_NUM; event++) {
994        m_event_counters[event] = 0;
995    }
996}
997void ${ident}_Profiler::setVersion(int version)
998{
999    m_version = version;
1000}
1001void ${ident}_Profiler::clearStats()
1002{
1003    for (int state = 0; state < ${ident}_State_NUM; state++) {
1004        for (int event = 0; event < ${ident}_Event_NUM; event++) {
1005            m_counters[state][event] = 0;
1006        }
1007    }
1008
1009    for (int event = 0; event < ${ident}_Event_NUM; event++) {
1010        m_event_counters[event] = 0;
1011    }
1012}
1013void ${ident}_Profiler::countTransition(${ident}_State state, ${ident}_Event event)
1014{
1015    assert(m_possible[state][event]);
1016    m_counters[state][event]++;
1017    m_event_counters[event]++;
1018}
1019void ${ident}_Profiler::possibleTransition(${ident}_State state, ${ident}_Event event)
1020{
1021    m_possible[state][event] = true;
1022}
1023void ${ident}_Profiler::dumpStats(ostream& out) const
1024{
1025    out << " --- ${ident} " << m_version << " ---" << endl;
1026    out << " - Event Counts -" << endl;
1027    for (int event = 0; event < ${ident}_Event_NUM; event++) {
1028        int count = m_event_counters[event];
1029        out << (${ident}_Event) event << "  " << count << endl;
1030    }
1031    out << endl;
1032    out << " - Transitions -" << endl;
1033    for (int state = 0; state < ${ident}_State_NUM; state++) {
1034        for (int event = 0; event < ${ident}_Event_NUM; event++) {
1035            if (m_possible[state][event]) {
1036                int count = m_counters[state][event];
1037                out << (${ident}_State) state << "  " << (${ident}_Event) event << "  " << count;
1038                if (count == 0) {
1039                    out << " <-- ";
1040                }
1041                out << endl;
1042            }
1043        }
1044        out << endl;
1045    }
1046}
1047''')
1048        code.write(path, "%s_Profiler.cc" % self.ident)
1049
1050    # **************************
1051    # ******* HTML Files *******
1052    # **************************
1053    def frameRef(self, click_href, click_target, over_href, over_target_num,
1054                 text):
1055        code = code_formatter(fix_newlines=False)
1056        code("""<A href=\"$click_href\" target=\"$click_target\" onMouseOver=\"if (parent.frames[$over_target_num].location != parent.location + '$over_href') { parent.frames[$over_target_num].location='$over_href' }\" >${{html.formatShorthand(text)}}</A>""")
1057        return str(code)
1058
1059    def writeHTMLFiles(self, path):
1060        # Create table with no row hilighted
1061        self.printHTMLTransitions(path, None)
1062
1063        # Generate transition tables
1064        for state in self.states.itervalues():
1065            self.printHTMLTransitions(path, state)
1066
1067        # Generate action descriptions
1068        for action in self.actions.itervalues():
1069            name = "%s_action_%s.html" % (self.ident, action.ident)
1070            code = html.createSymbol(action, "Action")
1071            code.write(path, name)
1072
1073        # Generate state descriptions
1074        for state in self.states.itervalues():
1075            name = "%s_State_%s.html" % (self.ident, state.ident)
1076            code = html.createSymbol(state, "State")
1077            code.write(path, name)
1078
1079        # Generate event descriptions
1080        for event in self.events.itervalues():
1081            name = "%s_Event_%s.html" % (self.ident, event.ident)
1082            code = html.createSymbol(event, "Event")
1083            code.write(path, name)
1084
1085    def printHTMLTransitions(self, path, active_state):
1086        code = code_formatter()
1087
1088        code('''
1089<HTML><BODY link="blue" vlink="blue">
1090
1091<H1 align="center">${{html.formatShorthand(self.short)}}:
1092''')
1093        code.indent()
1094        for i,machine in enumerate(self.symtab.getAllType(StateMachine)):
1095            mid = machine.ident
1096            if i != 0:
1097                extra = " - "
1098            else:
1099                extra = ""
1100            if machine == self:
1101                code('$extra$mid')
1102            else:
1103                code('$extra<A target="Table" href="${mid}_table.html">$mid</A>')
1104        code.dedent()
1105
1106        code("""
1107</H1>
1108
1109<TABLE border=1>
1110<TR>
1111  <TH> </TH>
1112""")
1113
1114        for event in self.events.itervalues():
1115            href = "%s_Event_%s.html" % (self.ident, event.ident)
1116            ref = self.frameRef(href, "Status", href, "1", event.short)
1117            code('<TH bgcolor=white>$ref</TH>')
1118
1119        code('</TR>')
1120        # -- Body of table
1121        for state in self.states.itervalues():
1122            # -- Each row
1123            if state == active_state:
1124                color = "yellow"
1125            else:
1126                color = "white"
1127
1128            click = "%s_table_%s.html" % (self.ident, state.ident)
1129            over = "%s_State_%s.html" % (self.ident, state.ident)
1130            text = html.formatShorthand(state.short)
1131            ref = self.frameRef(click, "Table", over, "1", state.short)
1132            code('''
1133<TR>
1134  <TH bgcolor=$color>$ref</TH>
1135''')
1136
1137            # -- One column for each event
1138            for event in self.events.itervalues():
1139                trans = self.table.get((state,event), None)
1140                if trans is None:
1141                    # This is the no transition case
1142                    if state == active_state:
1143                        color = "#C0C000"
1144                    else:
1145                        color = "lightgrey"
1146
1147                    code('<TD bgcolor=$color>&nbsp;</TD>')
1148                    continue
1149
1150                next = trans.nextState
1151                stall_action = False
1152
1153                # -- Get the actions
1154                for action in trans.actions:
1155                    if action.ident == "z_stall" or \
1156                       action.ident == "zz_recycleMandatoryQueue":
1157                        stall_action = True
1158
1159                # -- Print out "actions/next-state"
1160                if stall_action:
1161                    if state == active_state:
1162                        color = "#C0C000"
1163                    else:
1164                        color = "lightgrey"
1165
1166                elif active_state and next.ident == active_state.ident:
1167                    color = "aqua"
1168                elif state == active_state:
1169                    color = "yellow"
1170                else:
1171                    color = "white"
1172
1173                fix = code.nofix()
1174                code('<TD bgcolor=$color>')
1175                for action in trans.actions:
1176                    href = "%s_action_%s.html" % (self.ident, action.ident)
1177                    ref = self.frameRef(href, "Status", href, "1",
1178                                        action.short)
1179                    code('  $ref\n')
1180                if next != state:
1181                    if trans.actions:
1182                        code('/')
1183                    click = "%s_table_%s.html" % (self.ident, next.ident)
1184                    over = "%s_State_%s.html" % (self.ident, next.ident)
1185                    ref = self.frameRef(click, "Table", over, "1", next.short)
1186                    code("$ref")
1187                code("</TD>\n")
1188                code.fix(fix)
1189
1190            # -- Each row
1191            if state == active_state:
1192                color = "yellow"
1193            else:
1194                color = "white"
1195
1196            click = "%s_table_%s.html" % (self.ident, state.ident)
1197            over = "%s_State_%s.html" % (self.ident, state.ident)
1198            ref = self.frameRef(click, "Table", over, "1", state.short)
1199            code('''
1200  <TH bgcolor=$color>$ref</TH>
1201</TR>
1202''')
1203        code('''
1204<TR>
1205  <TH> </TH>
1206''')
1207
1208        for event in self.events.itervalues():
1209            href = "%s_Event_%s.html" % (self.ident, event.ident)
1210            ref = self.frameRef(href, "Status", href, "1", event.short)
1211            code('<TH bgcolor=white>$ref</TH>')
1212        code('''
1213</TR>
1214</TABLE>
1215</BODY></HTML>
1216''')
1217
1218
1219        if active_state:
1220            name = "%s_table_%s.html" % (self.ident, active_state.ident)
1221        else:
1222            name = "%s_table.html" % self.ident
1223        code.write(path, name)
1224
1225__all__ = [ "StateMachine" ]
1226