netfilter: qtaguid: add tag delete command, expand stats output.
JP Abgrall [Sat, 16 Jul 2011 05:27:28 +0000 (22:27 -0700)]
* Add a new ctrl command to delete stored data.
  d <acct_tag> [<uid>]
The uid will default to the running process's.
The accounting tag can be 0, in which case all counters and socket tags
associated with the uid will be cleared.

* Simplify the ctrl command handling at the expense of duplicate code.
This should make it easier to maintain.

* /proc/net/xt_qtaguid/stats now returns more stats
  idx iface acct_tag_hex uid_tag_int
  {rx,tx}_{bytes,packets}
  {rx,tx}_{tcp,udp,other}_{bytes,packets}
the {rx,tx}_{bytes,packets} are the totals.

* re-tagging will now allow changing the uid.

Change-Id: I9594621543cefeab557caa3d68a22a3eb320466d
Signed-off-by: JP Abgrall <jpa@google.com>

net/netfilter/xt_qtaguid.c

index 49ed432..0aa33da 100644 (file)
@@ -8,7 +8,8 @@
  * published by the Free Software Foundation.
  */
 
-/* TODO: support ipv6 for iface_stat */
+/* TODO: support ipv6 for iface_stat.
+ * Currently if an iface is only v6 it will not have stats collected. */
 
 #include <linux/file.h>
 #include <linux/inetdevice.h>
@@ -96,9 +97,6 @@ struct tag_stat {
        struct proc_dir_entry *proc_ptr;
 };
 
