From: Paul Moore <paul(a)paul-moore.com>
 Cong Wang correctly pointed out that the RCU read locking of the
 auditd_connection struct was wrong, this patch correct this by
 adopting a more traditional, and correct RCU locking model.
 This patch is heavily based on an earlier prototype by Cong Wang.
 [XXX: Cong Wang, as mentioned previously, I'd like to add your
  sign-off; please let me know if that is okay with you.]
 Cc: <stable(a)vger.kernel.org> # 4.11.x-: 264d509637d9
 Reported-by: Cong Wang <xiyou.wangcong(a)gmail.com>
 ??!! -> Signed-off-by: Cong Wang <xiyou.wangcong(a)gmail.com>
 Signed-off-by: Paul Moore <paul(a)paul-moore.com>
 ---
  kernel/audit.c |  157 ++++++++++++++++++++++++++++++++++++--------------------
  1 file changed, 100 insertions(+), 57 deletions(-) 
A quick note that I haven't tested this yet, I'm in the process of
building a kernel now, I just wanted to send this out early to in case
anyone noticed anything incredibly stupid.
 diff --git a/kernel/audit.c b/kernel/audit.c
 index 10bc2bad2adf..a7c6a50477aa 100644
 --- a/kernel/audit.c
 +++ b/kernel/audit.c
 @@ -112,18 +112,19 @@ struct audit_net {
   * @pid: auditd PID
   * @portid: netlink portid
   * @net: the associated network namespace
 - * @lock: spinlock to protect write access
 + * @rcu: RCU head
   *
   * Description:
   * This struct is RCU protected; you must either hold the RCU lock for reading
 - * or the included spinlock for writing.
 + * or the associated spinlock for writing.
   */
  static struct auditd_connection {
         struct pid *pid;
         u32 portid;
         struct net *net;
 -       spinlock_t lock;
 -} auditd_conn;
 +       struct rcu_head rcu;
 +} *auditd_conn = NULL;
 +static DEFINE_SPINLOCK(auditd_conn_lock);
  /* If audit_rate_limit is non-zero, limit the rate of sending audit records
   * to that number per second.  This prevents DoS attacks, but results in
 @@ -215,9 +216,11 @@ struct audit_reply {
  int auditd_test_task(struct task_struct *task)
  {
         int rc;
 +       struct auditd_connection *ac;
         rcu_read_lock();
 -       rc = (auditd_conn.pid && auditd_conn.pid == task_tgid(task) ? 1 : 0);
 +       ac = rcu_dereference(auditd_conn);
 +       rc = (ac && ac->pid == task_tgid(task) ? 1 : 0);
         rcu_read_unlock();
         return rc;
 @@ -225,22 +228,21 @@ int auditd_test_task(struct task_struct *task)
  /**
   * auditd_pid_vnr - Return the auditd PID relative to the namespace
 - * @auditd: the auditd connection
   *
   * Description:
 - * Returns the PID in relation to the namespace, 0 on failure.  This function
 - * takes the RCU read lock internally, but if the caller needs to protect the
 - * auditd_connection pointer it should take the RCU read lock as well.
 + * Returns the PID in relation to the namespace, 0 on failure.
   */
 -static pid_t auditd_pid_vnr(const struct auditd_connection *auditd)
 +static pid_t auditd_pid_vnr(void)
  {
         pid_t pid;
 +       const struct auditd_connection *ac;
         rcu_read_lock();
 -       if (!auditd || !auditd->pid)
 +       ac = rcu_dereference(auditd_conn);
 +       if (!ac || !ac->pid)
                 pid = 0;
         else
 -               pid = pid_vnr(auditd->pid);
 +               pid = pid_vnr(ac->pid);
         rcu_read_unlock();
         return pid;
 @@ -434,6 +436,24 @@ static int audit_set_failure(u32 state)
  }
  /**
 + * auditd_conn_free - RCU helper to release an auditd connection struct
 + * @rcu: RCU head
 + *
 + * Description:
 + * Drop any references inside the auditd connection tracking struct and free
 + * the memory.
 + */
 + static void auditd_conn_free(struct rcu_head *rcu)
 + {
 +       struct auditd_connection *ac;
 +
 +       ac = container_of(rcu, struct auditd_connection, rcu);
 +       put_pid(ac->pid);
 +       put_net(ac->net);
 +       kfree(ac);
 + }
 +
 +/**
   * auditd_set - Set/Reset the auditd connection state
   * @pid: auditd PID
   * @portid: auditd netlink portid
 @@ -441,27 +461,33 @@ static int audit_set_failure(u32 state)
   *
   * Description:
   * This function will obtain and drop network namespace references as
 - * necessary.
 + * necessary.  Returns zero on success, negative values on failure.
   */
 -static void auditd_set(struct pid *pid, u32 portid, struct net *net)
 +static int auditd_set(struct pid *pid, u32 portid, struct net *net)
  {
         unsigned long flags;
 +       struct auditd_connection *ac_old, *ac_new;
 -       spin_lock_irqsave(&auditd_conn.lock, flags);
 -       if (auditd_conn.pid)
 -               put_pid(auditd_conn.pid);
 -       if (pid)
 -               auditd_conn.pid = get_pid(pid);
 -       else
 -               auditd_conn.pid = NULL;
 -       auditd_conn.portid = portid;
 -       if (
auditd_conn.net)
 -               put_net(auditd_conn.net);
 -       if (net)
 -               
auditd_conn.net = get_net(net);
 -       else
 -               
auditd_conn.net = NULL;
 -       spin_unlock_irqrestore(&auditd_conn.lock, flags);
 +       if (!pid || !net)
 +               return -EINVAL;
 +
 +       ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL);
 +       if (!ac_new)
 +               return -ENOMEM;
 +       ac_new->pid = get_pid(pid);
 +       ac_new->portid = portid;
 +       ac_new->net = get_net(net);
 +
 +       spin_lock_irqsave(&auditd_conn_lock, flags);
 +       ac_old = rcu_dereference_protected(auditd_conn,
 +                                          lockdep_is_held(&auditd_conn_lock));
 +       rcu_assign_pointer(auditd_conn, ac_new);
 +       spin_unlock_irqrestore(&auditd_conn_lock, flags);
 +
 +       if (ac_old)
 +               call_rcu(&ac_old->rcu, auditd_conn_free);
 +
 +       return 0;
  }
  /**
 @@ -556,13 +582,19 @@ static void kauditd_retry_skb(struct sk_buff *skb)
   */
  static void auditd_reset(void)
  {
 +       unsigned long flags;
         struct sk_buff *skb;
 +       struct auditd_connection *ac_old;
         /* if it isn't already broken, break the connection */
 -       rcu_read_lock();
 -       if (auditd_conn.pid)
 -               auditd_set(0, 0, NULL);
 -       rcu_read_unlock();
 +       spin_lock_irqsave(&auditd_conn_lock, flags);
 +       ac_old = rcu_dereference_protected(auditd_conn,
 +                                          lockdep_is_held(&auditd_conn_lock));
 +       rcu_assign_pointer(auditd_conn, NULL);
 +       spin_unlock_irqrestore(&auditd_conn_lock, flags);
 +
 +       if (ac_old)
 +               call_rcu(&ac_old->rcu, auditd_conn_free);
         /* flush all of the main and retry queues to the hold queue */
         while ((skb = skb_dequeue(&audit_retry_queue)))
 @@ -588,6 +620,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
         u32 portid;
         struct net *net;
         struct sock *sk;
 +       struct auditd_connection *ac;
         /* NOTE: we can't call netlink_unicast while in the RCU section so
          *       take a reference to the network namespace and grab local
 @@ -597,15 +630,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
          *       section netlink_unicast() should safely return an error */
         rcu_read_lock();
 -       if (!auditd_conn.pid) {
 +       ac = rcu_dereference(auditd_conn);
 +       if (!ac) {
                 rcu_read_unlock();
                 rc = -ECONNREFUSED;
                 goto err;
         }
 -       net = 
auditd_conn.net;
 -       get_net(net);
 +       net = get_net(ac->net);
         sk = audit_get_sk(net);
 -       portid = auditd_conn.portid;
 +       portid = ac->portid;
         rcu_read_unlock();
         rc = netlink_unicast(sk, skb, portid, 0);
 @@ -740,6 +773,7 @@ static int kauditd_thread(void *dummy)
         u32 portid = 0;
         struct net *net = NULL;
         struct sock *sk = NULL;
 +       struct auditd_connection *ac;
  #define UNICAST_RETRIES 5
 @@ -747,14 +781,14 @@ static int kauditd_thread(void *dummy)
         while (!kthread_should_stop()) {
                 /* NOTE: see the lock comments in auditd_send_unicast_skb() */
                 rcu_read_lock();
 -               if (!auditd_conn.pid) {
 +               ac = rcu_dereference(auditd_conn);
 +               if (!ac) {
                         rcu_read_unlock();
                         goto main_queue;
                 }
 -               net = 
auditd_conn.net;
 -               get_net(net);
 +               net = get_net(ac->net);
                 sk = audit_get_sk(net);
 -               portid = auditd_conn.portid;
 +               portid = ac->portid;
                 rcu_read_unlock();
                 /* attempt to flush the hold queue */
 @@ -1117,7 +1151,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr
*nlh)
                 s.failure               = audit_failure;
                 /* NOTE: use pid_vnr() so the PID is relative to the current
                  *       namespace */
 -               s.pid                   = auditd_pid_vnr(&auditd_conn);
 +               s.pid                   = auditd_pid_vnr();
                 s.rate_limit            = audit_rate_limit;
                 s.backlog_limit         = audit_backlog_limit;
                 s.lost                  = atomic_read(&audit_lost);
 @@ -1160,7 +1194,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr
*nlh)
                         /* test the auditd connection */
                         audit_replace(req_pid);
 -                       auditd_pid = auditd_pid_vnr(&auditd_conn);
 +                       auditd_pid = auditd_pid_vnr();
                         /* only the current auditd can unregister itself */
                         if ((!new_pid) && (new_pid != auditd_pid)) {
                                 audit_log_config_change("audit_pid", new_pid,
 @@ -1174,19 +1208,30 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr
*nlh)
                                 return -EEXIST;
                         }
 -                       if (audit_enabled != AUDIT_OFF)
 -                               audit_log_config_change("audit_pid", new_pid,
 -                                                       auditd_pid, 1);
 -
                         if (new_pid) {
                                 /* register a new auditd connection */
 -                               auditd_set(req_pid, NETLINK_CB(skb).portid,
 -                                          sock_net(NETLINK_CB(skb).sk));
 +                               err = auditd_set(req_pid,
 +                                                NETLINK_CB(skb).portid,
 +                                                sock_net(NETLINK_CB(skb).sk));
 +                               if (audit_enabled != AUDIT_OFF)
 +                                       audit_log_config_change("audit_pid",
 +                                                               new_pid,
 +                                                               auditd_pid,
 +                                                               err ? 0 : 1);
 +                               if (err)
 +                                       return err;
 +
                                 /* try to process any backlog */
                                 wake_up_interruptible(&kauditd_wait);
 -                       } else
 +                       } else {
 +                               if (audit_enabled != AUDIT_OFF)
 +                                       audit_log_config_change("audit_pid",
 +                                                               new_pid,
 +                                                               auditd_pid, 1);
 +
                                 /* unregister the auditd connection */
                                 auditd_reset();
 +                       }
                 }
                 if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
                         err = audit_set_rate_limit(s.rate_limit);
 @@ -1454,10 +1499,11 @@ static void __net_exit audit_net_exit(struct net *net)
  {
         struct audit_net *aunet = net_generic(net, audit_net_id);
 -       rcu_read_lock();
 -       if (net == 
auditd_conn.net)
 -               auditd_reset();
 -       rcu_read_unlock();
 +       /* NOTE: you would think that we would want to check the auditd
 +        * connection and potentially reset it here if it lives in this
 +        * namespace, but since the auditd connection tracking struct holds a
 +        * reference to this namespace (see auditd_set()) we are only ever
 +        * going to get here after that connection has been released */
         netlink_kernel_release(aunet->sk);
  }
 @@ -1481,9 +1527,6 @@ static int __init audit_init(void)
                                                sizeof(struct audit_buffer),
                                                0, SLAB_PANIC, NULL);
 -       memset(&auditd_conn, 0, sizeof(auditd_conn));
 -       spin_lock_init(&auditd_conn.lock);
 -
         skb_queue_head_init(&audit_queue);
         skb_queue_head_init(&audit_retry_queue);
         skb_queue_head_init(&audit_hold_queue);
 --
 Linux-audit mailing list
 Linux-audit(a)redhat.com
 
https://www.redhat.com/mailman/listinfo/linux-audit