kvm: x86: mmu: Verify that restored PTE has needed perms in fast page fault

Before fast page fault restores an access track PTE back to a regular PTE,
it now also verifies that the restored PTE would grant the necessary
permissions for the faulting access to succeed. If not, it falls back
to the slow page fault path.

Signed-off-by: Junaid Shahid <junaids@google.com>
Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 437d162..2fd7586 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -373,6 +373,11 @@ static int is_last_spte(u64 pte, int level)
 	return 0;
 }
 
+static bool is_executable_pte(u64 spte)
+{
+	return (spte & (shadow_x_mask | shadow_nx_mask)) == shadow_x_mask;
+}
+
 static kvm_pfn_t spte_to_pfn(u64 pte)
 {
 	return (pte & PT64_BASE_ADDR_MASK) >> PAGE_SHIFT;
@@ -728,6 +733,23 @@ static u64 mark_spte_for_access_track(u64 spte)
 	return spte;
 }
 
+/* Restore an acc-track PTE back to a regular PTE */
+static u64 restore_acc_track_spte(u64 spte)
+{
+	u64 new_spte = spte;
+	u64 saved_bits = (spte >> shadow_acc_track_saved_bits_shift)
+			 & shadow_acc_track_saved_bits_mask;
+
+	WARN_ON_ONCE(!is_access_track_spte(spte));
+
+	new_spte &= ~shadow_acc_track_mask;
+	new_spte &= ~(shadow_acc_track_saved_bits_mask <<
+		      shadow_acc_track_saved_bits_shift);
+	new_spte |= saved_bits;
+
+	return new_spte;
+}
+
 /* Returns the Accessed status of the PTE and resets it at the same time. */
 static bool mmu_spte_age(u64 *sptep)
 {
@@ -3019,27 +3041,12 @@ static bool page_fault_can_be_fast(u32 error_code)
  */
 static bool
 fast_pf_fix_direct_spte(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
-			u64 *sptep, u64 old_spte,
-			bool remove_write_prot, bool remove_acc_track)
+			u64 *sptep, u64 old_spte, u64 new_spte)
 {
 	gfn_t gfn;
-	u64 new_spte = old_spte;
 
 	WARN_ON(!sp->role.direct);
 
-	if (remove_acc_track) {
-		u64 saved_bits = (old_spte >> shadow_acc_track_saved_bits_shift)
-				 & shadow_acc_track_saved_bits_mask;
-
-		new_spte &= ~shadow_acc_track_mask;
-		new_spte &= ~(shadow_acc_track_saved_bits_mask <<
-			      shadow_acc_track_saved_bits_shift);
-		new_spte |= saved_bits;
-	}
-
-	if (remove_write_prot)
-		new_spte |= PT_WRITABLE_MASK;
-
 	/*
 	 * Theoretically we could also set dirty bit (and flush TLB) here in
 	 * order to eliminate unnecessary PML logging. See comments in
@@ -3055,7 +3062,7 @@ fast_pf_fix_direct_spte(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
 	if (cmpxchg64(sptep, old_spte, new_spte) != old_spte)
 		return false;
 
-	if (remove_write_prot) {
+	if (is_writable_pte(new_spte) && !is_writable_pte(old_spte)) {
 		/*
 		 * The gfn of direct spte is stable since it is
 		 * calculated by sp->gfn.
@@ -3067,6 +3074,18 @@ fast_pf_fix_direct_spte(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
 	return true;
 }
 
+static bool is_access_allowed(u32 fault_err_code, u64 spte)
+{
+	if (fault_err_code & PFERR_FETCH_MASK)
+		return is_executable_pte(spte);
+
+	if (fault_err_code & PFERR_WRITE_MASK)
+		return is_writable_pte(spte);
+
+	/* Fault was on Read access */
+	return spte & PT_PRESENT_MASK;
+}
+
 /*
  * Return value:
  * - true: let the vcpu to access on the same address again.
@@ -3090,8 +3109,7 @@ static bool fast_page_fault(struct kvm_vcpu *vcpu, gva_t gva, int level,
 	walk_shadow_page_lockless_begin(vcpu);
 
 	do {
-		bool remove_write_prot = false;
-		bool remove_acc_track;
+		u64 new_spte;
 
 		for_each_shadow_entry_lockless(vcpu, gva, iterator, spte)
 			if (!is_shadow_present_pte(spte) ||
@@ -3112,52 +3130,44 @@ static bool fast_page_fault(struct kvm_vcpu *vcpu, gva_t gva, int level,
 		 * Need not check the access of upper level table entries since
 		 * they are always ACC_ALL.
 		 */
