]> nv-tegra.nvidia Code Review - linux-2.6.git/blobdiff - net/ipv4/tcp_diag.c
[TCPDIAG]: Introduce inet_diag_{register,unregister}
[linux-2.6.git] / net / ipv4 / tcp_diag.c
index f66945cb158fd346c0b2f87d0dec068620b7a838..b13b71cb9ced421205dc1da4255978d7566788e4 100644 (file)
 #include <net/tcp.h>
 #include <net/ipv6.h>
 #include <net/inet_common.h>
+#include <net/inet_connection_sock.h>
+#include <net/inet_hashtables.h>
+#include <net/inet_timewait_sock.h>
+#include <net/inet6_hashtables.h>
 
 #include <linux/inet.h>
 #include <linux/stddef.h>
 
 #include <linux/tcp_diag.h>
 
+static const struct inet_diag_handler **inet_diag_table;
+
 struct tcpdiag_entry
 {
        u32 *saddr;
@@ -45,30 +51,41 @@ static struct sock *tcpnl;
 #define TCPDIAG_PUT(skb, attrtype, attrlen) \
        RTA_DATA(__RTA_PUT(skb, attrtype, attrlen))
 
+#ifdef CONFIG_IP_TCPDIAG_DCCP
+extern struct inet_hashinfo dccp_hashinfo;
+#endif
+
 static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk,
-                       int ext, u32 pid, u32 seq, u16 nlmsg_flags)
+                       int ext, u32 pid, u32 seq, u16 nlmsg_flags,
+                       const struct nlmsghdr *unlh)
 {
-       struct inet_sock *inet = inet_sk(sk);
-       struct tcp_sock *tp = tcp_sk(sk);
+       const struct inet_sock *inet = inet_sk(sk);
+       const struct inet_connection_sock *icsk = inet_csk(sk);
        struct tcpdiagmsg *r;
        struct nlmsghdr  *nlh;
-       struct tcp_info  *info = NULL;
+       void *info = NULL;
        struct tcpdiag_meminfo  *minfo = NULL;
        unsigned char    *b = skb->tail;
+       const struct inet_diag_handler *handler;
 
-       nlh = NLMSG_PUT(skb, pid, seq, TCPDIAG_GETSOCK, sizeof(*r));
+       handler = inet_diag_table[unlh->nlmsg_type];
+       BUG_ON(handler == NULL);
+
+       nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r));
        nlh->nlmsg_flags = nlmsg_flags;
+
        r = NLMSG_DATA(nlh);
        if (sk->sk_state != TCP_TIME_WAIT) {
                if (ext & (1<<(TCPDIAG_MEMINFO-1)))
                        minfo = TCPDIAG_PUT(skb, TCPDIAG_MEMINFO, sizeof(*minfo));
                if (ext & (1<<(TCPDIAG_INFO-1)))
-                       info = TCPDIAG_PUT(skb, TCPDIAG_INFO, sizeof(*info));
+                       info = TCPDIAG_PUT(skb, TCPDIAG_INFO,
+                                          handler->idiag_info_size);
                
-               if (ext & (1<<(TCPDIAG_CONG-1))) {
-                       size_t len = strlen(tp->ca_ops->name);
+               if ((ext & (1 << (TCPDIAG_CONG - 1))) && icsk->icsk_ca_ops) {
+                       size_t len = strlen(icsk->icsk_ca_ops->name);
                        strcpy(TCPDIAG_PUT(skb, TCPDIAG_CONG, len+1),
-                              tp->ca_ops->name);
+                              icsk->icsk_ca_ops->name);
                }
        }
        r->tcpdiag_family = sk->sk_family;