-static LIST_HEAD(iface_stat_list);
-static DEFINE_SPINLOCK(iface_stat_list_lock);
-
 struct iface_stat {
        struct list_head list;
        char *ifname;
@@ -113,9 +111,8 @@ struct iface_stat {
        spinlock_t tag_stat_list_lock;
 };
 
-
-static struct rb_root sock_tag_tree = RB_ROOT;
-static DEFINE_SPINLOCK(sock_tag_list_lock);
+static LIST_HEAD(iface_stat_list);
+static DEFINE_SPINLOCK(iface_stat_list_lock);
 
 /*
  * Track tag that this socket is transferring data for, and not necesseraly
@@ -128,6 +125,9 @@ struct sock_tag {
        tag_t tag;
 };
 
+static struct rb_root sock_tag_tree = RB_ROOT;
+static DEFINE_SPINLOCK(sock_tag_list_lock);
+
 static bool qtaguid_mt(const struct sk_buff *skb, struct xt_action_param *par);
 
 /*----------------------------------------------*/
@@ -181,6 +181,14 @@ static inline uint64_t dc_sum_bytes(struct data_counters *counters,
                + counters->bpc[direction][IFS_PROTO_OTHER].bytes;
 }
 
+static inline uint64_t dc_sum_packets(struct data_counters *counters,
+                                     enum ifs_tx_rx direction)
+{
+       return counters->bpc[direction][IFS_TCP].packets
+               + counters->bpc[direction][IFS_UDP].packets
+               + counters->bpc[direction][IFS_PROTO_OTHER].packets;
+}
+
 static struct tag_stat *tag_stat_tree_search(struct rb_root *root, tag_t tag)
 {
        struct rb_node *node = root->rb_node;
@@ -397,12 +405,11 @@ void iface_stat_create(const struct net_device *net_dev)
                return;
        }
 
-       new_iface = kmalloc(sizeof(*new_iface), GFP_KERNEL);
+       new_iface = kzalloc(sizeof(*new_iface), GFP_KERNEL);
        if (new_iface == NULL) {
                pr_err("iface_stat: create(): failed to alloc iface_stat\n");
                return;
        }
-       memset(new_iface, 0, sizeof(*new_iface));
        new_iface->ifname = kstrdup(ifname, GFP_KERNEL);
        if (new_iface->ifname == NULL) {
                pr_err("iface_stat: create(): failed to alloc ifname\n");
@@ -531,12 +538,11 @@ static struct tag_stat *create_if_tag_stat(struct iface_stat *iface_entry,
        pr_debug("iface_stat: create_if_tag_stat(): ife=%p tag=0x%llx"
                 " (uid=%d)\n",
                 iface_entry, tag, get_uid_from_tag(tag));
-       new_tag_stat_entry = kmalloc(sizeof(*new_tag_stat_entry), GFP_ATOMIC);
+       new_tag_stat_entry = kzalloc(sizeof(*new_tag_stat_entry), GFP_ATOMIC);
        if (!new_tag_stat_entry) {
                pr_err("iface_stat: failed to alloc new tag entry\n");
                goto done;
        }
-       memset(new_tag_stat_entry, 0, sizeof(*new_tag_stat_entry));
        new_tag_stat_entry->tag = tag;
        tag_stat_tree_insert(new_tag_stat_entry, &iface_entry->tag_stat_tree);
 done:
@@ -852,7 +858,7 @@ static bool qtaguid_mt(const struct sk_buff *skb, struct xt_action_param *par)
                pr_debug("xt_qtaguid[%d]: leaving (sk?sk->sk_socket)=%p\n",
                        par->hooknum,
                        sk ? sk->sk_socket : NULL);
-               res =  (info->match ^ info->invert) == 0;
+               res = (info->match ^ info->invert) == 0;
                goto put_sock_ret_res;
        } else if (info->match & info->invert & XT_QTAGUID_SOCKET) {
                res = false;
@@ -922,7 +928,7 @@ static int qtaguid_ctrl_proc_read(char *page, char **num_items_returned,
        struct rb_node *node;
        int item_index = 0;
 
-       pr_debug("xt_qtaguid:proc ctrl page=%p off=%ld char_count=%d *eof=%d\n",
+       pr_debug("xt_qtaguid: proc ctrl page=%p off=%ld char_count=%d *eof=%d\n",
                page, items_to_skip, char_count, *eof);
 
        if (*eof)
@@ -934,7 +940,7 @@ static int qtaguid_ctrl_proc_read(char *page, char **num_items_returned,
             node = rb_next(node)) {
                if (item_index++ < items_to_skip)
                        continue;
-               sock_tag_entry =  rb_entry(node, struct sock_tag, node);
+               sock_tag_entry = rb_entry(node, struct sock_tag, node);
                uid = get_uid_from_tag(sock_tag_entry->tag);
                pr_debug("xt_qtaguid: proc_read(): sk=%p tag=0x%llx (uid=%d)\n",
                        sock_tag_entry->sk,
@@ -957,7 +963,103 @@ static int qtaguid_ctrl_proc_read(char *page, char **num_items_returned,
        return outp - page;
 }
 
-static int qtaguid_ctrl_parse(const char *input, int count)
+/* Delete socket tags, and stat tags associated with a given
+ * accouting tag and uid. */
+static int ctrl_cmd_delete(const char *input)
+{
+       char cmd;
+       uid_t uid = 0;
+       uid_t entry_uid;
+       tag_t acct_tag = 0;
+       tag_t tag;
+       int res, argc;
+       unsigned long flags, flags2;
+       struct iface_stat *iface_entry;
+       struct rb_node *node;
+       struct sock_tag *st_entry;
+       struct tag_stat *ts_entry;
+
+       pr_debug("xt_qtaguid: ctrl_delete(%s): entered\n", input);
+       argc = sscanf(input, "%c %llu %u", &cmd, &acct_tag, &uid);
+       pr_debug("xt_qtaguid: ctrl_delete(%s): argc=%d cmd=%c "
+                "acct_tag=0x%llx uid=%u\n", input, argc, cmd,
+                acct_tag, uid);
+       if (argc < 2) {
+               res = -EINVAL;
+               goto err;
+       }
+       if (!valid_atag(acct_tag)) {
+               pr_info("xt_qtaguid: ctrl_delete(%s): invalid tag\n", input);
+               res = -EINVAL;
+               goto err;
+       }
+       if (argc < 3)
+               uid = current_fsuid();
+
+       /* TODO: check that the uid == current_fsuid()
+        * except for special uid/gid. */
+
+       spin_lock_irqsave(&sock_tag_list_lock, flags);
+       node = rb_first(&sock_tag_tree);
+       while (node) {
+               st_entry = rb_entry(node, struct sock_tag, node);
+               entry_uid = get_uid_from_tag(st_entry->tag);
+               node = rb_next(node);
+               if (entry_uid != uid)
+                       continue;
+
+               if (!acct_tag || st_entry->tag == tag) {
+                       pr_debug("xt_qtaguid: ctrl_delete(): "
+                                "erase sk=%p tag=0x%llx (uid=%d)\n",
+                                st_entry->sk,
+                                st_entry->tag,
+                                entry_uid);
+                       rb_erase(&ts_entry->node, &sock_tag_tree);
+                       kfree(st_entry);
+               }
+       }
+       spin_unlock_irqrestore(&sock_tag_list_lock, flags);
+
+       /* If acct_tag is 0, then all entries belonging to uid are
+        * erased. */
+       tag = combine_atag_with_uid(acct_tag, uid);
+       spin_lock_irqsave(&iface_stat_list_lock, flags);
+       list_for_each_entry(iface_entry, &iface_stat_list, list) {
+
+               spin_lock_irqsave(&iface_entry->tag_stat_list_lock, flags2);
+               node = rb_first(&iface_entry->tag_stat_tree);
+               while (node) {
+                       ts_entry = rb_entry(node, struct tag_stat, node);
+                       entry_uid = get_uid_from_tag(ts_entry->tag);
+                       node = rb_next(node);
+                       if (entry_uid != uid)
+                               continue;
+                       if (!acct_tag || ts_entry->tag == tag) {
+                               pr_debug("xt_qtaguid: ctrl_delete(): erase "
+                                        "%s 0x%llx %u\n",
+                                        iface_entry->ifname,
+                                        get_atag_from_tag(ts_entry->tag),
+                                        entry_uid);
+                               rb_erase(&ts_entry->node,
+                                        &iface_entry->tag_stat_tree);
+                               kfree(ts_entry);
+                       }
+               }
+               spin_unlock_irqrestore(&iface_entry->tag_stat_list_lock,
+                                      flags2);
+
+       }
+       spin_unlock_irqrestore(&iface_stat_list_lock, flags);
+
+       res = 0;
+
+err:
+       pr_debug("xt_qtaguid: ctrl_delete(%s) res=%d\n", input, res);
+       return res;
+}
+
+
+static int ctrl_cmd_tag(const char *input)
 {
        char cmd;
        int sock_fd = 0;
@@ -968,117 +1070,139 @@ static int qtaguid_ctrl_parse(const char *input, int count)
        struct sock_tag *sock_tag_entry;
        unsigned long flags;
 
-       pr_debug("xt_qtaguid: ctrl(%s): entered\n", input);
        /* Unassigned args will get defaulted later. */
-       /* TODO: get acct_tag_str, keep a list of available tags for the
-        * uid, use num as acct_tag. */
        argc = sscanf(input, "%c %d %llu %u", &cmd, &sock_fd, &acct_tag, &uid);
-       pr_debug("xt_qtaguid: ctrl(%s): argc=%d cmd=%c sock_fd=%d "
-               "acct_tag=0x%llx uid=%u\n", input, argc, cmd, sock_fd,
-               acct_tag, uid);
+       pr_debug("xt_qtaguid: ctrl_tag(%s): argc=%d cmd=%c sock_fd=%d "
+                "acct_tag=0x%llx uid=%u\n", input, argc, cmd, sock_fd,
+                acct_tag, uid);
+       if (argc < 2) {
+               res = -EINVAL;
+               goto err;
+       }
+       el_socket = sockfd_lookup(sock_fd, &res);
+       if (!el_socket) {
+               pr_info("xt_qtaguid: ctrl_tag(%s): failed to lookup"
+                       " sock_fd=%d err=%d\n", input, sock_fd, res);
+               goto err;
+       }
+       if (argc < 3) {
+               acct_tag = 0;
+       } else if (!valid_atag(acct_tag)) {
+               pr_info("xt_qtaguid: ctrl_tag(%s): invalid tag\n", input);
+               res = -EINVAL;
+               goto err;
+       }
+       if (argc < 4)
+               uid = current_fsuid();
 
-       /* Collect params for commands */
-       switch (cmd) {
-       case 't':
-       case 'u':
-               if (argc < 2) {
-                       res = -EINVAL;
-                       goto err;
-               }
-               el_socket = sockfd_lookup(sock_fd, &res);
-               if (!el_socket) {
-                       pr_info("xt_qtaguid: ctrl(%s): failed to lookup"
-                               " sock_fd=%d err=%d\n", input, sock_fd, res);
+       spin_lock_irqsave(&sock_tag_list_lock, flags);
+       sock_tag_entry = get_sock_stat_nl(el_socket->sk);
+       if (sock_tag_entry) {
+               /* TODO: check that the uid == current_fsuid()
+                * except for special uid/gid. */
+               sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
+                                                           uid);
+       } else {
+               spin_unlock_irqrestore(&sock_tag_list_lock, flags);
+               sock_tag_entry = kzalloc(sizeof(*sock_tag_entry),
+                                        GFP_KERNEL);
+               if (!sock_tag_entry) {
+                       res = -ENOMEM;
                        goto err;
                }
+               sock_tag_entry->sk = el_socket->sk;
+               /* TODO: check that uid==current_fsuid() except
+                * for special uid/gid. */
+               sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
+                                                           uid);
                spin_lock_irqsave(&sock_tag_list_lock, flags);
-               /* TODO: optim: pass in the current_fsuid() to do lookups
-                * as look ups will always be initiated form the same uid. */
-               sock_tag_entry = get_sock_stat_nl(el_socket->sk);
-               if (!sock_tag_entry)
-                       spin_unlock_irqrestore(&sock_tag_list_lock, flags);
-               /* HERE: The lock is held if there was a matching sock tag entry */
-               break;
-       default:
+               sock_tag_tree_insert(sock_tag_entry, &sock_tag_tree);
+       }
+       spin_unlock_irqrestore(&sock_tag_list_lock, flags);
+
+       pr_debug("xt_qtaguid: tag: sock_tag_entry->sk=%p "
+                "...->tag=0x%llx (uid=%u)\n",
+                sock_tag_entry->sk, sock_tag_entry->tag,
+                get_uid_from_tag(sock_tag_entry->tag));
+       res = 0;
+
+err:
+       pr_debug("xt_qtaguid: ctrl_tag(%s) res=%d\n", input, res);
+       return res;
+}
+
+
+static int ctrl_cmd_untag(const char *input)
+{
+       char cmd;
+       int sock_fd = 0;
+       struct socket *el_socket;
+       int res, argc;
+       struct sock_tag *sock_tag_entry;
+       unsigned long flags;
+
+       pr_debug("xt_qtaguid: ctrl_untag(%s): entered\n", input);
+       argc = sscanf(input, "%c %d", &cmd, &sock_fd);
+       pr_debug("xt_qtaguid: ctrl_untag(%s): argc=%d cmd=%c sock_fd=%d\n",
+                input, argc, cmd, sock_fd);
+       if (argc < 2) {
                res = -EINVAL;
                goto err;
        }
-       /* HERE: The lock is held if there was a matching sock tag entry */
+       el_socket = sockfd_lookup(sock_fd, &res);
+       if (!el_socket) {
+               pr_info("xt_qtaguid: ctrl_untag(%s): failed to lookup"
+                       " sock_fd=%d err=%d\n", input, sock_fd, res);
+               goto err;
+       }
+       spin_lock_irqsave(&sock_tag_list_lock, flags);
+       sock_tag_entry = get_sock_stat_nl(el_socket->sk);
+       if (!sock_tag_entry) {
+               spin_unlock_irqrestore(&sock_tag_list_lock, flags);
+               res = -EINVAL;
+               goto err;
+       }
+
+       /* TODO: check that the uid==current_fsuid()
+        * except for special uid/gid. */
+       rb_erase(&sock_tag_entry->node, &sock_tag_tree);
+       spin_unlock_irqrestore(&sock_tag_list_lock, flags);
+       kfree(sock_tag_entry);
+
+       res = 0;
+err:
+       pr_debug("xt_qtaguid: ctrl_untag(%s): res=%d\n", input, res);
+       return res;
+}
 
-       /* Process commands */
+static int qtaguid_ctrl_parse(const char *input, int count)
+{
+       char cmd;
+       int res;
+
+       pr_debug("xt_qtaguid: ctrl(%s): entered\n", input);
+       cmd = input[0];
+       /* Collect params for commands */
        switch (cmd) {
+       case 'd':
+               res = ctrl_cmd_delete(input);
+               break;
 
        case 't':
-               if (argc < 2) {
-                       res = -EINVAL;
-                       /* HERE: The lock is held if there was a matching sock
-                        * tag entry */
-                       goto err_unlock;
-               }
-               if (argc < 3) {
-                       acct_tag = 0;
-               } else if (!valid_atag(acct_tag)) {
-                       res = -EINVAL;
-                       /* HERE: The lock is held if there was a matching sock
-                        * tag entry */
-                       goto err_unlock;
-               }
-               if (argc < 4)
-                       uid = current_fsuid();
-               if (!sock_tag_entry) {
-                       /* HERE: There is no lock held because there was no
-                        * sock tag entry */
-                       sock_tag_entry = kmalloc(sizeof(*sock_tag_entry),
-                                               GFP_KERNEL);
-                       if (!sock_tag_entry) {
-                               res = -ENOMEM;
-                               goto err;
-                       }
-                       memset(sock_tag_entry, 0, sizeof(*sock_tag_entry));
-                       sock_tag_entry->sk = el_socket->sk;
-                       /* TODO: check that uid==current_fsuid() except
-                        * for special uid/gid. */
-                       sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
-                                                               uid);
-                       spin_lock_irqsave(&sock_tag_list_lock, flags);
-                       sock_tag_tree_insert(sock_tag_entry, &sock_tag_tree);
-               } else {
-                       /* HERE: The lock is held because there is a matching
-                        * sock tag entry */
-                       /* Just update the acct_tag portion. */
-                       uid_t orig_uid = get_uid_from_tag(sock_tag_entry->tag);
-                       sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
-                                                               orig_uid);
-               }
-               spin_unlock_irqrestore(&sock_tag_list_lock, flags);
-               pr_debug("xt_qtaguid: tag: sock_tag_entry->sk=%p "
-                       "...->tag=0x%llx (uid=%u)\n",
-                       sock_tag_entry->sk, sock_tag_entry->tag,
-                       get_uid_from_tag(sock_tag_entry->tag));
+               res = ctrl_cmd_tag(input);
                break;
 
        case 'u':
-               if (!sock_tag_entry) {
-                       res = -EINVAL;
-                       goto err;
-               }
-               /* TODO: check that the uid==current_fsuid()
-                * except for special uid/gid. */
-               rb_erase(&sock_tag_entry->node, &sock_tag_tree);
-               spin_unlock_irqrestore(&sock_tag_list_lock, flags);
-               kfree(sock_tag_entry);
+               res = ctrl_cmd_untag(input);
                break;
-       }
-
-       /* All of the input has been processed */
-       res = count;
-       goto ok;
 
-err_unlock:
-       if (sock_tag_entry)
-               spin_unlock_irqrestore(&sock_tag_list_lock, flags);
+       default:
+               res = -EINVAL;
+               goto err;
+       }
+       if (!res)
+               res = count;
 err:
-ok:
        pr_debug("xt_qtaguid: ctrl(%s): res=%d\n", input, res);
        return res;
 }
@@ -1099,6 +1223,57 @@ static int qtaguid_ctrl_proc_write(struct file *file, const char __user *buffer,
        return qtaguid_ctrl_parse(input_buf, count);
 }
 
+static int print_stats_line(char *outp, int char_count, int item_index,
+                           char *ifname, tag_t tag,
+                           struct data_counters *counters)
+{
+       int len;
+       if (!item_index)
+               len = snprintf(outp, char_count,
+                        "idx iface acct_tag_hex uid_tag_int "
+                        "rx_bytes rx_packets "
+                        "tx_bytes tx_packets "
+                        "rx_tcp_packets rx_tcp_bytes "
+                        "rx_udp_packets rx_udp_bytes "
+                        "rx_other_packets rx_other_bytes "
+                        "tx_tcp_packets tx_tcp_bytes "
+                        "tx_udp_packets tx_udp_bytes "
+                        "tx_other_packets tx_other_bytes\n");
+       else
+               len = snprintf(outp, char_count,
+                              "%d %s 0x%llx %u "
+                              "%llu %llu "
+                              "%llu %llu "
+                              "%llu %llu "
+                              "%llu %llu "
+                              "%llu %llu "
+                              "%llu %llu "
+                              "%llu %llu "
+                              "%llu %llu\n",
+                              item_index,
+                              ifname,
+                              get_atag_from_tag(tag),
+                              get_uid_from_tag(tag),
+                              dc_sum_bytes(counters, IFS_RX),
+                              dc_sum_packets(counters, IFS_RX),
+                              dc_sum_bytes(counters, IFS_TX),
+                              dc_sum_packets(counters, IFS_TX),
+                              counters->bpc[IFS_RX][IFS_TCP].bytes,
+                              counters->bpc[IFS_RX][IFS_TCP].packets,
+                              counters->bpc[IFS_RX][IFS_UDP].bytes,
+                              counters->bpc[IFS_RX][IFS_UDP].packets,
+                              counters->bpc[IFS_RX][IFS_PROTO_OTHER].bytes,
+                              counters->bpc[IFS_RX][IFS_PROTO_OTHER].packets,
+                              counters->bpc[IFS_TX][IFS_TCP].bytes,
+                              counters->bpc[IFS_TX][IFS_TCP].packets,
+                              counters->bpc[IFS_TX][IFS_UDP].bytes,
+                              counters->bpc[IFS_TX][IFS_UDP].packets,
+                              counters->bpc[IFS_TX][IFS_PROTO_OTHER].bytes,
+                              counters->bpc[IFS_TX][IFS_PROTO_OTHER].packets);
+       return len;
+}
+
+
 /*
  * Procfs reader to get all tag stats using style "1)" as described in
  * fs/proc/generic.c
@@ -1126,9 +1301,8 @@ static int qtaguid_stats_proc_read(char *page, char **num_items_returned,
 
        if (!items_to_skip) {
                /* The idx is there to help debug when things go belly up. */
-               len = snprintf(outp, char_count,
-                       "idx iface acct_tag_hex uid_tag_int rx_bytes "
-                       "tx_bytes\n");
+               len = print_stats_line(outp, char_count, /*index*/0, NULL,
+                                      make_tag_from_uid(0), NULL);
                /* Don't advance the outp unless the whole line was printed */
                if (len >= char_count) {
                        *outp = '\0';
@@ -1137,7 +1311,6 @@ static int qtaguid_stats_proc_read(char *page, char **num_items_returned,
                outp += len;
                char_count -= len;
        }
-
        spin_lock_irqsave(&iface_stat_list_lock, flags);
        list_for_each_entry(iface_entry, &iface_stat_list, list) {
                struct rb_node *node;
@@ -1145,26 +1318,21 @@ static int qtaguid_stats_proc_read(char *page, char **num_items_returned,
                for (node = rb_first(&iface_entry->tag_stat_tree);
                     node;
                     node = rb_next(node)) {
-                       ts_entry =  rb_entry(node, struct tag_stat, node);
+                       ts_entry = rb_entry(node, struct tag_stat, node);
                        if (item_index++ < items_to_skip)
                                continue;
-                       len = snprintf(outp, char_count,
-                                      "%d %s 0x%llx %u %llu %llu\n",
-                                      item_index,
-                                      iface_entry->ifname,
-                                      get_atag_from_tag(ts_entry->tag),
-                                      get_uid_from_tag(ts_entry->tag),
-                                      dc_sum_bytes(&ts_entry->counters,
-                                                   IFS_RX),
-                                      dc_sum_bytes(&ts_entry->counters,
-                                                   IFS_TX));
+                       len = print_stats_line(outp, char_count,
+                                              item_index,
+                                              iface_entry->ifname,
+                                              ts_entry->tag,
+                                              &ts_entry->counters);
                        if (len >= char_count) {
+                               *outp = '\0';
                                spin_unlock_irqrestore(
                                        &iface_entry->tag_stat_list_lock,
                                        flags2);
                                spin_unlock_irqrestore(
                                        &iface_stat_list_lock, flags);
-                               *outp = '\0';
                                return outp - page;
                        }
                        outp += len;