SUNRPC: Convert the credcache lookup code to use RCU

Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
diff --git a/include/linux/sunrpc/auth.h b/include/linux/sunrpc/auth.h
index 4e78f0c..5974e8a 100644
--- a/include/linux/sunrpc/auth.h
+++ b/include/linux/sunrpc/auth.h
@@ -16,6 +16,7 @@
 #include <linux/sunrpc/xdr.h>
 
 #include <asm/atomic.h>
+#include <linux/rcupdate.h>
 
 /* size of the nodename buffer */
 #define UNX_MAXNODENAME	32
@@ -35,6 +36,7 @@
 struct rpc_cred {
 	struct hlist_node	cr_hash;	/* hash chain */
 	struct list_head	cr_lru;		/* lru garbage collection */
+	struct rcu_head		cr_rcu;
 	struct rpc_auth *	cr_auth;
 	const struct rpc_credops *cr_ops;
 #ifdef RPC_DEBUG
@@ -50,6 +52,7 @@
 };
 #define RPCAUTH_CRED_NEW	0
 #define RPCAUTH_CRED_UPTODATE	1
+#define RPCAUTH_CRED_HASHED	2
 
 #define RPCAUTH_CRED_MAGIC	0x0f4aa4f0
 
diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c
index 00f9649..ad7bde2 100644
--- a/net/sunrpc/auth.c
+++ b/net/sunrpc/auth.c
@@ -112,6 +112,14 @@
 
 static DEFINE_SPINLOCK(rpc_credcache_lock);
 
+static void
+rpcauth_unhash_cred_locked(struct rpc_cred *cred)
+{
+	hlist_del_rcu(&cred->cr_hash);
+	smp_mb__before_clear_bit();
+	clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
+}
+
 /*
  * Initialize RPC credential cache
  */
@@ -166,8 +174,7 @@
 			cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
 			get_rpccred(cred);
 			list_move_tail(&cred->cr_lru, &free);
-			smp_wmb();
-			hlist_del_init(&cred->cr_hash);
+			rpcauth_unhash_cred_locked(cred);
 		}
 	}
 	spin_unlock(&rpc_credcache_lock);
@@ -207,8 +214,7 @@
 			continue;
 		get_rpccred(cred);
 		list_add_tail(&cred->cr_lru, free);
-		smp_wmb();
-		hlist_del_init(&cred->cr_hash);
+		rpcauth_unhash_cred_locked(cred);
 	}
 }
 
@@ -218,10 +224,12 @@
 static void
 rpcauth_gc_credcache(struct rpc_cred_cache *cache, struct list_head *free)
 {
-	if (time_before(jiffies, cache->nextgc))
+	if (list_empty(&cred_unused) || time_before(jiffies, cache->nextgc))
 		return;
+	spin_lock(&rpc_credcache_lock);
 	cache->nextgc = jiffies + cache->expire;
 	rpcauth_prune_expired(free);
+	spin_unlock(&rpc_credcache_lock);
 }
 
 /*
@@ -234,42 +242,57 @@
 	LIST_HEAD(free);
 	struct rpc_cred_cache *cache = auth->au_credcache;
 	struct hlist_node *pos;
-	struct rpc_cred	*new = NULL,
-			*cred = NULL,
-			*entry;
+	struct rpc_cred	*cred = NULL,
+			*entry, *new;
 	int		nr = 0;
 
 	if (!(flags & RPCAUTH_LOOKUP_ROOTCREDS))
 		nr = acred->uid & RPC_CREDCACHE_MASK;
-retry:
+
+	rcu_read_lock();
+	hlist_for_each_entry_rcu(entry, pos, &cache->hashtable[nr], cr_hash) {
+		if (!entry->cr_ops->crmatch(acred, entry, flags))
+			continue;
+		spin_lock(&rpc_credcache_lock);
+		if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) {
+			spin_unlock(&rpc_credcache_lock);
+			continue;
+		}
+		cred = get_rpccred(entry);
+		spin_unlock(&rpc_credcache_lock);
+		break;
+	}
+	rcu_read_unlock();
+
+	if (cred != NULL) {
+		rpcauth_gc_credcache(cache, &free);
+		goto found;
+	}
+
+	new = auth->au_ops->crcreate(auth, acred, flags);
+	if (IS_ERR(new)) {
+		cred = new;
+		goto out;
+	}
+
 	spin_lock(&rpc_credcache_lock);
 	hlist_for_each_entry(entry, pos, &cache->hashtable[nr], cr_hash) {
 		if (!entry->cr_ops->crmatch(acred, entry, flags))
 			continue;
 		cred = get_rpccred(entry);
-		hlist_del(&entry->cr_hash);
 		break;
 	}
-	if (new) {
-		if (cred)
-			list_add_tail(&new->cr_lru, &free);
-		else
-			cred = new;
-	}
-	if (cred) {
-		hlist_add_head(&cred->cr_hash, &cache->hashtable[nr]);
-	}
-	rpcauth_gc_credcache(cache, &free);
-	spin_unlock(&rpc_credcache_lock);
-
-	rpcauth_destroy_credlist(&free);
-
-	if (!cred) {
-		new = auth->au_ops->crcreate(auth, acred, flags);
-		if (!IS_ERR(new))
-			goto retry;
+	if (cred == NULL) {
 		cred = new;
-	} else if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)
+		set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
+		hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
+	} else
+		list_add_tail(&new->cr_lru, &free);
+	rpcauth_prune_expired(&free);
+	cache->nextgc = jiffies + cache->expire;
+	spin_unlock(&rpc_credcache_lock);
+found:
+	if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)
 			&& cred->cr_ops->cr_init != NULL
 			&& !(flags & RPCAUTH_LOOKUP_NEW)) {
 		int res = cred->cr_ops->cr_init(auth, cred);
@@ -278,8 +301,9 @@
 			cred = ERR_PTR(res);
 		}
 	}
-
-	return (struct rpc_cred *) cred;
+	rpcauth_destroy_credlist(&free);
+out:
+	return cred;
 }
 
 struct rpc_cred *
@@ -357,21 +381,20 @@
 put_rpccred(struct rpc_cred *cred)
 {
 	/* Fast path for unhashed credentials */