@@ -81,7 +98,7 @@ static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk,
        r->id.tcpdiag_cookie[1] = (u32)(((unsigned long)sk >> 31) >> 1);
 
        if (r->tcpdiag_state == TCP_TIME_WAIT) {
-               struct tcp_tw_bucket *tw = (struct tcp_tw_bucket*)sk;
+               const struct inet_timewait_sock *tw = inet_twsk(sk);
                long tmo = tw->tw_ttd - jiffies;
                if (tmo < 0)
                        tmo = 0;
@@ -97,12 +114,14 @@ static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk,
                r->tcpdiag_wqueue = 0;
                r->tcpdiag_uid = 0;
                r->tcpdiag_inode = 0;
-#ifdef CONFIG_IP_TCPDIAG_IPV6
+#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
                if (r->tcpdiag_family == AF_INET6) {
+                       const struct tcp6_timewait_sock *tcp6tw = tcp6_twsk(sk);
+
                        ipv6_addr_copy((struct in6_addr *)r->id.tcpdiag_src,
-                                      &tw->tw_v6_rcv_saddr);
+                                      &tcp6tw->tw_v6_rcv_saddr);
                        ipv6_addr_copy((struct in6_addr *)r->id.tcpdiag_dst,
-                                      &tw->tw_v6_daddr);
+                                      &tcp6tw->tw_v6_daddr);
                }
 #endif
                nlh->nlmsg_len = skb->tail - b;
@@ -114,7 +133,7 @@ static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk,
        r->id.tcpdiag_src[0] = inet->rcv_saddr;
        r->id.tcpdiag_dst[0] = inet->daddr;
 
-#ifdef CONFIG_IP_TCPDIAG_IPV6
+#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
        if (r->tcpdiag_family == AF_INET6) {
                struct ipv6_pinfo *np = inet6_sk(sk);
 
@@ -127,17 +146,17 @@ static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk,
 
 #define EXPIRES_IN_MS(tmo)  ((tmo-jiffies)*1000+HZ-1)/HZ
 
-       if (tp->pending == TCP_TIME_RETRANS) {
+       if (icsk->icsk_pending == ICSK_TIME_RETRANS) {
                r->tcpdiag_timer = 1;
-               r->tcpdiag_retrans = tp->retransmits;
-               r->tcpdiag_expires = EXPIRES_IN_MS(tp->timeout);
-       } else if (tp->pending == TCP_TIME_PROBE0) {
+               r->tcpdiag_retrans = icsk->icsk_retransmits;
+               r->tcpdiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
+       } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
                r->tcpdiag_timer = 4;
-               r->tcpdiag_retrans = tp->probes_out;
-               r->tcpdiag_expires = EXPIRES_IN_MS(tp->timeout);
+               r->tcpdiag_retrans = icsk->icsk_probes_out;
+               r->tcpdiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
        } else if (timer_pending(&sk->sk_timer)) {
                r->tcpdiag_timer = 2;
-               r->tcpdiag_retrans = tp->probes_out;
+               r->tcpdiag_retrans = icsk->icsk_probes_out;
                r->tcpdiag_expires = EXPIRES_IN_MS(sk->sk_timer.expires);
        } else {
                r->tcpdiag_timer = 0;
@@ -145,8 +164,6 @@ static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk,
        }
 #undef EXPIRES_IN_MS
 
-       r->tcpdiag_rqueue = tp->rcv_nxt - tp->copied_seq;
-       r->tcpdiag_wqueue = tp->write_seq - tp->snd_una;
        r->tcpdiag_uid = sock_i_uid(sk);
        r->tcpdiag_inode = sock_i_ino(sk);
 
@@ -157,11 +174,11 @@ static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk,
                minfo->tcpdiag_tmem = atomic_read(&sk->sk_wmem_alloc);
        }
 
-       if (info) 
-               tcp_get_info(sk, info);
+       handler->idiag_get_info(sk, r, info);
 
-       if (sk->sk_state < TCP_TIME_WAIT && tp->ca_ops->get_info)
-               tp->ca_ops->get_info(tp, ext, skb);
+       if (sk->sk_state < TCP_TIME_WAIT &&
+           icsk->icsk_ca_ops && icsk->icsk_ca_ops->get_info)
+               icsk->icsk_ca_ops->get_info(sk, ext, skb);
 
        nlh->nlmsg_len = skb->tail - b;
        return skb->len;
@@ -172,38 +189,32 @@ nlmsg_failure:
        return -1;
 }
 
