[TCPDIAG]: Introduce inet_diag_{register,unregister}

Next changeset will rename tcp_diag to inet_diag and move the tcp_diag code out
of it and into a new tcp_diag.c, similar to the net/dccp/diag.c introduced in
this changeset, completing the transition to a generic inet_diag
infrastructure.

Signed-off-by: Arnaldo Carvalho de Melo <acme@mandriva.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/tcp_diag.c b/net/ipv4/tcp_diag.c
index b812191..b13b71c 100644
--- a/net/ipv4/tcp_diag.c
+++ b/net/ipv4/tcp_diag.c
@@ -34,6 +34,8 @@
 
 #include <linux/tcp_diag.h>
 
+static const struct inet_diag_handler **inet_diag_table;
+
 struct tcpdiag_entry
 {
 	u32 *saddr;
@@ -61,18 +63,24 @@
 	const struct inet_connection_sock *icsk = inet_csk(sk);
 	struct tcpdiagmsg *r;
 	struct nlmsghdr  *nlh;
-	struct tcp_info  *info = NULL;
+	void *info = NULL;
 	struct tcpdiag_meminfo  *minfo = NULL;
 	unsigned char	 *b = skb->tail;
+	const struct inet_diag_handler *handler;
+
+	handler = inet_diag_table[unlh->nlmsg_type];
+	BUG_ON(handler == NULL);
 
 	nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r));
 	nlh->nlmsg_flags = nlmsg_flags;
+
 	r = NLMSG_DATA(nlh);
 	if (sk->sk_state != TCP_TIME_WAIT) {
 		if (ext & (1<<(TCPDIAG_MEMINFO-1)))
 			minfo = TCPDIAG_PUT(skb, TCPDIAG_MEMINFO, sizeof(*minfo));
 		if (ext & (1<<(TCPDIAG_INFO-1)))
-			info = TCPDIAG_PUT(skb, TCPDIAG_INFO, sizeof(*info));
+			info = TCPDIAG_PUT(skb, TCPDIAG_INFO,
+					   handler->idiag_info_size);
 		
 		if ((ext & (1 << (TCPDIAG_CONG - 1))) && icsk->icsk_ca_ops) {
 			size_t len = strlen(icsk->icsk_ca_ops->name);
@@ -155,19 +163,6 @@
 		r->tcpdiag_expires = 0;
 	}
 #undef EXPIRES_IN_MS
-	/*
-	 * Ahem... for now we'll have some knowledge about TCP -acme
-	 * But this is just one of two small exceptions, both in this
-	 * function, so lets close our eyes for some 15 lines or so... 8)
-	 * -acme
-	 */
-	if (sk->sk_protocol == IPPROTO_TCP) {
-		const struct tcp_sock *tp = tcp_sk(sk);
-
-		r->tcpdiag_rqueue = tp->rcv_nxt - tp->copied_seq;
-		r->tcpdiag_wqueue = tp->write_seq - tp->snd_una;
-	} else
-		r->tcpdiag_rqueue = r->tcpdiag_wqueue = 0;
 
 	r->tcpdiag_uid = sock_i_uid(sk);
 	r->tcpdiag_inode = sock_i_ino(sk);
@@ -179,13 +174,7 @@
 		minfo->tcpdiag_tmem = atomic_read(&sk->sk_wmem_alloc);
 	}
 
-	/* Ahem... for now we'll have some knowledge about TCP -acme */
-	if (info) {
-		if (sk->sk_protocol == IPPROTO_TCP) 
-			tcp_get_info(sk, info);
-		else
-			memset(info, 0, sizeof(*info));
-	}
+	handler->idiag_get_info(sk, r, info);
 
 	if (sk->sk_state < TCP_TIME_WAIT &&
 	    icsk->icsk_ca_ops && icsk->icsk_ca_ops->get_info)
@@ -206,11 +195,13 @@
 	struct sock *sk;
 	struct tcpdiagreq *req = NLMSG_DATA(nlh);
 	struct sk_buff *rep;
