From dbcb06afe1c76d1129cb6d264949322a34c37185 Mon Sep 17 00:00:00 2001 From: John Ingalls <43973001+ingallsj@users.noreply.github.com> Date: Wed, 20 Mar 2024 14:58:50 -0700 Subject: [PATCH] PTW Hypervisor bug fixes: check GPA bits higher than HGATP.Mode (#3591) * util/package: VecToAugmentedVec extract * traverse check GPA bits higher than HGATP * check upper bits of VSATP addr when starting PTW * remove pte_cache_addr, just use pte_addr --- src/main/scala/rocket/PTW.scala | 48 ++++++++++++++----------------- src/main/scala/rocket/TLB.scala | 6 ++-- src/main/scala/util/package.scala | 8 ++++++ 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/src/main/scala/rocket/PTW.scala b/src/main/scala/rocket/PTW.scala index f0338807d..91b9a7bb2 100644 --- a/src/main/scala/rocket/PTW.scala +++ b/src/main/scala/rocket/PTW.scala @@ -278,7 +278,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( val aux_count = Reg(UInt(log2Ceil(pgLevels).W)) /** pte for 2-stage translation */ val aux_pte = Reg(new PTE) - val aux_ppn_hi = (pgLevels > 4 && r_req.addr.getWidth > aux_pte.ppn.getWidth).option(Reg(UInt((r_req.addr.getWidth - aux_pte.ppn.getWidth).W))) val gpa_pgoff = Reg(UInt(pgIdxBits.W)) // only valid in resp_gf case val stage2 = Reg(Bool()) val stage2_final = Reg(Bool()) @@ -301,7 +300,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( } } // construct pte from mem.resp - val (pte, invalid_paddr) = { + val (pte, invalid_paddr, invalid_gpa) = { val tmp = mem_resp_data.asTypeOf(new PTE()) val res = WireDefault(tmp) res.ppn := Mux(do_both_stages && !stage2, tmp.ppn(vpnBits.min(tmp.ppn.getWidth)-1, 0), tmp.ppn(ppnBits-1, 0)) @@ -310,10 +309,12 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( for (i <- 0 until pgLevels-1) when (count <= i.U && tmp.ppn((pgLevels-1-i)*pgLevelBits-1, (pgLevels-2-i)*pgLevelBits) =/= 0.U) { res.v := false.B } } - (res, Mux(do_both_stages && !stage2, (tmp.ppn >> vpnBits) =/= 0.U, (tmp.ppn >> ppnBits) =/= 0.U)) + (res, + Mux(do_both_stages && !stage2, (tmp.ppn >> vpnBits) =/= 0.U, (tmp.ppn >> ppnBits) =/= 0.U), + do_both_stages && !stage2 && checkInvalidHypervisorGPA(r_hgatp, tmp.ppn)) } // find non-leaf PTE, need traverse - val traverse = pte.table() && !invalid_paddr && count < (pgLevels-1).U + val traverse = pte.table() && !invalid_paddr && !invalid_gpa && count < (pgLevels-1).U /** address send to mem for enquerry */ val pte_addr = if (!usingVM) 0.U else { val vpn_idxs = (0 until pgLevels).map { i => @@ -328,19 +329,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( //use vpn slice as offset raw_pte_addr.apply(size.min(raw_pte_addr.getWidth) - 1, 0) } - /** pte_cache input addr */ - val pte_cache_addr = if (!usingHypervisor) pte_addr else { - val vpn_idxs = (0 until pgLevels-1).map { i => - val ext_aux_pte_ppn = aux_ppn_hi match { - case None => aux_pte.ppn - case Some(hi) => Cat(hi, aux_pte.ppn) - } - (ext_aux_pte_ppn >> (pgLevels - i - 1) * pgLevelBits)(pgLevelBits - 1, 0) - } - val vpn_idx = vpn_idxs(count) - val raw_pte_cache_addr = Cat(r_pte.ppn, vpn_idx) << log2Ceil(xLen/8) - raw_pte_cache_addr(vaddrBits.min(raw_pte_cache_addr.getWidth)-1, 0) - } /** stage2_pte_cache input addr */ val stage2_pte_cache_addr = if (!usingHypervisor) 0.U else { val vpn_idxs = (0 until pgLevels - 1).map { i => @@ -373,7 +361,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( else can_hit val tag = if (s2) Cat(true.B, stage2_pte_cache_addr.padTo(vaddrBits)) - else Cat(r_req.vstage1, pte_cache_addr.padTo(if (usingHypervisor) vaddrBits else paddrBits)) + else Cat(r_req.vstage1, pte_addr.padTo(if (usingHypervisor) vaddrBits else paddrBits)) val hits = tags.map(_ === tag).asUInt & valid val hit = hits.orR && can_hit @@ -539,7 +527,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( io.mem.req.bits.data := DontCare io.mem.req.bits.mask := DontCare - io.mem.s1_kill := l2_hit || state =/= s_wait1 + io.mem.s1_kill := l2_hit || (state =/= s_wait1) || resp_gf io.mem.s1_data := DontCare io.mem.s2_kill := false.B @@ -607,12 +595,11 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( count := Mux(arb.io.out.bits.bits.stage2, hgatp_initial_count, satp_initial_count) aux_count := Mux(arb.io.out.bits.bits.vstage1, vsatp_initial_count, 0.U) aux_pte.ppn := aux_ppn - aux_ppn_hi.foreach { _ := aux_ppn >> aux_pte.ppn.getWidth } aux_pte.reserved_for_future := 0.U resp_ae_ptw := false.B resp_ae_final := false.B resp_pf := false.B - resp_gf := false.B + resp_gf := checkInvalidHypervisorGPA(io.dpath.hgatp, aux_ppn) && arb.io.out.bits.bits.stage2 resp_hr := true.B resp_hw := true.B resp_hx := true.B @@ -630,7 +617,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( when (stage2_pte_cache_hit) { aux_count := aux_count + 1.U aux_pte.ppn := stage2_pte_cache_data - aux_ppn_hi.foreach { _ := 0.U } aux_pte.reserved_for_future := 0.U pte_hit := true.B }.elsewhen (pte_cache_hit) { @@ -639,6 +625,10 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( }.otherwise { next_state := Mux(io.mem.req.ready, s_wait1, s_req) } + when(resp_gf) { + next_state := s_ready + resp_valid(r_req_dest) := true.B + } } is (s_wait1) { // This Mux is for the l2_error case; the l2_hit && !l2_error case is overriden below @@ -676,7 +666,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( r_pte := OptimizationBarrier( // l2tlb hit->find a leaf PTE(l2_pte), respond to L1TLB - Mux(l2_hit && !l2_error, l2_pte, + Mux(l2_hit && !l2_error && !resp_gf, l2_pte, // S2 PTE cache hit -> proceed to the next level of walking, update the r_pte with hgatp Mux(state === s_req && stage2_pte_cache_hit, makeHypervisorRootPTE(r_hgatp, stage2_pte_cache_data, l2_pte), // pte cache hit->find a non-leaf PTE(pte_cache),continue to request mem @@ -691,7 +681,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( Mux(arb.io.out.fire, Mux(arb.io.out.bits.bits.stage2, makeHypervisorRootPTE(io.dpath.hgatp, io.dpath.vsatp.ppn, r_pte), makePTE(satp.ppn, r_pte)), r_pte)))))))) - when (l2_hit && !l2_error) { + when (l2_hit && !l2_error && !resp_gf) { assert(state === s_req || state === s_wait1) next_state := s_ready resp_valid(r_req_dest) := true.B @@ -753,14 +743,14 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( val s1_ppns = (0 until pgLevels-1).map(i => Cat(pte.ppn(pte.ppn.getWidth-1, (pgLevels-i-1)*pgLevelBits), r_req.addr(((pgLevels-i-1)*pgLevelBits min vpnBits)-1,0).padTo((pgLevels-i-1)*pgLevelBits))) :+ pte.ppn makePTE(s1_ppns(count), pte) }) - aux_ppn_hi.foreach { _ := 0.U } stage2 := true.B } for (i <- 0 until pgLevels) { val leaf = mem_resp_valid && !traverse && count === i.U - ccover(leaf && pte.v && !invalid_paddr && pte.reserved_for_future === 0.U, s"L$i", s"successful page-table access, level $i") + ccover(leaf && pte.v && !invalid_paddr && !invalid_gpa && pte.reserved_for_future === 0.U, s"L$i", s"successful page-table access, level $i") ccover(leaf && pte.v && invalid_paddr, s"L${i}_BAD_PPN_MSB", s"PPN too large, level $i") + ccover(leaf && pte.v && invalid_gpa, s"L${i}_BAD_GPA_MSB", s"GPA too large, level $i") ccover(leaf && pte.v && pte.reserved_for_future =/= 0.U, s"L${i}_BAD_RSV_MSB", s"reserved MSBs set, level $i") ccover(leaf && !mem_resp_data(0), s"L${i}_INVALID_PTE", s"page not present, level $i") if (i != pgLevels-1) @@ -790,6 +780,12 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()( pte.ppn := Cat(hgatp.ppn >> maxHypervisorExtraAddrBits, lsbs) pte } + /** use hgatp and vpn to check for gpa out of range */ + private def checkInvalidHypervisorGPA(hgatp: PTBR, vpn: UInt) = { + val count = pgLevels.U - minPgLevels.U - hgatp.additionalPgLevels + val idxs = (0 to pgLevels-minPgLevels).map(i => (vpn >> ((pgLevels-i)*pgLevelBits)+maxHypervisorExtraAddrBits)) + idxs.extract(count) =/= 0.U + } } /** Mix-ins for constructing tiles that might have a PTW */ diff --git a/src/main/scala/rocket/TLB.scala b/src/main/scala/rocket/TLB.scala index f0f0b1a4a..c1f04aef7 100644 --- a/src/main/scala/rocket/TLB.scala +++ b/src/main/scala/rocket/TLB.scala @@ -589,9 +589,9 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T val pf_ld_array = Mux(cmd_read, ((~Mux(cmd_readx, x_array, r_array) & ~ptw_ae_array) | ptw_pf_array) & ~ptw_gf_array, 0.U) val pf_st_array = Mux(cmd_write_perms, ((~w_array & ~ptw_ae_array) | ptw_pf_array) & ~ptw_gf_array, 0.U) val pf_inst_array = ((~x_array & ~ptw_ae_array) | ptw_pf_array) & ~ptw_gf_array - val gf_ld_array = Mux(priv_v && cmd_read, ~Mux(cmd_readx, hx_array, hr_array) & ~ptw_ae_array, 0.U) - val gf_st_array = Mux(priv_v && cmd_write_perms, ~hw_array & ~ptw_ae_array, 0.U) - val gf_inst_array = Mux(priv_v, ~hx_array & ~ptw_ae_array, 0.U) + val gf_ld_array = Mux(priv_v && cmd_read, (~Mux(cmd_readx, hx_array, hr_array) | ptw_gf_array) & ~ptw_ae_array, 0.U) + val gf_st_array = Mux(priv_v && cmd_write_perms, (~hw_array | ptw_gf_array) & ~ptw_ae_array, 0.U) + val gf_inst_array = Mux(priv_v, (~hx_array | ptw_gf_array) & ~ptw_ae_array, 0.U) val gpa_hits = { val need_gpa_mask = if (instruction) gf_inst_array else gf_ld_array | gf_st_array diff --git a/src/main/scala/util/package.scala b/src/main/scala/util/package.scala index bacaf5b7b..6fb8895fb 100644 --- a/src/main/scala/util/package.scala +++ b/src/main/scala/util/package.scala @@ -18,6 +18,12 @@ package object util { def isOneOf(u1: UInt, u2: UInt*): Bool = isOneOf(u1 +: u2.toSeq) } + implicit class VecToAugmentedVec[T <: Data](private val x: Vec[T]) extends AnyVal { + + /** Like Vec.apply(idx), but tolerates indices of mismatched width */ + def extract(idx: UInt): T = x((idx | 0.U(log2Ceil(x.size).W)).extract(log2Ceil(x.size) - 1, 0)) + } + implicit class SeqToAugmentedSeq[T <: Data](private val x: Seq[T]) extends AnyVal { def apply(idx: UInt): T = { if (x.size <= 1) { @@ -34,6 +40,8 @@ package object util { } } + def extract(idx: UInt): T = VecInit(x).extract(idx) + def asUInt: UInt = Cat(x.map(_.asUInt).reverse) def rotate(n: Int): Seq[T] = x.drop(n) ++ x.take(n)