Slab allocators: consistent ZERO_SIZE_PTR support and NULL result semantics

Define ZERO_OR_NULL_PTR macro to be able to remove the checks from the
allocators.  Move ZERO_SIZE_PTR related stuff into slab.h.

Make ZERO_SIZE_PTR work for all slab allocators and get rid of the
WARN_ON_ONCE(size == 0) that is still remaining in SLAB.

Make slub return NULL like the other allocators if a too large memory segment
is requested via __kmalloc.

Signed-off-by: Christoph Lameter <clameter@sgi.com>
Acked-by: Pekka Enberg <penberg@cs.helsinki.fi>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/mm/slab.c b/mm/slab.c
index 4bd8a53..d2cd304 100644
--- a/mm/slab.c
+++ b/mm/slab.c
@@ -775,6 +775,9 @@
 	 */
 	BUG_ON(malloc_sizes[INDEX_AC].cs_cachep == NULL);
 #endif
+	if (!size)
+		return ZERO_SIZE_PTR;
+
 	while (size > csizep->cs_size)
 		csizep++;
 
@@ -2351,7 +2354,7 @@
 		 * this should not happen at all.
 		 * But leave a BUG_ON for some lucky dude.
 		 */
-		BUG_ON(!cachep->slabp_cache);
+		BUG_ON(ZERO_OR_NULL_PTR(cachep->slabp_cache));
 	}
 	cachep->ctor = ctor;
 	cachep->name = name;
@@ -3653,8 +3656,8 @@
 	struct kmem_cache *cachep;
 
 	cachep = kmem_find_general_cachep(size, flags);
-	if (unlikely(cachep == NULL))
-		return NULL;
+	if (unlikely(ZERO_OR_NULL_PTR(cachep)))
+		return cachep;
 	return kmem_cache_alloc_node(cachep, flags, node);
 }
 
@@ -3760,7 +3763,7 @@
 	struct kmem_cache *c;
 	unsigned long flags;
 
-	if (unlikely(!objp))
+	if (unlikely(ZERO_OR_NULL_PTR(objp)))
 		return;
 	local_irq_save(flags);
 	kfree_debugcheck(objp);
@@ -4447,7 +4450,7 @@
  */
 size_t ksize(const void *objp)
 {
-	if (unlikely(objp == NULL))
+	if (unlikely(ZERO_OR_NULL_PTR(objp)))
 		return 0;
 
 	return obj_size(virt_to_cache(objp));
diff --git a/mm/slob.c b/mm/slob.c
index 154e7bd..41d32c3 100644
--- a/mm/slob.c
+++ b/mm/slob.c
@@ -347,7 +347,7 @@
 	slobidx_t units;
 	unsigned long flags;
 
-	if (!block)
+	if (ZERO_OR_NULL_PTR(block))
 		return;
 	BUG_ON(!size);
 
@@ -424,10 +424,13 @@
 
 void *__kmalloc_node(size_t size, gfp_t gfp, int node)
 {
+	unsigned int *m;
 	int align = max(ARCH_KMALLOC_MINALIGN, ARCH_SLAB_MINALIGN);
 
 	if (size < PAGE_SIZE - align) {
-		unsigned int *m;
+		if (!size)
+			return ZERO_SIZE_PTR;
+
 		m = slob_alloc(size + align, gfp, align, node);
 		if (m)
 			*m = size;
@@ -450,7 +453,7 @@
 {
 	struct slob_page *sp;
 
-	if (!block)
+	if (ZERO_OR_NULL_PTR(block))
 		return;
 
 	sp = (struct slob_page *)virt_to_page(block);
@@ -468,7 +471,7 @@
 {
 	struct slob_page *sp;
 
-	if (!block)
+	if (ZERO_OR_NULL_PTR(block))
 		return 0;
 
 	sp = (struct slob_page *)virt_to_page(block);
diff --git a/mm/slub.c b/mm/slub.c
index 1b0a95d..548d78d 100644
--- a/mm/slub.c
+++ b/mm/slub.c
@@ -2270,10 +2270,11 @@
 	int index = kmalloc_index(size);
 
 	if (!index)
-		return NULL;
+		return ZERO_SIZE_PTR;
 
 	/* Allocation too large? */
-	BUG_ON(index < 0);
+	if (index < 0)
+		return NULL;
 
 #ifdef CONFIG_ZONE_DMA
 	if ((flags & SLUB_DMA)) {
@@ -2314,9 +2315,10 @@
 {
 	struct kmem_cache *s = get_slab(size, flags);
 
-	if (s)
-		return slab_alloc(s, flags, -1, __builtin_return_address(0));
-	return ZERO_SIZE_PTR;
+	if (ZERO_OR_NULL_PTR(s))
+		return s;
+
+	return slab_alloc(s, flags, -1, __builtin_return_address(0));
 }
 EXPORT_SYMBOL(__kmalloc);
 
@@ -2325,9 +2327,10 @@
 {
 	struct kmem_cache *s = get_slab(size, flags);
 
-	if (s)
-		return slab_alloc(s, flags, node, __builtin_return_address(0));
-	return ZERO_SIZE_PTR;
+	if (ZERO_OR_NULL_PTR(s))
+		return s;
+
+	return slab_alloc(s, flags, node, __builtin_return_address(0));
 }
 EXPORT_SYMBOL(__kmalloc_node);
 #endif
@@ -2378,7 +2381,7 @@
 	 * this comparison would be true for all "negative" pointers
 	 * (which would cover the whole upper half of the address space).
 	 */
-	if ((unsigned long)x <= (unsigned long)ZERO_SIZE_PTR)
+	if (ZERO_OR_NULL_PTR(x))
 		return;
 
 	page = virt_to_head_page(x);
@@ -2687,8 +2690,8 @@
 {
 	struct kmem_cache *s = get_slab(size, gfpflags);
 
-	if (!s)
-		return ZERO_SIZE_PTR;
+	if (ZERO_OR_NULL_PTR(s))
+		return s;
 
 	return slab_alloc(s, gfpflags, -1, caller);
 }
@@ -2698,8 +2701,8 @@
 {
 	struct kmem_cache *s = get_slab(size, gfpflags);
 
-	if (!s)
-		return ZERO_SIZE_PTR;
+	if (ZERO_OR_NULL_PTR(s))
+		return s;
 
 	return slab_alloc(s, gfpflags, node, caller);
 }
diff --git a/mm/util.c b/mm/util.c
index 18396ea..f2f21b7 100644
--- a/mm/util.c
+++ b/mm/util.c
@@ -76,7 +76,7 @@
 
 	if (unlikely(!new_size)) {
 		kfree(p);
-		return NULL;
+		return ZERO_SIZE_PTR;
 	}
 
 	ks = ksize(p);