1/*
2 * Copyright (c) 2014-2015 Advanced Micro Devices, Inc.
3 * All rights reserved.
4 *
5 * For use for simulation and test purposes only
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright notice,
11 * this list of conditions and the following disclaimer.
12 *
13 * 2. Redistributions in binary form must reproduce the above copyright notice,
14 * this list of conditions and the following disclaimer in the documentation
15 * and/or other materials provided with the distribution.
16 *
17 * 3. Neither the name of the copyright holder nor the names of its
18 * contributors may be used to endorse or promote products derived from this
19 * software without specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31 * POSSIBILITY OF SUCH DAMAGE.
32 *
33 * Authors: John Kalamatianos,
34 *          Joe Gross
35 */
36
37#ifndef __LDS_STATE_HH__
38#define __LDS_STATE_HH__
39
40#include <array>
41#include <queue>
42#include <string>
43#include <unordered_map>
44#include <utility>
45#include <vector>
46
47#include "enums/MemType.hh"
48#include "gpu-compute/misc.hh"
49#include "mem/port.hh"
50#include "params/LdsState.hh"
51#include "sim/clocked_object.hh"
52
53class ComputeUnit;
54
55/**
56 * this represents a slice of the overall LDS, intended to be associated with an
57 * individual workgroup
58 */
59class LdsChunk
60{
61  public:
62    LdsChunk(const uint32_t x_size):
63        chunk(x_size)
64    {
65    }
66
67    LdsChunk() {}
68
69    /**
70     * a read operation
71     */
72    template<class T>
73    T
74    read(const uint32_t index)
75    {
76        fatal_if(!chunk.size(), "cannot read from an LDS chunk of size 0");
77        fatal_if(index >= chunk.size(), "out-of-bounds access to an LDS chunk");
78        T *p0 = (T *) (&(chunk.at(index)));
79        return *p0;
80    }
81
82    /**
83     * a write operation
84     */
85    template<class T>
86    void
87    write(const uint32_t index, const T value)
88    {
89        fatal_if(!chunk.size(), "cannot write to an LDS chunk of size 0");
90        fatal_if(index >= chunk.size(), "out-of-bounds access to an LDS chunk");
91        T *p0 = (T *) (&(chunk.at(index)));
92        *p0 = value;
93    }
94
95    /**
96     * get the size of this chunk
97     */
98    std::vector<uint8_t>::size_type
99    size() const
100    {
101        return chunk.size();
102    }
103
104  protected:
105    // the actual data store for this slice of the LDS
106    std::vector<uint8_t> chunk;
107};
108
109// Local Data Share (LDS) State per Wavefront (contents of the LDS region
110// allocated to the WorkGroup of this Wavefront)
111class LdsState: public ClockedObject
112{
113  protected:
114
115    /**
116     * an event to allow event-driven execution
117     */
118    class TickEvent: public Event
119    {
120      protected:
121
122        LdsState *ldsState = nullptr;
123
124        Tick nextTick = 0;
125
126      public:
127
128        TickEvent(LdsState *_ldsState) :
129            ldsState(_ldsState)
130        {
131        }
132
133        virtual void
134        process();
135
136        void
137        schedule(Tick when)
138        {
139            mainEventQueue[0]->schedule(this, when);
140        }
141
142        void
143        deschedule()
144        {
145            mainEventQueue[0]->deschedule(this);
146        }
147    };
148
149    /**
150     * CuSidePort is the LDS Port closer to the CU side
151     */
152    class CuSidePort: public SlavePort
153    {
154      public:
155        CuSidePort(const std::string &_name, LdsState *_ownerLds) :
156                SlavePort(_name, _ownerLds), ownerLds(_ownerLds)
157        {
158        }
159
160      protected:
161        LdsState *ownerLds;
162
163        virtual bool
164        recvTimingReq(PacketPtr pkt);
165
166        virtual Tick
167        recvAtomic(PacketPtr pkt)
168        {
169          return 0;
170        }
171
172        virtual void
173        recvFunctional(PacketPtr pkt);
174
175        virtual void
176        recvRangeChange()
177        {
178        }
179
180        virtual void
181        recvRetry();
182
183        virtual void
184        recvRespRetry();
185
186        virtual AddrRangeList
187        getAddrRanges() const
188        {
189          AddrRangeList ranges;
190          ranges.push_back(ownerLds->getAddrRange());
191          return ranges;
192        }
193
194        template<typename T>
195        void
196        loadData(PacketPtr packet);
197
198        template<typename T>
199        void
200        storeData(PacketPtr packet);
201
202        template<typename T>
203        void
204        atomicOperation(PacketPtr packet);
205    };
206
207  protected:
208
209    // the lds reference counter
210    // The key is the workgroup ID and dispatch ID
211    // The value is the number of wavefronts that reference this LDS, as
212    // wavefronts are launched, the counter goes up for that workgroup and when
213    // they return it decreases, once it reaches 0 then this chunk of the LDS is
214    // returned to the available pool. However,it is deallocated on the 1->0
215    // transition, not whenever the counter is 0 as it always starts with 0 when
216    // the workgroup asks for space
217    std::unordered_map<uint32_t,
218                       std::unordered_map<uint32_t, int32_t>> refCounter;
219
220    // the map that allows workgroups to access their own chunk of the LDS
221    std::unordered_map<uint32_t,
222                       std::unordered_map<uint32_t, LdsChunk>> chunkMap;
223
224    // an event to allow the LDS to wake up at a specified time
225    TickEvent tickEvent;
226
227    // the queue of packets that are going back to the CU after a
228    // read/write/atomic op
229    // TODO need to make this have a maximum size to create flow control
230    std::queue<std::pair<Tick, PacketPtr>> returnQueue;
231
232    // whether or not there are pending responses
233    bool retryResp = false;
234
235    bool
236    process();
237
238    GPUDynInstPtr
239    getDynInstr(PacketPtr packet);
240
241    bool
242    processPacket(PacketPtr packet);
243
244    unsigned
245    countBankConflicts(PacketPtr packet, unsigned *bankAccesses);
246
247    unsigned
248    countBankConflicts(GPUDynInstPtr gpuDynInst,
249                       unsigned *numBankAccesses);
250
251  public:
252    typedef LdsStateParams Params;
253
254    LdsState(const Params *params);
255
256    // prevent copy construction
257    LdsState(const LdsState&) = delete;
258
259    ~LdsState()
260    {
261        parent = nullptr;
262    }
263
264    const Params *
265    params() const
266    {
267        return dynamic_cast<const Params *>(_params);
268    }
269
270    bool
271    isRetryResp() const
272    {
273        return retryResp;
274    }
275
276    void
277    setRetryResp(const bool value)
278    {
279        retryResp = value;
280    }
281
282    // prevent assignment
283    LdsState &
284    operator=(const LdsState &) = delete;
285
286    /**
287     * use the dynamic wave id to create or just increase the reference count
288     */
289    int
290    increaseRefCounter(const uint32_t dispatchId, const uint32_t wgId)
291    {
292        int refCount = getRefCounter(dispatchId, wgId);
293        fatal_if(refCount < 0,
294                 "reference count should not be below zero");
295        return ++refCounter[dispatchId][wgId];
296    }
297
298    /**
299     * decrease the reference count after making sure it is in the list
300     * give back this chunk if the ref counter has reached 0
301     */
302    int
303    decreaseRefCounter(const uint32_t dispatchId, const uint32_t wgId)
304    {
305      int refCount = getRefCounter(dispatchId, wgId);
306
307      fatal_if(refCount <= 0,
308              "reference count should not be below zero or at zero to"
309              "decrement");
310
311      refCounter[dispatchId][wgId]--;
312
313      if (refCounter[dispatchId][wgId] == 0) {
314        releaseSpace(dispatchId, wgId);
315        return 0;
316      } else {
317        return refCounter[dispatchId][wgId];
318      }
319    }
320
321    /**
322     * return the current reference count for this workgroup id
323     */
324    int
325    getRefCounter(const uint32_t dispatchId, const uint32_t wgId) const
326    {
327      auto dispatchIter = chunkMap.find(dispatchId);
328      fatal_if(dispatchIter == chunkMap.end(),
329               "could not locate this dispatch id [%d]", dispatchId);
330
331      auto workgroup = dispatchIter->second.find(wgId);
332      fatal_if(workgroup == dispatchIter->second.end(),
333               "could not find this workgroup id within this dispatch id"
334               " did[%d] wgid[%d]", dispatchId, wgId);
335
336      auto refCountIter = refCounter.find(dispatchId);
337      if (refCountIter == refCounter.end()) {
338        fatal("could not locate this dispatch id [%d]", dispatchId);
339      } else {
340        auto workgroup = refCountIter->second.find(wgId);
341        if (workgroup == refCountIter->second.end()) {
342          fatal("could not find this workgroup id within this dispatch id"
343                  " did[%d] wgid[%d]", dispatchId, wgId);
344        } else {
345          return refCounter.at(dispatchId).at(wgId);
346        }
347      }
348
349      fatal("should not reach this point");
350      return 0;
351    }
352
353    /**
354     * assign a parent and request this amount of space be set aside
355     * for this wgid
356     */
357    LdsChunk *
358    reserveSpace(const uint32_t dispatchId, const uint32_t wgId,
359            const uint32_t size)
360    {
361        if (chunkMap.find(dispatchId) != chunkMap.end()) {
362            fatal_if(
363                chunkMap[dispatchId].find(wgId) != chunkMap[dispatchId].end(),
364                "duplicate workgroup ID asking for space in the LDS "
365                "did[%d] wgid[%d]", dispatchId, wgId);
366        }
367
368        fatal_if(bytesAllocated + size > maximumSize,
369                 "request would ask for more space than is available");
370
371        bytesAllocated += size;
372
373        chunkMap[dispatchId].emplace(wgId, LdsChunk(size));
374        // make an entry for this workgroup
375        refCounter[dispatchId][wgId] = 0;
376
377        return &chunkMap[dispatchId][wgId];
378    }
379
380    bool
381    returnQueuePush(std::pair<Tick, PacketPtr> thePair);
382
383    Tick
384    earliestReturnTime() const
385    {
386        // TODO set to max(lastCommand+1, curTick())
387        return returnQueue.empty() ? curTick() : returnQueue.back().first;
388    }
389
390    void
391    setParent(ComputeUnit *x_parent);
392
393    // accessors
394    ComputeUnit *
395    getParent() const
396    {
397        return parent;
398    }
399
400    std::string
401    getName()
402    {
403        return _name;
404    }
405
406    int
407    getBanks() const
408    {
409        return banks;
410    }
411
412    ComputeUnit *
413    getComputeUnit() const
414    {
415        return parent;
416    }
417
418    int
419    getBankConflictPenalty() const
420    {
421        return bankConflictPenalty;
422    }
423
424    /**
425     * get the allocated size for this workgroup
426     */
427    std::size_t
428    ldsSize(const uint32_t x_wgId)
429    {
430        return chunkMap[x_wgId].size();
431    }
432
433    AddrRange
434    getAddrRange() const
435    {
436        return range;
437    }
438
439    Port &
440    getPort(const std::string &if_name, PortID idx)
441    {
442        if (if_name == "cuPort") {
443            // TODO need to set name dynamically at this point?
444            return cuPort;
445        } else {
446            fatal("cannot resolve the port name " + if_name);
447        }
448    }
449
450    /**
451     * can this much space be reserved for a workgroup?
452     */
453    bool
454    canReserve(uint32_t x_size) const
455    {
456      return bytesAllocated + x_size <= maximumSize;
457    }
458
459  private:
460    /**
461     * give back the space
462     */
463    bool
464    releaseSpace(const uint32_t x_dispatchId, const uint32_t x_wgId)
465    {
466        auto dispatchIter = chunkMap.find(x_dispatchId);
467
468        if (dispatchIter == chunkMap.end()) {
469          fatal("dispatch id not found [%d]", x_dispatchId);
470        } else {
471          auto workgroupIter = dispatchIter->second.find(x_wgId);
472          if (workgroupIter == dispatchIter->second.end()) {
473            fatal("workgroup id [%d] not found in dispatch id [%d]",
474                    x_wgId, x_dispatchId);
475          }
476        }
477
478        fatal_if(bytesAllocated < chunkMap[x_dispatchId][x_wgId].size(),
479                 "releasing more space than was allocated");
480
481        bytesAllocated -= chunkMap[x_dispatchId][x_wgId].size();
482        chunkMap[x_dispatchId].erase(chunkMap[x_dispatchId].find(x_wgId));
483        return true;
484    }
485
486    // the port that connects this LDS to its owner CU
487    CuSidePort cuPort;
488
489    ComputeUnit* parent = nullptr;
490
491    std::string _name;
492
493    // the number of bytes currently reserved by all workgroups
494    int bytesAllocated = 0;
495
496    // the size of the LDS, the most bytes available
497    int maximumSize;
498
499    // Address range of this memory
500    AddrRange range;
501
502    // the penalty, in cycles, for each LDS bank conflict
503    int bankConflictPenalty = 0;
504
505    // the number of banks in the LDS underlying data store
506    int banks = 0;
507};
508
509#endif // __LDS_STATE_HH__
510