mm: add tlb_remove_check_page_size_change to track page size change

With commit e77b0852b551 ("mm/mmu_gather: track page size with mmu
gather and force flush if page size change") we added the ability to
force a tlb flush when the page size change in a mmu_gather loop.  We
did that by checking for a page size change every time we added a page
to mmu_gather for lazy flush/remove.  We can improve that by moving the
page size change check early and not doing it every time we add a page.

This also helps us to do tlb flush when invalidating a range covering
dax mapping.  Wrt dax mapping we don't have a backing struct page and
hence we don't call tlb_remove_page, which earlier forced the tlb flush
on page size change.  Moving the page size change check earlier means we
will do the same even for dax mapping.

We also avoid doing this check on architecture other than powerpc.

In a later patch we will remove page size check from tlb_remove_page().

Link: http://lkml.kernel.org/r/20161026084839.27299-5-aneesh.kumar@linux.vnet.ibm.com
Signed-off-by: Aneesh Kumar K.V <aneesh.kumar@linux.vnet.ibm.com>
Cc: "Kirill A. Shutemov" <kirill@shutemov.name>
Cc: Dan Williams <dan.j.williams@intel.com>
Cc: Ross Zwisler <ross.zwisler@linux.intel.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/arch/arm/include/asm/tlb.h b/arch/arm/include/asm/tlb.h
index 82841ba..a9d6de4 100644
--- a/arch/arm/include/asm/tlb.h
+++ b/arch/arm/include/asm/tlb.h
@@ -286,5 +286,11 @@ tlb_remove_pmd_tlb_entry(struct mmu_gather *tlb, pmd_t *pmdp, unsigned long addr
 
 #define tlb_migrate_finish(mm)		do { } while (0)
 
+#define tlb_remove_check_page_size_change tlb_remove_check_page_size_change
+static inline void tlb_remove_check_page_size_change(struct mmu_gather *tlb,
+						     unsigned int page_size)
+{
+}
+
 #endif /* CONFIG_MMU */
 #endif
diff --git a/arch/ia64/include/asm/tlb.h b/arch/ia64/include/asm/tlb.h
index b3f369a..bfe6295 100644
--- a/arch/ia64/include/asm/tlb.h
+++ b/arch/ia64/include/asm/tlb.h
@@ -286,6 +286,12 @@ do {							\
 #define tlb_remove_huge_tlb_entry(h, tlb, ptep, address)	\
 	tlb_remove_tlb_entry(tlb, ptep, address)
 
+#define tlb_remove_check_page_size_change tlb_remove_check_page_size_change
+static inline void tlb_remove_check_page_size_change(struct mmu_gather *tlb,
+						     unsigned int page_size)
+{
+}
+
 #define pte_free_tlb(tlb, ptep, address)		\
 do {							\
 	tlb->need_flush = 1;				\
diff --git a/arch/powerpc/include/asm/tlb.h b/arch/powerpc/include/asm/tlb.h
index 99e1397..6095575 100644
--- a/arch/powerpc/include/asm/tlb.h
+++ b/arch/powerpc/include/asm/tlb.h
@@ -28,6 +28,7 @@
 #define tlb_start_vma(tlb, vma)	do { } while (0)
 #define tlb_end_vma(tlb, vma)	do { } while (0)
 #define __tlb_remove_tlb_entry	__tlb_remove_tlb_entry
+#define tlb_remove_check_page_size_change tlb_remove_check_page_size_change
 
 extern void tlb_flush(struct mmu_gather *tlb);
 
@@ -46,6 +47,21 @@ static inline void __tlb_remove_tlb_entry(struct mmu_gather *tlb, pte_t *ptep,
 #endif
 }
 
+static inline void tlb_remove_check_page_size_change(struct mmu_gather *tlb,
+						     unsigned int page_size)
+{
+	if (!tlb->page_size)
+		tlb->page_size = page_size;
+	else if (tlb->page_size != page_size) {
+		tlb_flush_mmu(tlb);
+		/*
+		 * update the page size after flush for the new
+		 * mmu_gather.
+		 */
+		tlb->page_size = page_size;
+	}
+}
+
 #ifdef CONFIG_SMP
 static inline int mm_is_core_local(struct mm_struct *mm)
 {
diff --git a/arch/s390/include/asm/tlb.h b/arch/s390/include/asm/tlb.h
index 094440b..28b159c 100644
--- a/arch/s390/include/asm/tlb.h
+++ b/arch/s390/include/asm/tlb.h
@@ -165,4 +165,10 @@ static inline void pud_free_tlb(struct mmu_gather *tlb, pud_t *pud,
 #define tlb_remove_huge_tlb_entry(h, tlb, ptep, address)	\
 	tlb_remove_tlb_entry(tlb, ptep, address)
 
+#define tlb_remove_check_page_size_change tlb_remove_check_page_size_change
+static inline void tlb_remove_check_page_size_change(struct mmu_gather *tlb,
+						     unsigned int page_size)
+{
+}
+
 #endif /* _S390_TLB_H */
diff --git a/arch/sh/include/asm/tlb.h b/arch/sh/include/asm/tlb.h
index e7d15e8..0f988b3 100644
--- a/arch/sh/include/asm/tlb.h
+++ b/arch/sh/include/asm/tlb.h
@@ -130,6 +130,12 @@ static inline void tlb_remove_page_size(struct mmu_gather *tlb,
 	return tlb_remove_page(tlb, page);
 }
 
+#define tlb_remove_check_page_size_change tlb_remove_check_page_size_change
+static inline void tlb_remove_check_page_size_change(struct mmu_gather *tlb,
+						     unsigned int page_size)
+{
+}
+
 #define pte_free_tlb(tlb, ptep, addr)	pte_free((tlb)->mm, ptep)
 #define pmd_free_tlb(tlb, pmdp, addr)	pmd_free((tlb)->mm, pmdp)
 #define pud_free_tlb(tlb, pudp, addr)	pud_free((tlb)->mm, pudp)
diff --git a/arch/um/include/asm/tlb.h b/arch/um/include/asm/tlb.h
index a442702..8258dd4 100644
--- a/arch/um/include/asm/tlb.h
+++ b/arch/um/include/asm/tlb.h
@@ -144,6 +144,12 @@ static inline void tlb_remove_page_size(struct mmu_gather *tlb,
 #define tlb_remove_huge_tlb_entry(h, tlb, ptep, address)	\
 	tlb_remove_tlb_entry(tlb, ptep, address)
 
+#define tlb_remove_check_page_size_change tlb_remove_check_page_size_change
+static inline void tlb_remove_check_page_size_change(struct mmu_gather *tlb,
+						     unsigned int page_size)
+{
+}
+
 #define pte_free_tlb(tlb, ptep, addr) __pte_free_tlb(tlb, ptep, addr)
 
 #define pud_free_tlb(tlb, pudp, addr) __pud_free_tlb(tlb, pudp, addr)
diff --git a/include/asm-generic/tlb.h b/include/asm-generic/tlb.h
index 38c2b70..256c9de 100644
--- a/include/asm-generic/tlb.h
+++ b/include/asm-generic/tlb.h
@@ -182,6 +182,22 @@ static inline bool __tlb_remove_pte_page(struct mmu_gather *tlb, struct page *pa
 	return __tlb_remove_page(tlb, page);
 }
 
+#ifndef tlb_remove_check_page_size_change
+#define tlb_remove_check_page_size_change tlb_remove_check_page_size_change
+static inline void tlb_remove_check_page_size_change(struct mmu_gather *tlb,
+						     unsigned int page_size)
+{
+	/*
+	 * We don't care about page size change, just update
+	 * mmu_gather page size here so that debug checks
+	 * doesn't throw false warning.
+	 */
+#ifdef CONFIG_DEBUG_VM
+	tlb->page_size = page_size;
+#endif
+}
+#endif
+
 /*
  * In the case of tlb vma handling, we can optimise these away in the
  * case where we're doing a full MM flush.  When we're doing a munmap,
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 0103728..26fd116 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1323,6 +1323,8 @@ bool madvise_free_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
 	struct mm_struct *mm = tlb->mm;
 	bool ret = false;
 
+	tlb_remove_check_page_size_change(tlb, HPAGE_PMD_SIZE);
+
 	ptl = pmd_trans_huge_lock(pmd, vma);
 	if (!ptl)
 		goto out_unlocked;
@@ -1384,6 +1386,8 @@ int zap_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
 	pmd_t orig_pmd;
 	spinlock_t *ptl;
 
+	tlb_remove_check_page_size_change(tlb, HPAGE_PMD_SIZE);
+
 	ptl = __pmd_trans_huge_lock(pmd, vma);
 	if (!ptl)
 		return 0;
diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index 8e519da..3edb759 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -3286,6 +3286,11 @@ void __unmap_hugepage_range(struct mmu_gather *tlb, struct vm_area_struct *vma,
 	BUG_ON(start & ~huge_page_mask(h));
 	BUG_ON(end & ~huge_page_mask(h));
 
+	/*
+	 * This is a hugetlb vma, all the pte entries should point
+	 * to huge page.
+	 */
+	tlb_remove_check_page_size_change(tlb, sz);
 	tlb_start_vma(tlb, vma);
 	mmu_notifier_invalidate_range_start(mm, mmun_start, mmun_end);
 	address = start;
diff --git a/mm/madvise.c b/mm/madvise.c
index 93fb63e..0e3828ea 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -281,6 +281,7 @@ static int madvise_free_pte_range(pmd_t *pmd, unsigned long addr,
 	if (pmd_trans_unstable(pmd))
 		return 0;
 
+	tlb_remove_check_page_size_change(tlb, PAGE_SIZE);
 	orig_pte = pte = pte_offset_map_lock(mm, pmd, addr, &ptl);
 	arch_enter_lazy_mmu_mode();
 	for (; addr != end; pte++, addr += PAGE_SIZE) {
diff --git a/mm/memory.c b/mm/memory.c
index d86b7b4..eae20eb 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -528,7 +528,11 @@ void free_pgd_range(struct mmu_gather *tlb,
 		end -= PMD_SIZE;
 	if (addr > end - 1)
 		return;
-
+	/*
+	 * We add page table cache pages with PAGE_SIZE,
+	 * (see pte_free_tlb()), flush the tlb if we need
+	 */
+	tlb_remove_check_page_size_change(tlb, PAGE_SIZE);
 	pgd = pgd_offset(tlb->mm, addr);
 	do {
 		next = pgd_addr_end(addr, end);
@@ -1120,6 +1124,7 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
 	swp_entry_t entry;
 	struct page *pending_page = NULL;
 
+	tlb_remove_check_page_size_change(tlb, PAGE_SIZE);
 again:
 	init_rss_vec(rss);
 	start_pte = pte_offset_map_lock(mm, pmd, addr, &ptl);