-extern struct sock *tcp_v4_lookup(u32 saddr, u16 sport, u32 daddr, u16 dport,
-                                 int dif);
-#ifdef CONFIG_IP_TCPDIAG_IPV6
-extern struct sock *tcp_v6_lookup(struct in6_addr *saddr, u16 sport,
-                                 struct in6_addr *daddr, u16 dport,
-                                 int dif);
-#else
-static inline struct sock *tcp_v6_lookup(struct in6_addr *saddr, u16 sport,
-                                        struct in6_addr *daddr, u16 dport,
-                                        int dif)
-{
-       return NULL;
-}
-#endif
-
 static int tcpdiag_get_exact(struct sk_buff *in_skb, const struct nlmsghdr *nlh)
 {
        int err;
        struct sock *sk;
        struct tcpdiagreq *req = NLMSG_DATA(nlh);
        struct sk_buff *rep;
+       struct inet_hashinfo *hashinfo;
+       const struct inet_diag_handler *handler;
+
+       handler = inet_diag_table[nlh->nlmsg_type];
+       BUG_ON(handler == NULL);
+       hashinfo = handler->idiag_hashinfo;
 
        if (req->tcpdiag_family == AF_INET) {
-               sk = tcp_v4_lookup(req->id.tcpdiag_dst[0], req->id.tcpdiag_dport,
-                                  req->id.tcpdiag_src[0], req->id.tcpdiag_sport,
-                                  req->id.tcpdiag_if);
+               sk = inet_lookup(hashinfo, req->id.tcpdiag_dst[0],
+                                req->id.tcpdiag_dport, req->id.tcpdiag_src[0],
+                                req->id.tcpdiag_sport, req->id.tcpdiag_if);
        }
-#ifdef CONFIG_IP_TCPDIAG_IPV6
+#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
        else if (req->tcpdiag_family == AF_INET6) {
-               sk = tcp_v6_lookup((struct in6_addr*)req->id.tcpdiag_dst, req->id.tcpdiag_dport,
-                                  (struct in6_addr*)req->id.tcpdiag_src, req->id.tcpdiag_sport,
-                                  req->id.tcpdiag_if);
+               sk = inet6_lookup(hashinfo,
+                                 (struct in6_addr*)req->id.tcpdiag_dst,
+                                 req->id.tcpdiag_dport,
+                                 (struct in6_addr*)req->id.tcpdiag_src,
+                                 req->id.tcpdiag_sport,
+                                 req->id.tcpdiag_if);
        }
 #endif
        else {
@@ -221,15 +232,16 @@ static int tcpdiag_get_exact(struct sk_buff *in_skb, const struct nlmsghdr *nlh)
                goto out;
 
        err = -ENOMEM;
-       rep = alloc_skb(NLMSG_SPACE(sizeof(struct tcpdiagmsg)+
-                                   sizeof(struct tcpdiag_meminfo)+
-                                   sizeof(struct tcp_info)+64), GFP_KERNEL);
+       rep = alloc_skb(NLMSG_SPACE((sizeof(struct tcpdiagmsg) +
+                                    sizeof(struct tcpdiag_meminfo) +
+                                    handler->idiag_info_size + 64)),
+                       GFP_KERNEL);
        if (!rep)
                goto out;
 
        if (tcpdiag_fill(rep, sk, req->tcpdiag_ext,
                         NETLINK_CB(in_skb).pid,
-                        nlh->nlmsg_seq, 0) <= 0)
+                        nlh->nlmsg_seq, 0, nlh) <= 0)
                BUG();
 
        err = netlink_unicast(tcpnl, rep, NETLINK_CB(in_skb).pid, MSG_DONTWAIT);
@@ -239,7 +251,7 @@ static int tcpdiag_get_exact(struct sk_buff *in_skb, const struct nlmsghdr *nlh)
 out:
        if (sk) {
                if (sk->sk_state == TCP_TIME_WAIT)
-                       tcp_tw_put((struct tcp_tw_bucket*)sk);
+                       inet_twsk_put((struct inet_timewait_sock *)sk);
                else
                        sock_put(sk);
        }