-	struct inet_hashinfo *hashinfo = &tcp_hashinfo;
-#ifdef CONFIG_IP_TCPDIAG_DCCP
-	if (nlh->nlmsg_type == DCCPDIAG_GETSOCK)
-		hashinfo = &dccp_hashinfo;
-#endif
+	struct inet_hashinfo *hashinfo;
+	const struct inet_diag_handler *handler;
+
+	handler = inet_diag_table[nlh->nlmsg_type];
+	BUG_ON(handler == NULL);
+	hashinfo = handler->idiag_hashinfo;
+
 	if (req->tcpdiag_family == AF_INET) {
 		sk = inet_lookup(hashinfo, req->id.tcpdiag_dst[0],
 				 req->id.tcpdiag_dport, req->id.tcpdiag_src[0],
@@ -241,9 +232,10 @@
 		goto out;
 
 	err = -ENOMEM;
-	rep = alloc_skb(NLMSG_SPACE(sizeof(struct tcpdiagmsg)+
-				    sizeof(struct tcpdiag_meminfo)+
-				    sizeof(struct tcp_info)+64), GFP_KERNEL);
+	rep = alloc_skb(NLMSG_SPACE((sizeof(struct tcpdiagmsg) +
+				     sizeof(struct tcpdiag_meminfo) +
+				     handler->idiag_info_size + 64)),
+			GFP_KERNEL);
 	if (!rep)
 		goto out;
 
@@ -603,15 +595,16 @@
 	int i, num;
 	int s_i, s_num;
 	struct tcpdiagreq *r = NLMSG_DATA(cb->nlh);
+	const struct inet_diag_handler *handler;
 	struct inet_hashinfo *hashinfo;
 
+	handler = inet_diag_table[cb->nlh->nlmsg_type];
+	BUG_ON(handler == NULL);
+	hashinfo = handler->idiag_hashinfo;
+		
 	s_i = cb->args[1];
 	s_num = num = cb->args[2];
-		hashinfo = &tcp_hashinfo;
-#ifdef CONFIG_IP_TCPDIAG_DCCP
-	if (cb->nlh->nlmsg_type == DCCPDIAG_GETSOCK)
-		hashinfo = &dccp_hashinfo;
-#endif
+
 	if (cb->args[0] == 0) {
 		if (!(r->tcpdiag_states&(TCPF_LISTEN|TCPF_SYN_RECV)))
 			goto skip_listen_ht;
@@ -745,13 +738,12 @@
 	if (!(nlh->nlmsg_flags&NLM_F_REQUEST))
 		return 0;
 
-	if (nlh->nlmsg_type != TCPDIAG_GETSOCK
-#ifdef CONFIG_IP_TCPDIAG_DCCP
-	    && nlh->nlmsg_type != DCCPDIAG_GETSOCK
-#endif
-	   )
+	if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX)
 		goto err_inval;
 
+	if (inet_diag_table[nlh->nlmsg_type] == NULL)
+		return -ENOENT;
+
 	if (NLMSG_LENGTH(sizeof(struct tcpdiagreq)) > skb->len)
 		goto err_inval;
 
@@ -803,18 +795,95 @@
 	}
 }
 
+static void tcp_diag_get_info(struct sock *sk, struct tcpdiagmsg *r,
+			      void *_info)
+{
+	const struct tcp_sock *tp = tcp_sk(sk);
+	struct tcp_info *info = _info;
+
+	r->tcpdiag_rqueue = tp->rcv_nxt - tp->copied_seq;
+	r->tcpdiag_wqueue = tp->write_seq - tp->snd_una;
+	if (info != NULL)
+		tcp_get_info(sk, info);
+}
+
+static struct inet_diag_handler tcp_diag_handler = {
+	.idiag_hashinfo	 = &tcp_hashinfo,
+	.idiag_get_info	 = tcp_diag_get_info,
+	.idiag_type	 = TCPDIAG_GETSOCK,
+	.idiag_info_size = sizeof(struct tcp_info),
+};
+
+static DEFINE_SPINLOCK(inet_diag_register_lock);
+
+int inet_diag_register(const struct inet_diag_handler *h)
+{
+	const __u16 type = h->idiag_type;
+	int err = -EINVAL;
+
+	if (type >= INET_DIAG_GETSOCK_MAX)
+		goto out;
+
+	spin_lock(&inet_diag_register_lock);
+	err = -EEXIST;
+	if (inet_diag_table[type] == NULL) {
+		inet_diag_table[type] = h;
+		err = 0;
+	}
+	spin_unlock(&inet_diag_register_lock);
+out:
+	return err;
+}
+EXPORT_SYMBOL_GPL(inet_diag_register);
+
+void inet_diag_unregister(const struct inet_diag_handler *h)
+{
+	const __u16 type = h->idiag_type;
+
+	if (type >= INET_DIAG_GETSOCK_MAX)
+		return;
+
+	spin_lock(&inet_diag_register_lock);
+	inet_diag_table[type] = NULL;
+	spin_unlock(&inet_diag_register_lock);
+
+	synchronize_rcu();
+}
+EXPORT_SYMBOL_GPL(inet_diag_unregister);
+
 static int __init tcpdiag_init(void)
 {
+	const int inet_diag_table_size = (INET_DIAG_GETSOCK_MAX *
+					  sizeof(struct inet_diag_handler *));
+	int err = -ENOMEM;
+
+	inet_diag_table = kmalloc(inet_diag_table_size, GFP_KERNEL);
+	if (!inet_diag_table)
+		goto out;
+
+	memset(inet_diag_table, 0, inet_diag_table_size);
+
 	tcpnl = netlink_kernel_create(NETLINK_TCPDIAG, tcpdiag_rcv,
 				      THIS_MODULE);
 	if (tcpnl == NULL)
-		return -ENOMEM;
-	return 0;
+		goto out_free_table;
+
+	err = inet_diag_register(&tcp_diag_handler);
+	if (err)
+		goto out_sock_release;
+out:
+	return err;
+out_sock_release:
+	sock_release(tcpnl->sk_socket);
+out_free_table:
+	kfree(inet_diag_table);
+	goto out;
 }
 
 static void __exit tcpdiag_exit(void)
 {
 	sock_release(tcpnl->sk_socket);
+	kfree(inet_diag_table);
 }
 
 module_init(tcpdiag_init);