netfilter: xt_qtaguid: Fix socket refcounts when tagging
JP Abgrall [Wed, 17 Aug 2011 23:43:00 +0000 (16:43 -0700)]
* Don't hold the sockets after tagging.
  sockfd_lookup() does a get() on the associated file.
  There was no matching put() so a closed socket could never be
  freed.
* Don't rely on struct member order for tag_node
  The structs that had a struct tag_node member would work with
  the *_tree_* routines only because tag_node was 1st.
* Improve debug messages
  Provide info on who the caller is. Use unsigned int for uid.
* Only process NETDEV_UP events.
* Pacifier: disable netfilter matching. Leave .../stats header.

Change-Id: Iccb8ae3cca9608210c417597287a2391010dff2c
Signed-off-by: JP Abgrall <jpa@google.com>

net/netfilter/xt_qtaguid.c

index b04b471..b9dcfde 100644 (file)
@@ -183,7 +183,6 @@ struct tag_stat {
         * matching parent uid_tag.
         */
        struct data_counters *parent_counters;
-       struct proc_dir_entry *proc_ptr;
 };
 
 struct iface_stat {
@@ -215,8 +214,11 @@ struct iface_stat_work {
  * This is the tag against which tag_stat.counters will be billed.
  */
 struct sock_tag {
-       struct rb_node node;
-       struct sock *sk;
+       struct rb_node sock_node;
+       struct sock *sk;  /* Only used as a number, never dereferenced */
+       /* The socket is needed for sockfd_put() */
+       struct socket *socket;
+
        tag_t tag;
 };
 
@@ -345,24 +347,31 @@ static void tag_node_tree_insert(struct tag_node *data, struct rb_root *root)
 
 static void tag_stat_tree_insert(struct tag_stat *data, struct rb_root *root)
 {
-       tag_node_tree_insert((struct tag_node *)data, root);
+       tag_node_tree_insert(&data->tn, root);
 }
 
 static struct tag_stat *tag_stat_tree_search(struct rb_root *root, tag_t tag)
 {
-       return (struct tag_stat *)tag_node_tree_search(root, tag);
+       struct tag_node *node = tag_node_tree_search(root, tag);
+       if (!node)
+               return NULL;
+       return rb_entry(&node->node, struct tag_stat, tn.node);
 }
 
 static void tag_counter_set_tree_insert(struct tag_counter_set *data,
                                        struct rb_root *root)
 {
-       tag_node_tree_insert((struct tag_node *)data, root);
+       tag_node_tree_insert(&data->tn, root);
 }
 
 static struct tag_counter_set *tag_counter_set_tree_search(struct rb_root *root,
                                                           tag_t tag)
 {
-       return (struct tag_counter_set *)tag_node_tree_search(root, tag);
+       struct tag_node *node = tag_node_tree_search(root, tag);
+       if (!node)
+               return NULL;
+       return rb_entry(&node->node, struct tag_counter_set, tn.node);
+
 }
 
 static struct sock_tag *sock_tag_tree_search(struct rb_root *root,
@@ -371,7 +380,8 @@ static struct sock_tag *sock_tag_tree_search(struct rb_root *root,
        struct rb_node *node = root->rb_node;
 
        while (node) {
-               struct sock_tag *data = rb_entry(node, struct sock_tag, node);
+               struct sock_tag *data = rb_entry(node, struct sock_tag,
+                                                sock_node);
                ptrdiff_t result = sk - data->sk;
                if (result < 0)
                        node = node->rb_left;
@@ -389,7 +399,8 @@ static void sock_tag_tree_insert(struct sock_tag *data, struct rb_root *root)
 
        /* Figure out where to put new node */
        while (*new) {
-               struct sock_tag *this = rb_entry(*new, struct sock_tag, node);
+               struct sock_tag *this = rb_entry(*new, struct sock_tag,
+                                                sock_node);
                ptrdiff_t result = data->sk - this->sk;
                parent = *new;
                if (result < 0)
@@ -401,8 +412,8 @@ static void sock_tag_tree_insert(struct sock_tag *data, struct rb_root *root)
        }
 
        /* Add new node and rebalance tree. */
-       rb_link_node(&data->node, parent, new);
-       rb_insert_color(&data->node, root);
+       rb_link_node(&data->sock_node, parent, new);
+       rb_insert_color(&data->sock_node, root);
 }
 
 static int read_proc_u64(char *page, char **start, off_t off,
@@ -412,6 +423,7 @@ static int read_proc_u64(char *page, char **start, off_t off,
        uint64_t value;
        char *p = page;
        uint64_t *iface_entry = data;
+
        if (!data)
                return 0;
 
@@ -430,6 +442,7 @@ static int read_proc_bool(char *page, char **start, off_t off,
        bool value;
        char *p = page;
        bool *bool_entry = data;
+
        if (!data)
                return 0;
 
@@ -447,7 +460,7 @@ static int get_active_counter_set(tag_t tag)
        struct tag_counter_set *tcs;
 
        MT_DEBUG("qtaguid: get_active_counter_set(tag=0x%llx)"
-                " (uid=%d)\n",
+                " (uid=%u)\n",
                 tag, get_uid_from_tag(tag));
        /* For now we only handle UID tags for active sets */
        tag = get_utag_from_tag(tag);
@@ -469,11 +482,10 @@ static struct iface_stat *get_iface_entry(const char *ifname)
 
        /* Find the entry for tracking the specified tag within the interface */
        if (ifname == NULL) {
-               pr_info("iface_stat: NULL device name\n");
+               pr_info("qtaguid: iface_stat: get() NULL device name\n");
                return NULL;
        }
 
-
        /* Iterate over interfaces */
        list_for_each_entry(iface_entry, &iface_stat_list, list) {
                if (!strcmp(ifname, iface_entry->ifname))
@@ -525,12 +537,14 @@ static struct iface_stat *iface_alloc(const char *ifname)
 
        new_iface = kzalloc(sizeof(*new_iface), GFP_ATOMIC);
        if (new_iface == NULL) {
-               pr_err("qtaguid: iface_stat: create(): failed to alloc iface_stat\n");
+               pr_err("qtaguid: iface_stat: create(%s): "
+                      "iface_stat alloc failed\n", ifname);
                return NULL;
        }
        new_iface->ifname = kstrdup(ifname, GFP_ATOMIC);
        if (new_iface->ifname == NULL) {
-               pr_err("qtaguid: iface_stat: create(): failed to alloc ifname\n");
+               pr_err("qtaguid: iface_stat: create(%s): "
+                      "ifname alloc failed\n", ifname);
                kfree(new_iface);
                return NULL;
        }
@@ -544,8 +558,8 @@ static struct iface_stat *iface_alloc(const char *ifname)
         */
        isw = kmalloc(sizeof(*isw), GFP_ATOMIC);
        if (!isw) {
-               pr_err("qtaguid: iface_stat: create(): "
-                      "failed to alloc work for %s\n", new_iface->ifname);
+               pr_err("qtaguid: iface_stat: create(%s): "
+                      "work alloc failed\n", new_iface->ifname);
                kfree(new_iface->ifname);
                kfree(new_iface);
                return NULL;
@@ -571,10 +585,11 @@ void iface_stat_create(const struct net_device *net_dev,
        __be32 ipaddr = 0;
        struct iface_stat *new_iface;
 
-       IF_DEBUG("qtaguid: iface_stat: create(): ifa=%p netdev=%p->name=%s\n",
-                ifa, net_dev, net_dev ? net_dev->name : "");
+       IF_DEBUG("qtaguid: iface_stat: create(%s): ifa=%p netdev=%p\n",
+                net_dev ? net_dev->name : "?",
+                ifa, net_dev);
        if (!net_dev) {
-               pr_err("qtaguid: iface_stat: create(): no net dev!\n");
+               pr_err("qtaguid: iface_stat: create(): no net dev\n");
                return;
        }
 
@@ -582,16 +597,16 @@ void iface_stat_create(const struct net_device *net_dev,
        if (!ifa) {
                in_dev = in_dev_get(net_dev);
                if (!in_dev) {
-                       pr_err("qtaguid: iface_stat: create(): "
-                              "no inet dev for %s!\n", ifname);
+                       pr_err("qtaguid: iface_stat: create(%s): no inet dev\n",
+                              ifname);
                        return;
                }
-               IF_DEBUG("qtaguid: iface_stat: create(): in_dev=%p ifname=%p\n",
-                        in_dev, ifname);
+               IF_DEBUG("qtaguid: iface_stat: create(%s): in_dev=%p\n",
+                        ifname, in_dev);
                for (ifa = in_dev->ifa_list; ifa; ifa = ifa->ifa_next) {
-                       IF_DEBUG("qtaguid: iface_stat: create(): "
-                                "ifa=%p ifname=%s ifa_label=%s\n",
-                                ifa, ifname,
+                       IF_DEBUG("qtaguid: iface_stat: create(%s): "
+                                "ifa=%p ifa_label=%s\n",
+                                ifname, ifa,
                                 ifa->ifa_label ? ifa->ifa_label : "(null)");
                        if (ifa->ifa_label && !strcmp(ifname, ifa->ifa_label))
                                break;
@@ -599,8 +614,7 @@ void iface_stat_create(const struct net_device *net_dev,
        }
 
        if (!ifa) {
-               IF_DEBUG("qtaguid: iface_stat: create(): "
-                        "dev %s has no matching IP\n",
+               IF_DEBUG("qtaguid: iface_stat: create(%s): no matching IP\n",
                         ifname);
                goto done_put;
        }
@@ -609,29 +623,29 @@ void iface_stat_create(const struct net_device *net_dev,
        spin_lock_bh(&iface_stat_list_lock);
        entry = get_iface_entry(ifname);
        if (entry != NULL) {
-               IF_DEBUG("qtaguid: iface_stat: create(): dev %s entry=%p\n",
+               IF_DEBUG("qtaguid: iface_stat: create(%s): entry=%p\n",
                         ifname, entry);
                if (ipv4_is_loopback(ipaddr)) {
                        entry->active = false;
-                       IF_DEBUG("qtaguid: iface_stat: create(): "
-                                "disable tracking of loopback dev %s\n",
+                       IF_DEBUG("qtaguid: iface_stat: create(%s): "
+                                "disable tracking of loopback dev\n",
                                 ifname);
                } else {
                        entry->active = true;
-                       IF_DEBUG("qtaguid: iface_stat: create(): "
-                                "enable tracking of dev %s with ip=%pI4\n",
+                       IF_DEBUG("qtaguid: iface_stat: create(%s): "
+                                "enable tracking. ip=%pI4\n",
                                 ifname, &ipaddr);
                }
                goto done_unlock_put;
        } else if (ipv4_is_loopback(ipaddr)) {
-               IF_DEBUG("qtaguid: iface_stat: create(): ignore loopback dev %s"
-                        " ip=%pI4\n", ifname, &ipaddr);
+               IF_DEBUG("qtaguid: iface_stat: create(%s): "
+                        "ignore loopback dev. ip=%pI4\n", ifname, &ipaddr);
                goto done_unlock_put;
        }
 
        new_iface = iface_alloc(ifname);
-       IF_DEBUG("qtaguid: iface_stat: create(): done "
-                "entry=%p dev=%s ip=%pI4\n", new_iface, ifname, &ipaddr);
+       IF_DEBUG("qtaguid: iface_stat: create(%s): done "
+                "entry=%p ip=%pI4\n", ifname, new_iface, &ipaddr);
 
 done_unlock_put:
        spin_unlock_bh(&iface_stat_list_lock);
@@ -659,17 +673,16 @@ void iface_stat_create_ipv6(const struct net_device *net_dev,
 
        in_dev = in_dev_get(net_dev);
        if (!in_dev) {
-               pr_err("qtaguid: iface_stat: create6(): no inet dev for %s!\n",
+               pr_err("qtaguid: iface_stat: create6(%s): no inet dev\n",
                       ifname);
                return;
        }
 
-       IF_DEBUG("qtaguid: iface_stat: create6(): in_dev=%p ifname=%p\n",
-                in_dev, ifname);
+       IF_DEBUG("qtaguid: iface_stat: create6(%s): in_dev=%p\n",
+                ifname, in_dev);
 
        if (!ifa) {
-               IF_DEBUG("qtaguid: iface_stat: create6(): "
-                        "dev %s has no matching IP\n",
+               IF_DEBUG("qtaguid: iface_stat: create6(%s): no matching IP\n",
                         ifname);
                goto done_put;
        }
@@ -678,30 +691,30 @@ void iface_stat_create_ipv6(const struct net_device *net_dev,
        spin_lock_bh(&iface_stat_list_lock);
        entry = get_iface_entry(ifname);
        if (entry != NULL) {
-               IF_DEBUG("qtaguid: iface_stat: create6(): dev %s entry=%p\n",
+               IF_DEBUG("qtaguid: iface_stat: create6(%s): entry=%p\n",
                         ifname, entry);
                if (addr_type & IPV6_ADDR_LOOPBACK) {
                        entry->active = false;
-                       IF_DEBUG("qtaguid: iface_stat: create6(): "
-                                "disable tracking of loopback dev %s\n",
+                       IF_DEBUG("qtaguid: iface_stat: create6(%s): "
+                                "disable tracking of loopback dev\n",
                                 ifname);
                } else {
                        entry->active = true;
-                       IF_DEBUG("qtaguid: iface_stat: create6(): "
-                                "enable tracking of dev %s with ip=%pI6c\n",
+                       IF_DEBUG("qtaguid: iface_stat: create6(%s): "
+                                "enable tracking. ip=%pI6c\n",
                                 ifname, &ifa->addr);
                }
                goto done_unlock_put;
        } else if (addr_type & IPV6_ADDR_LOOPBACK) {
-               IF_DEBUG("qtaguid: iface_stat: create6(): "
-                        "ignore loopback dev %s ip=%pI6c\n",
+               IF_DEBUG("qtaguid: iface_stat: create6(%s): "
+                        "ignore loopback dev. ip=%pI6c\n",
                         ifname, &ifa->addr);
                goto done_unlock_put;
        }
 
        new_iface = iface_alloc(ifname);
-       IF_DEBUG("qtaguid: iface_stat: create6(): done "
-                "entry=%p dev=%s ip=%pI6c\n", new_iface, ifname, &ifa->addr);
+       IF_DEBUG("qtaguid: iface_stat: create6(%s): done "
+                "entry=%p ip=%pI6c\n", ifname, new_iface, &ifa->addr);
 
 done_unlock_put:
        spin_unlock_bh(&iface_stat_list_lock);
@@ -760,22 +773,22 @@ static void iface_stat_update(struct net_device *dev)
        spin_lock_bh(&iface_stat_list_lock);
        entry = get_iface_entry(dev->name);
        if (entry == NULL) {
-               IF_DEBUG("qtaguid: iface_stat: dev %s monitor not found\n",
+               IF_DEBUG("qtaguid: iface_stat_update: dev=%s not tracked\n",
                         dev->name);
                spin_unlock_bh(&iface_stat_list_lock);
                return;
        }
+       IF_DEBUG("qtaguid: iface_stat_update: dev=%s entry=%p\n",
+                dev->name, entry);
        if (entry->active) {
                entry->tx_bytes += stats->tx_bytes;
                entry->tx_packets += stats->tx_packets;
                entry->rx_bytes += stats->rx_bytes;
                entry->rx_packets += stats->rx_packets;
                entry->active = false;
-               IF_DEBUG("qtaguid: iface_stat: Updating stats for "
-                       "dev %s which went down\n", dev->name);
        } else {
-               IF_DEBUG("qtaguid: iface_stat: Did not update stats for "
-                       "dev %s which went down\n", dev->name);
+               IF_DEBUG("qtaguid: iface_stat_update: dev=%s inactive\n",
+                       dev->name);
        }
        spin_unlock_bh(&iface_stat_list_lock);
 }
@@ -785,7 +798,7 @@ static void tag_stat_update(struct tag_stat *tag_entry,
 {
        int active_set;
        active_set = get_active_counter_set(tag_entry->tn.tag);
-       MT_DEBUG("qtaguid: tag_stat_update(tag=0x%llx (uid=%d) set=%d "
+       MT_DEBUG("qtaguid: tag_stat_update(tag=0x%llx (uid=%u) set=%d "
                 "dir=%d proto=%d bytes=%d)\n",
                 tag_entry->tn.tag, get_uid_from_tag(tag_entry->tn.tag),
                 active_set, direction, proto, bytes);
@@ -806,11 +819,11 @@ static struct tag_stat *create_if_tag_stat(struct iface_stat *iface_entry,
 {
        struct tag_stat *new_tag_stat_entry = NULL;
        IF_DEBUG("qtaguid: iface_stat: create_if_tag_stat(): ife=%p tag=0x%llx"
-                " (uid=%d)\n",
+                " (uid=%u)\n",
                 iface_entry, tag, get_uid_from_tag(tag));
        new_tag_stat_entry = kzalloc(sizeof(*new_tag_stat_entry), GFP_ATOMIC);
        if (!new_tag_stat_entry) {
-               pr_err("qtaguid: iface_stat: failed to alloc new tag entry\n");
+               pr_err("qtaguid: iface_stat: tag stat alloc failed\n");
                goto done;
        }
        new_tag_stat_entry->tn.tag = tag;
@@ -831,19 +844,20 @@ static void if_tag_stat_update(const char *ifname, uid_t uid,
        struct iface_stat *iface_entry;
        struct tag_stat *new_tag_stat;
        MT_DEBUG("qtaguid: if_tag_stat_update(ifname=%s "
-               "uid=%d sk=%p dir=%d proto=%d bytes=%d)\n",
+               "uid=%u sk=%p dir=%d proto=%d bytes=%d)\n",
                 ifname, uid, sk, direction, proto, bytes);
 
 
        iface_entry = get_iface_entry(ifname);
        if (!iface_entry) {
-               pr_err("qtaguid: iface_stat: interface %s not found\n", ifname);
+               pr_err("qtaguid: iface_stat: stat_update() %s not found\n",
+                      ifname);
                return;
        }
        /* It is ok to process data when an iface_entry is inactive */
 
-       MT_DEBUG("qtaguid: iface_stat: stat_update() got entry=%p\n",
-                iface_entry);
+       MT_DEBUG("qtaguid: iface_stat: stat_update() dev=%s entry=%p\n",
+                ifname, iface_entry);
 
        /*
         * Look for a tagged sock.
@@ -860,7 +874,7 @@ static void if_tag_stat_update(const char *ifname, uid_t uid,
                tag = combine_atag_with_uid(acct_tag, uid);
        }
        MT_DEBUG("qtaguid: iface_stat: stat_update(): "
-                " looking for tag=0x%llx (uid=%d) in ife=%p\n",
+                " looking for tag=0x%llx (uid=%u) in ife=%p\n",
                 tag, get_uid_from_tag(tag), iface_entry);
        /* Loop over tag list under this interface for {acct_tag,uid_tag} */
        spin_lock_bh(&iface_entry->tag_stat_list_lock);
@@ -913,12 +927,6 @@ static int iface_netdev_event_handler(struct notifier_block *nb,
 
        switch (event) {
        case NETDEV_UP:
-       case NETDEV_REBOOT:
-       case NETDEV_CHANGE:
-       case NETDEV_REGISTER:  /* Most likely no IP */
-       case NETDEV_CHANGEADDR:  /* MAC addr change */
-       case NETDEV_CHANGENAME:
-       case NETDEV_FEAT_CHANGE:  /* Might be usefull when cell type changes */
                iface_stat_create(dev, NULL);
                break;
        case NETDEV_UNREGISTER:
@@ -992,25 +1000,26 @@ static int __init iface_stat_init(struct proc_dir_entry *parent_procdir)
 
        iface_stat_procdir = proc_mkdir(iface_stat_procdirname, parent_procdir);
        if (!iface_stat_procdir) {
-               pr_err("qtaguid: iface_stat: failed to create proc entry\n");
+               pr_err("qtaguid: iface_stat: init failed to create proc entry\n");
                err = -1;
                goto err;
        }
        err = register_netdevice_notifier(&iface_netdev_notifier_blk);
        if (err) {
-               pr_err("qtaguid: iface_stat: failed to register dev event handler\n");
+               pr_err("qtaguid: iface_stat: init "
+                      "failed to register dev event handler\n");
                goto err_zap_entry;
        }
        err = register_inetaddr_notifier(&iface_inetaddr_notifier_blk);
        if (err) {
-               pr_err("qtaguid: iface_stat: "
+               pr_err("qtaguid: iface_stat: init "
                       "failed to register ipv4 dev event handler\n");
                goto err_unreg_nd;
        }
 
        err = register_inet6addr_notifier(&iface_inet6addr_notifier_blk);
        if (err) {
-               pr_err("qtaguid: iface_stat: "
+               pr_err("qtaguid: iface_stat: init "
                       "failed to register ipv6 dev event handler\n");
                goto err_unreg_ip4_addr;
        }
@@ -1116,6 +1125,9 @@ static bool qtaguid_mt(const struct sk_buff *skb, struct xt_action_param *par)
        uid_t sock_uid;
        bool res;
 
+       if (unlikely(module_passive))
+               return (info->match ^ info->invert) == 0;
+
        MT_DEBUG("qtaguid[%d]: entered skb=%p par->in=%p/out=%p fam=%d\n",
                 par->hooknum, skb, par->in, par->out, par->family);
 
@@ -1145,7 +1157,7 @@ static bool qtaguid_mt(const struct sk_buff *skb, struct xt_action_param *par)
                        par->hooknum, sk, sk->sk_socket,
                        sk->sk_socket ? sk->sk_socket->file : (void *)-1LL);
                filp = sk->sk_socket ? sk->sk_socket->file : NULL;
-               MT_DEBUG("qtaguid[%d]: filp...uid=%d\n",
+               MT_DEBUG("qtaguid[%d]: filp...uid=%u\n",
                        par->hooknum, filp ? filp->f_cred->fsuid : -1);
        }
 
@@ -1254,12 +1266,13 @@ static int qtaguid_ctrl_proc_read(char *page, char **num_items_returned,
             node = rb_next(node)) {
                if (item_index++ < items_to_skip)
                        continue;
-               sock_tag_entry = rb_entry(node, struct sock_tag, node);
+               sock_tag_entry = rb_entry(node, struct sock_tag, sock_node);
                uid = get_uid_from_tag(sock_tag_entry->tag);
-               CT_DEBUG("qtaguid: proc_read(): sk=%p tag=0x%llx (uid=%d)\n",
-                       sock_tag_entry->sk,
-                       sock_tag_entry->tag,
-                       uid);
+               CT_DEBUG("qtaguid: proc_read(): sk=%p tag=0x%llx (uid=%u)\n",
+                        sock_tag_entry->sk,
+                        sock_tag_entry->tag,
+                        uid
+                       );
                len = snprintf(outp, char_count,
                               "sock=%p tag=0x%llx (uid=%u)\n",
                               sock_tag_entry->sk, sock_tag_entry->tag, uid);
@@ -1315,10 +1328,9 @@ static int ctrl_cmd_delete(const char *input)
        struct tag_stat *ts_entry;
        struct tag_counter_set *tcs_entry;
 
-       CT_DEBUG("qtaguid: ctrl_delete(%s): entered\n", input);
        argc = sscanf(input, "%c %llu %u", &cmd, &acct_tag, &uid);
        CT_DEBUG("qtaguid: ctrl_delete(%s): argc=%d cmd=%c "
-                "acct_tag=0x%llx uid=%u\n", input, argc, cmd,
+                "user_tag=0x%llx uid=%u\n", input, argc, cmd,
                 acct_tag, uid);
        if (argc < 2) {
                res = -EINVAL;
@@ -1332,8 +1344,9 @@ static int ctrl_cmd_delete(const char *input)
        if (argc < 3) {
                uid = current_fsuid();
        } else if (!can_impersonate_uid(uid)) {
-               pr_info("qtaguid: ctrl_delete(%s): insuficient priv\n",
-                       input);
+               pr_info("qtaguid: ctrl_delete(%s): "
+                       "insufficient priv from pid=%u uid=%u\n",
+                       input, current->pid, current_fsuid());
                res = -EPERM;
                goto err;
        }
@@ -1342,7 +1355,7 @@ static int ctrl_cmd_delete(const char *input)
        spin_lock_bh(&sock_tag_list_lock);
        node = rb_first(&sock_tag_tree);
        while (node) {
-               st_entry = rb_entry(node, struct sock_tag, node);
+               st_entry = rb_entry(node, struct sock_tag, sock_node);
                entry_uid = get_uid_from_tag(st_entry->tag);
                node = rb_next(node);
                if (entry_uid != uid)
@@ -1350,11 +1363,12 @@ static int ctrl_cmd_delete(const char *input)
 
                if (!acct_tag || st_entry->tag == tag) {
                        CT_DEBUG("qtaguid: ctrl_delete(): "
-                                "erase st: sk=%p tag=0x%llx (uid=%d)\n",
+                                "erase st: sk=%p tag=0x%llx (uid=%u)\n",
                                 st_entry->sk,
                                 st_entry->tag,
                                 entry_uid);
-                       rb_erase(&st_entry->node, &sock_tag_tree);
+                       rb_erase(&st_entry->sock_node, &sock_tag_tree);
+                       sockfd_put(st_entry->socket);
                        kfree(st_entry);
                }
        }
@@ -1367,7 +1381,7 @@ static int ctrl_cmd_delete(const char *input)
        tcs_entry = tag_counter_set_tree_search(&tag_counter_set_tree, tag);
        if (tcs_entry) {
                CT_DEBUG("qtaguid: ctrl_delete(): "
-                        "erase tcs: tag=0x%llx (uid=%d) set=%d\n",
+                        "erase tcs: tag=0x%llx (uid=%u) set=%d\n",
                         tcs_entry->tn.tag,
                         get_uid_from_tag(tcs_entry->tn.tag),
                         tcs_entry->active_set);
@@ -1408,7 +1422,6 @@ static int ctrl_cmd_delete(const char *input)
        res = 0;
 
 err:
-       CT_DEBUG("qtaguid: ctrl_delete(%s) res=%d\n", input, res);
        return res;
 }
 
@@ -1421,7 +1434,6 @@ static int ctrl_cmd_counter_set(const char *input)
        struct tag_counter_set *tcs;
        int counter_set;
 
-       CT_DEBUG("qtaguid: ctrl_counterset(%s): entered\n", input);
        argc = sscanf(input, "%c %d %u", &cmd, &counter_set, &uid);
        CT_DEBUG("qtaguid: ctrl_counterset(%s): argc=%d cmd=%c "
                 "set=%d uid=%u\n", input, argc, cmd,
@@ -1437,8 +1449,9 @@ static int ctrl_cmd_counter_set(const char *input)
                goto err;
        }
        if (!can_manipulate_uids()) {
-               pr_info("qtaguid: ctrl_counterset(%s): insufficient priv\n",
-                       input);
+               pr_info("qtaguid: ctrl_counterset(%s): "
+                       "insufficient priv from pid=%u uid=%u\n",
+                       input, current->pid, current_fsuid());
                res = -EPERM;
                goto err;
        }
@@ -1459,7 +1472,7 @@ static int ctrl_cmd_counter_set(const char *input)
                tcs->tn.tag = tag;
                tag_counter_set_tree_insert(tcs, &tag_counter_set_tree);
                CT_DEBUG("qtaguid: ctrl_counterset(%s): added tcs tag=0x%llx "
-                        "(uid=%d) set=%d\n",
+                        "(uid=%u) set=%d\n",
                         input, tag, get_uid_from_tag(tag), counter_set);
        }
        tcs->active_set = counter_set;
@@ -1468,7 +1481,6 @@ static int ctrl_cmd_counter_set(const char *input)
        res = 0;
 
 err:
-       CT_DEBUG("qtaguid: ctrl_counterset(%s) res=%d\n", input, res);
        return res;
 }
 
@@ -1479,6 +1491,7 @@ static int ctrl_cmd_tag(const char *input)
        uid_t uid = 0;
        tag_t acct_tag = 0;
        struct socket *el_socket;
+       int refcnt = -1;
        int res, argc;
        struct sock_tag *sock_tag_entry;
 
@@ -1491,55 +1504,80 @@ static int ctrl_cmd_tag(const char *input)
                res = -EINVAL;
                goto err;
        }
-       el_socket = sockfd_lookup(sock_fd, &res);
+       el_socket = sockfd_lookup(sock_fd, &res);  /* This locks the file */
        if (!el_socket) {
                pr_info("qtaguid: ctrl_tag(%s): failed to lookup"
                        " sock_fd=%d err=%d\n", input, sock_fd, res);
                goto err;
        }
+       refcnt = atomic_read(&el_socket->file->f_count);
+       CT_DEBUG("qtaguid: ctrl_tag(%s): socket->...->f_count=%d\n",
+                input, refcnt);
        if (argc < 3) {
                acct_tag = 0;
        } else if (!valid_atag(acct_tag)) {
                pr_info("qtaguid: ctrl_tag(%s): invalid tag\n", input);
                res = -EINVAL;
-               goto err;
-       }
+               goto err_put;
+       }
+       CT_DEBUG("qtaguid: ctrl_tag(%s): "
+                "uid=%u euid=%u fsuid=%u "
+                "in_group=%d in_egroup=%d\n",
+                input, current_uid(), current_euid(), current_fsuid(),
+                in_group_p(proc_stats_readall_gid),
+                in_egroup_p(proc_stats_readall_gid));
        if (argc < 4) {
                uid = current_fsuid();
        } else if (!can_impersonate_uid(uid)) {
-               pr_info("qtaguid: ctrl_tag(%s): insuficient priv\n",
-                       input);
+               pr_info("qtaguid: ctrl_tag(%s): "
+                       "insufficient priv from pid=%u uid=%u\n",
+                       input, current->pid, current_fsuid());
                res = -EPERM;
-               goto err;
+               goto err_put;
        }
 
        spin_lock_bh(&sock_tag_list_lock);
        sock_tag_entry = get_sock_stat_nl(el_socket->sk);
        if (sock_tag_entry) {
+               /*
+                * This is a re-tagging, so release the sock_fd that was
+                * locked at the time of the 1st tagging.
+                */
+               sockfd_put(sock_tag_entry->socket);
+               refcnt--;
                sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
                                                            uid);
        } else {
                sock_tag_entry = kzalloc(sizeof(*sock_tag_entry),
                                         GFP_ATOMIC);
                if (!sock_tag_entry) {
+                       pr_err("qtaguid: ctrl_tag(%s): "
+                              "socket tag alloc failed\n",
+                              input);
+                       spin_unlock_bh(&sock_tag_list_lock);
                        res = -ENOMEM;
-                       goto err;
+                       goto err_put;
                }
                sock_tag_entry->sk = el_socket->sk;
+               sock_tag_entry->socket = el_socket;
                sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
                                                            uid);
                sock_tag_tree_insert(sock_tag_entry, &sock_tag_tree);
        }
        spin_unlock_bh(&sock_tag_list_lock);
+       /* We keep the ref to the socket (file) until it is untagged */
+       CT_DEBUG("qtaguid: ctrl_tag(%s): done. socket->...->f_count=%d\n",
+                input,
+                el_socket ? atomic_read(&el_socket->file->f_count) : -1);
+       return 0;
 
-       CT_DEBUG("qtaguid: tag: sock_tag_entry->sk=%p "
-                "...->tag=0x%llx (uid=%u)\n",
-                sock_tag_entry->sk, sock_tag_entry->tag,
-                get_uid_from_tag(sock_tag_entry->tag));
-       res = 0;
-
+err_put:
+       /* Release the sock_fd that was grabbed by sockfd_lookup(). */
+       sockfd_put(el_socket);
+       refcnt--;
 err:
-       CT_DEBUG("qtaguid: ctrl_tag(%s) res=%d\n", input, res);
+       CT_DEBUG("qtaguid: ctrl_tag(%s): done. socket->...->f_count=%d\n",
+                input, refcnt);
        return res;
 }
 
@@ -1548,10 +1586,10 @@ static int ctrl_cmd_untag(const char *input)
        char cmd;
        int sock_fd = 0;
        struct socket *el_socket;
+       int refcnt = -1;
        int res, argc;
        struct sock_tag *sock_tag_entry;
 
-       CT_DEBUG("qtaguid: ctrl_untag(%s): entered\n", input);
        argc = sscanf(input, "%c %d", &cmd, &sock_fd);
        CT_DEBUG("qtaguid: ctrl_untag(%s): argc=%d cmd=%c sock_fd=%d\n",
                 input, argc, cmd, sock_fd);
@@ -1559,28 +1597,49 @@ static int ctrl_cmd_untag(const char *input)
                res = -EINVAL;
                goto err;
        }
-       el_socket = sockfd_lookup(sock_fd, &res);
+       el_socket = sockfd_lookup(sock_fd, &res);  /* This locks the file */
        if (!el_socket) {
                pr_info("qtaguid: ctrl_untag(%s): failed to lookup"
                        " sock_fd=%d err=%d\n", input, sock_fd, res);
                goto err;
        }
+       refcnt = atomic_read(&el_socket->file->f_count);
+       CT_DEBUG("qtaguid: ctrl_untag(%s): socket->...->f_count=%d\n",
+                input, refcnt);
        spin_lock_bh(&sock_tag_list_lock);
        sock_tag_entry = get_sock_stat_nl(el_socket->sk);
        if (!sock_tag_entry) {
                spin_unlock_bh(&sock_tag_list_lock);
                res = -EINVAL;
-               goto err;
+               goto err_put;
        }
-       /* The socket already belongs to the current process
-        * so it can do whatever it wants to it. */
-       rb_erase(&sock_tag_entry->node, &sock_tag_tree);
+       /*
+        * The socket already belongs to the current process
+        * so it can do whatever it wants to it.
+        */
+       rb_erase(&sock_tag_entry->sock_node, &sock_tag_tree);
+
+       /*
+        * Release the sock_fd that was grabbed at tag time,
+        * and once more for the sockfd_lookup() here.
+        */
+       sockfd_put(sock_tag_entry->socket);
        spin_unlock_bh(&sock_tag_list_lock);
+       sockfd_put(el_socket);
+       refcnt -= 2;
        kfree(sock_tag_entry);
+       CT_DEBUG("qtaguid: ctrl_untag(%s): done. socket->...->f_count=%d\n",
+                input, refcnt);
 
-       res = 0;
+       return 0;
+
+err_put:
+       /* Release the sock_fd that was grabbed by sockfd_lookup(). */
+       sockfd_put(el_socket);
+       refcnt--;
 err:
-       CT_DEBUG("qtaguid: ctrl_untag(%s): res=%d\n", input, res);
+       CT_DEBUG("qtaguid: ctrl_untag(%s): done. socket->...->f_count=%d\n",
+                input, refcnt);
        return res;
 }
 
@@ -1589,7 +1648,6 @@ static int qtaguid_ctrl_parse(const char *input, int count)
        char cmd;
        int res;
 
-       CT_DEBUG("qtaguid: ctrl(%s): entered\n", input);
        cmd = input[0];
        /* Collect params for commands */
        switch (cmd) {
@@ -1667,10 +1725,12 @@ static int pp_stats_line(struct proc_print_info *ppi, int cnt_set)
                tag_t tag = ppi->ts_entry->tn.tag;
                uid_t stat_uid = get_uid_from_tag(tag);
                if (!can_read_other_uid_stats(stat_uid)) {
-                       CT_DEBUG("qtaguid: insufficient priv for stat line:"
-                                "%s 0x%llx %u\n",
+                       CT_DEBUG("qtaguid: stats line: "
+                                "%s 0x%llx %u: "
+                                "insufficient priv from pid=%u uid=%u\n",
                                 ppi->iface_entry->ifname,
-                                get_atag_from_tag(tag), stat_uid);
+                                get_atag_from_tag(tag), stat_uid,
+                                current->pid, current_fsuid());
                        return 0;
                }
                cnts = &ppi->ts_entry->counters;
@@ -1748,8 +1808,11 @@ static int qtaguid_stats_proc_read(char *page, char **num_items_returned,
        ppi.num_items_returned = num_items_returned;
 
        if (unlikely(module_passive)) {
+               len = pp_stats_line(&ppi, 0);
+               /* The header should always be shorter than the buffer. */
+               WARN_ON(len >= ppi.char_count);
                *eof = 1;
-               return 0;
+               return len;
        }
 
        CT_DEBUG("qtaguid:proc stats page=%p *num_items_returned=%p off=%ld "