@@ -414,7 +426,7 @@ static int tcpdiag_dump_sock(struct sk_buff *skb, struct sock *sk,
                struct inet_sock *inet = inet_sk(sk);
 
                entry.family = sk->sk_family;
-#ifdef CONFIG_IP_TCPDIAG_IPV6
+#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
                if (entry.family == AF_INET6) {
                        struct ipv6_pinfo *np = inet6_sk(sk);
 
@@ -435,12 +447,13 @@ static int tcpdiag_dump_sock(struct sk_buff *skb, struct sock *sk,
        }
 
        return tcpdiag_fill(skb, sk, r->tcpdiag_ext, NETLINK_CB(cb->skb).pid,
-                           cb->nlh->nlmsg_seq, NLM_F_MULTI);
+                           cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
 }
 
 static int tcpdiag_fill_req(struct sk_buff *skb, struct sock *sk,
                            struct request_sock *req,
-                           u32 pid, u32 seq)
+                           u32 pid, u32 seq,
+                           const struct nlmsghdr *unlh)
 {
        const struct inet_request_sock *ireq = inet_rsk(req);
        struct inet_sock *inet = inet_sk(sk);
@@ -449,7 +462,7 @@ static int tcpdiag_fill_req(struct sk_buff *skb, struct sock *sk,
        struct nlmsghdr *nlh;
        long tmo;
 
-       nlh = NLMSG_PUT(skb, pid, seq, TCPDIAG_GETSOCK, sizeof(*r));
+       nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r));
        nlh->nlmsg_flags = NLM_F_MULTI;
        r = NLMSG_DATA(nlh);
 