-
-		if (error_code & PFERR_FETCH_MASK) {
-			if ((spte & (shadow_x_mask | shadow_nx_mask))
-			    == shadow_x_mask) {
-				fault_handled = true;
-				break;
-			}
-		} else if (error_code & PFERR_WRITE_MASK) {
-			if (is_writable_pte(spte)) {
-				fault_handled = true;
-				break;
-			}
-
-			/*
-			 * Currently, to simplify the code, write-protection can
-			 * be removed in the fast path only if the SPTE was
-			 * write-protected for dirty-logging.
-			 */
-			remove_write_prot =
-				spte_can_locklessly_be_made_writable(spte);
-		} else {
-			/* Fault was on Read access */
-			if (spte & PT_PRESENT_MASK) {
-				fault_handled = true;
-				break;
-			}
+		if (is_access_allowed(error_code, spte)) {
+			fault_handled = true;
+			break;
 		}
 
-		remove_acc_track = is_access_track_spte(spte);
+		new_spte = spte;
 
-		/* Verify that the fault can be handled in the fast path */
-		if (!remove_acc_track && !remove_write_prot)
-			break;
+		if (is_access_track_spte(spte))
+			new_spte = restore_acc_track_spte(new_spte);
 
 		/*
-		 * Do not fix write-permission on the large spte since we only
-		 * dirty the first page into the dirty-bitmap in
-		 * fast_pf_fix_direct_spte() that means other pages are missed
-		 * if its slot is dirty-logged.
-		 *
-		 * Instead, we let the slow page fault path create a normal spte
-		 * to fix the access.
-		 *
-		 * See the comments in kvm_arch_commit_memory_region().
+		 * Currently, to simplify the code, write-protection can
+		 * be removed in the fast path only if the SPTE was
+		 * write-protected for dirty-logging or access tracking.
 		 */
-		if (sp->role.level > PT_PAGE_TABLE_LEVEL && remove_write_prot)
+		if ((error_code & PFERR_WRITE_MASK) &&
+		    spte_can_locklessly_be_made_writable(spte))
+		{
+			new_spte |= PT_WRITABLE_MASK;
+
+			/*
+			 * Do not fix write-permission on the large spte.  Since
+			 * we only dirty the first page into the dirty-bitmap in
+			 * fast_pf_fix_direct_spte(), other pages are missed
+			 * if its slot has dirty logging enabled.
+			 *
+			 * Instead, we let the slow page fault path create a
+			 * normal spte to fix the access.
+			 *
+			 * See the comments in kvm_arch_commit_memory_region().
+			 */
+			if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+				break;
+		}
+
+		/* Verify that the fault can be handled in the fast path */
+		if (new_spte == spte ||
+		    !is_access_allowed(error_code, new_spte))
 			break;
 
 		/*
@@ -3167,8 +3177,7 @@ static bool fast_page_fault(struct kvm_vcpu *vcpu, gva_t gva, int level,
 		 */
 		fault_handled = fast_pf_fix_direct_spte(vcpu, sp,
 							iterator.sptep, spte,
-							remove_write_prot,
-							remove_acc_track);
+							new_spte);
 		if (fault_handled)
 			break;