Select Git revision
addresstranslator.go 9.01 KiB
package addresstranslator
import (
"log"
"reflect"
"gitlab.com/akita/akita"
"gitlab.com/akita/mem"
"gitlab.com/akita/mem/cache"
"gitlab.com/akita/mem/vm"
"gitlab.com/akita/util/tracing"
)
type transaction struct {
incomingReqs []mem.AccessReq
translationReq *vm.TranslationReq
translationRsp *vm.TranslationRsp
translationDone bool
}
type reqToBottom struct {
reqFromTop mem.AccessReq
reqToBottom mem.AccessReq
}
// AddressTranslator is a component that forwards the read/write requests with
// the address translated from virtual to physical.
type AddressTranslator struct {
*akita.TickingComponent
TopPort akita.Port
BottomPort akita.Port
TranslationPort akita.Port
CtrlPort akita.Port
lowModuleFinder cache.LowModuleFinder
translationProvider akita.Port
log2PageSize uint64
gpuID uint64
numReqPerCycle int
isFlushing bool
transactions []*transaction
inflightReqToBottom []reqToBottom
}
// Tick updates state at each cycle.
func (t *AddressTranslator) Tick(now akita.VTimeInSec) bool {
madeProgress := false
if !t.isFlushing {
madeProgress = t.runPipeline(now)
}
madeProgress = t.handleCtrlRequest(now) || madeProgress
return madeProgress
}
func (t *AddressTranslator) runPipeline(now akita.VTimeInSec) bool {
madeProgress := false
for i := 0; i < t.numReqPerCycle; i++ {
madeProgress = t.respond(now) || madeProgress
}
for i := 0; i < t.numReqPerCycle; i++ {
madeProgress = t.forward(now) || madeProgress
}
for i := 0; i < t.numReqPerCycle; i++ {
madeProgress = t.parseTranslation(now) || madeProgress
}
for i := 0; i < t.numReqPerCycle; i++ {
madeProgress = t.translate(now) || madeProgress
}
return madeProgress
}
func (t *AddressTranslator) translate(now akita.VTimeInSec) bool {
item := t.TopPort.Peek()
if item == nil {
return false
}
req := item.(mem.AccessReq)
vAddr := req.GetAddress()
vPageID := t.addrToPageID(vAddr)
transReq := vm.TranslationReqBuilder{}.
WithSendTime(now).
WithSrc(t.TranslationPort).
WithDst(t.translationProvider).
WithPID(req.GetPID()).
WithVAddr(vPageID).
WithGPUID(t.gpuID).
Build()
err := t.TranslationPort.Send(transReq)
if err != nil {
return false
}
translation := &transaction{
incomingReqs: []mem.AccessReq{req},
translationReq: transReq,
}
t.transactions = append(t.transactions, translation)
tracing.TraceReqReceive(req, now, t)
tracing.TraceReqInitiate(transReq, now, t, tracing.MsgIDAtReceiver(req, t))
t.TopPort.Retrieve(now)
return true
}
func (t *AddressTranslator) parseTranslation(now akita.VTimeInSec) bool {
rsp := t.TranslationPort.Retrieve(now)
if rsp == nil {
return false
}
transRsp := rsp.(*vm.TranslationRsp)
translation := t.findTranslationByReqID(transRsp.RespondTo)
if translation != nil {
translation.translationRsp = transRsp
translation.translationDone = true
tracing.TraceReqFinalize(translation.translationReq, now, t)
}
return true
}
func (t *AddressTranslator) forward(now akita.VTimeInSec) bool {
if len(t.transactions) == 0 || !t.transactions[0].translationDone {
return false
}
translation := t.transactions[0]
reqFromTop := translation.incomingReqs[0]
translatedReq := t.createTranslatedReq(
reqFromTop,
translation.translationRsp.Page)
translatedReq.Meta().SendTime = now
err := t.BottomPort.Send(translatedReq)
if err != nil {
return false
}
t.inflightReqToBottom = append(t.inflightReqToBottom,
reqToBottom{
reqFromTop: reqFromTop,
reqToBottom: translatedReq,
})
translation.incomingReqs = translation.incomingReqs[1:]
if len(translation.incomingReqs) == 0 {
t.removeExistingTranslation(translation)
}
tracing.TraceReqInitiate(translatedReq, now, t,
tracing.MsgIDAtReceiver(reqFromTop, t))
return true
}
func (t *AddressTranslator) respond(now akita.VTimeInSec) bool {
rsp := t.BottomPort.Peek()
if rsp == nil {
return false
}
reqInBottom := false
var reqFromTop mem.AccessReq
var reqToBottomCombo reqToBottom
var rspToTop mem.AccessRsp
switch rsp := rsp.(type) {
case *mem.DataReadyRsp:
reqInBottom = t.isReqInBottomByID(rsp.RespondTo)
if reqInBottom {
reqToBottomCombo = t.findReqToBottomByID(rsp.RespondTo)
reqFromTop = reqToBottomCombo.reqFromTop
drToTop := mem.DataReadyRspBuilder{}.
WithSendTime(now).
WithSrc(t.TopPort).
WithDst(reqFromTop.Meta().Src).
WithRspTo(reqFromTop.Meta().ID).
WithData(rsp.Data).
Build()
rspToTop = drToTop
}
case *mem.WriteDoneRsp:
reqInBottom = t.isReqInBottomByID(rsp.RespondTo)
if reqInBottom {
reqToBottomCombo = t.findReqToBottomByID(rsp.RespondTo)
reqFromTop = reqToBottomCombo.reqFromTop
rspToTop = mem.WriteDoneRspBuilder{}.
WithSendTime(now).
WithSrc(t.TopPort).
WithDst(reqFromTop.Meta().Src).
WithRspTo(reqFromTop.Meta().ID).
Build()
}
default:
log.Panicf("cannot handle respond of type %s", reflect.TypeOf(rsp))
}
if reqInBottom {
err := t.TopPort.Send(rspToTop)
if err != nil {
return false
}
t.removeReqToBottomByID(rsp.(mem.AccessRsp).GetRespondTo())
tracing.TraceReqFinalize(reqToBottomCombo.reqToBottom, now, t)
tracing.TraceReqComplete(reqToBottomCombo.reqFromTop, now, t)
}
t.BottomPort.Retrieve(now)
return true
}
func (t *AddressTranslator) createTranslatedReq(
req mem.AccessReq,
page vm.Page,
) mem.AccessReq {
switch req := req.(type) {
case *mem.ReadReq:
return t.createTranslatedReadReq(req, page)
case *mem.WriteReq:
return t.createTranslatedWriteReq(req, page)
default:
log.Panicf("cannot translate request of type %s", reflect.TypeOf(req))
return nil
}
}
func (t *AddressTranslator) createTranslatedReadReq(
req *mem.ReadReq,
page vm.Page,
) *mem.ReadReq {
offset := req.Address % (1 << t.log2PageSize)
addr := page.PAddr + offset
clone := mem.ReadReqBuilder{}.
WithSrc(t.BottomPort).
WithDst(t.lowModuleFinder.Find(addr)).
WithAddress(addr).
WithByteSize(req.AccessByteSize).
WithPID(0).
Build()
clone.CanWaitForCoalesce = req.CanWaitForCoalesce
return clone
}
func (t *AddressTranslator) createTranslatedWriteReq(
req *mem.WriteReq,
page vm.Page,
) *mem.WriteReq {
offset := req.Address % (1 << t.log2PageSize)
addr := page.PAddr + offset
clone := mem.WriteReqBuilder{}.
WithSrc(t.BottomPort).
WithDst(t.lowModuleFinder.Find(addr)).
WithData(req.Data).
WithDirtyMask(req.DirtyMask).
WithAddress(addr).
WithPID(0).
Build()
clone.CanWaitForCoalesce = req.CanWaitForCoalesce
return clone
}
func (t *AddressTranslator) addrToPageID(addr uint64) uint64 {
return (addr >> t.log2PageSize) << t.log2PageSize
}
func (t *AddressTranslator) findTranslationByReqID(id string) *transaction {
for _, t := range t.transactions {
if t.translationReq.ID == id {
return t
}
}
return nil
}
func (t *AddressTranslator) removeExistingTranslation(trans *transaction) {
for i, tr := range t.transactions {
if tr == trans {
t.transactions = append(t.transactions[:i], t.transactions[i+1:]...)
return
}
}
panic("translation not found")
}
func (t *AddressTranslator) isReqInBottomByID(id string) bool {
for _, r := range t.inflightReqToBottom {
if r.reqToBottom.Meta().ID == id {
return true
}
}
return false
}
func (t *AddressTranslator) findReqToBottomByID(id string) reqToBottom {
for _, r := range t.inflightReqToBottom {
if r.reqToBottom.Meta().ID == id {
return r
}
}
panic("req to bottom not found")
}
func (t *AddressTranslator) removeReqToBottomByID(id string) {
for i, r := range t.inflightReqToBottom {
if r.reqToBottom.Meta().ID == id {
t.inflightReqToBottom = append(
t.inflightReqToBottom[:i],
t.inflightReqToBottom[i+1:]...)
return
}
}
panic("req to bottom not found")
}
func (t *AddressTranslator) handleCtrlRequest(now akita.VTimeInSec) bool {
req := t.CtrlPort.Peek()
if req == nil {
return false
}
req = t.CtrlPort.Retrieve(now)
switch req := req.(type) {
case *AddressTranslatorFlushReq:
return t.handleFlushReq(now, req)
case *AddressTranslatorRestartReq:
return t.handleRestartReq(now, req)
default:
log.Panicf("cannot process request %s", reflect.TypeOf(req))
}
return true
}
func (t *AddressTranslator) handleFlushReq(
now akita.VTimeInSec,
req *AddressTranslatorFlushReq,
) bool {
rsp := AddressTranslatorFlushRspBuilder{}.
WithSrc(t.CtrlPort).
WithDst(req.Src).
WithSendTime(now).
Build()
err := t.CtrlPort.Send(rsp)
if err != nil {
return false
}
t.transactions = nil
t.inflightReqToBottom = nil
t.isFlushing = true
return true
}
func (t *AddressTranslator) handleRestartReq(
now akita.VTimeInSec,
req *AddressTranslatorRestartReq,
) bool {
for t.TopPort.Retrieve(now) != nil {
t.TopPort.Retrieve(now)
}
for t.BottomPort.Retrieve(now) != nil {
t.BottomPort.Retrieve(now)
}
for t.TranslationPort.Retrieve(now) != nil {
t.TranslationPort.Retrieve(now)
}
rsp := AddressTranslatorRestartRspBuilder{}.
WithSrc(t.CtrlPort).
WithDst(req.Src).
WithSendTime(now).
Build()
err := t.CtrlPort.Send(rsp)
if err != nil {
log.Panicf("AT failed to send restart rsp to Ctrl component")
}
t.isFlushing = false
return true
}