Skip to content
Snippets Groups Projects
Select Git revision
  • v3 default
  • update_akita
  • solve-TLB-memory-leak
  • generic_lmf
  • v2
  • dir_latency
  • 103-storage-should-be-able-to-set-granularity
  • 104-support-tracing-memory-responses
  • akita_v3
  • 102-limit-tlb-concurrency
  • domained_timing
  • master protected
  • 97-gl0_invalidate_navisim
  • 81-implement-colt-tlb-for-gpu
  • fix_timing_order_issue
  • 86-dram-detailed-modeling
  • coherency-hmg
  • round_robin
  • scheduling-research
  • dram
  • v3.0.0-alpha.10
  • v3.0.0-alpha.9
  • v3.0.0-alpha.8
  • v3.0.0-alpha.7
  • v3.0.0-alpha.6
  • v2.5.0
  • v3.0.0-alpha.5
  • v2.4.2
  • v3.0.0-alpha.4
  • v3.0.0-alpha.3
  • v2.4.1
  • v3.0.0-alpha.1
  • v2.4.0
  • v2.3.2
  • v2.3.1
  • v2.3.0
  • v2.2.0
  • v2.1.2
  • v2.1.1
  • v2.1.0
40 results

addresstranslator.go

addresstranslator.go 9.01 KiB
package addresstranslator

import (
	"log"
	"reflect"

	"gitlab.com/akita/akita"
	"gitlab.com/akita/mem"
	"gitlab.com/akita/mem/cache"
	"gitlab.com/akita/mem/vm"
	"gitlab.com/akita/util/tracing"
)

type transaction struct {
	incomingReqs    []mem.AccessReq
	translationReq  *vm.TranslationReq
	translationRsp  *vm.TranslationRsp
	translationDone bool
}

type reqToBottom struct {
	reqFromTop  mem.AccessReq
	reqToBottom mem.AccessReq
}

// AddressTranslator is a component that forwards the read/write requests with
// the address translated from virtual to physical.
type AddressTranslator struct {
	*akita.TickingComponent

	TopPort         akita.Port
	BottomPort      akita.Port
	TranslationPort akita.Port
	CtrlPort        akita.Port

	lowModuleFinder     cache.LowModuleFinder
	translationProvider akita.Port
	log2PageSize        uint64
	gpuID               uint64
	numReqPerCycle      int

	isFlushing bool

	transactions        []*transaction
	inflightReqToBottom []reqToBottom
}

// Tick updates state at each cycle.
func (t *AddressTranslator) Tick(now akita.VTimeInSec) bool {
	madeProgress := false

	if !t.isFlushing {
		madeProgress = t.runPipeline(now)
	}

	madeProgress = t.handleCtrlRequest(now) || madeProgress

	return madeProgress
}

func (t *AddressTranslator) runPipeline(now akita.VTimeInSec) bool {
	madeProgress := false

	for i := 0; i < t.numReqPerCycle; i++ {
		madeProgress = t.respond(now) || madeProgress
	}

	for i := 0; i < t.numReqPerCycle; i++ {
		madeProgress = t.forward(now) || madeProgress
	}

	for i := 0; i < t.numReqPerCycle; i++ {
		madeProgress = t.parseTranslation(now) || madeProgress
	}

	for i := 0; i < t.numReqPerCycle; i++ {
		madeProgress = t.translate(now) || madeProgress
	}

	return madeProgress
}

func (t *AddressTranslator) translate(now akita.VTimeInSec) bool {
	item := t.TopPort.Peek()
	if item == nil {
		return false
	}

	req := item.(mem.AccessReq)
	vAddr := req.GetAddress()
	vPageID := t.addrToPageID(vAddr)

	transReq := vm.TranslationReqBuilder{}.
		WithSendTime(now).
		WithSrc(t.TranslationPort).
		WithDst(t.translationProvider).
		WithPID(req.GetPID()).
		WithVAddr(vPageID).
		WithGPUID(t.gpuID).
		Build()
	err := t.TranslationPort.Send(transReq)
	if err != nil {
		return false
	}

	translation := &transaction{
		incomingReqs:   []mem.AccessReq{req},
		translationReq: transReq,
	}
	t.transactions = append(t.transactions, translation)

	tracing.TraceReqReceive(req, now, t)
	tracing.TraceReqInitiate(transReq, now, t, tracing.MsgIDAtReceiver(req, t))

	t.TopPort.Retrieve(now)
	return true
}

