nf: qtaguid: workaround xt_socket_get_sk() returning bad SKs.
[linux-2.6.git] / net / netfilter / xt_qtaguid.c
1 /*
2  * Kernel iptables module to track stats for packets based on user tags.
3  *
4  * (C) 2011 Google, Inc
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  */
10
11 /* TODO: support ipv6 for iface_stat */
12
13 #include <linux/file.h>
14 #include <linux/inetdevice.h>
15 #include <linux/module.h>
16 #include <linux/netfilter/x_tables.h>
17 #include <linux/netfilter/xt_qtaguid.h>
18 #include <linux/skbuff.h>
19 #include <linux/workqueue.h>
20 #include <net/sock.h>
21 #include <net/tcp.h>
22 #include <net/udp.h>
23
24 #include <linux/netfilter/xt_socket.h>
25 /* We only use the xt_socket funcs within a similar context to avoid unexpected
26  * return values. */
27 #define XT_SOCKET_SUPPORTED_HOOKS \
28         ((1 << NF_INET_PRE_ROUTING) | (1 << NF_INET_LOCAL_IN))
29
30
31 /*---------------------------------------------------------------------------*/
32 /*
33  * Tags:
34  *
35  * They represent what the data usage counters will be tracked against.
36  * By default a tag is just based on the UID.
37  * The UID is used as the base for policying, and can not be ignored.
38  * So a tag will always at least represent a UID (uid_tag).
39  *
40  * A tag can be augmented with an "accounting tag" which is associated
41  * with a UID.
42  * User space can set the acct_tag portion of the tag which is then used
43  * with sockets: all data belong to that socket will be counted against the
44  * tag. The policing is then based on the tag's uid_tag portion,
45  * and stats are collected for the acct_tag portion seperately.
46  *
47  * There could be
48  * a:  {acct_tag=1, uid_tag=10003}
49  * b:  {acct_tag=2, uid_tag=10003}
50  * c:  {acct_tag=3, uid_tag=10003}
51  * d:  {acct_tag=0, uid_tag=10003}
52  * (a, b, and c represent tags associated with specific sockets.
53  * d is for the totals for that uid, including all untagged traffic.
54  * Typically d is used with policing/quota rules.
55  *
56  * We want tag_t big enough to distinguish uid_t and acct_tag.
57  * It might become a struct if needed.
58  * Nothing should be using it as an int.
59  */
60 typedef uint64_t tag_t;  /* Only used via accessors */
61
62 static const char *iface_stat_procdirname = "iface_stat";
63 static struct proc_dir_entry *iface_stat_procdir;
64
65 enum ifs_tx_rx {
66         IFS_TX,
67         IFS_RX,
68         IFS_MAX_DIRECTIONS
69 };
70
71 /* For now, TCP, UDP, the rest */
72 enum ifs_proto {
73         IFS_TCP,
74         IFS_UDP,
75         IFS_PROTO_OTHER,
76         IFS_MAX_PROTOS
77 };
78
79 struct byte_packet_counters {
80         uint64_t bytes;
81         uint64_t packets;
82 };
83
84 struct data_counters {
85         struct byte_packet_counters bpc[IFS_MAX_DIRECTIONS][IFS_MAX_PROTOS];
86 };
87
88 struct tag_stat {
89         struct rb_node node;
90         tag_t tag;
91
92         struct data_counters counters;
93         /* If this tag is acct_tag based, we need to count against the
94          * matching parent uid_tag. */
95         struct data_counters *parent_counters;
96         struct proc_dir_entry *proc_ptr;
97 };
98
99 static LIST_HEAD(iface_stat_list);
100 static DEFINE_SPINLOCK(iface_stat_list_lock);
101
102 struct iface_stat {
103         struct list_head list;
104         char *ifname;
105         uint64_t rx_bytes;
106         uint64_t rx_packets;
107         uint64_t tx_bytes;
108         uint64_t tx_packets;
109         bool active;
110         struct proc_dir_entry *proc_ptr;
111
112         struct rb_root tag_stat_tree;
113         spinlock_t tag_stat_list_lock;
114 };
115
116
117 static struct rb_root sock_tag_tree = RB_ROOT;
118 static DEFINE_SPINLOCK(sock_tag_list_lock);
119
120 /*
121  * Track tag that this socket is transferring data for, and not necesseraly
122  * the uid that owns the socket.
123  * This is the tag against which tag_stat.counters will be billed.
124  */
125 struct sock_tag {
126         struct rb_node node;
127         struct sock *sk;
128         tag_t tag;
129 };
130
131 static bool qtaguid_mt(const struct sk_buff *skb, struct xt_action_param *par);
132
133 /*----------------------------------------------*/
134 static inline int tag_compare(tag_t t1, tag_t t2)
135 {
136         return t1 < t2 ? -1 : t1 == t2 ? 0 : 1;
137 }
138
139
140 static inline tag_t combine_atag_with_uid(tag_t acct_tag, uid_t uid)
141 {
142         return acct_tag | uid;
143 }
144 static inline tag_t make_tag_from_uid(uid_t uid)
145 {
146         return uid;
147 }
148 static inline uid_t get_uid_from_tag(tag_t tag)
149 {
150         return tag & 0xFFFFFFFFULL;
151 }
152 static inline tag_t get_utag_from_tag(tag_t tag)
153 {
154         return tag & 0xFFFFFFFFULL;
155 }
156 static inline tag_t get_atag_from_tag(tag_t tag)
157 {
158         return tag & ~0xFFFFFFFFULL;
159 }
160
161 static inline bool valid_atag(tag_t tag)
162 {
163         return !(tag & 0xFFFFFFFFULL);
164 }
165
166 static inline void dc_add_byte_packets(struct data_counters *counters,
167                                   enum ifs_tx_rx direction,
168                                   enum ifs_proto ifs_proto,
169                                   int bytes,
170                                   int packets)
171 {
172         counters->bpc[direction][ifs_proto].bytes += bytes;
173         counters->bpc[direction][ifs_proto].packets += packets;
174 }
175
176 static inline uint64_t dc_sum_bytes(struct data_counters *counters,
177                                     enum ifs_tx_rx direction)
178 {
179         return counters->bpc[direction][IFS_TCP].bytes
180                 + counters->bpc[direction][IFS_UDP].bytes
181                 + counters->bpc[direction][IFS_PROTO_OTHER].bytes;
182 }
183
184 static struct tag_stat *tag_stat_tree_search(struct rb_root *root, tag_t tag)
185 {
186         struct rb_node *node = root->rb_node;
187
188         while (node) {
189                 struct tag_stat *data = rb_entry(node, struct tag_stat, node);
190                 int result = tag_compare(tag, data->tag);
191                 pr_debug("qtaguid: tag_stat_tree_search(): tag=0x%llx"
192                          " (uid=%d)\n",
193                          data->tag,
194                          get_uid_from_tag(data->tag));
195
196                 if (result < 0)
197                         node = node->rb_left;
198                 else if (result > 0)
199                         node = node->rb_right;
200                 else
201                         return data;
202         }
203         return NULL;
204 }
205
206 static void tag_stat_tree_insert(struct tag_stat *data, struct rb_root *root)
207 {
208         struct rb_node **new = &(root->rb_node), *parent = NULL;
209
210         /* Figure out where to put new node */
211         while (*new) {
212                 struct tag_stat *this = rb_entry(*new, struct tag_stat,
213                                                  node);
214                 int result = tag_compare(data->tag, this->tag);
215                 pr_debug("qtaguid: tag_stat_tree_insert(): tag=0x%llx"
216                          " (uid=%d)\n",
217                          this->tag,
218                          get_uid_from_tag(this->tag));
219                 parent = *new;
220                 if (result < 0)
221                         new = &((*new)->rb_left);
222                 else if (result > 0)
223                         new = &((*new)->rb_right);
224                 else
225                         BUG();
226         }
227
228         /* Add new node and rebalance tree. */
229         rb_link_node(&data->node, parent, new);
230         rb_insert_color(&data->node, root);
231 }
232
233 static struct sock_tag *sock_tag_tree_search(struct rb_root *root,
234                                              const struct sock *sk)
235 {
236         struct rb_node *node = root->rb_node;
237
238         while (node) {
239                 struct sock_tag *data = rb_entry(node, struct sock_tag, node);
240                 ptrdiff_t result = sk - data->sk;
241                 if (result < 0)
242                         node = node->rb_left;
243                 else if (result > 0)
244                         node = node->rb_right;
245                 else
246                         return data;
247         }
248         return NULL;
249 }
250
251 static void sock_tag_tree_insert(struct sock_tag *data, struct rb_root *root)
252 {
253         struct rb_node **new = &(root->rb_node), *parent = NULL;
254
255         /* Figure out where to put new node */
256         while (*new) {
257                 struct sock_tag *this = rb_entry(*new, struct sock_tag, node);
258                 ptrdiff_t result = data->sk - this->sk;
259                 parent = *new;
260                 if (result < 0)
261                         new = &((*new)->rb_left);
262                 else if (result > 0)
263                         new = &((*new)->rb_right);
264                 else
265                         BUG();
266         }
267
268         /* Add new node and rebalance tree. */
269         rb_link_node(&data->node, parent, new);
270         rb_insert_color(&data->node, root);
271 }
272
273 static int read_proc_u64(char *page, char **start, off_t off,
274                         int count, int *eof, void *data)
275 {
276         int len;
277         uint64_t value;
278         char *p = page;
279         uint64_t *iface_entry = data;
280         if (!data)
281                 return 0;
282
283         value = *iface_entry;
284         p += sprintf(p, "%llu\n", value);
285         len = (p - page) - off;
286         *eof = (len <= count) ? 1 : 0;
287         *start = page + off;
288         return len;
289 }
290
291 static int read_proc_bool(char *page, char **start, off_t off,
292                         int count, int *eof, void *data)
293 {
294         int len;
295         bool value;
296         char *p = page;
297         bool *bool_entry = data;
298         if (!data)
299                 return 0;
300
301         value = *bool_entry;
302         p += sprintf(p, "%u\n", value);
303         len = (p - page) - off;
304         *eof = (len <= count) ? 1 : 0;
305         *start = page + off;
306         return len;
307 }
308
309 /* Find the entry for tracking the specified interface. */
310 static struct iface_stat *get_iface_stat(const char *ifname)
311 {
312         unsigned long flags;
313         struct iface_stat *iface_entry;
314         if (!ifname)
315                 return NULL;
316
317         spin_lock_irqsave(&iface_stat_list_lock, flags);
318         list_for_each_entry(iface_entry, &iface_stat_list, list) {
319                 if (!strcmp(iface_entry->ifname, ifname))
320                         goto done;
321         }
322         iface_entry = NULL;
323 done:
324         spin_unlock_irqrestore(&iface_stat_list_lock, flags);
325         return iface_entry;
326 }
327
328 /*
329  * Create a new entry for tracking the specified interface.
330  * Do nothing if the entry already exists.
331  * Called when an interface is configured with a valid IP address.
332  */
333 void iface_stat_create(const struct net_device *net_dev)
334 {
335         struct in_device *in_dev;
336         unsigned long flags;
337         struct iface_stat *new_iface;
338         struct proc_dir_entry *proc_entry;
339         const char *ifname;
340         struct iface_stat *entry;
341         __be32 ipaddr = 0;
342         struct in_ifaddr *ifa = NULL;
343
344         ASSERT_RTNL(); /* No need for separate locking */
345
346         pr_debug("iface_stat: create(): netdev=%p->name=%s\n",
347                  net_dev, net_dev ? net_dev->name : "");
348         if (!net_dev) {
349                 pr_err("iface_stat: create(): no net dev!\n");
350                 return;
351         }
352
353         in_dev = __in_dev_get_rtnl(net_dev);
354         if (!in_dev) {
355                 pr_err("iface_stat: create(): no inet dev!\n");
356                 return;
357         }
358
359         pr_debug("iface_stat: create(): in_dev=%p\n", in_dev);
360         ifname = net_dev->name;
361         pr_debug("iface_stat: create(): ifname=%p\n", ifname);
362         for (ifa = in_dev->ifa_list; ifa; ifa = ifa->ifa_next) {
363                 pr_debug("iface_stat: create(): for(): ifa=%p ifname=%p\n",
364                          ifa, ifname);
365                 pr_debug("iface_stat: create(): ifname=%s ifa_label=%s\n",
366                          ifname, ifa->ifa_label ? ifa->ifa_label : "(null)");
367                 if (ifa->ifa_label && !strcmp(ifname, ifa->ifa_label))
368                         break;
369         }
370
371         if (ifa) {
372                 ipaddr = ifa->ifa_local;
373         } else {
374                 pr_err("iface_stat: create(): dev %s has no matching IP\n",
375                        ifname);
376                 return;
377         }
378
379         entry = get_iface_stat(net_dev->name);
380         if (entry != NULL) {
381                 pr_debug("iface_stat: create(): dev %s entry=%p\n", ifname,
382                          entry);
383                 if (ipv4_is_loopback(ipaddr)) {
384                         entry->active = false;
385                         pr_debug("iface_stat: create(): disable tracking of "
386                                  "loopback dev %s\n", ifname);
387                 } else {
388                         entry->active = true;
389                         pr_debug("iface_stat: create(): enable tracking of "
390                                  "dev %s with ip=%pI4\n",
391                                  ifname, &ipaddr);
392                 }
393                 return;
394         } else if (ipv4_is_loopback(ipaddr)) {
395                 pr_debug("iface_stat: create(): ignore loopback dev %s"
396                          " ip=%pI4\n", ifname, &ipaddr);
397                 return;
398         }
399
400         new_iface = kmalloc(sizeof(*new_iface), GFP_KERNEL);
401         if (new_iface == NULL) {
402                 pr_err("iface_stat: create(): failed to alloc iface_stat\n");
403                 return;
404         }
405         memset(new_iface, 0, sizeof(*new_iface));
406         new_iface->ifname = kstrdup(ifname, GFP_KERNEL);
407         if (new_iface->ifname == NULL) {
408                 pr_err("iface_stat: create(): failed to alloc ifname\n");
409                 kfree(new_iface);
410                 return;
411         }
412         spin_lock_init(&new_iface->tag_stat_list_lock);
413
414         new_iface->active = true;
415
416         new_iface->tag_stat_tree = RB_ROOT;
417         spin_lock_irqsave(&iface_stat_list_lock, flags);
418         list_add(&new_iface->list, &iface_stat_list);
419         spin_unlock_irqrestore(&iface_stat_list_lock, flags);
420
421         proc_entry = proc_mkdir(ifname, iface_stat_procdir);
422         new_iface->proc_ptr = proc_entry;
423
424         /* TODO: make root access only */
425         create_proc_read_entry("tx_bytes", S_IRUGO, proc_entry,
426                         read_proc_u64, &new_iface->tx_bytes);
427         create_proc_read_entry("rx_bytes", S_IRUGO, proc_entry,
428                         read_proc_u64, &new_iface->rx_bytes);
429         create_proc_read_entry("tx_packets", S_IRUGO, proc_entry,
430                         read_proc_u64, &new_iface->tx_packets);
431         create_proc_read_entry("rx_packets", S_IRUGO, proc_entry,
432                         read_proc_u64, &new_iface->rx_packets);
433         create_proc_read_entry("active", S_IRUGO, proc_entry,
434                         read_proc_bool, &new_iface->active);
435
436         pr_debug("iface_stat: create(): done entry=%p dev=%s ip=%pI4\n",
437                  new_iface, ifname, &ipaddr);
438 }
439
440 static struct sock_tag *get_sock_stat_nl(const struct sock *sk)
441 {
442         pr_debug("xt_qtaguid: get_sock_stat_nl(sk=%p)\n", sk);
443         return sock_tag_tree_search(&sock_tag_tree, sk);
444 }
445
446 static struct sock_tag *get_sock_stat(const struct sock *sk)
447 {
448         unsigned long flags;
449         struct sock_tag *sock_tag_entry;
450         pr_debug("xt_qtaguid: get_sock_stat(sk=%p)\n", sk);
451         if (!sk)
452                 return NULL;
453         spin_lock_irqsave(&sock_tag_list_lock, flags);
454         sock_tag_entry = get_sock_stat_nl(sk);
455         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
456         return sock_tag_entry;
457 }
458
459 static void
460 data_counters_update(struct data_counters *dc,  enum ifs_tx_rx direction,
461                 int proto, int bytes)
462 {
463         switch (proto) {
464         case IPPROTO_TCP:
465                 dc_add_byte_packets(dc, direction, IFS_TCP, bytes, 1);
466                 break;
467         case IPPROTO_UDP:
468                 dc_add_byte_packets(dc, direction, IFS_UDP, bytes, 1);
469                 break;
470         case IPPROTO_IP:
471         default:
472                 dc_add_byte_packets(dc, direction, IFS_PROTO_OTHER, bytes, 1);
473                 break;
474         }
475 }
476
477
478 /*
479  * Update stats for the specified interface. Do nothing if the entry
480  * does not exist (when a device was never configured with an IP address).
481  * Called when an device is being unregistered.
482  */
483 void iface_stat_update(struct net_device *dev)
484 {
485         struct rtnl_link_stats64 dev_stats, *stats;
486         struct iface_stat *entry;
487         stats = dev_get_stats(dev, &dev_stats);
488         ASSERT_RTNL();
489
490         entry = get_iface_stat(dev->name);
491         if (entry == NULL) {
492                 pr_debug("iface_stat: dev %s monitor not found\n", dev->name);
493                 return;
494         }
495         if (entry->active) {
496                 entry->tx_bytes += stats->tx_bytes;
497                 entry->tx_packets += stats->tx_packets;
498                 entry->rx_bytes += stats->rx_bytes;
499                 entry->rx_packets += stats->rx_packets;
500                 entry->active = false;
501                 pr_debug("iface_stat: Updating stats for "
502                         "dev %s which went down\n", dev->name);
503         } else {
504                 pr_debug("iface_stat: Did not update stats for "
505                         "dev %s which went down\n", dev->name);
506         }
507 }
508
509
510 static void tag_stat_update(struct tag_stat *tag_entry,
511                         enum ifs_tx_rx direction, int proto, int bytes)
512 {
513         pr_debug("xt_qtaguid: tag_stat_update(tag=0x%llx (uid=%d) dir=%d "
514                 "proto=%d bytes=%d)\n",
515                 tag_entry->tag, get_uid_from_tag(tag_entry->tag), direction,
516                 proto, bytes);
517         data_counters_update(&tag_entry->counters, direction, proto, bytes);
518         if (tag_entry->parent_counters)
519                 data_counters_update(tag_entry->parent_counters, direction,
520                                 proto, bytes);
521 }
522
523
524 /* Create a new entry for tracking the specified {acct_tag,uid_tag} within
525  * the interface.
526  * iface_entry->tag_stat_list_lock should be held. */
527 static struct tag_stat *create_if_tag_stat(struct iface_stat *iface_entry,
528                                            tag_t tag)
529 {
530         struct tag_stat *new_tag_stat_entry = NULL;
531         pr_debug("iface_stat: create_if_tag_stat(): ife=%p tag=0x%llx"
532                  " (uid=%d)\n",
533                  iface_entry, tag, get_uid_from_tag(tag));
534         new_tag_stat_entry = kmalloc(sizeof(*new_tag_stat_entry), GFP_ATOMIC);
535         if (!new_tag_stat_entry) {
536                 pr_err("iface_stat: failed to alloc new tag entry\n");
537                 goto done;
538         }
539         memset(new_tag_stat_entry, 0, sizeof(*new_tag_stat_entry));
540         new_tag_stat_entry->tag = tag;
541         tag_stat_tree_insert(new_tag_stat_entry, &iface_entry->tag_stat_tree);
542 done:
543         return new_tag_stat_entry;
544 }
545
546 static struct iface_stat *get_iface_entry(const char *ifname)
547 {
548         struct iface_stat *iface_entry;
549         unsigned long flags;
550
551         /* Find the entry for tracking the specified tag within the interface */
552         if (ifname == NULL) {
553                 pr_info("iface_stat: NULL device name\n");
554                 return NULL;
555         }
556
557
558         /* Iterate over interfaces */
559         spin_lock_irqsave(&iface_stat_list_lock, flags);
560         list_for_each_entry(iface_entry, &iface_stat_list, list) {
561                 if (!strcmp(ifname, iface_entry->ifname))
562                         goto done;
563         }
564         iface_entry = NULL;
565 done:
566         spin_unlock_irqrestore(&iface_stat_list_lock, flags);
567         return iface_entry;
568 }
569
570 static void if_tag_stat_update(const char *ifname, uid_t uid,
571                                const struct sock *sk, enum ifs_tx_rx direction,
572                                int proto, int bytes)
573 {
574         struct tag_stat *tag_stat_entry;
575         tag_t tag, acct_tag;
576         tag_t uid_tag;
577         struct data_counters *uid_tag_counters;
578         struct sock_tag *sock_tag_entry;
579         struct iface_stat *iface_entry;
580         unsigned long flags;
581         struct tag_stat *new_tag_stat;
582         pr_debug("xt_qtaguid: if_tag_stat_update(ifname=%s "
583                 "uid=%d sk=%p dir=%d proto=%d bytes=%d)\n",
584                  ifname, uid, sk, direction, proto, bytes);
585
586
587         iface_entry = get_iface_entry(ifname);
588         if (!iface_entry) {
589                 pr_err("iface_stat: interface %s not found\n", ifname);
590                 return;
591         }
592         /* else { If the iface_entry becomes inactive, it is still ok
593          * to process the data. } */
594
595         pr_debug("iface_stat: stat_update() got entry=%p\n", iface_entry);
596
597         /* Look for a tagged sock.
598          * It will have an acct_uid. */
599         sock_tag_entry = get_sock_stat(sk);
600         if (sock_tag_entry) {
601                 tag = sock_tag_entry->tag;
602                 acct_tag = get_atag_from_tag(tag);
603                 uid_tag = get_utag_from_tag(tag);
604         } else {
605                 uid_tag = make_tag_from_uid(uid);
606                 acct_tag = 0;
607                 tag = combine_atag_with_uid(acct_tag, uid);
608         }
609         pr_debug("iface_stat: stat_update(): looking for tag=0x%llx (uid=%d)"
610                  " in ife=%p\n",
611                  tag, get_uid_from_tag(tag), iface_entry);
612         /* Loop over tag list under this interface for {acct_tag,uid_tag} */
613         spin_lock_irqsave(&iface_entry->tag_stat_list_lock, flags);
614
615         tag_stat_entry = tag_stat_tree_search(&iface_entry->tag_stat_tree,
616                                               tag);
617         if (tag_stat_entry) {
618                 /* Updating the {acct_tag, uid_tag} entry handles both stats:
619                  * {0, uid_tag} will also get updated. */
620                 tag_stat_update(tag_stat_entry, direction, proto, bytes);
621                 spin_unlock_irqrestore(&iface_entry->tag_stat_list_lock, flags);
622                 return;
623         }
624
625         /* Loop over tag list under this interface for {0,uid_tag} */
626         tag_stat_entry = tag_stat_tree_search(&iface_entry->tag_stat_tree,
627                                               uid_tag);
628         if (!tag_stat_entry) {
629                 /* Here: the base uid_tag did not exist */
630                 /*
631                  * No parent counters. So
632                  *  - No {0, uid_tag} stats and no {acc_tag, uid_tag} stats.
633                  */
634                 new_tag_stat = create_if_tag_stat(iface_entry, uid_tag);
635                 uid_tag_counters = &new_tag_stat->counters;
636         } else {
637                 uid_tag_counters = &tag_stat_entry->counters;
638         }
639
640         if (acct_tag) {
641                 new_tag_stat = create_if_tag_stat(iface_entry, tag);
642                 new_tag_stat->parent_counters = uid_tag_counters;
643         }
644         spin_unlock_irqrestore(&iface_entry->tag_stat_list_lock, flags);
645         tag_stat_update(new_tag_stat, direction, proto, bytes);
646 }
647
648 static int iface_netdev_event_handler(struct notifier_block *nb,
649                                       unsigned long event, void *ptr) {
650         struct net_device *dev = ptr;
651
652         pr_debug("iface_stat: netdev_event(): ev=0x%lx netdev=%p->name=%s\n",
653                  event, dev, dev ? dev->name : "");
654
655         switch (event) {
656         case NETDEV_UP:
657         case NETDEV_REBOOT:
658         case NETDEV_CHANGE:
659         case NETDEV_REGISTER:  /* Most likely no IP */
660         case NETDEV_CHANGEADDR:  /* MAC addr change */
661         case NETDEV_CHANGENAME:
662         case NETDEV_FEAT_CHANGE:  /* Might be usefull when cell type changes */
663                 iface_stat_create(dev);
664                 break;
665         case NETDEV_UNREGISTER:
666                 iface_stat_update(dev);
667                 break;
668         }
669         return NOTIFY_DONE;
670 }
671
672 static int iface_inetaddr_event_handler(struct notifier_block *nb,
673                                         unsigned long event, void *ptr) {
674
675         struct in_ifaddr *ifa = ptr;
676         struct in_device *in_dev = ifa->ifa_dev;
677         struct net_device *dev = in_dev->dev;
678
679         pr_debug("iface_stat: inetaddr_event(): ev=0x%lx netdev=%p->name=%s\n",
680                  event, dev, dev ? dev->name : "");
681
682         switch (event) {
683         case NETDEV_UP:
684                 iface_stat_create(dev);
685                 break;
686         }
687         return NOTIFY_DONE;
688 }
689
690 static struct notifier_block iface_netdev_notifier_blk = {
691         .notifier_call = iface_netdev_event_handler,
692 };
693
694 static struct notifier_block iface_inetaddr_notifier_blk = {
695         .notifier_call = iface_inetaddr_event_handler,
696 };
697
698 static int __init iface_stat_init(struct proc_dir_entry *parent_procdir)
699 {
700         int err;
701
702         iface_stat_procdir = proc_mkdir(iface_stat_procdirname, parent_procdir);
703         if (!iface_stat_procdir) {
704                 pr_err("iface_stat: failed to create proc entry\n");
705                 err = -1;
706                 goto err;
707         }
708         err = register_netdevice_notifier(&iface_netdev_notifier_blk);
709         if (err) {
710                 pr_err("iface_stat: failed to register dev event handler\n");
711                 goto err_unreg_nd;
712         }
713         err = register_inetaddr_notifier(&iface_inetaddr_notifier_blk);
714         if (err) {
715                 pr_err("iface_stat: failed to register dev event handler\n");
716                 goto err_zap_entry;
717         }
718         return 0;
719
720 err_unreg_nd:
721         unregister_netdevice_notifier(&iface_netdev_notifier_blk);
722 err_zap_entry:
723         remove_proc_entry(iface_stat_procdirname, parent_procdir);
724 err:
725         return err;
726 }
727
728 static struct sock *qtaguid_find_sk(const struct sk_buff *skb,
729                                     struct xt_action_param *par)
730 {
731         struct sock *sk;
732         unsigned int hook_mask = (1 << par->hooknum);
733
734         pr_debug("xt_qtaguid: find_sk(skb=%p) hooknum=%d family=%d\n", skb,
735                  par->hooknum, par->family);
736
737         /* Let's not abuse the the xt_socket_get*_sk(), or else it will
738          * return garbage SKs. */
739         if (!(hook_mask & XT_SOCKET_SUPPORTED_HOOKS))
740                 return NULL;
741
742         switch (par->family) {
743         case NFPROTO_IPV6:
744                 sk = xt_socket_get6_sk(skb, par);
745                 break;
746         case NFPROTO_IPV4:
747                 sk = xt_socket_get4_sk(skb, par);
748                 break;
749         default:
750                 return NULL;
751         }
752
753         /* Seems to be issues on the file ptr for TCP_TIME_WAIT SKs.
754          * http://kerneltrap.org/mailarchive/linux-netdev/2010/10/21/6287959
755          * Not fixed in 3.0-r3 :(
756          */
757         if (sk) {
758                 pr_debug("xt_qtaguid: %p->sk_proto=%u "
759                          "->sk_state=%d\n", sk, sk->sk_protocol, sk->sk_state);
760                 if (sk->sk_state  == TCP_TIME_WAIT) {
761                         xt_socket_put_sk(sk);
762                         sk = NULL;
763                 }
764         }
765         return sk;
766 }
767
768 static void account_for_uid(const struct sk_buff *skb,
769                             const struct sock *alternate_sk, uid_t uid,
770                             struct xt_action_param *par)
771 {
772         const struct net_device *el_dev;
773
774         if (!skb->dev) {
775                 pr_debug("xt_qtaguid[%d]: no skb->dev\n", par->hooknum);
776                 el_dev = par->in ? : par->out;
777         } else {
778                 const struct net_device *other_dev;
779                 el_dev = skb->dev;
780                 other_dev = par->in ? : par->out;
781                 if (el_dev != other_dev) {
782                         pr_debug("xt_qtaguid[%d]: skb->dev=%p %s vs "
783                                 "par->(in/out)=%p %s\n",
784                                 par->hooknum, el_dev, el_dev->name, other_dev,
785                                 other_dev->name);
786                 }
787         }
788
789         if (unlikely(!el_dev)) {
790                 pr_info("xt_qtaguid[%d]: no par->in/out?!!\n", par->hooknum);
791         } else if (unlikely(!el_dev->name)) {
792                 pr_info("xt_qtaguid[%d]: no dev->name?!!\n", par->hooknum);
793         } else {
794                 pr_debug("xt_qtaguid[%d]: dev name=%s type=%d\n",
795                         par->hooknum,
796                         el_dev->name,
797                         el_dev->type);
798
799                 if_tag_stat_update(el_dev->name, uid,
800                                 skb->sk ? skb->sk : alternate_sk,
801                                 par->in ? IFS_RX : IFS_TX,
802                                 ip_hdr(skb)->protocol, skb->len);
803         }
804 }
805
806 static bool qtaguid_mt(const struct sk_buff *skb, struct xt_action_param *par)
807 {
808         const struct xt_qtaguid_match_info *info = par->matchinfo;
809         const struct file *filp;
810         bool got_sock = false;
811         struct sock *sk;
812         uid_t sock_uid;
813         bool res;
814         pr_debug("xt_qtaguid[%d]: entered skb=%p par->in=%p/out=%p fam=%d\n",
815                  par->hooknum, skb, par->in, par->out, par->family);
816         if (skb == NULL) {
817                 res = (info->match ^ info->invert) == 0;
818                 goto ret_res;
819         }
820
821         sk = skb->sk;
822
823         if (sk == NULL) {
824                 /*  A missing sk->sk_socket happens when packets are in-flight
825                  * and the matching socket is already closed and gone.
826                  */
827                 sk = qtaguid_find_sk(skb, par);
828                 /* If we got the socket from the find_sk(), we will need to put
829                  * it back, as nf_tproxy_get_sock_v4() got it. */
830                 got_sock = sk;
831         }
832         pr_debug("xt_qtaguid[%d]: sk=%p got_sock=%d proto=%d\n",
833                 par->hooknum, sk, got_sock, ip_hdr(skb)->protocol);
834         if (sk != NULL) {
835                 pr_debug("xt_qtaguid[%d]: sk=%p->sk_socket=%p->file=%p\n",
836                         par->hooknum, sk, sk->sk_socket,
837                         sk->sk_socket ? sk->sk_socket->file : (void *)-1LL);
838                 filp = sk->sk_socket ? sk->sk_socket->file : NULL;
839                 pr_debug("xt_qtaguid[%d]: filp...uid=%d\n",
840                         par->hooknum, filp ? filp->f_cred->fsuid : -1);
841         }
842
843         if (sk == NULL || sk->sk_socket == NULL) {
844                 /* Here, the qtaguid_find_sk() using connection tracking
845                  * couldn't find the owner, so for now we just count them
846                  * against the system. */
847                 /* TODO: unhack how to force just accounting.
848                  * For now we only do iface stats when the uid-owner is not
849                  * requested */
850                 if (!(info->match & XT_QTAGUID_UID))
851                         account_for_uid(skb, sk, 0, par);
852                 pr_debug("xt_qtaguid[%d]: leaving (sk?sk->sk_socket)=%p\n",
853                         par->hooknum,
854                         sk ? sk->sk_socket : NULL);
855                 res =  (info->match ^ info->invert) == 0;
856                 goto put_sock_ret_res;
857         } else if (info->match & info->invert & XT_QTAGUID_SOCKET) {
858                 res = false;
859                 goto put_sock_ret_res;
860         }
861         filp = sk->sk_socket->file;
862         if (filp == NULL) {
863                 pr_debug("xt_qtaguid[%d]: leaving filp=NULL\n", par->hooknum);
864                 res = ((info->match ^ info->invert) &
865                         (XT_QTAGUID_UID | XT_QTAGUID_GID)) == 0;
866                 goto put_sock_ret_res;
867         }
868         sock_uid = filp->f_cred->fsuid;
869         /* TODO: unhack how to force just accounting.
870          * For now we only do iface stats when the uid-owner is not requested */
871         if (!(info->match & XT_QTAGUID_UID))
872                 account_for_uid(skb, sk, sock_uid, par);
873
874         /* The following two tests fail the match when:
875          *    id not in range AND no inverted condition requested
876          * or id     in range AND    inverted condition requested
877          * Thus (!a && b) || (a && !b) == a ^ b
878          */
879         if (info->match & XT_QTAGUID_UID)
880                 if ((filp->f_cred->fsuid >= info->uid_min &&
881                      filp->f_cred->fsuid <= info->uid_max) ^
882                     !(info->invert & XT_QTAGUID_UID)) {
883                         pr_debug("xt_qtaguid[%d]: leaving uid not matching\n",
884                                  par->hooknum);
885                         res = false;
886                         goto put_sock_ret_res;
887                 }
888         if (info->match & XT_QTAGUID_GID)
889                 if ((filp->f_cred->fsgid >= info->gid_min &&
890                                 filp->f_cred->fsgid <= info->gid_max) ^
891                         !(info->invert & XT_QTAGUID_GID)) {
892                         pr_debug("xt_qtaguid[%d]: leaving gid not matching\n",
893                                 par->hooknum);
894                         res = false;
895                         goto put_sock_ret_res;
896                 }
897
898         pr_debug("xt_qtaguid[%d]: leaving matched\n", par->hooknum);
899         res = true;
900
901 put_sock_ret_res:
902         if (got_sock)
903                 xt_socket_put_sk(sk);
904 ret_res:
905         pr_debug("xt_qtaguid[%d]: left %d\n", par->hooknum, res);
906         return res;
907 }
908
909 /* TODO: Use Documentation/filesystems/seq_file.txt? */
910 static int qtaguid_ctrl_proc_read(char *page, char **start, off_t off,
911                                 int count, int *eof, void *data)
912 {
913         char *out = page + off;
914         int len;
915         unsigned long flags;
916         uid_t uid;
917         struct sock_tag *sock_tag_entry;
918         struct rb_node *node;
919         pr_debug("xt_qtaguid:proc ctrl page=%p off=%ld count=%d eof=%p\n",
920                 page, off, count, eof);
921
922         *eof = 0;
923         spin_lock_irqsave(&sock_tag_list_lock, flags);
924         for (node = rb_first(&sock_tag_tree);
925              node;
926              node = rb_next(node)) {
927                 sock_tag_entry =  rb_entry(node, struct sock_tag, node);
928                 uid = get_uid_from_tag(sock_tag_entry->tag);
929                 pr_debug("xt_qtaguid: proc_read(): sk=%p tag=0x%llx (uid=%d)\n",
930                         sock_tag_entry->sk,
931                         sock_tag_entry->tag,
932                         uid);
933                 len = snprintf(out, count, "sock=%p tag=0x%llx (uid=%u)\n",
934                         sock_tag_entry->sk, sock_tag_entry->tag, uid);
935                 out += len;
936                 count -= len;
937                 if (!count) {
938                         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
939                         return out - page;
940                 }
941         }
942         *eof = 1;
943         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
944         return out - page;
945 }
946
947 static int qtaguid_ctrl_parse(const char *input, int count)
948 {
949         char cmd;
950         int sock_fd = 0;
951         uid_t uid = 0;
952         tag_t acct_tag = 0;
953         struct socket *el_socket;
954         int res, argc;
955         struct sock_tag *sock_tag_entry;
956         unsigned long flags;
957
958         pr_debug("xt_qtaguid: ctrl(%s): entered\n", input);
959         /* Unassigned args will get defaulted later. */
960         /* TODO: get acct_tag_str, keep a list of available tags for the
961          * uid, use num as acct_tag. */
962         argc = sscanf(input, "%c %d %llu %u", &cmd, &sock_fd, &acct_tag, &uid);
963         pr_debug("xt_qtaguid: ctrl(%s): argc=%d cmd=%c sock_fd=%d "
964                 "acct_tag=0x%llx uid=%u\n", input, argc, cmd, sock_fd,
965                 acct_tag, uid);
966
967         /* Collect params for commands */
968         switch (cmd) {
969         case 't':
970         case 'u':
971                 if (argc < 2) {
972                         res = -EINVAL;
973                         goto err;
974                 }
975                 el_socket = sockfd_lookup(sock_fd, &res);
976                 if (!el_socket) {
977                         pr_info("xt_qtaguid: ctrl(%s): failed to lookup"
978                                 " sock_fd=%d err=%d\n", input, sock_fd, res);
979                         goto err;
980                 }
981                 spin_lock_irqsave(&sock_tag_list_lock, flags);
982                 /* TODO: optim: pass in the current_fsuid() to do lookups
983                  * as look ups will always be initiated form the same uid. */
984                 sock_tag_entry = get_sock_stat_nl(el_socket->sk);
985                 if (!sock_tag_entry)
986                         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
987                 break;
988         default:
989                 res = -EINVAL;
990                 goto err;
991         }
992
993         /* Process commands */
994         switch (cmd) {
995
996         case 't':
997                 if (argc < 2) {
998                         res = -EINVAL;
999                         goto err_unlock;
1000                 }
1001                 if (argc < 3) {
1002                         acct_tag = 0;
1003                 } else if (!valid_atag(acct_tag)) {
1004                         res = -EINVAL;
1005                         goto err_unlock;
1006                 }
1007                 if (argc < 4)
1008                         uid = current_fsuid();
1009                 if (!sock_tag_entry) {
1010                         sock_tag_entry = kmalloc(sizeof(*sock_tag_entry),
1011                                                 GFP_KERNEL);
1012                         if (!sock_tag_entry) {
1013                                 res = -ENOMEM;
1014                                 goto err;
1015                         }
1016                         memset(sock_tag_entry, 0, sizeof(*sock_tag_entry));
1017                         sock_tag_entry->sk = el_socket->sk;
1018                         /* TODO: check that uid==current_fsuid() except
1019                          * for special uid/gid. */
1020                         sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
1021                                                                 uid);
1022                         spin_lock_irqsave(&sock_tag_list_lock, flags);
1023                         sock_tag_tree_insert(sock_tag_entry, &sock_tag_tree);
1024                         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
1025                 } else {
1026                         /* Just update the acct_tag portion. */
1027                         uid_t orig_uid = get_uid_from_tag(sock_tag_entry->tag);
1028                         sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
1029                                                                 orig_uid);
1030                 }
1031                 pr_debug("xt_qtaguid: tag: sock_tag_entry->sk=%p "
1032                         "...->tag=0x%llx (uid=%u)\n",
1033                         sock_tag_entry->sk, sock_tag_entry->tag,
1034                         get_uid_from_tag(sock_tag_entry->tag));
1035                 break;
1036
1037         case 'u':
1038                 if (!sock_tag_entry) {
1039                         res = -EINVAL;
1040                         goto err;
1041                 }
1042                 /* TODO: check that the uid==current_fsuid()
1043                  * except for special uid/gid. */
1044                 rb_erase(&sock_tag_entry->node, &sock_tag_tree);
1045                 spin_unlock_irqrestore(&sock_tag_list_lock, flags);
1046                 kfree(sock_tag_entry);
1047                 break;
1048         }
1049
1050         /* All of the input has been processed */
1051         res = count;
1052         goto ok;
1053
1054 err_unlock:
1055         if (!sock_tag_entry)
1056                 spin_unlock_irqrestore(&sock_tag_list_lock, flags);
1057 err:
1058 ok:
1059         pr_debug("xt_qtaguid: ctrl(%s): res=%d\n", input, res);
1060         return res;
1061 }
1062
1063 #define MAX_QTAGUID_CTRL_INPUT_LEN 255
1064 static int qtaguid_ctrl_proc_write(struct file *file, const char __user *buffer,
1065                         unsigned long count, void *data)
1066 {
1067         char input_buf[MAX_QTAGUID_CTRL_INPUT_LEN];
1068
1069         if (count >= MAX_QTAGUID_CTRL_INPUT_LEN)
1070                 return -EINVAL;
1071
1072         if (copy_from_user(input_buf, buffer, count))
1073                 return -EFAULT;
1074
1075         input_buf[count] = '\0';
1076         return qtaguid_ctrl_parse(input_buf, count);
1077 }
1078
1079 /*
1080  * Procfs reader to get all tag stats using style "1)" as described in
1081  * fs/proc/generic.c
1082  * Groups all protocols tx/rx bytes.
1083  */
1084 static int qtaguid_stats_proc_read(char *page, char **num_items_returned,
1085                                 off_t items_to_skip, int char_count, int *eof,
1086                                 void *data)
1087 {
1088         char *outp = page;
1089         int len;
1090         unsigned long flags, flags2;
1091         struct iface_stat *iface_entry;
1092         struct tag_stat *ts_entry;
1093         int item_index = 0;
1094
1095         /* TODO: make root access only */
1096
1097         pr_debug("xt_qtaguid:proc stats page=%p *num_items_returned=%p off=%ld "
1098                 "char_count=%d *eof=%d\n", page, *num_items_returned,
1099                 items_to_skip, char_count, *eof);
1100
1101         if (*eof)
1102                 return 0;
1103
1104         if (!items_to_skip) {
1105                 /* The idx is there to help debug when things go belly up. */
1106                 len = snprintf(outp, char_count,
1107                         "idx iface acct_tag_hex uid_tag_int rx_bytes "
1108                         "tx_bytes\n");
1109                 /* Don't advance the outp unless the whole line was printed */
1110                 if (len >= char_count) {
1111                         *outp = '\0';
1112                         return outp - page;
1113                 }
1114                 outp += len;
1115                 char_count -= len;
1116         }
1117
1118         spin_lock_irqsave(&iface_stat_list_lock, flags);
1119         list_for_each_entry(iface_entry, &iface_stat_list, list) {
1120                 struct rb_node *node;
1121                 spin_lock_irqsave(&iface_entry->tag_stat_list_lock, flags2);
1122                 for (node = rb_first(&iface_entry->tag_stat_tree);
1123                      node;
1124                      node = rb_next(node)) {
1125                         ts_entry =  rb_entry(node, struct tag_stat, node);
1126                         if (item_index++ < items_to_skip)
1127                                 continue;
1128                         len = snprintf(outp, char_count,
1129                                        "%d %s 0x%llx %u %llu %llu\n",
1130                                        item_index,
1131                                        iface_entry->ifname,
1132                                        get_atag_from_tag(ts_entry->tag),
1133                                        get_uid_from_tag(ts_entry->tag),
1134                                        dc_sum_bytes(&ts_entry->counters,
1135                                                     IFS_RX),
1136                                        dc_sum_bytes(&ts_entry->counters,
1137                                                     IFS_TX));
1138                         if (len >= char_count) {
1139                                 spin_unlock_irqrestore(
1140                                         &iface_entry->tag_stat_list_lock,
1141                                         flags2);
1142                                 spin_unlock_irqrestore(
1143                                         &iface_stat_list_lock, flags);
1144                                 *outp = '\0';
1145                                 return outp - page;
1146                         }
1147                         outp += len;
1148                         char_count -= len;
1149                         (*(int *)num_items_returned)++;
1150                 }
1151                 spin_unlock_irqrestore(&iface_entry->tag_stat_list_lock,
1152                                 flags2);
1153         }
1154         spin_unlock_irqrestore(&iface_stat_list_lock, flags);
1155
1156         *eof = 1;
1157         return outp - page;
1158 }
1159
1160 /*------------------------------------------*/
1161 static const char *module_procdirname = "xt_qtaguid";
1162 static struct proc_dir_entry *xt_qtaguid_procdir;
1163 static struct proc_dir_entry *xt_qtaguid_ctrl_file;
1164 static struct proc_dir_entry *xt_qtaguid_stats_file;
1165
1166 static int __init qtaguid_proc_register(struct proc_dir_entry **res_procdir)
1167 {
1168         int ret;
1169         *res_procdir = proc_mkdir(module_procdirname, init_net.proc_net);
1170         if (!*res_procdir) {
1171                 pr_err("xt_qtaguid: failed to create proc/.../xt_qtaguid\n");
1172                 ret = -ENOMEM;
1173                 goto no_dir;
1174         }
1175
1176         xt_qtaguid_ctrl_file = create_proc_entry("ctrl", 0666,
1177                                                 *res_procdir);
1178         if (!xt_qtaguid_ctrl_file) {
1179                 pr_err("xt_qtaguid: failed to create xt_qtaguid/ctrl "
1180                         " file\n");
1181                 ret = -ENOMEM;
1182                 goto no_ctrl_entry;
1183         }
1184         xt_qtaguid_ctrl_file->read_proc = qtaguid_ctrl_proc_read;
1185         xt_qtaguid_ctrl_file->write_proc = qtaguid_ctrl_proc_write;
1186
1187         xt_qtaguid_stats_file = create_proc_entry("stats", 0666,
1188                                                 *res_procdir);
1189         if (!xt_qtaguid_stats_file) {
1190                 pr_err("xt_qtaguid: failed to create xt_qtaguid/stats "
1191                         "file\n");
1192                 ret = -ENOMEM;
1193                 goto no_stats_entry;
1194         }
1195         /*
1196          * TODO: add extra read_proc for full stats with protocol
1197          * breakout
1198          */
1199         xt_qtaguid_stats_file->read_proc = qtaguid_stats_proc_read;
1200         /*
1201          * TODO: add support counter hacking
1202          * xt_qtaguid_stats_file->write_proc = qtaguid_stats_proc_write;
1203          */
1204         return 0;
1205
1206 no_stats_entry:
1207         remove_proc_entry("ctrl", *res_procdir);
1208 no_ctrl_entry:
1209         remove_proc_entry("xt_qtaguid", NULL);
1210 no_dir:
1211         return ret;
1212 }
1213
1214 static struct xt_match qtaguid_mt_reg __read_mostly = {
1215         /*
1216          * This module masquerades as the "owner" module so that iptables
1217          * tools can deal with it.
1218          */
1219         .name       = "owner",
1220         .revision   = 1,
1221         .family     = NFPROTO_UNSPEC,
1222         .match      = qtaguid_mt,
1223         .matchsize  = sizeof(struct xt_qtaguid_match_info),
1224         .me         = THIS_MODULE,
1225 };
1226
1227 static int __init qtaguid_mt_init(void)
1228 {
1229         if (qtaguid_proc_register(&xt_qtaguid_procdir)
1230             || iface_stat_init(xt_qtaguid_procdir)
1231             || xt_register_match(&qtaguid_mt_reg))
1232                 return -1;
1233         return 0;
1234 }
1235
1236 /* TODO: allow unloading of the module.
1237  * For now stats are permanent.
1238  * Kconfig forces'y/n' and never an 'm'.
1239  */
1240
1241 module_init(qtaguid_mt_init);
1242 MODULE_AUTHOR("jpa <jpa@google.com>");
1243 MODULE_DESCRIPTION("Xtables: socket owner+tag matching and associated stats");
1244 MODULE_LICENSE("GPL");
1245 MODULE_ALIAS("ipt_owner");
1246 MODULE_ALIAS("ip6t_owner");
1247 MODULE_ALIAS("ipt_qtaguid");
1248 MODULE_ALIAS("ip6t_qtaguid");