[XFRM]: Speed up xfrm_policy and xfrm_state walking

Change xfrm_policy and xfrm_state walking algorithm from O(n^2) to O(n).
This is achieved adding the entries to one more list which is used
solely for walking the entries.

This also fixes some races where the dump can have duplicate or missing
entries when the SPD/SADB is modified during an ongoing dump.

Dumping SADB with 20000 entries using "time ip xfrm state" the sys
time dropped from 1.012s to 0.080s.

Signed-off-by: Timo Teras <timo.teras@iki.fi>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/linux/xfrm.h b/include/linux/xfrm.h
index e31b8c8..0c82c80 100644
--- a/include/linux/xfrm.h
+++ b/include/linux/xfrm.h
@@ -113,7 +113,8 @@
 {
 	XFRM_POLICY_TYPE_MAIN	= 0,
 	XFRM_POLICY_TYPE_SUB	= 1,
-	XFRM_POLICY_TYPE_MAX	= 2
+	XFRM_POLICY_TYPE_MAX	= 2,
+	XFRM_POLICY_TYPE_ANY	= 255
 };
 
 enum
diff --git a/include/net/xfrm.h b/include/net/xfrm.h
index eea7785..9b62056 100644
--- a/include/net/xfrm.h
+++ b/include/net/xfrm.h
@@ -121,6 +121,7 @@
 struct xfrm_state
 {
 	/* Note: bydst is re-used during gc */
+	struct list_head	all;
 	struct hlist_node	bydst;
 	struct hlist_node	bysrc;
 	struct hlist_node	byspi;
@@ -424,6 +425,7 @@
 struct xfrm_policy
 {
 	struct xfrm_policy	*next;
+	struct list_head	bytype;
 	struct hlist_node	bydst;
 	struct hlist_node	byidx;
 
@@ -1160,6 +1162,18 @@
 	int priority;
 };
 
+struct xfrm_state_walk {
+	struct xfrm_state *state;
+	int count;
+	u8 proto;
+};
+
+struct xfrm_policy_walk {
+	struct xfrm_policy *policy;
+	int count;
+	u8 type, cur_type;
+};
+
 extern void xfrm_init(void);
 extern void xfrm4_init(void);
 extern void xfrm_state_init(void);
@@ -1184,7 +1198,23 @@
 extern int xfrm_proc_init(void);
 #endif
 
-extern int xfrm_state_walk(u8 proto, int (*func)(struct xfrm_state *, int, void*), void *);
+static inline void xfrm_state_walk_init(struct xfrm_state_walk *walk, u8 proto)
+{
+	walk->proto = proto;
+	walk->state = NULL;
+	walk->count = 0;
+}
+
+static inline void xfrm_state_walk_done(struct xfrm_state_walk *walk)
+{
+	if (walk->state != NULL) {
+		xfrm_state_put(walk->state);
+		walk->state = NULL;
+	}
+}
+
+extern int xfrm_state_walk(struct xfrm_state_walk *walk,
+			   int (*func)(struct xfrm_state *, int, void*), void *);
 extern struct xfrm_state *xfrm_state_alloc(void);
 extern struct xfrm_state *xfrm_state_find(xfrm_address_t *daddr, xfrm_address_t *saddr, 
 					  struct flowi *fl, struct xfrm_tmpl *tmpl,
@@ -1306,7 +1336,25 @@
 #endif
 
 struct xfrm_policy *xfrm_policy_alloc(gfp_t gfp);
-extern int xfrm_policy_walk(u8 type, int (*func)(struct xfrm_policy *, int, int, void*), void *);
+
+static inline void xfrm_policy_walk_init(struct xfrm_policy_walk *walk, u8 type)
+{
+	walk->cur_type = XFRM_POLICY_TYPE_MAIN;
+	walk->type = type;
+	walk->policy = NULL;
+	walk->count = 0;
+}
+
+static inline void xfrm_policy_walk_done(struct xfrm_policy_walk *walk)
+{
+	if (walk->policy != NULL) {
+		xfrm_pol_put(walk->policy);
+		walk->policy = NULL;
+	}
+}
+
+extern int xfrm_policy_walk(struct xfrm_policy_walk *walk,
+	int (*func)(struct xfrm_policy *, int, int, void*), void *);
 int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl);
 struct xfrm_policy *xfrm_policy_bysel_ctx(u8 type, int dir,
 					  struct xfrm_selector *sel,
diff --git a/net/key/af_key.c b/net/key/af_key.c
index 8b5f486..7cb6f12 100644
--- a/net/key/af_key.c
+++ b/net/key/af_key.c
@@ -1742,12 +1742,18 @@
 {
 	u8 proto;
 	struct pfkey_dump_data data = { .skb = skb, .hdr = hdr, .sk = sk };
+	struct xfrm_state_walk walk;
+	int rc;
 
 	proto = pfkey_satype2proto(hdr->sadb_msg_satype);
 	if (proto == 0)
 		return -EINVAL;
 
-	return xfrm_state_walk(proto, dump_sa, &data);
+	xfrm_state_walk_init(&walk, proto);
+	rc = xfrm_state_walk(&walk, dump_sa, &data);
+	xfrm_state_walk_done(&walk);
+
+	return rc;
 }
 
 static int pfkey_promisc(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
@@ -1780,7 +1786,9 @@
 
 static u32 gen_reqid(void)
 {
+	struct xfrm_policy_walk walk;
 	u32 start;
+	int rc;
 	static u32 reqid = IPSEC_MANUAL_REQID_MAX;
 
 	start = reqid;
@@ -1788,8 +1796,10 @@
 		++reqid;
 		if (reqid == 0)
 			reqid = IPSEC_MANUAL_REQID_MAX+1;
-		if (xfrm_policy_walk(XFRM_POLICY_TYPE_MAIN, check_reqid,
-				     (void*)&reqid) != -EEXIST)
+		xfrm_policy_walk_init(&walk, XFRM_POLICY_TYPE_MAIN);
+		rc = xfrm_policy_walk(&walk, check_reqid, (void*)&reqid);
+		xfrm_policy_walk_done(&walk);
+		if (rc != -EEXIST)
 			return reqid;
 	} while (reqid != start);
 	return 0;
@@ -2665,8 +2675,14 @@
 static int pfkey_spddump(struct sock *sk, struct sk_buff *skb, struct sadb_msg *hdr, void **ext_hdrs)
 {
 	struct pfkey_dump_data data = { .skb = skb, .hdr = hdr, .sk = sk };
+	struct xfrm_policy_walk walk;
+	int rc;
 
-	return xfrm_policy_walk(XFRM_POLICY_TYPE_MAIN, dump_sp, &data);
+	xfrm_policy_walk_init(&walk, XFRM_POLICY_TYPE_MAIN);
+	rc = xfrm_policy_walk(&walk, dump_sp, &data);
+	xfrm_policy_walk_done(&walk);
+
+	return rc;
 }
 
 static int key_notify_policy_flush(struct km_event *c)
diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c
index 9fc4c31..bae94a8 100644
--- a/net/xfrm/xfrm_policy.c
+++ b/net/xfrm/xfrm_policy.c
@@ -46,6 +46,7 @@
 
 static DEFINE_RWLOCK(xfrm_policy_lock);
 
+static struct list_head xfrm_policy_bytype[XFRM_POLICY_TYPE_MAX];
 unsigned int xfrm_policy_count[XFRM_POLICY_MAX*2];
 EXPORT_SYMBOL(xfrm_policy_count);
 
@@ -208,6 +209,7 @@
 	policy = kzalloc(sizeof(struct xfrm_policy), gfp);
 
 	if (policy) {
+		INIT_LIST_HEAD(&policy->bytype);
 		INIT_HLIST_NODE(&policy->bydst);
 		INIT_HLIST_NODE(&policy->byidx);
 		rwlock_init(&policy->lock);
@@ -230,6 +232,10 @@
 	if (del_timer(&policy->timer))
 		BUG();
 
+	write_lock_bh(&xfrm_policy_lock);
+	list_del(&policy->bytype);
+	write_unlock_bh(&xfrm_policy_lock);
+
 	security_xfrm_policy_free(policy);
 	kfree(policy);
 }
@@ -584,6 +590,7 @@
 	policy->curlft.use_time = 0;
 	if (!mod_timer(&policy->timer, jiffies + HZ))
 		xfrm_pol_hold(policy);
+	list_add_tail(&policy->bytype, &xfrm_policy_bytype[policy->type]);
 	write_unlock_bh(&xfrm_policy_lock);
 
 	if (delpol)
@@ -822,57 +829,60 @@
 }
 EXPORT_SYMBOL(xfrm_policy_flush);
 
-int xfrm_policy_walk(u8 type, int (*func)(struct xfrm_policy *, int, int, void*),
+int xfrm_policy_walk(struct xfrm_policy_walk *walk,
+		     int (*func)(struct xfrm_policy *, int, int, void*),
 		     void *data)
 {
-	struct xfrm_policy *pol, *last = NULL;
-	struct hlist_node *entry;
-	int dir, last_dir = 0, count, error;
+	struct xfrm_policy *old, *pol, *last = NULL;
+	int error = 0;
 
+	if (walk->type >= XFRM_POLICY_TYPE_MAX &&
+	    walk->type != XFRM_POLICY_TYPE_ANY)
+		return -EINVAL;
+
+	if (walk->policy == NULL && walk->count != 0)
+		return 0;
+
+	old = pol = walk->policy;
+	walk->policy = NULL;
 	read_lock_bh(&xfrm_policy_lock);
-	count = 0;
 
-	for (dir = 0; dir < 2*XFRM_POLICY_MAX; dir++) {
-		struct hlist_head *table = xfrm_policy_bydst[dir].table;
-		int i;
+	for (; walk->cur_type < XFRM_POLICY_TYPE_MAX; walk->cur_type++) {
+		if (walk->type != walk->cur_type &&
+		    walk->type != XFRM_POLICY_TYPE_ANY)
+			continue;
 
-		hlist_for_each_entry(pol, entry,
-				     &xfrm_policy_inexact[dir], bydst) {
-			if (pol->type != type)
+		if (pol == NULL) {
+			pol = list_first_entry(&xfrm_policy_bytype[walk->cur_type],
+					       struct xfrm_policy, bytype);
+		}
+		list_for_each_entry_from(pol, &xfrm_policy_bytype[walk->cur_type], bytype) {
+			if (pol->dead)
 				continue;
 			if (last) {
-				error = func(last, last_dir % XFRM_POLICY_MAX,
-					     count, data);
-				if (error)
+				error = func(last, xfrm_policy_id2dir(last->index),
+					     walk->count, data);
+				if (error) {
+					xfrm_pol_hold(last);
+					walk->policy = last;
 					goto out;
+				}
 			}
 			last = pol;
-			last_dir = dir;
-			count++;
+			walk->count++;
 		}
-		for (i = xfrm_policy_bydst[dir].hmask; i >= 0; i--) {
-			hlist_for_each_entry(pol, entry, table + i, bydst) {
-				if (pol->type != type)
-					continue;
-				if (last) {
-					error = func(last, last_dir % XFRM_POLICY_MAX,
-						     count, data);
-					if (error)
-						goto out;
-				}
-				last = pol;
-				last_dir = dir;
-				count++;
-			}
-		}
+		pol = NULL;
 	}
-	if (count == 0) {
+	if (walk->count == 0) {
 		error = -ENOENT;
 		goto out;
 	}
-	error = func(last, last_dir % XFRM_POLICY_MAX, 0, data);
+	if (last)
+		error = func(last, xfrm_policy_id2dir(last->index), 0, data);
 out:
 	read_unlock_bh(&xfrm_policy_lock);
+	if (old != NULL)
+		xfrm_pol_put(old);
 	return error;
 }
 EXPORT_SYMBOL(xfrm_policy_walk);
@@ -2365,6 +2375,9 @@
 			panic("XFRM: failed to allocate bydst hash\n");
 	}
 
+	for (dir = 0; dir < XFRM_POLICY_TYPE_MAX; dir++)
+		INIT_LIST_HEAD(&xfrm_policy_bytype[dir]);
+
 	INIT_WORK(&xfrm_policy_gc_work, xfrm_policy_gc_task);
 	register_netdevice_notifier(&xfrm_dev_notifier);
 }
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index 7ba65e8..9880b79 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -50,6 +50,7 @@
  * Main use is finding SA after policy selected tunnel or transport mode.
  * Also, it can be used by ah/esp icmp error handler to find offending SA.
  */
+static LIST_HEAD(xfrm_state_all);
 static struct hlist_head *xfrm_state_bydst __read_mostly;
 static struct hlist_head *xfrm_state_bysrc __read_mostly;
 static struct hlist_head *xfrm_state_byspi __read_mostly;
@@ -510,6 +511,7 @@
 	if (x) {
 		atomic_set(&x->refcnt, 1);
 		atomic_set(&x->tunnel_users, 0);
+		INIT_LIST_HEAD(&x->all);
 		INIT_HLIST_NODE(&x->bydst);
 		INIT_HLIST_NODE(&x->bysrc);
 		INIT_HLIST_NODE(&x->byspi);
@@ -533,6 +535,10 @@
 {
 	BUG_TRAP(x->km.state == XFRM_STATE_DEAD);
 
+	spin_lock_bh(&xfrm_state_lock);
+	list_del(&x->all);
+	spin_unlock_bh(&xfrm_state_lock);
+
 	spin_lock_bh(&xfrm_state_gc_lock);
 	hlist_add_head(&x->bydst, &xfrm_state_gc_list);
 	spin_unlock_bh(&xfrm_state_gc_lock);
@@ -909,6 +915,8 @@
 
 	x->genid = ++xfrm_state_genid;
 
+	list_add_tail(&x->all, &xfrm_state_all);
+
 	h = xfrm_dst_hash(&x->id.daddr, &x->props.saddr,
 			  x->props.reqid, x->props.family);
 	hlist_add_head(&x->bydst, xfrm_state_bydst+h);
@@ -1518,36 +1526,47 @@
 }
 EXPORT_SYMBOL(xfrm_alloc_spi);
 
-int xfrm_state_walk(u8 proto, int (*func)(struct xfrm_state *, int, void*),
+int xfrm_state_walk(struct xfrm_state_walk *walk,
+		    int (*func)(struct xfrm_state *, int, void*),
 		    void *data)
 {
-	int i;
-	struct xfrm_state *x, *last = NULL;
-	struct hlist_node *entry;
-	int count = 0;
+	struct xfrm_state *old, *x, *last = NULL;
 	int err = 0;
 
+	if (walk->state == NULL && walk->count != 0)
+		return 0;
+
+	old = x = walk->state;
+	walk->state = NULL;
 	spin_lock_bh(&xfrm_state_lock);
-	for (i = 0; i <= xfrm_state_hmask; i++) {
-		hlist_for_each_entry(x, entry, xfrm_state_bydst+i, bydst) {
-			if (!xfrm_id_proto_match(x->id.proto, proto))
-				continue;
-			if (last) {
-				err = func(last, count, data);
-				if (err)
-					goto out;
+	if (x == NULL)
+		x = list_first_entry(&xfrm_state_all, struct xfrm_state, all);
+	list_for_each_entry_from(x, &xfrm_state_all, all) {
+		if (x->km.state == XFRM_STATE_DEAD)
+			continue;
+		if (!xfrm_id_proto_match(x->id.proto, walk->proto))
+			continue;
+		if (last) {
+			err = func(last, walk->count, data);
+			if (err) {
+				xfrm_state_hold(last);
+				walk->state = last;
+				goto out;
 			}
-			last = x;
-			count++;
 		}
+		last = x;
+		walk->count++;
 	}
-	if (count == 0) {
+	if (walk->count == 0) {
 		err = -ENOENT;
 		goto out;
 	}
-	err = func(last, 0, data);
+	if (last)
+		err = func(last, 0, data);
 out:
 	spin_unlock_bh(&xfrm_state_lock);
+	if (old != NULL)
+		xfrm_state_put(old);
 	return err;
 }
 EXPORT_SYMBOL(xfrm_state_walk);
diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c
index f971ca5..f5fd5b3 100644
--- a/net/xfrm/xfrm_user.c
+++ b/net/xfrm/xfrm_user.c
@@ -532,8 +532,6 @@
 	struct sk_buff *out_skb;
 	u32 nlmsg_seq;
 	u16 nlmsg_flags;
-	int start_idx;
-	int this_idx;
 };
 
 static int copy_sec_ctx(struct xfrm_sec_ctx *s, struct sk_buff *skb)
@@ -600,9 +598,6 @@
 	struct nlmsghdr *nlh;
 	int err;
 
-	if (sp->this_idx < sp->start_idx)
-		goto out;
-
 	nlh = nlmsg_put(skb, NETLINK_CB(in_skb).pid, sp->nlmsg_seq,
 			XFRM_MSG_NEWSA, sizeof(*p), sp->nlmsg_flags);
 	if (nlh == NULL)
@@ -615,8 +610,6 @@
 		goto nla_put_failure;
 
 	nlmsg_end(skb, nlh);
-out:
-	sp->this_idx++;
 	return 0;
 
 nla_put_failure:
@@ -624,18 +617,32 @@
 	return err;
 }
 
+static int xfrm_dump_sa_done(struct netlink_callback *cb)
+{
+	struct xfrm_state_walk *walk = (struct xfrm_state_walk *) &cb->args[1];
+	xfrm_state_walk_done(walk);
+	return 0;
+}
+
 static int xfrm_dump_sa(struct sk_buff *skb, struct netlink_callback *cb)
 {
+	struct xfrm_state_walk *walk = (struct xfrm_state_walk *) &cb->args[1];
 	struct xfrm_dump_info info;
 
+	BUILD_BUG_ON(sizeof(struct xfrm_state_walk) >
+		     sizeof(cb->args) - sizeof(cb->args[0]));
+
 	info.in_skb = cb->skb;
 	info.out_skb = skb;
 	info.nlmsg_seq = cb->nlh->nlmsg_seq;
 	info.nlmsg_flags = NLM_F_MULTI;
-	info.this_idx = 0;
-	info.start_idx = cb->args[0];
-	(void) xfrm_state_walk(0, dump_one_state, &info);
-	cb->args[0] = info.this_idx;
+
+	if (!cb->args[0]) {
+		cb->args[0] = 1;
+		xfrm_state_walk_init(walk, 0);
+	}
+
+	(void) xfrm_state_walk(walk, dump_one_state, &info);
 
 	return skb->len;
 }
@@ -654,7 +661,6 @@
 	info.out_skb = skb;
 	info.nlmsg_seq = seq;
 	info.nlmsg_flags = 0;
-	info.this_idx = info.start_idx = 0;
 
 	if (dump_one_state(x, 0, &info)) {
 		kfree_skb(skb);
@@ -1232,9 +1238,6 @@
 	struct sk_buff *skb = sp->out_skb;
 	struct nlmsghdr *nlh;
 
-	if (sp->this_idx < sp->start_idx)
-		goto out;
-
 	nlh = nlmsg_put(skb, NETLINK_CB(in_skb).pid, sp->nlmsg_seq,
 			XFRM_MSG_NEWPOLICY, sizeof(*p), sp->nlmsg_flags);
 	if (nlh == NULL)
@@ -1250,8 +1253,6 @@
 		goto nlmsg_failure;
 
 	nlmsg_end(skb, nlh);
-out:
-	sp->this_idx++;
 	return 0;
 
 nlmsg_failure:
@@ -1259,21 +1260,33 @@
 	return -EMSGSIZE;
 }
 
+static int xfrm_dump_policy_done(struct netlink_callback *cb)
+{
+	struct xfrm_policy_walk *walk = (struct xfrm_policy_walk *) &cb->args[1];
+
+	xfrm_policy_walk_done(walk);
+	return 0;
+}
+
 static int xfrm_dump_policy(struct sk_buff *skb, struct netlink_callback *cb)
 {
+	struct xfrm_policy_walk *walk = (struct xfrm_policy_walk *) &cb->args[1];
 	struct xfrm_dump_info info;
 
+	BUILD_BUG_ON(sizeof(struct xfrm_policy_walk) >
+		     sizeof(cb->args) - sizeof(cb->args[0]));
+
 	info.in_skb = cb->skb;
 	info.out_skb = skb;
 	info.nlmsg_seq = cb->nlh->nlmsg_seq;
 	info.nlmsg_flags = NLM_F_MULTI;
-	info.this_idx = 0;
-	info.start_idx = cb->args[0];
-	(void) xfrm_policy_walk(XFRM_POLICY_TYPE_MAIN, dump_one_policy, &info);
-#ifdef CONFIG_XFRM_SUB_POLICY
-	(void) xfrm_policy_walk(XFRM_POLICY_TYPE_SUB, dump_one_policy, &info);
-#endif
-	cb->args[0] = info.this_idx;
+
+	if (!cb->args[0]) {
+		cb->args[0] = 1;
+		xfrm_policy_walk_init(walk, XFRM_POLICY_TYPE_ANY);
+	}
+
+	(void) xfrm_policy_walk(walk, dump_one_policy, &info);
 
 	return skb->len;
 }
@@ -1293,7 +1306,6 @@
 	info.out_skb = skb;
 	info.nlmsg_seq = seq;
 	info.nlmsg_flags = 0;
-	info.this_idx = info.start_idx = 0;
 
 	if (dump_one_policy(xp, dir, 0, &info) < 0) {
 		kfree_skb(skb);
@@ -1891,15 +1903,18 @@
 static struct xfrm_link {
 	int (*doit)(struct sk_buff *, struct nlmsghdr *, struct nlattr **);
 	int (*dump)(struct sk_buff *, struct netlink_callback *);
+	int (*done)(struct netlink_callback *);
 } xfrm_dispatch[XFRM_NR_MSGTYPES] = {
 	[XFRM_MSG_NEWSA       - XFRM_MSG_BASE] = { .doit = xfrm_add_sa        },
 	[XFRM_MSG_DELSA       - XFRM_MSG_BASE] = { .doit = xfrm_del_sa        },
 	[XFRM_MSG_GETSA       - XFRM_MSG_BASE] = { .doit = xfrm_get_sa,
-						   .dump = xfrm_dump_sa       },
+						   .dump = xfrm_dump_sa,
+						   .done = xfrm_dump_sa_done  },
 	[XFRM_MSG_NEWPOLICY   - XFRM_MSG_BASE] = { .doit = xfrm_add_policy    },
 	[XFRM_MSG_DELPOLICY   - XFRM_MSG_BASE] = { .doit = xfrm_get_policy    },
 	[XFRM_MSG_GETPOLICY   - XFRM_MSG_BASE] = { .doit = xfrm_get_policy,
-						   .dump = xfrm_dump_policy   },
+						   .dump = xfrm_dump_policy,
+						   .done = xfrm_dump_policy_done },
 	[XFRM_MSG_ALLOCSPI    - XFRM_MSG_BASE] = { .doit = xfrm_alloc_userspi },
 	[XFRM_MSG_ACQUIRE     - XFRM_MSG_BASE] = { .doit = xfrm_add_acquire   },
 	[XFRM_MSG_EXPIRE      - XFRM_MSG_BASE] = { .doit = xfrm_add_sa_expire },
@@ -1938,7 +1953,7 @@
 		if (link->dump == NULL)
 			return -EINVAL;
 
-		return netlink_dump_start(xfrm_nl, skb, nlh, link->dump, NULL);
+		return netlink_dump_start(xfrm_nl, skb, nlh, link->dump, link->done);
 	}
 
 	err = nlmsg_parse(nlh, xfrm_msg_min[type], attrs, XFRMA_MAX,