func (t *AddressTranslator) parseTranslation(now akita.VTimeInSec) bool {
	rsp := t.TranslationPort.Retrieve(now)
	if rsp == nil {
		return false
	}

	transRsp := rsp.(*vm.TranslationRsp)
	translation := t.findTranslationByReqID(transRsp.RespondTo)
	if translation != nil {
		translation.translationRsp = transRsp
		translation.translationDone = true

		tracing.TraceReqFinalize(translation.translationReq, now, t)
	}

	return true
}

func (t *AddressTranslator) forward(now akita.VTimeInSec) bool {
	if len(t.transactions) == 0 || !t.transactions[0].translationDone {
		return false
	}

	translation := t.transactions[0]
	reqFromTop := translation.incomingReqs[0]
	translatedReq := t.createTranslatedReq(
		reqFromTop,
		translation.translationRsp.Page)
	translatedReq.Meta().SendTime = now
	err := t.BottomPort.Send(translatedReq)
	if err != nil {
		return false
	}

	t.inflightReqToBottom = append(t.inflightReqToBottom,
		reqToBottom{
			reqFromTop:  reqFromTop,
			reqToBottom: translatedReq,
		})
	translation.incomingReqs = translation.incomingReqs[1:]
	if len(translation.incomingReqs) == 0 {
		t.removeExistingTranslation(translation)
	}

	tracing.TraceReqInitiate(translatedReq, now, t,
		tracing.MsgIDAtReceiver(reqFromTop, t))

	return true
}

func (t *AddressTranslator) respond(now akita.VTimeInSec) bool {
	rsp := t.BottomPort.Peek()
	if rsp == nil {
		return false
	}

	reqInBottom := false

	var reqFromTop mem.AccessReq
	var reqToBottomCombo reqToBottom
	var rspToTop mem.AccessRsp
	switch rsp := rsp.(type) {
	case *mem.DataReadyRsp:
		reqInBottom = t.isReqInBottomByID(rsp.RespondTo)
		if reqInBottom {
			reqToBottomCombo = t.findReqToBottomByID(rsp.RespondTo)
			reqFromTop = reqToBottomCombo.reqFromTop
			drToTop := mem.DataReadyRspBuilder{}.
				WithSendTime(now).
				WithSrc(t.TopPort).
				WithDst(reqFromTop.Meta().Src).
				WithRspTo(reqFromTop.Meta().ID).
				WithData(rsp.Data).
				Build()
			rspToTop = drToTop
		}
	case *mem.WriteDoneRsp:
		reqInBottom = t.isReqInBottomByID(rsp.RespondTo)
		if reqInBottom {
			reqToBottomCombo = t.findReqToBottomByID(rsp.RespondTo)
			reqFromTop = reqToBottomCombo.reqFromTop
			rspToTop = mem.WriteDoneRspBuilder{}.
				WithSendTime(now).
				WithSrc(t.TopPort).
				WithDst(reqFromTop.Meta().Src).
				WithRspTo(reqFromTop.Meta().ID).
				Build()
		}
	default:
		log.Panicf("cannot handle respond of type %s", reflect.TypeOf(rsp))
	}
	if reqInBottom {
		err := t.TopPort.Send(rspToTop)
		if err != nil {
			return false
		}

		t.removeReqToBottomByID(rsp.(mem.AccessRsp).GetRespondTo())

		tracing.TraceReqFinalize(reqToBottomCombo.reqToBottom, now, t)
		tracing.TraceReqComplete(reqToBottomCombo.reqFromTop, now, t)
	}

	t.BottomPort.Retrieve(now)
	return true
}

func (t *AddressTranslator) createTranslatedReq(
	req mem.AccessReq,
	page vm.Page,
) mem.AccessReq {
	switch req := req.(type) {
	case *mem.ReadReq:
		return t.createTranslatedReadReq(req, page)
	case *mem.WriteReq:
		return t.createTranslatedWriteReq(req, page)
	default:
		log.Panicf("cannot translate request of type %s", reflect.TypeOf(req))
		return nil
	}
}