@@ -475,7 +488,7 @@ static int tcpdiag_fill_req(struct sk_buff *skb, struct sock *sk,
        r->tcpdiag_wqueue = 0;
        r->tcpdiag_uid = sock_i_uid(sk);
        r->tcpdiag_inode = 0;
-#ifdef CONFIG_IP_TCPDIAG_IPV6
+#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
        if (r->tcpdiag_family == AF_INET6) {
                ipv6_addr_copy((struct in6_addr *)r->id.tcpdiag_src,
                               &tcp6_rsk(req)->loc_addr);
@@ -497,7 +510,7 @@ static int tcpdiag_dump_reqs(struct sk_buff *skb, struct sock *sk,
 {
        struct tcpdiag_entry entry;
        struct tcpdiagreq *r = NLMSG_DATA(cb->nlh);
-       struct tcp_sock *tp = tcp_sk(sk);
+       struct inet_connection_sock *icsk = inet_csk(sk);
        struct listen_sock *lopt;
        struct rtattr *bc = NULL;
        struct inet_sock *inet = inet_sk(sk);
@@ -513,9 +526,9 @@ static int tcpdiag_dump_reqs(struct sk_buff *skb, struct sock *sk,
 
        entry.family = sk->sk_family;
 
-       read_lock_bh(&tp->accept_queue.syn_wait_lock);
+       read_lock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
 
-       lopt = tp->accept_queue.listen_opt;
+       lopt = icsk->icsk_accept_queue.listen_opt;
        if (!lopt || !lopt->qlen)
                goto out;
 
@@ -525,7 +538,7 @@ static int tcpdiag_dump_reqs(struct sk_buff *skb, struct sock *sk,
                entry.userlocks = sk->sk_userlocks;
        }
 
-       for (j = s_j; j < TCP_SYNQ_HSIZE; j++) {
+       for (j = s_j; j < lopt->nr_table_entries; j++) {
                struct request_sock *req, *head = lopt->syn_table[j];
 
                reqnum = 0;
@@ -540,13 +553,13 @@ static int tcpdiag_dump_reqs(struct sk_buff *skb, struct sock *sk,
 
                        if (bc) {
                                entry.saddr =
-#ifdef CONFIG_IP_TCPDIAG_IPV6
+#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
                                        (entry.family == AF_INET6) ?
                                        tcp6_rsk(req)->loc_addr.s6_addr32 :
 #endif
                                        &ireq->loc_addr;
                                entry.daddr = 
-#ifdef CONFIG_IP_TCPDIAG_IPV6
+#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
                                        (entry.family == AF_INET6) ?
                                        tcp6_rsk(req)->rmt_addr.s6_addr32 :
 #endif
@@ -560,7 +573,7 @@ static int tcpdiag_dump_reqs(struct sk_buff *skb, struct sock *sk,
 
                        err = tcpdiag_fill_req(skb, sk, req,
                                               NETLINK_CB(cb->skb).pid,
-                                              cb->nlh->nlmsg_seq);
+                                              cb->nlh->nlmsg_seq, cb->nlh);
                        if (err < 0) {
                                cb->args[3] = j + 1;
                                cb->args[4] = reqnum;
@@ -572,7 +585,7 @@ static int tcpdiag_dump_reqs(struct sk_buff *skb, struct sock *sk,
        }
 
 out:
-       read_unlock_bh(&tp->accept_queue.syn_wait_lock);
+       read_unlock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
 
        return err;
 }
@@ -582,20 +595,27 @@ static int tcpdiag_dump(struct sk_buff *skb, struct netlink_callback *cb)
        int i, num;
        int s_i, s_num;
        struct tcpdiagreq *r = NLMSG_DATA(cb->nlh);
+       const struct inet_diag_handler *handler;
+       struct inet_hashinfo *hashinfo;
 
+       handler = inet_diag_table[cb->nlh->nlmsg_type];
+       BUG_ON(handler == NULL);
+       hashinfo = handler->idiag_hashinfo;
+               
        s_i = cb->args[1];
        s_num = num = cb->args[2];
 
        if (cb->args[0] == 0) {
                if (!(r->tcpdiag_states&(TCPF_LISTEN|TCPF_SYN_RECV)))
                        goto skip_listen_ht;
-               tcp_listen_lock();
-               for (i = s_i; i < TCP_LHTABLE_SIZE; i++) {
+
+               inet_listen_lock(hashinfo);
+               for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
                        struct sock *sk;
                        struct hlist_node *node;
 
                        num = 0;
-                       sk_for_each(sk, node, &tcp_listening_hash[i]) {
+                       sk_for_each(sk, node, &hashinfo->listening_hash[i]) {
                                struct inet_sock *inet = inet_sk(sk);
 
                                if (num < s_num) {
@@ -613,7 +633,7 @@ static int tcpdiag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                                        goto syn_recv;
 
                                if (tcpdiag_dump_sock(skb, sk, cb) < 0) {
-                                       tcp_listen_unlock();
+                                       inet_listen_unlock(hashinfo);
                                        goto done;
                                }
 
@@ -622,7 +642,7 @@ syn_recv:
                                        goto next_listen;
 
                                if (tcpdiag_dump_reqs(skb, sk, cb) < 0) {
-                                       tcp_listen_unlock();
+                                       inet_listen_unlock(hashinfo);
                                        goto done;
                                }
 
@@ -636,7 +656,7 @@ next_listen:
                        cb->args[3] = 0;
                        cb->args[4] = 0;
                }
-               tcp_listen_unlock();
+               inet_listen_unlock(hashinfo);
 skip_listen_ht:
                cb->args[0] = 1;
                s_i = num = s_num = 0;
@@ -645,8 +665,8 @@ skip_listen_ht:
        if (!(r->tcpdiag_states&~(TCPF_LISTEN|TCPF_SYN_RECV)))
                return skb->len;
 
-       for (i = s_i; i < tcp_ehash_size; i++) {
-               struct tcp_ehash_bucket *head = &tcp_ehash[i];
+       for (i = s_i; i < hashinfo->ehash_size; i++) {
+               struct inet_ehash_bucket *head = &hashinfo->ehash[i];
                struct sock *sk;
                struct hlist_node *node;
 
@@ -678,7 +698,7 @@ next_normal:
 
                if (r->tcpdiag_states&TCPF_TIME_WAIT) {
                        sk_for_each(sk, node,
-                                   &tcp_ehash[i + tcp_ehash_size].chain) {
+                                   &hashinfo->ehash[i + hashinfo->ehash_size].chain) {
                                struct inet_sock *inet = inet_sk(sk);
 
                                if (num < s_num)
@@ -718,9 +738,12 @@ tcpdiag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
        if (!(nlh->nlmsg_flags&NLM_F_REQUEST))
                return 0;
 
-       if (nlh->nlmsg_type != TCPDIAG_GETSOCK)
+       if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX)
                goto err_inval;
 
+       if (inet_diag_table[nlh->nlmsg_type] == NULL)
+               return -ENOENT;
+
        if (NLMSG_LENGTH(sizeof(struct tcpdiagreq)) > skb->len)
                goto err_inval;
 
@@ -772,17 +795,95 @@ static void tcpdiag_rcv(struct sock *sk, int len)
        }
 }
 
+static void tcp_diag_get_info(struct sock *sk, struct tcpdiagmsg *r,
+                             void *_info)
+{
+       const struct tcp_sock *tp = tcp_sk(sk);
+       struct tcp_info *info = _info;
+
+       r->tcpdiag_rqueue = tp->rcv_nxt - tp->copied_seq;
+       r->tcpdiag_wqueue = tp->write_seq - tp->snd_una;
+       if (info != NULL)
+               tcp_get_info(sk, info);
+}
+
+static struct inet_diag_handler tcp_diag_handler = {
+       .idiag_hashinfo  = &tcp_hashinfo,
+       .idiag_get_info  = tcp_diag_get_info,
+       .idiag_type      = TCPDIAG_GETSOCK,
+       .idiag_info_size = sizeof(struct tcp_info),
+};
+
+static DEFINE_SPINLOCK(inet_diag_register_lock);
+
+int inet_diag_register(const struct inet_diag_handler *h)
+{
+       const __u16 type = h->idiag_type;
+       int err = -EINVAL;
+
+       if (type >= INET_DIAG_GETSOCK_MAX)
+               goto out;
+
+       spin_lock(&inet_diag_register_lock);
+       err = -EEXIST;
+       if (inet_diag_table[type] == NULL) {
+               inet_diag_table[type] = h;
+               err = 0;
+       }
+       spin_unlock(&inet_diag_register_lock);
+out:
+       return err;
+}
+EXPORT_SYMBOL_GPL(inet_diag_register);
+
+void inet_diag_unregister(const struct inet_diag_handler *h)
+{
+       const __u16 type = h->idiag_type;
+
+       if (type >= INET_DIAG_GETSOCK_MAX)
+               return;
+
+       spin_lock(&inet_diag_register_lock);
+       inet_diag_table[type] = NULL;
+       spin_unlock(&inet_diag_register_lock);
+
+       synchronize_rcu();
+}
+EXPORT_SYMBOL_GPL(inet_diag_unregister);
+
 static int __init tcpdiag_init(void)
 {
-       tcpnl = netlink_kernel_create(NETLINK_TCPDIAG, tcpdiag_rcv);
+       const int inet_diag_table_size = (INET_DIAG_GETSOCK_MAX *
+                                         sizeof(struct inet_diag_handler *));
+       int err = -ENOMEM;
+
+       inet_diag_table = kmalloc(inet_diag_table_size, GFP_KERNEL);
+       if (!inet_diag_table)
+               goto out;
+
+       memset(inet_diag_table, 0, inet_diag_table_size);
+
+       tcpnl = netlink_kernel_create(NETLINK_TCPDIAG, tcpdiag_rcv,
+                                     THIS_MODULE);
        if (tcpnl == NULL)
-               return -ENOMEM;
-       return 0;
+               goto out_free_table;
+
+       err = inet_diag_register(&tcp_diag_handler);
+       if (err)
+               goto out_sock_release;
+out:
+       return err;
+out_sock_release:
+       sock_release(tcpnl->sk_socket);
+out_free_table:
+       kfree(inet_diag_table);
+       goto out;
 }
 
 static void __exit tcpdiag_exit(void)
 {
        sock_release(tcpnl->sk_socket);
+       kfree(inet_diag_table);
 }
 
 module_init(tcpdiag_init);