d6b48230a54015d9b8cac977fe913e36ab07c79e
[linux-2.6.git] / net / netfilter / ipset / ip_set_core.c
1 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
2  *                         Patrick Schaaf <bof@bof.de>
3  * Copyright (C) 2003-2011 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  */
9
10 /* Kernel module for IP set management */
11
12 #include <linux/init.h>
13 #include <linux/module.h>
14 #include <linux/moduleparam.h>
15 #include <linux/ip.h>
16 #include <linux/skbuff.h>
17 #include <linux/spinlock.h>
18 #include <linux/netlink.h>
19 #include <linux/rculist.h>
20 #include <linux/version.h>
21 #include <net/netlink.h>
22
23 #include <linux/netfilter.h>
24 #include <linux/netfilter/nfnetlink.h>
25 #include <linux/netfilter/ipset/ip_set.h>
26
27 static LIST_HEAD(ip_set_type_list);             /* all registered set types */
28 static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
29
30 static struct ip_set **ip_set_list;             /* all individual sets */
31 static ip_set_id_t ip_set_max = CONFIG_IP_SET_MAX; /* max number of sets */
32
33 #define STREQ(a, b)     (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
34
35 static unsigned int max_sets;
36
37 module_param(max_sets, int, 0600);
38 MODULE_PARM_DESC(max_sets, "maximal number of sets");
39 MODULE_LICENSE("GPL");
40 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
41 MODULE_DESCRIPTION("core IP set support");
42 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
43
44 /*
45  * The set types are implemented in modules and registered set types
46  * can be found in ip_set_type_list. Adding/deleting types is
47  * serialized by ip_set_type_mutex.
48  */
49
50 static inline void
51 ip_set_type_lock(void)
52 {
53         mutex_lock(&ip_set_type_mutex);
54 }
55
56 static inline void
57 ip_set_type_unlock(void)
58 {
59         mutex_unlock(&ip_set_type_mutex);
60 }
61
62 /* Register and deregister settype */
63
64 static struct ip_set_type *
65 find_set_type(const char *name, u8 family, u8 revision)
66 {
67         struct ip_set_type *type;
68
69         list_for_each_entry_rcu(type, &ip_set_type_list, list)
70                 if (STREQ(type->name, name) &&
71                     (type->family == family || type->family == AF_UNSPEC) &&
72                     type->revision == revision)
73                         return type;
74         return NULL;
75 }
76
77 /* Unlock, try to load a set type module and lock again */
78 static int
79 try_to_load_type(const char *name)
80 {
81         nfnl_unlock();
82         pr_debug("try to load ip_set_%s\n", name);
83         if (request_module("ip_set_%s", name) < 0) {
84                 pr_warning("Can't find ip_set type %s\n", name);
85                 nfnl_lock();
86                 return -IPSET_ERR_FIND_TYPE;
87         }
88         nfnl_lock();
89         return -EAGAIN;
90 }
91
92 /* Find a set type and reference it */
93 static int
94 find_set_type_get(const char *name, u8 family, u8 revision,
95                   struct ip_set_type **found)
96 {
97         struct ip_set_type *type;
98         int err;
99
100         rcu_read_lock();
101         *found = find_set_type(name, family, revision);
102         if (*found) {
103                 err = !try_module_get((*found)->me) ? -EFAULT : 0;
104                 goto unlock;
105         }
106         /* Make sure the type is loaded but we don't support the revision */
107         list_for_each_entry_rcu(type, &ip_set_type_list, list)
108                 if (STREQ(type->name, name)) {
109                         err = -IPSET_ERR_FIND_TYPE;
110                         goto unlock;
111                 }
112         rcu_read_unlock();
113
114         return try_to_load_type(name);
115
116 unlock:
117         rcu_read_unlock();
118         return err;
119 }
120
121 /* Find a given set type by name and family.
122  * If we succeeded, the supported minimal and maximum revisions are
123  * filled out.
124  */
125 static int
126 find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max)
127 {
128         struct ip_set_type *type;
129         bool found = false;
130
131         *min = 255; *max = 0;
132         rcu_read_lock();
133         list_for_each_entry_rcu(type, &ip_set_type_list, list)
134                 if (STREQ(type->name, name) &&
135                     (type->family == family || type->family == AF_UNSPEC)) {
136                         found = true;
137                         if (type->revision < *min)
138                                 *min = type->revision;
139                         if (type->revision > *max)
140                                 *max = type->revision;
141                 }
142         rcu_read_unlock();
143         if (found)
144                 return 0;
145
146         return try_to_load_type(name);
147 }
148
149 #define family_name(f)  ((f) == AF_INET ? "inet" : \
150                          (f) == AF_INET6 ? "inet6" : "any")
151
152 /* Register a set type structure. The type is identified by
153  * the unique triple of name, family and revision.
154  */
155 int
156 ip_set_type_register(struct ip_set_type *type)
157 {
158         int ret = 0;
159
160         if (type->protocol != IPSET_PROTOCOL) {
161                 pr_warning("ip_set type %s, family %s, revision %u uses "
162                            "wrong protocol version %u (want %u)\n",
163                            type->name, family_name(type->family),
164                            type->revision, type->protocol, IPSET_PROTOCOL);
165                 return -EINVAL;
166         }
167
168         ip_set_type_lock();
169         if (find_set_type(type->name, type->family, type->revision)) {
170                 /* Duplicate! */
171                 pr_warning("ip_set type %s, family %s, revision %u "
172                            "already registered!\n", type->name,
173                            family_name(type->family), type->revision);
174                 ret = -EINVAL;
175                 goto unlock;
176         }
177         list_add_rcu(&type->list, &ip_set_type_list);
178         pr_debug("type %s, family %s, revision %u registered.\n",
179                  type->name, family_name(type->family), type->revision);
180 unlock:
181         ip_set_type_unlock();
182         return ret;
183 }
184 EXPORT_SYMBOL_GPL(ip_set_type_register);
185
186 /* Unregister a set type. There's a small race with ip_set_create */
187 void
188 ip_set_type_unregister(struct ip_set_type *type)
189 {
190         ip_set_type_lock();
191         if (!find_set_type(type->name, type->family, type->revision)) {
192                 pr_warning("ip_set type %s, family %s, revision %u "
193                            "not registered\n", type->name,
194                            family_name(type->family), type->revision);
195                 goto unlock;
196         }
197         list_del_rcu(&type->list);
198         pr_debug("type %s, family %s, revision %u unregistered.\n",
199                  type->name, family_name(type->family), type->revision);
200 unlock:
201         ip_set_type_unlock();
202
203         synchronize_rcu();
204 }
205 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
206
207 /* Utility functions */
208 void *
209 ip_set_alloc(size_t size)
210 {
211         void *members = NULL;
212
213         if (size < KMALLOC_MAX_SIZE)
214                 members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
215
216         if (members) {
217                 pr_debug("%p: allocated with kmalloc\n", members);
218                 return members;
219         }
220
221         members = vzalloc(size);
222         if (!members)
223                 return NULL;
224         pr_debug("%p: allocated with vmalloc\n", members);
225
226         return members;
227 }
228 EXPORT_SYMBOL_GPL(ip_set_alloc);
229
230 void
231 ip_set_free(void *members)
232 {
233         pr_debug("%p: free with %s\n", members,
234                  is_vmalloc_addr(members) ? "vfree" : "kfree");
235         if (is_vmalloc_addr(members))
236                 vfree(members);
237         else
238                 kfree(members);
239 }
240 EXPORT_SYMBOL_GPL(ip_set_free);
241
242 static inline bool
243 flag_nested(const struct nlattr *nla)
244 {
245         return nla->nla_type & NLA_F_NESTED;
246 }
247
248 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
249         [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
250         [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
251                                             .len = sizeof(struct in6_addr) },
252 };
253
254 int
255 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
256 {
257         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
258
259         if (unlikely(!flag_nested(nla)))
260                 return -IPSET_ERR_PROTOCOL;
261         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
262                 return -IPSET_ERR_PROTOCOL;
263         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
264                 return -IPSET_ERR_PROTOCOL;
265
266         *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
267         return 0;
268 }
269 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
270
271 int
272 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
273 {
274         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
275
276         if (unlikely(!flag_nested(nla)))
277                 return -IPSET_ERR_PROTOCOL;
278
279         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
280                 return -IPSET_ERR_PROTOCOL;
281         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
282                 return -IPSET_ERR_PROTOCOL;
283
284         memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
285                 sizeof(struct in6_addr));
286         return 0;
287 }
288 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
289
290 /*
291  * Creating/destroying/renaming/swapping affect the existence and
292  * the properties of a set. All of these can be executed from userspace
293  * only and serialized by the nfnl mutex indirectly from nfnetlink.
294  *
295  * Sets are identified by their index in ip_set_list and the index
296  * is used by the external references (set/SET netfilter modules).
297  *
298  * The set behind an index may change by swapping only, from userspace.
299  */
300
301 static inline void
302 __ip_set_get(ip_set_id_t index)
303 {
304         atomic_inc(&ip_set_list[index]->ref);
305 }
306
307 static inline void
308 __ip_set_put(ip_set_id_t index)
309 {
310         atomic_dec(&ip_set_list[index]->ref);
311 }
312
313 /*
314  * Add, del and test set entries from kernel.
315  *
316  * The set behind the index must exist and must be referenced
317  * so it can't be destroyed (or changed) under our foot.
318  */
319
320 int
321 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
322             u8 family, u8 dim, u8 flags)
323 {
324         struct ip_set *set = ip_set_list[index];
325         int ret = 0;
326
327         BUG_ON(set == NULL || atomic_read(&set->ref) == 0);
328         pr_debug("set %s, index %u\n", set->name, index);
329
330         if (dim < set->type->dimension ||
331             !(family == set->family || set->family == AF_UNSPEC))
332                 return 0;
333
334         read_lock_bh(&set->lock);
335         ret = set->variant->kadt(set, skb, IPSET_TEST, family, dim, flags);
336         read_unlock_bh(&set->lock);
337
338         if (ret == -EAGAIN) {
339                 /* Type requests element to be completed */
340                 pr_debug("element must be competed, ADD is triggered\n");
341                 write_lock_bh(&set->lock);
342                 set->variant->kadt(set, skb, IPSET_ADD, family, dim, flags);
343                 write_unlock_bh(&set->lock);
344                 ret = 1;
345         }
346
347         /* Convert error codes to nomatch */
348         return (ret < 0 ? 0 : ret);
349 }
350 EXPORT_SYMBOL_GPL(ip_set_test);
351
352 int
353 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
354            u8 family, u8 dim, u8 flags)
355 {
356         struct ip_set *set = ip_set_list[index];
357         int ret;
358
359         BUG_ON(set == NULL || atomic_read(&set->ref) == 0);
360         pr_debug("set %s, index %u\n", set->name, index);
361
362         if (dim < set->type->dimension ||
363             !(family == set->family || set->family == AF_UNSPEC))
364                 return 0;
365
366         write_lock_bh(&set->lock);
367         ret = set->variant->kadt(set, skb, IPSET_ADD, family, dim, flags);
368         write_unlock_bh(&set->lock);
369
370         return ret;
371 }
372 EXPORT_SYMBOL_GPL(ip_set_add);
373
374 int
375 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
376            u8 family, u8 dim, u8 flags)
377 {
378         struct ip_set *set = ip_set_list[index];
379         int ret = 0;
380
381         BUG_ON(set == NULL || atomic_read(&set->ref) == 0);
382         pr_debug("set %s, index %u\n", set->name, index);
383
384         if (dim < set->type->dimension ||
385             !(family == set->family || set->family == AF_UNSPEC))
386                 return 0;
387
388         write_lock_bh(&set->lock);
389         ret = set->variant->kadt(set, skb, IPSET_DEL, family, dim, flags);
390         write_unlock_bh(&set->lock);
391
392         return ret;
393 }
394 EXPORT_SYMBOL_GPL(ip_set_del);
395
396 /*
397  * Find set by name, reference it once. The reference makes sure the
398  * thing pointed to, does not go away under our feet.
399  *
400  * The nfnl mutex must already be activated.
401  */
402 ip_set_id_t
403 ip_set_get_byname(const char *name, struct ip_set **set)
404 {
405         ip_set_id_t i, index = IPSET_INVALID_ID;
406         struct ip_set *s;
407
408         for (i = 0; i < ip_set_max; i++) {
409                 s = ip_set_list[i];
410                 if (s != NULL && STREQ(s->name, name)) {
411                         __ip_set_get(i);
412                         index = i;
413                         *set = s;
414                 }
415         }
416
417         return index;
418 }
419 EXPORT_SYMBOL_GPL(ip_set_get_byname);
420
421 /*
422  * If the given set pointer points to a valid set, decrement
423  * reference count by 1. The caller shall not assume the index
424  * to be valid, after calling this function.
425  *
426  * The nfnl mutex must already be activated.
427  */
428 void
429 ip_set_put_byindex(ip_set_id_t index)
430 {
431         if (ip_set_list[index] != NULL) {
432                 BUG_ON(atomic_read(&ip_set_list[index]->ref) == 0);
433                 __ip_set_put(index);
434         }
435 }
436 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
437
438 /*
439  * Get the name of a set behind a set index.
440  * We assume the set is referenced, so it does exist and
441  * can't be destroyed. The set cannot be renamed due to
442  * the referencing either.
443  *
444  * The nfnl mutex must already be activated.
445  */
446 const char *
447 ip_set_name_byindex(ip_set_id_t index)
448 {
449         const struct ip_set *set = ip_set_list[index];
450
451         BUG_ON(set == NULL);
452         BUG_ON(atomic_read(&set->ref) == 0);
453
454         /* Referenced, so it's safe */
455         return set->name;
456 }
457 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
458
459 /*
460  * Routines to call by external subsystems, which do not
461  * call nfnl_lock for us.
462  */
463
464 /*
465  * Find set by name, reference it once. The reference makes sure the
466  * thing pointed to, does not go away under our feet.
467  *
468  * The nfnl mutex is used in the function.
469  */
470 ip_set_id_t
471 ip_set_nfnl_get(const char *name)
472 {
473         struct ip_set *s;
474         ip_set_id_t index;
475
476         nfnl_lock();
477         index = ip_set_get_byname(name, &s);
478         nfnl_unlock();
479
480         return index;
481 }
482 EXPORT_SYMBOL_GPL(ip_set_nfnl_get);
483
484 /*
485  * Find set by index, reference it once. The reference makes sure the
486  * thing pointed to, does not go away under our feet.
487  *
488  * The nfnl mutex is used in the function.
489  */
490 ip_set_id_t
491 ip_set_nfnl_get_byindex(ip_set_id_t index)
492 {
493         if (index > ip_set_max)
494                 return IPSET_INVALID_ID;
495
496         nfnl_lock();
497         if (ip_set_list[index])
498                 __ip_set_get(index);
499         else
500                 index = IPSET_INVALID_ID;
501         nfnl_unlock();
502
503         return index;
504 }
505 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
506
507 /*
508  * If the given set pointer points to a valid set, decrement
509  * reference count by 1. The caller shall not assume the index
510  * to be valid, after calling this function.
511  *
512  * The nfnl mutex is used in the function.
513  */
514 void
515 ip_set_nfnl_put(ip_set_id_t index)
516 {
517         nfnl_lock();
518         if (ip_set_list[index] != NULL) {
519                 BUG_ON(atomic_read(&ip_set_list[index]->ref) == 0);
520                 __ip_set_put(index);
521         }
522         nfnl_unlock();
523 }
524 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
525
526 /*
527  * Communication protocol with userspace over netlink.
528  *
529  * We already locked by nfnl_lock.
530  */
531
532 static inline bool
533 protocol_failed(const struct nlattr * const tb[])
534 {
535         return !tb[IPSET_ATTR_PROTOCOL] ||
536                nla_get_u8(tb[IPSET_ATTR_PROTOCOL]) != IPSET_PROTOCOL;
537 }
538
539 static inline u32
540 flag_exist(const struct nlmsghdr *nlh)
541 {
542         return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
543 }
544
545 static struct nlmsghdr *
546 start_msg(struct sk_buff *skb, u32 pid, u32 seq, unsigned int flags,
547           enum ipset_cmd cmd)
548 {
549         struct nlmsghdr *nlh;
550         struct nfgenmsg *nfmsg;
551
552         nlh = nlmsg_put(skb, pid, seq, cmd | (NFNL_SUBSYS_IPSET << 8),
553                         sizeof(*nfmsg), flags);
554         if (nlh == NULL)
555                 return NULL;
556
557         nfmsg = nlmsg_data(nlh);
558         nfmsg->nfgen_family = AF_INET;
559         nfmsg->version = NFNETLINK_V0;
560         nfmsg->res_id = 0;
561
562         return nlh;
563 }
564
565 /* Create a set */
566
567 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
568         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
569         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
570                                     .len = IPSET_MAXNAMELEN - 1 },
571         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
572                                     .len = IPSET_MAXNAMELEN - 1},
573         [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
574         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
575         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
576 };
577
578 static ip_set_id_t
579 find_set_id(const char *name)
580 {
581         ip_set_id_t i, index = IPSET_INVALID_ID;
582         const struct ip_set *set;
583
584         for (i = 0; index == IPSET_INVALID_ID && i < ip_set_max; i++) {
585                 set = ip_set_list[i];
586                 if (set != NULL && STREQ(set->name, name))
587                         index = i;
588         }
589         return index;
590 }
591
592 static inline struct ip_set *
593 find_set(const char *name)
594 {
595         ip_set_id_t index = find_set_id(name);
596
597         return index == IPSET_INVALID_ID ? NULL : ip_set_list[index];
598 }
599
600 static int
601 find_free_id(const char *name, ip_set_id_t *index, struct ip_set **set)
602 {
603         ip_set_id_t i;
604
605         *index = IPSET_INVALID_ID;
606         for (i = 0;  i < ip_set_max; i++) {
607                 if (ip_set_list[i] == NULL) {
608                         if (*index == IPSET_INVALID_ID)
609                                 *index = i;
610                 } else if (STREQ(name, ip_set_list[i]->name)) {
611                         /* Name clash */
612                         *set = ip_set_list[i];
613                         return -EEXIST;
614                 }
615         }
616         if (*index == IPSET_INVALID_ID)
617                 /* No free slot remained */
618                 return -IPSET_ERR_MAX_SETS;
619         return 0;
620 }
621
622 static int
623 ip_set_create(struct sock *ctnl, struct sk_buff *skb,
624               const struct nlmsghdr *nlh,
625               const struct nlattr * const attr[])
626 {
627         struct ip_set *set, *clash = NULL;
628         ip_set_id_t index = IPSET_INVALID_ID;
629         struct nlattr *tb[IPSET_ATTR_CREATE_MAX+1] = {};
630         const char *name, *typename;
631         u8 family, revision;
632         u32 flags = flag_exist(nlh);
633         int ret = 0;
634
635         if (unlikely(protocol_failed(attr) ||
636                      attr[IPSET_ATTR_SETNAME] == NULL ||
637                      attr[IPSET_ATTR_TYPENAME] == NULL ||
638                      attr[IPSET_ATTR_REVISION] == NULL ||
639                      attr[IPSET_ATTR_FAMILY] == NULL ||
640                      (attr[IPSET_ATTR_DATA] != NULL &&
641                       !flag_nested(attr[IPSET_ATTR_DATA]))))
642                 return -IPSET_ERR_PROTOCOL;
643
644         name = nla_data(attr[IPSET_ATTR_SETNAME]);
645         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
646         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
647         revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
648         pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
649                  name, typename, family_name(family), revision);
650
651         /*
652          * First, and without any locks, allocate and initialize
653          * a normal base set structure.
654          */
655         set = kzalloc(sizeof(struct ip_set), GFP_KERNEL);
656         if (!set)
657                 return -ENOMEM;
658         rwlock_init(&set->lock);
659         strlcpy(set->name, name, IPSET_MAXNAMELEN);
660         atomic_set(&set->ref, 0);
661         set->family = family;
662
663         /*
664          * Next, check that we know the type, and take
665          * a reference on the type, to make sure it stays available
666          * while constructing our new set.
667          *
668          * After referencing the type, we try to create the type
669          * specific part of the set without holding any locks.
670          */
671         ret = find_set_type_get(typename, family, revision, &(set->type));
672         if (ret)
673                 goto out;
674
675         /*
676          * Without holding any locks, create private part.
677          */
678         if (attr[IPSET_ATTR_DATA] &&
679             nla_parse_nested(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA],
680                              set->type->create_policy)) {
681                 ret = -IPSET_ERR_PROTOCOL;
682                 goto put_out;
683         }
684
685         ret = set->type->create(set, tb, flags);
686         if (ret != 0)
687                 goto put_out;
688
689         /* BTW, ret==0 here. */
690
691         /*
692          * Here, we have a valid, constructed set and we are protected
693          * by nfnl_lock. Find the first free index in ip_set_list and
694          * check clashing.
695          */
696         if ((ret = find_free_id(set->name, &index, &clash)) != 0) {
697                 /* If this is the same set and requested, ignore error */
698                 if (ret == -EEXIST &&
699                     (flags & IPSET_FLAG_EXIST) &&
700                     STREQ(set->type->name, clash->type->name) &&
701                     set->type->family == clash->type->family &&
702                     set->type->revision == clash->type->revision &&
703                     set->variant->same_set(set, clash))
704                         ret = 0;
705                 goto cleanup;
706         }
707
708         /*
709          * Finally! Add our shiny new set to the list, and be done.
710          */
711         pr_debug("create: '%s' created with index %u!\n", set->name, index);
712         ip_set_list[index] = set;
713
714         return ret;
715
716 cleanup:
717         set->variant->destroy(set);
718 put_out:
719         module_put(set->type->me);
720 out:
721         kfree(set);
722         return ret;
723 }
724
725 /* Destroy sets */
726
727 static const struct nla_policy
728 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
729         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
730         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
731                                     .len = IPSET_MAXNAMELEN - 1 },
732 };
733
734 static void
735 ip_set_destroy_set(ip_set_id_t index)
736 {
737         struct ip_set *set = ip_set_list[index];
738
739         pr_debug("set: %s\n",  set->name);
740         ip_set_list[index] = NULL;
741
742         /* Must call it without holding any lock */
743         set->variant->destroy(set);
744         module_put(set->type->me);
745         kfree(set);
746 }
747
748 static int
749 ip_set_destroy(struct sock *ctnl, struct sk_buff *skb,
750                const struct nlmsghdr *nlh,
751                const struct nlattr * const attr[])
752 {
753         ip_set_id_t i;
754
755         if (unlikely(protocol_failed(attr)))
756                 return -IPSET_ERR_PROTOCOL;
757
758         /* References are protected by the nfnl mutex */
759         if (!attr[IPSET_ATTR_SETNAME]) {
760                 for (i = 0; i < ip_set_max; i++) {
761                         if (ip_set_list[i] != NULL &&
762                             (atomic_read(&ip_set_list[i]->ref)))
763                                 return -IPSET_ERR_BUSY;
764                 }
765                 for (i = 0; i < ip_set_max; i++) {
766                         if (ip_set_list[i] != NULL)
767                                 ip_set_destroy_set(i);
768                 }
769         } else {
770                 i = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
771                 if (i == IPSET_INVALID_ID)
772                         return -ENOENT;
773                 else if (atomic_read(&ip_set_list[i]->ref))
774                         return -IPSET_ERR_BUSY;
775
776                 ip_set_destroy_set(i);
777         }
778         return 0;
779 }
780
781 /* Flush sets */
782
783 static void
784 ip_set_flush_set(struct ip_set *set)
785 {
786         pr_debug("set: %s\n",  set->name);
787
788         write_lock_bh(&set->lock);
789         set->variant->flush(set);
790         write_unlock_bh(&set->lock);
791 }
792
793 static int
794 ip_set_flush(struct sock *ctnl, struct sk_buff *skb,
795              const struct nlmsghdr *nlh,
796              const struct nlattr * const attr[])
797 {
798         ip_set_id_t i;
799
800         if (unlikely(protocol_failed(attr)))
801                 return -EPROTO;
802
803         if (!attr[IPSET_ATTR_SETNAME]) {
804                 for (i = 0; i < ip_set_max; i++)
805                         if (ip_set_list[i] != NULL)
806                                 ip_set_flush_set(ip_set_list[i]);
807         } else {
808                 i = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
809                 if (i == IPSET_INVALID_ID)
810                         return -ENOENT;
811
812                 ip_set_flush_set(ip_set_list[i]);
813         }
814
815         return 0;
816 }
817
818 /* Rename a set */
819
820 static const struct nla_policy
821 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
822         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
823         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
824                                     .len = IPSET_MAXNAMELEN - 1 },
825         [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
826                                     .len = IPSET_MAXNAMELEN - 1 },
827 };
828
829 static int
830 ip_set_rename(struct sock *ctnl, struct sk_buff *skb,
831               const struct nlmsghdr *nlh,
832               const struct nlattr * const attr[])
833 {
834         struct ip_set *set;
835         const char *name2;
836         ip_set_id_t i;
837
838         if (unlikely(protocol_failed(attr) ||
839                      attr[IPSET_ATTR_SETNAME] == NULL ||
840                      attr[IPSET_ATTR_SETNAME2] == NULL))
841                 return -IPSET_ERR_PROTOCOL;
842
843         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
844         if (set == NULL)
845                 return -ENOENT;
846         if (atomic_read(&set->ref) != 0)
847                 return -IPSET_ERR_REFERENCED;
848
849         name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
850         for (i = 0; i < ip_set_max; i++) {
851                 if (ip_set_list[i] != NULL &&
852                     STREQ(ip_set_list[i]->name, name2))
853                         return -IPSET_ERR_EXIST_SETNAME2;
854         }
855         strncpy(set->name, name2, IPSET_MAXNAMELEN);
856
857         return 0;
858 }
859
860 /* Swap two sets so that name/index points to the other.
861  * References and set names are also swapped.
862  *
863  * We are protected by the nfnl mutex and references are
864  * manipulated only by holding the mutex. The kernel interfaces
865  * do not hold the mutex but the pointer settings are atomic
866  * so the ip_set_list always contains valid pointers to the sets.
867  */
868
869 static int
870 ip_set_swap(struct sock *ctnl, struct sk_buff *skb,
871             const struct nlmsghdr *nlh,
872             const struct nlattr * const attr[])
873 {
874         struct ip_set *from, *to;
875         ip_set_id_t from_id, to_id;
876         char from_name[IPSET_MAXNAMELEN];
877         u32 from_ref;
878
879         if (unlikely(protocol_failed(attr) ||
880                      attr[IPSET_ATTR_SETNAME] == NULL ||
881                      attr[IPSET_ATTR_SETNAME2] == NULL))
882                 return -IPSET_ERR_PROTOCOL;
883
884         from_id = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
885         if (from_id == IPSET_INVALID_ID)
886                 return -ENOENT;
887
888         to_id = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME2]));
889         if (to_id == IPSET_INVALID_ID)
890                 return -IPSET_ERR_EXIST_SETNAME2;
891
892         from = ip_set_list[from_id];
893         to = ip_set_list[to_id];
894
895         /* Features must not change.
896          * Not an artifical restriction anymore, as we must prevent
897          * possible loops created by swapping in setlist type of sets. */
898         if (!(from->type->features == to->type->features &&
899               from->type->family == to->type->family))
900                 return -IPSET_ERR_TYPE_MISMATCH;
901
902         /* No magic here: ref munging protected by the nfnl_lock */
903         strncpy(from_name, from->name, IPSET_MAXNAMELEN);
904         from_ref = atomic_read(&from->ref);
905
906         strncpy(from->name, to->name, IPSET_MAXNAMELEN);
907         atomic_set(&from->ref, atomic_read(&to->ref));
908         strncpy(to->name, from_name, IPSET_MAXNAMELEN);
909         atomic_set(&to->ref, from_ref);
910
911         ip_set_list[from_id] = to;
912         ip_set_list[to_id] = from;
913
914         return 0;
915 }
916
917 /* List/save set data */
918
919 #define DUMP_INIT       0L
920 #define DUMP_ALL        1L
921 #define DUMP_ONE        2L
922 #define DUMP_LAST       3L
923
924 static int
925 ip_set_dump_done(struct netlink_callback *cb)
926 {
927         if (cb->args[2]) {
928                 pr_debug("release set %s\n", ip_set_list[cb->args[1]]->name);
929                 __ip_set_put((ip_set_id_t) cb->args[1]);
930         }
931         return 0;
932 }
933
934 static inline void
935 dump_attrs(struct nlmsghdr *nlh)
936 {
937         const struct nlattr *attr;
938         int rem;
939
940         pr_debug("dump nlmsg\n");
941         nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
942                 pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
943         }
944 }
945
946 static int
947 dump_init(struct netlink_callback *cb)
948 {
949         struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
950         int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
951         struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
952         struct nlattr *attr = (void *)nlh + min_len;
953         ip_set_id_t index;
954
955         /* Second pass, so parser can't fail */
956         nla_parse(cda, IPSET_ATTR_CMD_MAX,
957                   attr, nlh->nlmsg_len - min_len, ip_set_setname_policy);
958
959         /* cb->args[0] : dump single set/all sets
960          *         [1] : set index
961          *         [..]: type specific
962          */
963
964         if (!cda[IPSET_ATTR_SETNAME]) {
965                 cb->args[0] = DUMP_ALL;
966                 return 0;
967         }
968
969         index = find_set_id(nla_data(cda[IPSET_ATTR_SETNAME]));
970         if (index == IPSET_INVALID_ID)
971                 return -ENOENT;
972
973         cb->args[0] = DUMP_ONE;
974         cb->args[1] = index;
975         return 0;
976 }
977
978 static int
979 ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
980 {
981         ip_set_id_t index = IPSET_INVALID_ID, max;
982         struct ip_set *set = NULL;
983         struct nlmsghdr *nlh = NULL;
984         unsigned int flags = NETLINK_CB(cb->skb).pid ? NLM_F_MULTI : 0;
985         int ret = 0;
986
987         if (cb->args[0] == DUMP_INIT) {
988                 ret = dump_init(cb);
989                 if (ret < 0) {
990                         nlh = nlmsg_hdr(cb->skb);
991                         /* We have to create and send the error message
992                          * manually :-( */
993                         if (nlh->nlmsg_flags & NLM_F_ACK)
994                                 netlink_ack(cb->skb, nlh, ret);
995                         return ret;
996                 }
997         }
998
999         if (cb->args[1] >= ip_set_max)
1000                 goto out;
1001
1002         pr_debug("args[0]: %ld args[1]: %ld\n", cb->args[0], cb->args[1]);
1003         max = cb->args[0] == DUMP_ONE ? cb->args[1] + 1 : ip_set_max;
1004         for (; cb->args[1] < max; cb->args[1]++) {
1005                 index = (ip_set_id_t) cb->args[1];
1006                 set = ip_set_list[index];
1007                 if (set == NULL) {
1008                         if (cb->args[0] == DUMP_ONE) {
1009                                 ret = -ENOENT;
1010                                 goto out;
1011                         }
1012                         continue;
1013                 }
1014                 /* When dumping all sets, we must dump "sorted"
1015                  * so that lists (unions of sets) are dumped last.
1016                  */
1017                 if (cb->args[0] != DUMP_ONE &&
1018                     !((cb->args[0] == DUMP_ALL) ^
1019                       (set->type->features & IPSET_DUMP_LAST)))
1020                         continue;
1021                 pr_debug("List set: %s\n", set->name);
1022                 if (!cb->args[2]) {
1023                         /* Start listing: make sure set won't be destroyed */
1024                         pr_debug("reference set\n");
1025                         __ip_set_get(index);
1026                 }
1027                 nlh = start_msg(skb, NETLINK_CB(cb->skb).pid,
1028                                 cb->nlh->nlmsg_seq, flags,
1029                                 IPSET_CMD_LIST);
1030                 if (!nlh) {
1031                         ret = -EMSGSIZE;
1032                         goto release_refcount;
1033                 }
1034                 NLA_PUT_U8(skb, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1035                 NLA_PUT_STRING(skb, IPSET_ATTR_SETNAME, set->name);
1036                 switch (cb->args[2]) {
1037                 case 0:
1038                         /* Core header data */
1039                         NLA_PUT_STRING(skb, IPSET_ATTR_TYPENAME,
1040                                        set->type->name);
1041                         NLA_PUT_U8(skb, IPSET_ATTR_FAMILY,
1042                                    set->family);
1043                         NLA_PUT_U8(skb, IPSET_ATTR_REVISION,
1044                                    set->type->revision);
1045                         ret = set->variant->head(set, skb);
1046                         if (ret < 0)
1047                                 goto release_refcount;
1048                         /* Fall through and add elements */
1049                 default:
1050                         read_lock_bh(&set->lock);
1051                         ret = set->variant->list(set, skb, cb);
1052                         read_unlock_bh(&set->lock);
1053                         if (!cb->args[2]) {
1054                                 /* Set is done, proceed with next one */
1055                                 if (cb->args[0] == DUMP_ONE)
1056                                         cb->args[1] = IPSET_INVALID_ID;
1057                                 else
1058                                         cb->args[1]++;
1059                         }
1060                         goto release_refcount;
1061                 }
1062         }
1063         goto out;
1064
1065 nla_put_failure:
1066         ret = -EFAULT;
1067 release_refcount:
1068         /* If there was an error or set is done, release set */
1069         if (ret || !cb->args[2]) {
1070                 pr_debug("release set %s\n", ip_set_list[index]->name);
1071                 __ip_set_put(index);
1072         }
1073
1074         /* If we dump all sets, continue with dumping last ones */
1075         if (cb->args[0] == DUMP_ALL && cb->args[1] >= max && !cb->args[2])
1076                 cb->args[0] = DUMP_LAST;
1077
1078 out:
1079         if (nlh) {
1080                 nlmsg_end(skb, nlh);
1081                 pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1082                 dump_attrs(nlh);
1083         }
1084
1085         return ret < 0 ? ret : skb->len;
1086 }
1087
1088 static int
1089 ip_set_dump(struct sock *ctnl, struct sk_buff *skb,
1090             const struct nlmsghdr *nlh,
1091             const struct nlattr * const attr[])
1092 {
1093         if (unlikely(protocol_failed(attr)))
1094                 return -IPSET_ERR_PROTOCOL;
1095
1096         return netlink_dump_start(ctnl, skb, nlh,
1097                                   ip_set_dump_start,
1098                                   ip_set_dump_done);
1099 }
1100
1101 /* Add, del and test */
1102
1103 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1104         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1105         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1106                                     .len = IPSET_MAXNAMELEN - 1 },
1107         [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1108         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1109         [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1110 };
1111
1112 static int
1113 call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1114         struct nlattr *tb[], enum ipset_adt adt,
1115         u32 flags, bool use_lineno)
1116 {
1117         int ret, retried = 0;
1118         u32 lineno = 0;
1119         bool eexist = flags & IPSET_FLAG_EXIST;
1120
1121         do {
1122                 write_lock_bh(&set->lock);
1123                 ret = set->variant->uadt(set, tb, adt, &lineno, flags);
1124                 write_unlock_bh(&set->lock);
1125         } while (ret == -EAGAIN &&
1126                  set->variant->resize &&
1127                  (ret = set->variant->resize(set, retried++)) == 0);
1128
1129         if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1130                 return 0;
1131         if (lineno && use_lineno) {
1132                 /* Error in restore/batch mode: send back lineno */
1133                 struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1134                 struct sk_buff *skb2;
1135                 struct nlmsgerr *errmsg;
1136                 size_t payload = sizeof(*errmsg) + nlmsg_len(nlh);
1137                 int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
1138                 struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
1139                 struct nlattr *cmdattr;
1140                 u32 *errline;
1141
1142                 skb2 = nlmsg_new(payload, GFP_KERNEL);
1143                 if (skb2 == NULL)
1144                         return -ENOMEM;
1145                 rep = __nlmsg_put(skb2, NETLINK_CB(skb).pid,
1146                                   nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1147                 errmsg = nlmsg_data(rep);
1148                 errmsg->error = ret;
1149                 memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1150                 cmdattr = (void *)&errmsg->msg + min_len;
1151
1152                 nla_parse(cda, IPSET_ATTR_CMD_MAX,
1153                           cmdattr, nlh->nlmsg_len - min_len,
1154                           ip_set_adt_policy);
1155
1156                 errline = nla_data(cda[IPSET_ATTR_LINENO]);
1157
1158                 *errline = lineno;
1159
1160                 netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1161                 /* Signal netlink not to send its ACK/errmsg.  */
1162                 return -EINTR;
1163         }
1164
1165         return ret;
1166 }
1167
1168 static int
1169 ip_set_uadd(struct sock *ctnl, struct sk_buff *skb,
1170             const struct nlmsghdr *nlh,
1171             const struct nlattr * const attr[])
1172 {
1173         struct ip_set *set;
1174         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1175         const struct nlattr *nla;
1176         u32 flags = flag_exist(nlh);
1177         bool use_lineno;
1178         int ret = 0;
1179
1180         if (unlikely(protocol_failed(attr) ||
1181                      attr[IPSET_ATTR_SETNAME] == NULL ||
1182                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1183                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1184                      (attr[IPSET_ATTR_DATA] != NULL &&
1185                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1186                      (attr[IPSET_ATTR_ADT] != NULL &&
1187                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1188                        attr[IPSET_ATTR_LINENO] == NULL))))
1189                 return -IPSET_ERR_PROTOCOL;
1190
1191         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1192         if (set == NULL)
1193                 return -ENOENT;
1194
1195         use_lineno = !!attr[IPSET_ATTR_LINENO];
1196         if (attr[IPSET_ATTR_DATA]) {
1197                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1198                                      attr[IPSET_ATTR_DATA],
1199                                      set->type->adt_policy))
1200                         return -IPSET_ERR_PROTOCOL;
1201                 ret = call_ad(ctnl, skb, set, tb, IPSET_ADD, flags,
1202                               use_lineno);
1203         } else {
1204                 int nla_rem;
1205
1206                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1207                         memset(tb, 0, sizeof(tb));
1208                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1209                             !flag_nested(nla) ||
1210                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1211                                              set->type->adt_policy))
1212                                 return -IPSET_ERR_PROTOCOL;
1213                         ret = call_ad(ctnl, skb, set, tb, IPSET_ADD,
1214                                       flags, use_lineno);
1215                         if (ret < 0)
1216                                 return ret;
1217                 }
1218         }
1219         return ret;
1220 }
1221
1222 static int
1223 ip_set_udel(struct sock *ctnl, struct sk_buff *skb,
1224             const struct nlmsghdr *nlh,
1225             const struct nlattr * const attr[])
1226 {
1227         struct ip_set *set;
1228         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1229         const struct nlattr *nla;
1230         u32 flags = flag_exist(nlh);
1231         bool use_lineno;
1232         int ret = 0;
1233
1234         if (unlikely(protocol_failed(attr) ||
1235                      attr[IPSET_ATTR_SETNAME] == NULL ||
1236                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1237                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1238                      (attr[IPSET_ATTR_DATA] != NULL &&
1239                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1240                      (attr[IPSET_ATTR_ADT] != NULL &&
1241                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1242                        attr[IPSET_ATTR_LINENO] == NULL))))
1243                 return -IPSET_ERR_PROTOCOL;
1244
1245         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1246         if (set == NULL)
1247                 return -ENOENT;
1248
1249         use_lineno = !!attr[IPSET_ATTR_LINENO];
1250         if (attr[IPSET_ATTR_DATA]) {
1251                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1252                                      attr[IPSET_ATTR_DATA],
1253                                      set->type->adt_policy))
1254                         return -IPSET_ERR_PROTOCOL;
1255                 ret = call_ad(ctnl, skb, set, tb, IPSET_DEL, flags,
1256                               use_lineno);
1257         } else {
1258                 int nla_rem;
1259
1260                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1261                         memset(tb, 0, sizeof(*tb));
1262                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1263                             !flag_nested(nla) ||
1264                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1265                                              set->type->adt_policy))
1266                                 return -IPSET_ERR_PROTOCOL;
1267                         ret = call_ad(ctnl, skb, set, tb, IPSET_DEL,
1268                                       flags, use_lineno);
1269                         if (ret < 0)
1270                                 return ret;
1271                 }
1272         }
1273         return ret;
1274 }
1275
1276 static int
1277 ip_set_utest(struct sock *ctnl, struct sk_buff *skb,
1278              const struct nlmsghdr *nlh,
1279              const struct nlattr * const attr[])
1280 {
1281         struct ip_set *set;
1282         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1283         int ret = 0;
1284
1285         if (unlikely(protocol_failed(attr) ||
1286                      attr[IPSET_ATTR_SETNAME] == NULL ||
1287                      attr[IPSET_ATTR_DATA] == NULL ||
1288                      !flag_nested(attr[IPSET_ATTR_DATA])))
1289                 return -IPSET_ERR_PROTOCOL;
1290
1291         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1292         if (set == NULL)
1293                 return -ENOENT;
1294
1295         if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA],
1296                              set->type->adt_policy))
1297                 return -IPSET_ERR_PROTOCOL;
1298
1299         read_lock_bh(&set->lock);
1300         ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0);
1301         read_unlock_bh(&set->lock);
1302         /* Userspace can't trigger element to be re-added */
1303         if (ret == -EAGAIN)
1304                 ret = 1;
1305
1306         return ret < 0 ? ret : ret > 0 ? 0 : -IPSET_ERR_EXIST;
1307 }
1308
1309 /* Get headed data of a set */
1310
1311 static int
1312 ip_set_header(struct sock *ctnl, struct sk_buff *skb,
1313               const struct nlmsghdr *nlh,
1314               const struct nlattr * const attr[])
1315 {
1316         const struct ip_set *set;
1317         struct sk_buff *skb2;
1318         struct nlmsghdr *nlh2;
1319         ip_set_id_t index;
1320         int ret = 0;
1321
1322         if (unlikely(protocol_failed(attr) ||
1323                      attr[IPSET_ATTR_SETNAME] == NULL))
1324                 return -IPSET_ERR_PROTOCOL;
1325
1326         index = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
1327         if (index == IPSET_INVALID_ID)
1328                 return -ENOENT;
1329         set = ip_set_list[index];
1330
1331         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1332         if (skb2 == NULL)
1333                 return -ENOMEM;
1334
1335         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1336                          IPSET_CMD_HEADER);
1337         if (!nlh2)
1338                 goto nlmsg_failure;
1339         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1340         NLA_PUT_STRING(skb2, IPSET_ATTR_SETNAME, set->name);
1341         NLA_PUT_STRING(skb2, IPSET_ATTR_TYPENAME, set->type->name);
1342         NLA_PUT_U8(skb2, IPSET_ATTR_FAMILY, set->family);
1343         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION, set->type->revision);
1344         nlmsg_end(skb2, nlh2);
1345
1346         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1347         if (ret < 0)
1348                 return ret;
1349
1350         return 0;
1351
1352 nla_put_failure:
1353         nlmsg_cancel(skb2, nlh2);
1354 nlmsg_failure:
1355         kfree_skb(skb2);
1356         return -EMSGSIZE;
1357 }
1358
1359 /* Get type data */
1360
1361 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1362         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1363         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1364                                     .len = IPSET_MAXNAMELEN - 1 },
1365         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1366 };
1367
1368 static int
1369 ip_set_type(struct sock *ctnl, struct sk_buff *skb,
1370             const struct nlmsghdr *nlh,
1371             const struct nlattr * const attr[])
1372 {
1373         struct sk_buff *skb2;
1374         struct nlmsghdr *nlh2;
1375         u8 family, min, max;
1376         const char *typename;
1377         int ret = 0;
1378
1379         if (unlikely(protocol_failed(attr) ||
1380                      attr[IPSET_ATTR_TYPENAME] == NULL ||
1381                      attr[IPSET_ATTR_FAMILY] == NULL))
1382                 return -IPSET_ERR_PROTOCOL;
1383
1384         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1385         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1386         ret = find_set_type_minmax(typename, family, &min, &max);
1387         if (ret)
1388                 return ret;
1389
1390         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1391         if (skb2 == NULL)
1392                 return -ENOMEM;
1393
1394         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1395                          IPSET_CMD_TYPE);
1396         if (!nlh2)
1397                 goto nlmsg_failure;
1398         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1399         NLA_PUT_STRING(skb2, IPSET_ATTR_TYPENAME, typename);
1400         NLA_PUT_U8(skb2, IPSET_ATTR_FAMILY, family);
1401         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION, max);
1402         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION_MIN, min);
1403         nlmsg_end(skb2, nlh2);
1404
1405         pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1406         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1407         if (ret < 0)
1408                 return ret;
1409
1410         return 0;
1411
1412 nla_put_failure:
1413         nlmsg_cancel(skb2, nlh2);
1414 nlmsg_failure:
1415         kfree_skb(skb2);
1416         return -EMSGSIZE;
1417 }
1418
1419 /* Get protocol version */
1420
1421 static const struct nla_policy
1422 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1423         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1424 };
1425
1426 static int
1427 ip_set_protocol(struct sock *ctnl, struct sk_buff *skb,
1428                 const struct nlmsghdr *nlh,
1429                 const struct nlattr * const attr[])
1430 {
1431         struct sk_buff *skb2;
1432         struct nlmsghdr *nlh2;
1433         int ret = 0;
1434
1435         if (unlikely(attr[IPSET_ATTR_PROTOCOL] == NULL))
1436                 return -IPSET_ERR_PROTOCOL;
1437
1438         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1439         if (skb2 == NULL)
1440                 return -ENOMEM;
1441
1442         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1443                          IPSET_CMD_PROTOCOL);
1444         if (!nlh2)
1445                 goto nlmsg_failure;
1446         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1447         nlmsg_end(skb2, nlh2);
1448
1449         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1450         if (ret < 0)
1451                 return ret;
1452
1453         return 0;
1454
1455 nla_put_failure:
1456         nlmsg_cancel(skb2, nlh2);
1457 nlmsg_failure:
1458         kfree_skb(skb2);
1459         return -EMSGSIZE;
1460 }
1461
1462 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1463         [IPSET_CMD_CREATE]      = {
1464                 .call           = ip_set_create,
1465                 .attr_count     = IPSET_ATTR_CMD_MAX,
1466                 .policy         = ip_set_create_policy,
1467         },
1468         [IPSET_CMD_DESTROY]     = {
1469                 .call           = ip_set_destroy,
1470                 .attr_count     = IPSET_ATTR_CMD_MAX,
1471                 .policy         = ip_set_setname_policy,
1472         },
1473         [IPSET_CMD_FLUSH]       = {
1474                 .call           = ip_set_flush,
1475                 .attr_count     = IPSET_ATTR_CMD_MAX,
1476                 .policy         = ip_set_setname_policy,
1477         },
1478         [IPSET_CMD_RENAME]      = {
1479                 .call           = ip_set_rename,
1480                 .attr_count     = IPSET_ATTR_CMD_MAX,
1481                 .policy         = ip_set_setname2_policy,
1482         },
1483         [IPSET_CMD_SWAP]        = {
1484                 .call           = ip_set_swap,
1485                 .attr_count     = IPSET_ATTR_CMD_MAX,
1486                 .policy         = ip_set_setname2_policy,
1487         },
1488         [IPSET_CMD_LIST]        = {
1489                 .call           = ip_set_dump,
1490                 .attr_count     = IPSET_ATTR_CMD_MAX,
1491                 .policy         = ip_set_setname_policy,
1492         },
1493         [IPSET_CMD_SAVE]        = {
1494                 .call           = ip_set_dump,
1495                 .attr_count     = IPSET_ATTR_CMD_MAX,
1496                 .policy         = ip_set_setname_policy,
1497         },
1498         [IPSET_CMD_ADD] = {
1499                 .call           = ip_set_uadd,
1500                 .attr_count     = IPSET_ATTR_CMD_MAX,
1501                 .policy         = ip_set_adt_policy,
1502         },
1503         [IPSET_CMD_DEL] = {
1504                 .call           = ip_set_udel,
1505                 .attr_count     = IPSET_ATTR_CMD_MAX,
1506                 .policy         = ip_set_adt_policy,
1507         },
1508         [IPSET_CMD_TEST]        = {
1509                 .call           = ip_set_utest,
1510                 .attr_count     = IPSET_ATTR_CMD_MAX,
1511                 .policy         = ip_set_adt_policy,
1512         },
1513         [IPSET_CMD_HEADER]      = {
1514                 .call           = ip_set_header,
1515                 .attr_count     = IPSET_ATTR_CMD_MAX,
1516                 .policy         = ip_set_setname_policy,
1517         },
1518         [IPSET_CMD_TYPE]        = {
1519                 .call           = ip_set_type,
1520                 .attr_count     = IPSET_ATTR_CMD_MAX,
1521                 .policy         = ip_set_type_policy,
1522         },
1523         [IPSET_CMD_PROTOCOL]    = {
1524                 .call           = ip_set_protocol,
1525                 .attr_count     = IPSET_ATTR_CMD_MAX,
1526                 .policy         = ip_set_protocol_policy,
1527         },
1528 };
1529
1530 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
1531         .name           = "ip_set",
1532         .subsys_id      = NFNL_SUBSYS_IPSET,
1533         .cb_count       = IPSET_MSG_MAX,
1534         .cb             = ip_set_netlink_subsys_cb,
1535 };
1536
1537 /* Interface to iptables/ip6tables */
1538
1539 static int
1540 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
1541 {
1542         unsigned *op;
1543         void *data;
1544         int copylen = *len, ret = 0;
1545
1546         if (!capable(CAP_NET_ADMIN))
1547                 return -EPERM;
1548         if (optval != SO_IP_SET)
1549                 return -EBADF;
1550         if (*len < sizeof(unsigned))
1551                 return -EINVAL;
1552
1553         data = vmalloc(*len);
1554         if (!data)
1555                 return -ENOMEM;
1556         if (copy_from_user(data, user, *len) != 0) {
1557                 ret = -EFAULT;
1558                 goto done;
1559         }
1560         op = (unsigned *) data;
1561
1562         if (*op < IP_SET_OP_VERSION) {
1563                 /* Check the version at the beginning of operations */
1564                 struct ip_set_req_version *req_version = data;
1565                 if (req_version->version != IPSET_PROTOCOL) {
1566                         ret = -EPROTO;
1567                         goto done;
1568                 }
1569         }
1570
1571         switch (*op) {
1572         case IP_SET_OP_VERSION: {
1573                 struct ip_set_req_version *req_version = data;
1574
1575                 if (*len != sizeof(struct ip_set_req_version)) {
1576                         ret = -EINVAL;
1577                         goto done;
1578                 }
1579
1580                 req_version->version = IPSET_PROTOCOL;
1581                 ret = copy_to_user(user, req_version,
1582                                    sizeof(struct ip_set_req_version));
1583                 goto done;
1584         }
1585         case IP_SET_OP_GET_BYNAME: {
1586                 struct ip_set_req_get_set *req_get = data;
1587
1588                 if (*len != sizeof(struct ip_set_req_get_set)) {
1589                         ret = -EINVAL;
1590                         goto done;
1591                 }
1592                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
1593                 nfnl_lock();
1594                 req_get->set.index = find_set_id(req_get->set.name);
1595                 nfnl_unlock();
1596                 goto copy;
1597         }
1598         case IP_SET_OP_GET_BYINDEX: {
1599                 struct ip_set_req_get_set *req_get = data;
1600
1601                 if (*len != sizeof(struct ip_set_req_get_set) ||
1602                     req_get->set.index >= ip_set_max) {
1603                         ret = -EINVAL;
1604                         goto done;
1605                 }
1606                 nfnl_lock();
1607                 strncpy(req_get->set.name,
1608                         ip_set_list[req_get->set.index]
1609                                 ? ip_set_list[req_get->set.index]->name : "",
1610                         IPSET_MAXNAMELEN);
1611                 nfnl_unlock();
1612                 goto copy;
1613         }
1614         default:
1615                 ret = -EBADMSG;
1616                 goto done;
1617         }       /* end of switch(op) */
1618
1619 copy:
1620         ret = copy_to_user(user, data, copylen);
1621
1622 done:
1623         vfree(data);
1624         if (ret > 0)
1625                 ret = 0;
1626         return ret;
1627 }
1628
1629 static struct nf_sockopt_ops so_set __read_mostly = {
1630         .pf             = PF_INET,
1631         .get_optmin     = SO_IP_SET,
1632         .get_optmax     = SO_IP_SET + 1,
1633         .get            = &ip_set_sockfn_get,
1634         .owner          = THIS_MODULE,
1635 };
1636
1637 static int __init
1638 ip_set_init(void)
1639 {
1640         int ret;
1641
1642         if (max_sets)
1643                 ip_set_max = max_sets;
1644         if (ip_set_max >= IPSET_INVALID_ID)
1645                 ip_set_max = IPSET_INVALID_ID - 1;
1646
1647         ip_set_list = kzalloc(sizeof(struct ip_set *) * ip_set_max,
1648                               GFP_KERNEL);
1649         if (!ip_set_list) {
1650                 pr_err("ip_set: Unable to create ip_set_list\n");
1651                 return -ENOMEM;
1652         }
1653
1654         ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
1655         if (ret != 0) {
1656                 pr_err("ip_set: cannot register with nfnetlink.\n");
1657                 kfree(ip_set_list);
1658                 return ret;
1659         }
1660         ret = nf_register_sockopt(&so_set);
1661         if (ret != 0) {
1662                 pr_err("SO_SET registry failed: %d\n", ret);
1663                 nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1664                 kfree(ip_set_list);
1665                 return ret;
1666         }
1667
1668         pr_notice("ip_set: protocol %u\n", IPSET_PROTOCOL);
1669         return 0;
1670 }
1671
1672 static void __exit
1673 ip_set_fini(void)
1674 {
1675         /* There can't be any existing set */
1676         nf_unregister_sockopt(&so_set);
1677         nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1678         kfree(ip_set_list);
1679         pr_debug("these are the famous last words\n");
1680 }
1681
1682 module_init(ip_set_init);
1683 module_exit(ip_set_fini);