SUNRPC: Fix races in rpcauth_create
See the FIXME: auth_flavors[] really needs a lock and module refcounting.
Signed-off-by: Trond Myklebust <Trond.Myklebust@netapp.com>
diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c
index f6b6c81..584f243 100644
--- a/net/sunrpc/auth.c
+++ b/net/sunrpc/auth.c
@@ -18,6 +18,7 @@
# define RPCDBG_FACILITY RPCDBG_AUTH
#endif
+static DEFINE_SPINLOCK(rpc_authflavor_lock);
static struct rpc_authops * auth_flavors[RPC_AUTH_MAXFLAVOR] = {
&authnull_ops, /* AUTH_NULL */
&authunix_ops, /* AUTH_UNIX */
@@ -35,26 +36,34 @@
rpcauth_register(struct rpc_authops *ops)
{
rpc_authflavor_t flavor;
+ int ret = -EPERM;
if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
return -EINVAL;
- if (auth_flavors[flavor] != NULL)
- return -EPERM; /* what else? */
- auth_flavors[flavor] = ops;
- return 0;
+ spin_lock(&rpc_authflavor_lock);
+ if (auth_flavors[flavor] == NULL) {
+ auth_flavors[flavor] = ops;
+ ret = 0;
+ }
+ spin_unlock(&rpc_authflavor_lock);
+ return ret;
}
int
rpcauth_unregister(struct rpc_authops *ops)
{
rpc_authflavor_t flavor;
+ int ret = -EPERM;
if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
return -EINVAL;
- if (auth_flavors[flavor] != ops)
- return -EPERM; /* what else? */
- auth_flavors[flavor] = NULL;
- return 0;
+ spin_lock(&rpc_authflavor_lock);
+ if (auth_flavors[flavor] == ops) {
+ auth_flavors[flavor] = NULL;
+ ret = 0;
+ }
+ spin_unlock(&rpc_authflavor_lock);
+ return ret;
}
struct rpc_auth *
@@ -68,15 +77,19 @@
if (flavor >= RPC_AUTH_MAXFLAVOR)
goto out;
- /* FIXME - auth_flavors[] really needs an rw lock,
- * and module refcounting. */
#ifdef CONFIG_KMOD
if ((ops = auth_flavors[flavor]) == NULL)
request_module("rpc-auth-%u", flavor);
#endif
- if ((ops = auth_flavors[flavor]) == NULL)
+ spin_lock(&rpc_authflavor_lock);
+ ops = auth_flavors[flavor];
+ if (ops == NULL || !try_module_get(ops->owner)) {
+ spin_unlock(&rpc_authflavor_lock);
goto out;
+ }
+ spin_unlock(&rpc_authflavor_lock);
auth = ops->create(clnt, pseudoflavor);
+ module_put(ops->owner);
if (IS_ERR(auth))
return auth;
if (clnt->cl_auth)