-	if (!hlist_unhashed(&cred->cr_hash))
+	if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0)
 		goto need_lock;
 
 	if (!atomic_dec_and_test(&cred->cr_count))
 		return;
 	goto out_destroy;
-
 need_lock:
 	if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
 		return;
 	if (!list_empty(&cred->cr_lru))
 		list_del_init(&cred->cr_lru);
 	if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
-		hlist_del(&cred->cr_hash);
-	else if (!hlist_unhashed(&cred->cr_hash)) {
+		rpcauth_unhash_cred_locked(cred);
+	else if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) {
 		cred->cr_expire = jiffies;
 		list_add_tail(&cred->cr_lru, &cred_unused);
 		spin_unlock(&rpc_credcache_lock);
diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c
index 55c47ae..068fa6d 100644
--- a/net/sunrpc/auth_gss/auth_gss.c
+++ b/net/sunrpc/auth_gss/auth_gss.c
@@ -694,15 +694,25 @@
 }
 
 static void
-gss_destroy_cred(struct rpc_cred *rc)
+gss_free_cred(struct gss_cred *gss_cred)
 {
-	struct gss_cred *cred = container_of(rc, struct gss_cred, gc_base);
+	dprintk("RPC:       gss_free_cred %p\n", gss_cred);
+	if (gss_cred->gc_ctx)
+		gss_put_ctx(gss_cred->gc_ctx);
+	kfree(gss_cred);
+}
 
-	dprintk("RPC:       gss_destroy_cred \n");
+static void
+gss_free_cred_callback(struct rcu_head *head)
+{
+	struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu);
+	gss_free_cred(gss_cred);
+}
 
-	if (cred->gc_ctx)
-		gss_put_ctx(cred->gc_ctx);
-	kfree(cred);
+static void
+gss_destroy_cred(struct rpc_cred *cred)
+{
+	call_rcu(&cred->cr_rcu, gss_free_cred_callback);
 }
 
 /*
diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c
index 29d50ff..f7ff6ad 100644
--- a/net/sunrpc/auth_unix.c
+++ b/net/sunrpc/auth_unix.c
@@ -93,11 +93,23 @@
 }
 
 static void
-unx_destroy_cred(struct rpc_cred *rcred)
+unx_free_cred(struct unx_cred *unx_cred)
 {
-	struct unx_cred	*cred = container_of(rcred, struct unx_cred, uc_base);
+	dprintk("RPC:       unx_free_cred %p\n", unx_cred);
+	kfree(unx_cred);
+}
 
-	kfree(cred);
+static void
+unx_free_cred_callback(struct rcu_head *head)
+{
+	struct unx_cred *unx_cred = container_of(head, struct unx_cred, uc_base.cr_rcu);
+	unx_free_cred(unx_cred);
+}
+
+static void
+unx_destroy_cred(struct rpc_cred *cred)
+{
+	call_rcu(&cred->cr_rcu, unx_free_cred_callback);
 }
 
 /*