func (t *AddressTranslator) createTranslatedReadReq(
	req *mem.ReadReq,
	page vm.Page,
) *mem.ReadReq {
	offset := req.Address % (1 << t.log2PageSize)
	addr := page.PAddr + offset
	clone := mem.ReadReqBuilder{}.
		WithSrc(t.BottomPort).
		WithDst(t.lowModuleFinder.Find(addr)).
		WithAddress(addr).
		WithByteSize(req.AccessByteSize).
		WithPID(0).
		Build()
	clone.CanWaitForCoalesce = req.CanWaitForCoalesce
	return clone
}

func (t *AddressTranslator) createTranslatedWriteReq(
	req *mem.WriteReq,
	page vm.Page,
) *mem.WriteReq {
	offset := req.Address % (1 << t.log2PageSize)
	addr := page.PAddr + offset
	clone := mem.WriteReqBuilder{}.
		WithSrc(t.BottomPort).
		WithDst(t.lowModuleFinder.Find(addr)).
		WithData(req.Data).
		WithDirtyMask(req.DirtyMask).
		WithAddress(addr).
		WithPID(0).
		Build()
	clone.CanWaitForCoalesce = req.CanWaitForCoalesce
	return clone
}

func (t *AddressTranslator) addrToPageID(addr uint64) uint64 {
	return (addr >> t.log2PageSize) << t.log2PageSize
}
func (t *AddressTranslator) findTranslationByReqID(id string) *transaction {
	for _, t := range t.transactions {
		if t.translationReq.ID == id {
			return t
		}
	}
	return nil
}

func (t *AddressTranslator) removeExistingTranslation(trans *transaction) {
	for i, tr := range t.transactions {
		if tr == trans {
			t.transactions = append(t.transactions[:i], t.transactions[i+1:]...)
			return
		}
	}
	panic("translation not found")
}

func (t *AddressTranslator) isReqInBottomByID(id string) bool {
	for _, r := range t.inflightReqToBottom {
		if r.reqToBottom.Meta().ID == id {
			return true
		}
	}
	return false
}

func (t *AddressTranslator) findReqToBottomByID(id string) reqToBottom {
	for _, r := range t.inflightReqToBottom {
		if r.reqToBottom.Meta().ID == id {
			return r
		}
	}
	panic("req to bottom not found")
}

func (t *AddressTranslator) removeReqToBottomByID(id string) {
	for i, r := range t.inflightReqToBottom {
		if r.reqToBottom.Meta().ID == id {
			t.inflightReqToBottom = append(
				t.inflightReqToBottom[:i],
				t.inflightReqToBottom[i+1:]...)
			return
		}
	}
	panic("req to bottom not found")
}

func (t *AddressTranslator) handleCtrlRequest(now akita.VTimeInSec) bool {
	req := t.CtrlPort.Peek()
	if req == nil {
		return false
	}

	req = t.CtrlPort.Retrieve(now)

	switch req := req.(type) {
	case *AddressTranslatorFlushReq:
		return t.handleFlushReq(now, req)
	case *AddressTranslatorRestartReq:
		return t.handleRestartReq(now, req)

	default:
		log.Panicf("cannot process request %s", reflect.TypeOf(req))
	}

	return true
}
func (t *AddressTranslator) handleFlushReq(
	now akita.VTimeInSec,
	req *AddressTranslatorFlushReq,
) bool {
	rsp := AddressTranslatorFlushRspBuilder{}.
		WithSrc(t.CtrlPort).
		WithDst(req.Src).
		WithSendTime(now).
		Build()

	err := t.CtrlPort.Send(rsp)

	if err != nil {
		return false
	}

	t.transactions = nil
	t.inflightReqToBottom = nil
	t.isFlushing = true

	return true
}

func (t *AddressTranslator) handleRestartReq(
	now akita.VTimeInSec,
	req *AddressTranslatorRestartReq,
) bool {
	for t.TopPort.Retrieve(now) != nil {
		t.TopPort.Retrieve(now)
	}

	for t.BottomPort.Retrieve(now) != nil {
		t.BottomPort.Retrieve(now)
	}

	for t.TranslationPort.Retrieve(now) != nil {
		t.TranslationPort.Retrieve(now)
	}

	rsp := AddressTranslatorRestartRspBuilder{}.
		WithSrc(t.CtrlPort).
		WithDst(req.Src).
		WithSendTime(now).
		Build()

	err := t.CtrlPort.Send(rsp)

	if err != nil {
		log.Panicf("AT failed to send restart rsp to Ctrl component")
	}

	t.isFlushing = false

	return true
}