a764552cc3871345e81029fedfd769cd8fc01e91
[linux-2.6.git] / net / netfilter / xt_qtaguid.c
1 /*
2  * Kernel iptables module to track stats for packets based on user tags.
3  *
4  * (C) 2011 Google, Inc
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  */
10
11 /* TODO: support ipv6 */
12
13 #include <linux/file.h>
14 #include <linux/inetdevice.h>
15 #include <linux/module.h>
16 #include <linux/netfilter/x_tables.h>
17 #include <linux/netfilter/xt_qtaguid.h>
18 #include <linux/skbuff.h>
19 #include <linux/workqueue.h>
20 #include <net/sock.h>
21 #include <net/tcp.h>
22 #include <net/udp.h>
23
24 #include <linux/netfilter/xt_socket.h>
25
26
27 /*---------------------------------------------------------------------------*/
28 /*
29  * Tags:
30  *
31  * They represent what the data usage counters will be tracked against.
32  * By default a tag is just based on the UID.
33  * The UID is used as the base for policying, and can not be ignored.
34  * So a tag will always at least represent a UID (uid_tag).
35  *
36  * A tag can be augmented with an "accounting tag" which is associated
37  * with a UID.
38  * User space can set the acct_tag portion of the tag which is then used
39  * with sockets: all data belong to that socket will be counted against the
40  * tag. The policing is then based on the tag's uid_tag portion,
41  * and stats are collected for the acct_tag portion seperately.
42  *
43  * There could be
44  * a:  {acct_tag=1, uid_tag=10003}
45  * b:  {acct_tag=2, uid_tag=10003}
46  * c:  {acct_tag=3, uid_tag=10003}
47  * d:  {acct_tag=0, uid_tag=10003}
48  * (a, b, and c represent tags associated with specific sockets.
49  * d is for the totals for that uid, including all untagged traffic.
50  * Typically d is used with policing/quota rules.
51  *
52  * We want tag_t big enough to distinguish uid_t and acct_tag.
53  * It might become a struct if needed.
54  * Nothing should be using it as an int.
55  */
56 typedef uint64_t tag_t;  /* Only used via accessors */
57
58 static const char *iface_stat_procdirname = "iface_stat";
59 static struct proc_dir_entry *iface_stat_procdir;
60
61 enum ifs_tx_rx {
62         IFS_TX,
63         IFS_RX,
64         IFS_MAX_DIRECTIONS
65 };
66
67 /* For now, TCP, UDP, the rest */
68 enum ifs_proto {
69         IFS_TCP,
70         IFS_UDP,
71         IFS_PROTO_OTHER,
72         IFS_MAX_PROTOS
73 };
74
75 struct byte_packet_counters {
76         uint64_t bytes;
77         uint64_t packets;
78 };
79
80 struct data_counters {
81         struct byte_packet_counters bpc[IFS_MAX_DIRECTIONS][IFS_MAX_PROTOS];
82 };
83
84 struct tag_stat {
85         struct rb_node node;
86         tag_t tag;
87
88         struct data_counters counters;
89         /* If this tag is acct_tag based, we need to count against the
90          * matching parent uid_tag. */
91         struct data_counters *parent_counters;
92         struct proc_dir_entry *proc_ptr;
93 };
94
95 static LIST_HEAD(iface_stat_list);
96 static DEFINE_SPINLOCK(iface_stat_list_lock);
97
98 struct iface_stat {
99         struct list_head list;
100         char *ifname;
101         uint64_t rx_bytes;
102         uint64_t rx_packets;
103         uint64_t tx_bytes;
104         uint64_t tx_packets;
105         bool active;
106         struct proc_dir_entry *proc_ptr;
107
108         struct rb_root tag_stat_tree;
109         spinlock_t tag_stat_list_lock;
110 };
111
112
113 static struct rb_root sock_tag_tree = RB_ROOT;
114 static DEFINE_SPINLOCK(sock_tag_list_lock);
115
116 /*
117  * Track tag that this socket is transferring data for, and not necesseraly
118  * the uid that owns the socket.
119  * This is the tag against which tag_stat.counters will be billed.
120  */
121 struct sock_tag {
122         struct rb_node node;
123         struct sock *sk;
124         tag_t tag;
125 };
126
127 static bool qtaguid_mt(const struct sk_buff *skb, struct xt_action_param *par);
128
129 /*----------------------------------------------*/
130 static inline int tag_compare(tag_t t1, tag_t t2)
131 {
132         return t1 < t2 ? -1 : t1 == t2 ? 0 : 1;
133 }
134
135
136 static inline tag_t combine_atag_with_uid(tag_t acct_tag, uid_t uid)
137 {
138         return acct_tag | uid;
139 }
140 static inline tag_t make_tag_from_uid(uid_t uid)
141 {
142         return uid;
143 }
144 static inline uid_t get_uid_from_tag(tag_t tag)
145 {
146         return tag & 0xFFFFFFFFULL;
147 }
148 static inline tag_t get_utag_from_tag(tag_t tag)
149 {
150         return tag & 0xFFFFFFFFULL;
151 }
152 static inline tag_t get_atag_from_tag(tag_t tag)
153 {
154         return tag & ~0xFFFFFFFFULL;
155 }
156
157 static inline bool valid_atag(tag_t tag)
158 {
159         return !(tag & 0xFFFFFFFFULL);
160 }
161
162 static inline void dc_add_byte_packets(struct data_counters *counters,
163                                   enum ifs_tx_rx direction,
164                                   enum ifs_proto ifs_proto,
165                                   int bytes,
166                                   int packets)
167 {
168         counters->bpc[direction][ifs_proto].bytes += bytes;
169         counters->bpc[direction][ifs_proto].packets += packets;
170 }
171
172 static inline uint64_t dc_sum_bytes(struct data_counters *counters,
173                                     enum ifs_tx_rx direction)
174 {
175         return counters->bpc[direction][IFS_TCP].bytes
176                 + counters->bpc[direction][IFS_UDP].bytes
177                 + counters->bpc[direction][IFS_PROTO_OTHER].bytes;
178 }
179
180 static struct tag_stat *tag_stat_tree_search(struct rb_root *root, tag_t tag)
181 {
182         struct rb_node *node = root->rb_node;
183
184         while (node) {
185                 struct tag_stat *data = rb_entry(node, struct tag_stat, node);
186                 int result = tag_compare(tag, data->tag);
187                 pr_debug("qtaguid: tag_stat_tree_search(): tag=0x%llx"
188                          " (uid=%d)\n",
189                          data->tag,
190                          get_uid_from_tag(data->tag));
191
192                 if (result < 0)
193                         node = node->rb_left;
194                 else if (result > 0)
195                         node = node->rb_right;
196                 else
197                         return data;
198         }
199         return NULL;
200 }
201
202 static void tag_stat_tree_insert(struct tag_stat *data, struct rb_root *root)
203 {
204         struct rb_node **new = &(root->rb_node), *parent = NULL;
205
206         /* Figure out where to put new node */
207         while (*new) {
208                 struct tag_stat *this = rb_entry(*new, struct tag_stat,
209                                                  node);
210                 int result = tag_compare(data->tag, this->tag);
211                 pr_debug("qtaguid: tag_stat_tree_insert(): tag=0x%llx"
212                          " (uid=%d)\n",
213                          this->tag,
214                          get_uid_from_tag(this->tag));
215                 parent = *new;
216                 if (result < 0)
217                         new = &((*new)->rb_left);
218                 else if (result > 0)
219                         new = &((*new)->rb_right);
220                 else
221                         BUG();
222         }
223
224         /* Add new node and rebalance tree. */
225         rb_link_node(&data->node, parent, new);
226         rb_insert_color(&data->node, root);
227 }
228
229 static struct sock_tag *sock_tag_tree_search(struct rb_root *root,
230                                              const struct sock *sk)
231 {
232         struct rb_node *node = root->rb_node;
233
234         while (node) {
235                 struct sock_tag *data = rb_entry(node, struct sock_tag, node);
236                 ptrdiff_t result = sk - data->sk;
237                 if (result < 0)
238                         node = node->rb_left;
239                 else if (result > 0)
240                         node = node->rb_right;
241                 else
242                         return data;
243         }
244         return NULL;
245 }
246
247 static void sock_tag_tree_insert(struct sock_tag *data, struct rb_root *root)
248 {
249         struct rb_node **new = &(root->rb_node), *parent = NULL;
250
251         /* Figure out where to put new node */
252         while (*new) {
253                 struct sock_tag *this = rb_entry(*new, struct sock_tag, node);
254                 ptrdiff_t result = data->sk - this->sk;
255                 parent = *new;
256                 if (result < 0)
257                         new = &((*new)->rb_left);
258                 else if (result > 0)
259                         new = &((*new)->rb_right);
260                 else
261                         BUG();
262         }
263
264         /* Add new node and rebalance tree. */
265         rb_link_node(&data->node, parent, new);
266         rb_insert_color(&data->node, root);
267 }
268
269 static int read_proc_u64(char *page, char **start, off_t off,
270                         int count, int *eof, void *data)
271 {
272         int len;
273         uint64_t value;
274         char *p = page;
275         uint64_t *iface_entry = data;
276         if (!data)
277                 return 0;
278
279         value = *iface_entry;
280         p += sprintf(p, "%llu\n", value);
281         len = (p - page) - off;
282         *eof = (len <= count) ? 1 : 0;
283         *start = page + off;
284         return len;
285 }
286
287 static int read_proc_bool(char *page, char **start, off_t off,
288                         int count, int *eof, void *data)
289 {
290         int len;
291         bool value;
292         char *p = page;
293         bool *bool_entry = data;
294         if (!data)
295                 return 0;
296
297         value = *bool_entry;
298         p += sprintf(p, "%u\n", value);
299         len = (p - page) - off;
300         *eof = (len <= count) ? 1 : 0;
301         *start = page + off;
302         return len;
303 }
304
305 /* Find the entry for tracking the specified interface. */
306 static struct iface_stat *get_iface_stat(const char *ifname)
307 {
308         unsigned long flags;
309         struct iface_stat *iface_entry;
310         if (!ifname)
311                 return NULL;
312
313         spin_lock_irqsave(&iface_stat_list_lock, flags);
314         list_for_each_entry(iface_entry, &iface_stat_list, list) {
315                 if (!strcmp(iface_entry->ifname, ifname))
316                         goto done;
317         }
318         iface_entry = NULL;
319 done:
320         spin_unlock_irqrestore(&iface_stat_list_lock, flags);
321         return iface_entry;
322 }
323
324 /*
325  * Create a new entry for tracking the specified interface.
326  * Do nothing if the entry already exists.
327  * Called when an interface is configured with a valid IP address.
328  */
329 void iface_stat_create(const struct net_device *net_dev)
330 {
331         struct in_device *in_dev;
332         unsigned long flags;
333         struct iface_stat *new_iface;
334         struct proc_dir_entry *proc_entry;
335         const char *ifname;
336         struct iface_stat *entry;
337         __be32 ipaddr = 0;
338         struct in_ifaddr *ifa = NULL;
339
340         ASSERT_RTNL(); /* No need for separate locking */
341
342         pr_debug("iface_stat: create(): netdev=%p->name=%s\n",
343                  net_dev, net_dev ? net_dev->name : "");
344         if (!net_dev) {
345                 pr_err("iface_stat: create(): no net dev!\n");
346                 return;
347         }
348
349         in_dev = __in_dev_get_rtnl(net_dev);
350         if (!in_dev) {
351                 pr_err("iface_stat: create(): no inet dev!\n");
352                 return;
353         }
354
355         pr_debug("iface_stat: create(): in_dev=%p\n", in_dev);
356         ifname = net_dev->name;
357         pr_debug("iface_stat: create(): ifname=%p\n", ifname);
358         for (ifa = in_dev->ifa_list; ifa; ifa = ifa->ifa_next) {
359                 pr_debug("iface_stat: create(): for(): ifa=%p ifname=%p\n",
360                          ifa, ifname);
361                 pr_debug("iface_stat: create(): ifname=%s ifa_label=%s\n",
362                          ifname, ifa->ifa_label ? ifa->ifa_label : "(null)");
363                 if (ifa->ifa_label && !strcmp(ifname, ifa->ifa_label))
364                         break;
365         }
366
367         if (ifa) {
368                 ipaddr = ifa->ifa_local;
369         } else {
370                 pr_err("iface_stat: create(): dev %s has no matching IP\n",
371                        ifname);
372                 return;
373         }
374
375         entry = get_iface_stat(net_dev->name);
376         if (entry != NULL) {
377                 pr_debug("iface_stat: create(): dev %s entry=%p\n", ifname,
378                          entry);
379                 if (ipv4_is_loopback(ipaddr)) {
380                         entry->active = false;
381                         pr_debug("iface_stat: create(): disable tracking of "
382                                  "loopback dev %s\n", ifname);
383                 } else {
384                         entry->active = true;
385                         pr_debug("iface_stat: create(): enable tracking of "
386                                  "dev %s with ip=%pI4\n",
387                                  ifname, &ipaddr);
388                 }
389                 return;
390         } else if (ipv4_is_loopback(ipaddr)) {
391                 pr_debug("iface_stat: create(): ignore loopback dev %s"
392                          " ip=%pI4\n", ifname, &ipaddr);
393                 return;
394         }
395
396         new_iface = kmalloc(sizeof(*new_iface), GFP_KERNEL);
397         if (new_iface == NULL) {
398                 pr_err("iface_stat: create(): failed to alloc iface_stat\n");
399                 return;
400         }
401         memset(new_iface, 0, sizeof(*new_iface));
402         new_iface->ifname = kstrdup(ifname, GFP_KERNEL);
403         if (new_iface->ifname == NULL) {
404                 pr_err("iface_stat: create(): failed to alloc ifname\n");
405                 kfree(new_iface);
406                 return;
407         }
408         spin_lock_init(&new_iface->tag_stat_list_lock);
409
410         new_iface->active = true;
411
412         new_iface->tag_stat_tree = RB_ROOT;
413         spin_lock_irqsave(&iface_stat_list_lock, flags);
414         list_add(&new_iface->list, &iface_stat_list);
415         spin_unlock_irqrestore(&iface_stat_list_lock, flags);
416
417         proc_entry = proc_mkdir(ifname, iface_stat_procdir);
418         new_iface->proc_ptr = proc_entry;
419
420         /* TODO: make root access only */
421         create_proc_read_entry("tx_bytes", S_IRUGO, proc_entry,
422                         read_proc_u64, &new_iface->tx_bytes);
423         create_proc_read_entry("rx_bytes", S_IRUGO, proc_entry,
424                         read_proc_u64, &new_iface->rx_bytes);
425         create_proc_read_entry("tx_packets", S_IRUGO, proc_entry,
426                         read_proc_u64, &new_iface->tx_packets);
427         create_proc_read_entry("rx_packets", S_IRUGO, proc_entry,
428                         read_proc_u64, &new_iface->rx_packets);
429         create_proc_read_entry("active", S_IRUGO, proc_entry,
430                         read_proc_bool, &new_iface->active);
431
432         pr_debug("iface_stat: create(): done entry=%p dev=%s ip=%pI4\n",
433                  new_iface, ifname, &ipaddr);
434 }
435
436 static struct sock_tag *get_sock_stat_nl(const struct sock *sk)
437 {
438         pr_debug("xt_qtaguid: get_sock_stat_nl(sk=%p)\n", sk);
439         return sock_tag_tree_search(&sock_tag_tree, sk);
440 }
441
442 static struct sock_tag *get_sock_stat(const struct sock *sk)
443 {
444         unsigned long flags;
445         struct sock_tag *sock_tag_entry;
446         pr_debug("xt_qtaguid: get_sock_stat(sk=%p)\n", sk);
447         if (!sk)
448                 return NULL;
449         spin_lock_irqsave(&sock_tag_list_lock, flags);
450         sock_tag_entry = get_sock_stat_nl(sk);
451         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
452         return sock_tag_entry;
453 }
454
455 static void
456 data_counters_update(struct data_counters *dc,  enum ifs_tx_rx direction,
457                 int proto, int bytes)
458 {
459         switch (proto) {
460         case IPPROTO_TCP:
461                 dc_add_byte_packets(dc, direction, IFS_TCP, bytes, 1);
462                 break;
463         case IPPROTO_UDP:
464                 dc_add_byte_packets(dc, direction, IFS_UDP, bytes, 1);
465                 break;
466         case IPPROTO_IP:
467         default:
468                 dc_add_byte_packets(dc, direction, IFS_PROTO_OTHER, bytes, 1);
469                 break;
470         }
471 }
472
473
474 /*
475  * Update stats for the specified interface. Do nothing if the entry
476  * does not exist (when a device was never configured with an IP address).
477  * Called when an device is being unregistered.
478  */
479 void iface_stat_update(struct net_device *dev)
480 {
481         struct rtnl_link_stats64 dev_stats, *stats;
482         struct iface_stat *entry;
483         stats = dev_get_stats(dev, &dev_stats);
484         ASSERT_RTNL();
485
486         entry = get_iface_stat(dev->name);
487         if (entry == NULL) {
488                 pr_debug("iface_stat: dev %s monitor not found\n", dev->name);
489                 return;
490         }
491         if (entry->active) {
492                 entry->tx_bytes += stats->tx_bytes;
493                 entry->tx_packets += stats->tx_packets;
494                 entry->rx_bytes += stats->rx_bytes;
495                 entry->rx_packets += stats->rx_packets;
496                 entry->active = false;
497                 pr_debug("iface_stat: Updating stats for "
498                         "dev %s which went down\n", dev->name);
499         } else {
500                 pr_debug("iface_stat: Did not update stats for "
501                         "dev %s which went down\n", dev->name);
502         }
503 }
504
505
506 static void tag_stat_update(struct tag_stat *tag_entry,
507                         enum ifs_tx_rx direction, int proto, int bytes)
508 {
509         pr_debug("xt_qtaguid: tag_stat_update(tag=0x%llx (uid=%d) dir=%d "
510                 "proto=%d bytes=%d)\n",
511                 tag_entry->tag, get_uid_from_tag(tag_entry->tag), direction,
512                 proto, bytes);
513         data_counters_update(&tag_entry->counters, direction, proto, bytes);
514         if (tag_entry->parent_counters)
515                 data_counters_update(tag_entry->parent_counters, direction,
516                                 proto, bytes);
517 }
518
519
520 /* Create a new entry for tracking the specified {acct_tag,uid_tag} within
521  * the interface.
522  * iface_entry->tag_stat_list_lock should be held. */
523 static struct tag_stat *create_if_tag_stat(struct iface_stat *iface_entry,
524                                            tag_t tag)
525 {
526         struct tag_stat *new_tag_stat_entry = NULL;
527         pr_debug("iface_stat: create_if_tag_stat(): ife=%p tag=0x%llx"
528                  " (uid=%d)\n",
529                  iface_entry, tag, get_uid_from_tag(tag));
530         new_tag_stat_entry = kmalloc(sizeof(*new_tag_stat_entry), GFP_ATOMIC);
531         if (!new_tag_stat_entry) {
532                 pr_err("iface_stat: failed to alloc new tag entry\n");
533                 goto done;
534         }
535         memset(new_tag_stat_entry, 0, sizeof(*new_tag_stat_entry));
536         new_tag_stat_entry->tag = tag;
537         tag_stat_tree_insert(new_tag_stat_entry, &iface_entry->tag_stat_tree);
538 done:
539         return new_tag_stat_entry;
540 }
541
542 static struct iface_stat *get_iface_entry(const char *ifname)
543 {
544         struct iface_stat *iface_entry;
545         unsigned long flags;
546
547         /* Find the entry for tracking the specified tag within the interface */
548         if (ifname == NULL) {
549                 pr_info("iface_stat: NULL device name\n");
550                 return NULL;
551         }
552
553
554         /* Iterate over interfaces */
555         spin_lock_irqsave(&iface_stat_list_lock, flags);
556         list_for_each_entry(iface_entry, &iface_stat_list, list) {
557                 if (!strcmp(ifname, iface_entry->ifname))
558                         goto done;
559         }
560         iface_entry = NULL;
561 done:
562         spin_unlock_irqrestore(&iface_stat_list_lock, flags);
563         return iface_entry;
564 }
565
566 static void if_tag_stat_update(const char *ifname, uid_t uid,
567                                const struct sock *sk, enum ifs_tx_rx direction,
568                                int proto, int bytes)
569 {
570         struct tag_stat *tag_stat_entry;
571         tag_t tag, acct_tag;
572         tag_t uid_tag;
573         struct data_counters *uid_tag_counters;
574         struct sock_tag *sock_tag_entry;
575         struct iface_stat *iface_entry;
576         unsigned long flags;
577         struct tag_stat *new_tag_stat;
578         pr_debug("xt_qtaguid: if_tag_stat_update(ifname=%s "
579                 "uid=%d sk=%p dir=%d proto=%d bytes=%d)\n",
580                  ifname, uid, sk, direction, proto, bytes);
581
582
583         iface_entry = get_iface_entry(ifname);
584         if (!iface_entry) {
585                 pr_err("iface_stat: interface %s not found\n", ifname);
586                 return;
587         }
588         /* else { If the iface_entry becomes inactive, it is still ok
589          * to process the data. } */
590
591         pr_debug("iface_stat: stat_update() got entry=%p\n", iface_entry);
592
593         /* Look for a tagged sock.
594          * It will have an acct_uid. */
595         sock_tag_entry = get_sock_stat(sk);
596         if (sock_tag_entry) {
597                 tag = sock_tag_entry->tag;
598                 acct_tag = get_atag_from_tag(tag);
599                 uid_tag = get_utag_from_tag(tag);
600         } else {
601                 uid_tag = make_tag_from_uid(uid);
602                 acct_tag = 0;
603                 tag = combine_atag_with_uid(acct_tag, uid);
604         }
605         pr_debug("iface_stat: stat_update(): looking for tag=0x%llx (uid=%d)"
606                  " in ife=%p\n",
607                  tag, get_uid_from_tag(tag), iface_entry);
608         /* Loop over tag list under this interface for {acct_tag,uid_tag} */
609         spin_lock_irqsave(&iface_entry->tag_stat_list_lock, flags);
610
611         tag_stat_entry = tag_stat_tree_search(&iface_entry->tag_stat_tree,
612                                               tag);
613         if (tag_stat_entry) {
614                 /* Updating the {acct_tag, uid_tag} entry handles both stats:
615                  * {0, uid_tag} will also get updated. */
616                 tag_stat_update(tag_stat_entry, direction, proto, bytes);
617                 spin_unlock_irqrestore(&iface_entry->tag_stat_list_lock, flags);
618                 return;
619         }
620
621         /* Loop over tag list under this interface for {0,uid_tag} */
622         tag_stat_entry = tag_stat_tree_search(&iface_entry->tag_stat_tree,
623                                               uid_tag);
624         if (!tag_stat_entry) {
625                 /* Here: the base uid_tag did not exist */
626                 /*
627                  * No parent counters. So
628                  *  - No {0, uid_tag} stats and no {acc_tag, uid_tag} stats.
629                  */
630                 new_tag_stat = create_if_tag_stat(iface_entry, uid_tag);
631                 uid_tag_counters = &new_tag_stat->counters;
632         } else {
633                 uid_tag_counters = &tag_stat_entry->counters;
634         }
635
636         if (acct_tag) {
637                 new_tag_stat = create_if_tag_stat(iface_entry, tag);
638                 new_tag_stat->parent_counters = uid_tag_counters;
639         }
640         spin_unlock_irqrestore(&iface_entry->tag_stat_list_lock, flags);
641         tag_stat_update(new_tag_stat, direction, proto, bytes);
642 }
643
644 static int iface_netdev_event_handler(struct notifier_block *nb,
645                                       unsigned long event, void *ptr) {
646         struct net_device *dev = ptr;
647
648         pr_debug("iface_stat: netdev_event(): ev=0x%lx netdev=%p->name=%s\n",
649                  event, dev, dev ? dev->name : "");
650
651         switch (event) {
652         case NETDEV_UP:
653         case NETDEV_REBOOT:
654         case NETDEV_CHANGE:
655         case NETDEV_REGISTER:  /* Most likely no IP */
656         case NETDEV_CHANGEADDR:  /* MAC addr change */
657         case NETDEV_CHANGENAME:
658         case NETDEV_FEAT_CHANGE:  /* Might be usefull when cell type changes */
659                 iface_stat_create(dev);
660                 break;
661         case NETDEV_UNREGISTER:
662                 iface_stat_update(dev);
663                 break;
664         }
665         return NOTIFY_DONE;
666 }
667
668 static int iface_inetaddr_event_handler(struct notifier_block *nb,
669                                         unsigned long event, void *ptr) {
670
671         struct in_ifaddr *ifa = ptr;
672         struct in_device *in_dev = ifa->ifa_dev;
673         struct net_device *dev = in_dev->dev;
674
675         pr_debug("iface_stat: inetaddr_event(): ev=0x%lx netdev=%p->name=%s\n",
676                  event, dev, dev ? dev->name : "");
677
678         switch (event) {
679         case NETDEV_UP:
680                 iface_stat_create(dev);
681                 break;
682         }
683         return NOTIFY_DONE;
684 }
685
686 static struct notifier_block iface_netdev_notifier_blk = {
687         .notifier_call = iface_netdev_event_handler,
688 };
689
690 static struct notifier_block iface_inetaddr_notifier_blk = {
691         .notifier_call = iface_inetaddr_event_handler,
692 };
693
694 static int __init iface_stat_init(struct proc_dir_entry *parent_procdir)
695 {
696         int err;
697
698         iface_stat_procdir = proc_mkdir(iface_stat_procdirname, parent_procdir);
699         if (!iface_stat_procdir) {
700                 pr_err("iface_stat: failed to create proc entry\n");
701                 err = -1;
702                 goto err;
703         }
704         err = register_netdevice_notifier(&iface_netdev_notifier_blk);
705         if (err) {
706                 pr_err("iface_stat: failed to register dev event handler\n");
707                 goto err_unreg_nd;
708         }
709         err = register_inetaddr_notifier(&iface_inetaddr_notifier_blk);
710         if (err) {
711                 pr_err("iface_stat: failed to register dev event handler\n");
712                 goto err_zap_entry;
713         }
714         return 0;
715
716 err_unreg_nd:
717         unregister_netdevice_notifier(&iface_netdev_notifier_blk);
718 err_zap_entry:
719         remove_proc_entry(iface_stat_procdirname, parent_procdir);
720 err:
721         return err;
722 }
723
724 static struct sock *qtaguid_find_sk(const struct sk_buff *skb,
725                                     struct xt_action_param *par)
726 {
727         struct sock *sk;
728
729         sk = xt_socket_get4_sk(skb, par);
730         /* TODO: is this fixed?
731          * Seems to be issues on the file ptr for TCP+TIME_WAIT SKs.
732          * http://kerneltrap.org/mailarchive/linux-netdev/2010/10/21/6287959
733          */
734         if (sk)
735                 pr_debug("xt_qtaguid: %p->sk_proto=%u "
736                         "->sk_state=%d\n", sk, sk->sk_protocol,
737                         sk->sk_state);
738         return sk;
739 }
740
741 static void account_for_uid(const struct sk_buff *skb,
742                             const struct sock *alternate_sk, uid_t uid,
743                             struct xt_action_param *par)
744 {
745         const struct net_device *el_dev;
746
747         if (!skb->dev) {
748                 pr_debug("xt_qtaguid[%d]: no skb->dev\n", par->hooknum);
749                 el_dev = par->in ? : par->out;
750         } else {
751                 const struct net_device *other_dev;
752                 el_dev = skb->dev;
753                 other_dev = par->in ? : par->out;
754                 if (el_dev != other_dev) {
755                         pr_debug("xt_qtaguid[%d]: skb->dev=%p %s vs "
756                                 "par->(in/out)=%p %s\n",
757                                 par->hooknum, el_dev, el_dev->name, other_dev,
758                                 other_dev->name);
759                 }
760         }
761
762         if (unlikely(!el_dev)) {
763                 pr_info("xt_qtaguid[%d]: no par->in/out?!!\n", par->hooknum);
764         } else if (unlikely(!el_dev->name)) {
765                 pr_info("xt_qtaguid[%d]: no dev->name?!!\n", par->hooknum);
766         } else {
767                 pr_debug("xt_qtaguid[%d]: dev name=%s type=%d\n",
768                         par->hooknum,
769                         el_dev->name,
770                         el_dev->type);
771
772                 if_tag_stat_update(el_dev->name, uid,
773                                 skb->sk ? skb->sk : alternate_sk,
774                                 par->in ? IFS_RX : IFS_TX,
775                                 ip_hdr(skb)->protocol, skb->len);
776         }
777 }
778
779 static bool qtaguid_mt(const struct sk_buff *skb, struct xt_action_param *par)
780 {
781         const struct xt_qtaguid_match_info *info = par->matchinfo;
782         const struct file *filp;
783         bool got_sock = false;
784         struct sock *sk;
785         uid_t sock_uid;
786         bool res;
787         pr_debug("xt_qtaguid[%d]: entered skb=%p par->in=%p/out=%p\n",
788                 par->hooknum, skb, par->in, par->out);
789         if (skb == NULL) {
790                 res = (info->match ^ info->invert) == 0;
791                 goto ret_res;
792         }
793
794         sk = skb->sk;
795
796         if (sk == NULL) {
797                 /*  A missing sk->sk_socket happens when packets are in-flight
798                  * and the matching socket is already closed and gone.
799                  */
800                 sk = qtaguid_find_sk(skb, par);
801                 /* If we got the socket from the find_sk(), we will need to put
802                  * it back, as nf_tproxy_get_sock_v4() got it. */
803                 got_sock = sk;
804         }
805         pr_debug("xt_qtaguid[%d]: sk=%p got_sock=%d proto=%d\n",
806                 par->hooknum, sk, got_sock, ip_hdr(skb)->protocol);
807         if (sk != NULL) {
808                 pr_debug("xt_qtaguid[%d]: sk=%p->sk_socket=%p->file=%p\n",
809                         par->hooknum, sk, sk->sk_socket,
810                         sk->sk_socket ? sk->sk_socket->file : (void *)-1LL);
811                 filp = sk->sk_socket ? sk->sk_socket->file : NULL;
812                 pr_debug("xt_qtaguid[%d]: filp...uid=%d\n",
813                         par->hooknum, filp ? filp->f_cred->fsuid : -1);
814         }
815
816         if (sk == NULL || sk->sk_socket == NULL) {
817                 /* Here, the qtaguid_find_sk() using connection tracking
818                  * couldn't find the owner, so for now we just count them
819                  * against the system. */
820                 /* TODO: unhack how to force just accounting.
821                  * For now we only do iface stats when the uid-owner is not
822                  * requested */
823                 if (!(info->match & XT_QTAGUID_UID))
824                         account_for_uid(skb, sk, 0, par);
825                 pr_debug("xt_qtaguid[%d]: leaving (sk?sk->sk_socket)=%p\n",
826                         par->hooknum,
827                         sk ? sk->sk_socket : NULL);
828                 res =  (info->match ^ info->invert) == 0;
829                 goto put_sock_ret_res;
830         } else if (info->match & info->invert & XT_QTAGUID_SOCKET) {
831                 res = false;
832                 goto put_sock_ret_res;
833         }
834         filp = sk->sk_socket->file;
835         if (filp == NULL) {
836                 pr_debug("xt_qtaguid[%d]: leaving filp=NULL\n", par->hooknum);
837                 res = ((info->match ^ info->invert) &
838                         (XT_QTAGUID_UID | XT_QTAGUID_GID)) == 0;
839                 goto put_sock_ret_res;
840         }
841         sock_uid = filp->f_cred->fsuid;
842         /* TODO: unhack how to force just accounting.
843          * For now we only do iface stats when the uid-owner is not requested */
844         if (!(info->match & XT_QTAGUID_UID))
845                 account_for_uid(skb, sk, sock_uid, par);
846
847         /* The following two tests fail the match when:
848          *    id not in range AND no inverted condition requested
849          * or id     in range AND    inverted condition requested
850          * Thus (!a && b) || (a && !b) == a ^ b
851          */
852         if (info->match & XT_QTAGUID_UID)
853                 if ((filp->f_cred->fsuid >= info->uid_min &&
854                      filp->f_cred->fsuid <= info->uid_max) ^
855                     !(info->invert & XT_QTAGUID_UID)) {
856                         pr_debug("xt_qtaguid[%d]: leaving uid not matching\n",
857                                  par->hooknum);
858                         res = false;
859                         goto put_sock_ret_res;
860                 }
861         if (info->match & XT_QTAGUID_GID)
862                 if ((filp->f_cred->fsgid >= info->gid_min &&
863                                 filp->f_cred->fsgid <= info->gid_max) ^
864                         !(info->invert & XT_QTAGUID_GID)) {
865                         pr_debug("xt_qtaguid[%d]: leaving gid not matching\n",
866                                 par->hooknum);
867                         res = false;
868                         goto put_sock_ret_res;
869                 }
870
871         pr_debug("xt_qtaguid[%d]: leaving matched\n", par->hooknum);
872         res = true;
873
874 put_sock_ret_res:
875         if (got_sock)
876                 xt_socket_put_sk(sk);
877 ret_res:
878         pr_debug("xt_qtaguid[%d]: left %d\n", par->hooknum, res);
879         return res;
880 }
881
882 /* TODO: Use Documentation/filesystems/seq_file.txt? */
883 static int qtaguid_ctrl_proc_read(char *page, char **start, off_t off,
884                                 int count, int *eof, void *data)
885 {
886         char *out = page + off;
887         int len;
888         unsigned long flags;
889         uid_t uid;
890         struct sock_tag *sock_tag_entry;
891         struct rb_node *node;
892         pr_debug("xt_qtaguid:proc ctrl page=%p off=%ld count=%d eof=%p\n",
893                 page, off, count, eof);
894
895         *eof = 0;
896         spin_lock_irqsave(&sock_tag_list_lock, flags);
897         for (node = rb_first(&sock_tag_tree);
898              node;
899              node = rb_next(node)) {
900                 sock_tag_entry =  rb_entry(node, struct sock_tag, node);
901                 uid = get_uid_from_tag(sock_tag_entry->tag);
902                 pr_debug("xt_qtaguid: proc_read(): sk=%p tag=0x%llx (uid=%d)\n",
903                         sock_tag_entry->sk,
904                         sock_tag_entry->tag,
905                         uid);
906                 len = snprintf(out, count, "sock=%p tag=0x%llx (uid=%u)\n",
907                         sock_tag_entry->sk, sock_tag_entry->tag, uid);
908                 out += len;
909                 count -= len;
910                 if (!count) {
911                         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
912                         return out - page;
913                 }
914         }
915         *eof = 1;
916         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
917         return out - page;
918 }
919
920 static int qtaguid_ctrl_parse(const char *input, int count)
921 {
922         char cmd;
923         int sock_fd = 0;
924         uid_t uid = 0;
925         tag_t acct_tag = 0;
926         struct socket *el_socket;
927         int res, argc;
928         struct sock_tag *sock_tag_entry;
929         unsigned long flags;
930
931         pr_debug("xt_qtaguid: ctrl(%s): entered\n", input);
932         /* Unassigned args will get defaulted later. */
933         /* TODO: get acct_tag_str, keep a list of available tags for the
934          * uid, use num as acct_tag. */
935         argc = sscanf(input, "%c %d %llu %u", &cmd, &sock_fd, &acct_tag, &uid);
936         pr_debug("xt_qtaguid: ctrl(%s): argc=%d cmd=%c sock_fd=%d "
937                 "acct_tag=0x%llx uid=%u\n", input, argc, cmd, sock_fd,
938                 acct_tag, uid);
939
940         /* Collect params for commands */
941         switch (cmd) {
942         case 't':
943         case 'u':
944                 if (argc < 2) {
945                         res = -EINVAL;
946                         goto err;
947                 }
948                 el_socket = sockfd_lookup(sock_fd, &res);
949                 if (!el_socket) {
950                         pr_info("xt_qtaguid: ctrl(%s): failed to lookup"
951                                 " sock_fd=%d err=%d\n", input, sock_fd, res);
952                         goto err;
953                 }
954                 spin_lock_irqsave(&sock_tag_list_lock, flags);
955                 /* TODO: optim: pass in the current_fsuid() to do lookups
956                  * as look ups will always be initiated form the same uid. */
957                 sock_tag_entry = get_sock_stat_nl(el_socket->sk);
958                 if (!sock_tag_entry)
959                         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
960                 break;
961         default:
962                 res = -EINVAL;
963                 goto err;
964         }
965
966         /* Process commands */
967         switch (cmd) {
968
969         case 't':
970                 if (argc < 2) {
971                         res = -EINVAL;
972                         goto err_unlock;
973                 }
974                 if (argc < 3) {
975                         acct_tag = 0;
976                 } else if (!valid_atag(acct_tag)) {
977                         res = -EINVAL;
978                         goto err_unlock;
979                 }
980                 if (argc < 4)
981                         uid = current_fsuid();
982                 if (!sock_tag_entry) {
983                         sock_tag_entry = kmalloc(sizeof(*sock_tag_entry),
984                                                 GFP_KERNEL);
985                         if (!sock_tag_entry) {
986                                 res = -ENOMEM;
987                                 goto err;
988                         }
989                         memset(sock_tag_entry, 0, sizeof(*sock_tag_entry));
990                         sock_tag_entry->sk = el_socket->sk;
991                         /* TODO: check that uid==current_fsuid() except
992                          * for special uid/gid. */
993                         sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
994                                                                 uid);
995                         spin_lock_irqsave(&sock_tag_list_lock, flags);
996                         sock_tag_tree_insert(sock_tag_entry, &sock_tag_tree);
997                         spin_unlock_irqrestore(&sock_tag_list_lock, flags);
998                 } else {
999                         /* Just update the acct_tag portion. */
1000                         uid_t orig_uid = get_uid_from_tag(sock_tag_entry->tag);
1001                         sock_tag_entry->tag = combine_atag_with_uid(acct_tag,
1002                                                                 orig_uid);
1003                 }
1004                 pr_debug("xt_qtaguid: tag: sock_tag_entry->sk=%p "
1005                         "...->tag=0x%llx (uid=%u)\n",
1006                         sock_tag_entry->sk, sock_tag_entry->tag,
1007                         get_uid_from_tag(sock_tag_entry->tag));
1008                 break;
1009
1010         case 'u':
1011                 if (!sock_tag_entry) {
1012                         res = -EINVAL;
1013                         goto err;
1014                 }
1015                 /* TODO: check that the uid==current_fsuid()
1016                  * except for special uid/gid. */
1017                 rb_erase(&sock_tag_entry->node, &sock_tag_tree);
1018                 spin_unlock_irqrestore(&sock_tag_list_lock, flags);
1019                 kfree(sock_tag_entry);
1020                 break;
1021         }
1022
1023         /* All of the input has been processed */
1024         res = count;
1025         goto ok;
1026
1027 err_unlock:
1028         if (!sock_tag_entry)
1029                 spin_unlock_irqrestore(&sock_tag_list_lock, flags);
1030 err:
1031 ok:
1032         pr_debug("xt_qtaguid: ctrl(%s): res=%d\n", input, res);
1033         return res;
1034 }
1035
1036 #define MAX_QTAGUID_CTRL_INPUT_LEN 255
1037 static int qtaguid_ctrl_proc_write(struct file *file, const char __user *buffer,
1038                         unsigned long count, void *data)
1039 {
1040         char input_buf[MAX_QTAGUID_CTRL_INPUT_LEN];
1041
1042         if (count >= MAX_QTAGUID_CTRL_INPUT_LEN)
1043                 return -EINVAL;
1044
1045         if (copy_from_user(input_buf, buffer, count))
1046                 return -EFAULT;
1047
1048         input_buf[count] = '\0';
1049         return qtaguid_ctrl_parse(input_buf, count);
1050 }
1051
1052 /*
1053  * Procfs reader to get all tag stats using style "1)" as described in
1054  * fs/proc/generic.c
1055  * Groups all protocols tx/rx bytes.
1056  */
1057 static int qtaguid_stats_proc_read(char *page, char **num_items_returned,
1058                                 off_t items_to_skip, int char_count, int *eof,
1059                                 void *data)
1060 {
1061         char *outp = page;
1062         int len;
1063         unsigned long flags, flags2;
1064         struct iface_stat *iface_entry;
1065         struct tag_stat *ts_entry;
1066         int item_index = 0;
1067
1068         /* TODO: make root access only */
1069
1070         pr_debug("xt_qtaguid:proc stats page=%p *num_items_returned=%p off=%ld "
1071                 "char_count=%d *eof=%d\n", page, *num_items_returned,
1072                 items_to_skip, char_count, *eof);
1073
1074         if (*eof)
1075                 return 0;
1076
1077         if (!items_to_skip) {
1078                 /* The idx is there to help debug when things go belly up. */
1079                 len = snprintf(outp, char_count,
1080                         "idx iface acct_tag_hex uid_tag_int rx_bytes "
1081                         "tx_bytes\n");
1082                 /* Don't advance the outp unless the whole line was printed */
1083                 if (len >= char_count) {
1084                         *outp = '\0';
1085                         return outp - page;
1086                 }
1087                 outp += len;
1088                 char_count -= len;
1089         }
1090
1091         spin_lock_irqsave(&iface_stat_list_lock, flags);
1092         list_for_each_entry(iface_entry, &iface_stat_list, list) {
1093                 struct rb_node *node;
1094                 spin_lock_irqsave(&iface_entry->tag_stat_list_lock, flags2);
1095                 for (node = rb_first(&iface_entry->tag_stat_tree);
1096                      node;
1097                      node = rb_next(node)) {
1098                         ts_entry =  rb_entry(node, struct tag_stat, node);
1099                         if (item_index++ < items_to_skip)
1100                                 continue;
1101                         len = snprintf(outp, char_count,
1102                                        "%d %s 0x%llx %u %llu %llu\n",
1103                                        item_index,
1104                                        iface_entry->ifname,
1105                                        get_atag_from_tag(ts_entry->tag),
1106                                        get_uid_from_tag(ts_entry->tag),
1107                                        dc_sum_bytes(&ts_entry->counters,
1108                                                     IFS_RX),
1109                                        dc_sum_bytes(&ts_entry->counters,
1110                                                     IFS_TX));
1111                         if (len >= char_count) {
1112                                 spin_unlock_irqrestore(
1113                                         &iface_entry->tag_stat_list_lock,
1114                                         flags2);
1115                                 spin_unlock_irqrestore(
1116                                         &iface_stat_list_lock, flags);
1117                                 *outp = '\0';
1118                                 return outp - page;
1119                         }
1120                         outp += len;
1121                         char_count -= len;
1122                         (*(int *)num_items_returned)++;
1123                 }
1124                 spin_unlock_irqrestore(&iface_entry->tag_stat_list_lock,
1125                                 flags2);
1126         }
1127         spin_unlock_irqrestore(&iface_stat_list_lock, flags);
1128
1129         *eof = 1;
1130         return outp - page;
1131 }
1132
1133 /*------------------------------------------*/
1134 static const char *module_procdirname = "xt_qtaguid";
1135 static struct proc_dir_entry *xt_qtaguid_procdir;
1136 static struct proc_dir_entry *xt_qtaguid_ctrl_file;
1137 static struct proc_dir_entry *xt_qtaguid_stats_file;
1138
1139 static int __init qtaguid_proc_register(struct proc_dir_entry **res_procdir)
1140 {
1141         int ret;
1142         *res_procdir = proc_mkdir(module_procdirname, init_net.proc_net);
1143         if (!*res_procdir) {
1144                 pr_err("xt_qtaguid: failed to create proc/.../xt_qtaguid\n");
1145                 ret = -ENOMEM;
1146                 goto no_dir;
1147         }
1148
1149         xt_qtaguid_ctrl_file = create_proc_entry("ctrl", 0666,
1150                                                 *res_procdir);
1151         if (!xt_qtaguid_ctrl_file) {
1152                 pr_err("xt_qtaguid: failed to create xt_qtaguid/ctrl "
1153                         " file\n");
1154                 ret = -ENOMEM;
1155                 goto no_ctrl_entry;
1156         }
1157         xt_qtaguid_ctrl_file->read_proc = qtaguid_ctrl_proc_read;
1158         xt_qtaguid_ctrl_file->write_proc = qtaguid_ctrl_proc_write;
1159
1160         xt_qtaguid_stats_file = create_proc_entry("stats", 0666,
1161                                                 *res_procdir);
1162         if (!xt_qtaguid_stats_file) {
1163                 pr_err("xt_qtaguid: failed to create xt_qtaguid/stats "
1164                         "file\n");
1165                 ret = -ENOMEM;
1166                 goto no_stats_entry;
1167         }
1168         /*
1169          * TODO: add extra read_proc for full stats with protocol
1170          * breakout
1171          */
1172         xt_qtaguid_stats_file->read_proc = qtaguid_stats_proc_read;
1173         /*
1174          * TODO: add support counter hacking
1175          * xt_qtaguid_stats_file->write_proc = qtaguid_stats_proc_write;
1176          */
1177         return 0;
1178
1179 no_stats_entry:
1180         remove_proc_entry("ctrl", *res_procdir);
1181 no_ctrl_entry:
1182         remove_proc_entry("xt_qtaguid", NULL);
1183 no_dir:
1184         return ret;
1185 }
1186
1187 static struct xt_match qtaguid_mt_reg __read_mostly = {
1188         /*
1189          * This module masquerades as the "owner" module so that iptables
1190          * tools can deal with it.
1191          */
1192         .name       = "owner",
1193         .revision   = 1,
1194         .family     = NFPROTO_UNSPEC,
1195         .match      = qtaguid_mt,
1196         .matchsize  = sizeof(struct xt_qtaguid_match_info),
1197         .me         = THIS_MODULE,
1198 };
1199
1200 static int __init qtaguid_mt_init(void)
1201 {
1202         if (qtaguid_proc_register(&xt_qtaguid_procdir)
1203             || iface_stat_init(xt_qtaguid_procdir)
1204             || xt_register_match(&qtaguid_mt_reg))
1205                 return -1;
1206         return 0;
1207 }
1208
1209 /* TODO: allow unloading of the module.
1210  * For now stats are permanent.
1211  * Kconfig forces'y/n' and never an 'm'.
1212  */
1213
1214 module_init(qtaguid_mt_init);
1215 MODULE_AUTHOR("jpa <jpa@google.com>");
1216 MODULE_DESCRIPTION("Xtables: socket owner+tag matching and associated stats");
1217 MODULE_LICENSE("GPL");
1218 MODULE_ALIAS("ipt_owner");
1219 MODULE_ALIAS("ip6t_owner");
1220 MODULE_ALIAS("ipt_qtaguid");
1221 MODULE_ALIAS("ip6t_qtaguid");