196,239d195
< ComputeUnit::InitializeWFContext(WFContext *wfCtx, NDRange *ndr, int cnt,
< int trueWgSize[], int trueWgSizeTotal,
< LdsChunk *ldsChunk, uint64_t origSpillMemStart)
< {
< wfCtx->cnt = cnt;
<
< VectorMask init_mask;
< init_mask.reset();
<
< for (int k = 0; k < wfSize(); ++k) {
< if (k + cnt * wfSize() < trueWgSizeTotal)
< init_mask[k] = 1;
< }
<
< wfCtx->init_mask = init_mask.to_ullong();
< wfCtx->exec_mask = init_mask.to_ullong();
<
< wfCtx->bar_cnt.resize(wfSize(), 0);
<
< wfCtx->max_bar_cnt = 0;
< wfCtx->old_barrier_cnt = 0;
< wfCtx->barrier_cnt = 0;
<
< wfCtx->privBase = ndr->q.privMemStart;
< ndr->q.privMemStart += ndr->q.privMemPerItem * wfSize();
<
< wfCtx->spillBase = ndr->q.spillMemStart;
< ndr->q.spillMemStart += ndr->q.spillMemPerItem * wfSize();
<
< wfCtx->pc = 0;
< wfCtx->rpc = UINT32_MAX;
<
< // set the wavefront context to have a pointer to this section of the LDS
< wfCtx->ldsChunk = ldsChunk;
<
< // WG state
< wfCtx->wg_id = ndr->globalWgId;
< wfCtx->barrier_id = barrier_id;
<
< // Kernel wide state
< wfCtx->ndr = ndr;
< }
<
< void
267,268c223,224
< ComputeUnit::StartWF(Wavefront *w, WFContext *wfCtx, int trueWgSize[],
< int trueWgSizeTotal)
---
> ComputeUnit::StartWF(Wavefront *w, int trueWgSize[], int trueWgSizeTotal,
> int cnt, LdsChunk *ldsChunk, NDRange *ndr)
271,272d226
< int cnt = wfCtx->cnt;
< NDRange *ndr = wfCtx->ndr;
276a231,238
> VectorMask init_mask;
> init_mask.reset();
>
> for (int k = 0; k < wfSize(); ++k) {
> if (k + cnt * wfSize() < trueWgSizeTotal)
> init_mask[k] = 1;
> }
>
279c241
< w->init_mask = wfCtx->init_mask;
---
> w->init_mask = init_mask.to_ullong();
293,294d254
< w->old_barrier_cnt = wfCtx->old_barrier_cnt;
< w->barrier_cnt = wfCtx->barrier_cnt;
297,299c257
< for (int i = 0; i < wfSize(); ++i) {
< w->bar_cnt[i] = wfCtx->bar_cnt[i];
< }
---
> w->bar_cnt.resize(wfSize(), 0);
301,303c259,261
< w->max_bar_cnt = wfCtx->max_bar_cnt;
< w->privBase = wfCtx->privBase;
< w->spillBase = wfCtx->spillBase;
---
> w->max_bar_cnt = 0;
> w->old_barrier_cnt = 0;
> w->barrier_cnt = 0;
305c263,264
< w->pushToReconvergenceStack(wfCtx->pc, wfCtx->rpc, wfCtx->exec_mask);
---
> w->privBase = ndr->q.privMemStart;
> ndr->q.privMemStart += ndr->q.privMemPerItem * wfSize();
306a266,270
> w->spillBase = ndr->q.spillMemStart;
> ndr->q.spillMemStart += ndr->q.spillMemPerItem * wfSize();
>
> w->pushToReconvergenceStack(0, UINT32_MAX, init_mask.to_ulong());
>
308,309c272,273
< w->wg_id = wfCtx->wg_id;
< w->dispatchid = wfCtx->ndr->dispatchId;
---
> w->wg_id = ndr->globalWgId;
> w->dispatchid = ndr->dispatchId;
314c278
< w->barrier_id = wfCtx->barrier_id;
---
> w->barrier_id = barrier_id;
317,318c281,282
< // move this from the context into the actual wavefront
< w->ldsChunk = wfCtx->ldsChunk;
---
> // set the wavefront context to have a pointer to this section of the LDS
> w->ldsChunk = ldsChunk;
343d306
< wfCtx->bar_cnt.clear();
379d341
< uint64_t origSpillMemStart = ndr->q.spillMemStart;
406,411c368
< WFContext wfCtx;
<
< InitializeWFContext(&wfCtx, ndr, cnt, trueWgSize, trueWgSizeTotal,
< ldsChunk, origSpillMemStart);
<
< StartWF(w, &wfCtx, trueWgSize, trueWgSizeTotal);
---
> StartWF(w, trueWgSize, trueWgSizeTotal, cnt, ldsChunk, ndr);