mac80211: move csa counters from sdata to beacon/presp

Having csa counters part of beacon and probe_resp
structures makes it easier to get rid of possible
races between setting a beacon and updating
counters on SMP systems by guaranteeing counters
are always consistent against given beacon struct.

While at it relax WARN_ON into WARN_ON_ONCE to
prevent spamming logs and racing.

Signed-off-by: Michal Kazior <michal.kazior@tieto.com>
[remove pointless array check]
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
diff --git a/net/mac80211/tx.c b/net/mac80211/tx.c
index 3c80bf2..ed56f00 100644
--- a/net/mac80211/tx.c
+++ b/net/mac80211/tx.c
@@ -2426,7 +2426,7 @@
 	u8 *beacon_data;
 	size_t beacon_data_len;
 	int i;
-	u8 count = sdata->csa_current_counter;
+	u8 count = beacon->csa_current_counter;
 
 	switch (sdata->vif.type) {
 	case NL80211_IFTYPE_AP:
@@ -2445,46 +2445,53 @@
 		return;
 	}
 
+	rcu_read_lock();
 	for (i = 0; i < IEEE80211_MAX_CSA_COUNTERS_NUM; ++i) {
-		u16 counter_offset_beacon =
-			sdata->csa_counter_offset_beacon[i];
-		u16 counter_offset_presp = sdata->csa_counter_offset_presp[i];
+		resp = rcu_dereference(sdata->u.ap.probe_resp);
 
-		if (counter_offset_beacon) {
-			if (WARN_ON(counter_offset_beacon >= beacon_data_len))
-				return;
-
-			beacon_data[counter_offset_beacon] = count;
-		}
-
-		if (sdata->vif.type == NL80211_IFTYPE_AP &&
-		    counter_offset_presp) {
-			rcu_read_lock();
-			resp = rcu_dereference(sdata->u.ap.probe_resp);
-
-			/* If nl80211 accepted the offset, this should
-			 * not happen.
-			 */
-			if (WARN_ON(!resp)) {
+		if (beacon->csa_counter_offsets[i]) {
+			if (WARN_ON_ONCE(beacon->csa_counter_offsets[i] >=
+					 beacon_data_len)) {
 				rcu_read_unlock();
 				return;
 			}
-			resp->data[counter_offset_presp] = count;
-			rcu_read_unlock();
+
+			beacon_data[beacon->csa_counter_offsets[i]] = count;
 		}
+
+		if (sdata->vif.type == NL80211_IFTYPE_AP && resp)
+			resp->data[resp->csa_counter_offsets[i]] = count;
 	}
+	rcu_read_unlock();
 }
 
 u8 ieee80211_csa_update_counter(struct ieee80211_vif *vif)
 {
 	struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif);
+	struct beacon_data *beacon = NULL;
+	u8 count = 0;
 
-	sdata->csa_current_counter--;
+	rcu_read_lock();
+
+	if (sdata->vif.type == NL80211_IFTYPE_AP)
+		beacon = rcu_dereference(sdata->u.ap.beacon);
+	else if (sdata->vif.type == NL80211_IFTYPE_ADHOC)
+		beacon = rcu_dereference(sdata->u.ibss.presp);
+	else if (ieee80211_vif_is_mesh(&sdata->vif))
+		beacon = rcu_dereference(sdata->u.mesh.beacon);
+
+	if (!beacon)
+		goto unlock;
+
+	beacon->csa_current_counter--;
 
 	/* the counter should never reach 0 */
-	WARN_ON(!sdata->csa_current_counter);
+	WARN_ON_ONCE(!beacon->csa_current_counter);
+	count = beacon->csa_current_counter;
 
-	return sdata->csa_current_counter;
+unlock:
+	rcu_read_unlock();
+	return count;
 }
 EXPORT_SYMBOL(ieee80211_csa_update_counter);
 
@@ -2494,7 +2501,6 @@
 	struct beacon_data *beacon = NULL;
 	u8 *beacon_data;
 	size_t beacon_data_len;
-	int counter_beacon = sdata->csa_counter_offset_beacon[0];
 	int ret = false;
 
 	if (!ieee80211_sdata_running(sdata))
@@ -2532,10 +2538,10 @@
 		goto out;
 	}
 
-	if (WARN_ON(counter_beacon > beacon_data_len))
+	if (WARN_ON_ONCE(beacon->csa_counter_offsets[0] > beacon_data_len))
 		goto out;
 
-	if (beacon_data[counter_beacon] == 1)
+	if (beacon_data[beacon->csa_counter_offsets[0]] == 1)
 		ret = true;
  out:
 	rcu_read_unlock();
@@ -2551,6 +2557,7 @@
 		       bool is_template)
 {
 	struct ieee80211_local *local = hw_to_local(hw);
+	struct beacon_data *beacon = NULL;
 	struct sk_buff *skb = NULL;
 	struct ieee80211_tx_info *info;
 	struct ieee80211_sub_if_data *sdata = NULL;
@@ -2572,8 +2579,8 @@
 
 	if (sdata->vif.type == NL80211_IFTYPE_AP) {
 		struct ieee80211_if_ap *ap = &sdata->u.ap;
-		struct beacon_data *beacon = rcu_dereference(ap->beacon);
 
+		beacon = rcu_dereference(ap->beacon);
 		if (beacon) {
 			if (sdata->vif.csa_active) {
 				if (!is_template)
@@ -2616,34 +2623,34 @@
 	} else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) {
 		struct ieee80211_if_ibss *ifibss = &sdata->u.ibss;
 		struct ieee80211_hdr *hdr;
-		struct beacon_data *presp = rcu_dereference(ifibss->presp);
 
-		if (!presp)
+		beacon = rcu_dereference(ifibss->presp);
+		if (!beacon)
 			goto out;
 
 		if (sdata->vif.csa_active) {
 			if (!is_template)
 				ieee80211_csa_update_counter(vif);
 
-			ieee80211_set_csa(sdata, presp);
+			ieee80211_set_csa(sdata, beacon);
 		}
 
-		skb = dev_alloc_skb(local->tx_headroom + presp->head_len +
+		skb = dev_alloc_skb(local->tx_headroom + beacon->head_len +
 				    local->hw.extra_beacon_tailroom);
 		if (!skb)
 			goto out;
 		skb_reserve(skb, local->tx_headroom);
-		memcpy(skb_put(skb, presp->head_len), presp->head,
-		       presp->head_len);
+		memcpy(skb_put(skb, beacon->head_len), beacon->head,
+		       beacon->head_len);
 
 		hdr = (struct ieee80211_hdr *) skb->data;
 		hdr->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT |
 						 IEEE80211_STYPE_BEACON);
 	} else if (ieee80211_vif_is_mesh(&sdata->vif)) {
 		struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh;
-		struct beacon_data *bcn = rcu_dereference(ifmsh->beacon);
 
-		if (!bcn)
+		beacon = rcu_dereference(ifmsh->beacon);
+		if (!beacon)
 			goto out;
 
 		if (sdata->vif.csa_active) {
@@ -2655,40 +2662,42 @@
 				 */
 				ieee80211_csa_update_counter(vif);
 
-			ieee80211_set_csa(sdata, bcn);
+			ieee80211_set_csa(sdata, beacon);
 		}
 
 		if (ifmsh->sync_ops)
-			ifmsh->sync_ops->adjust_tbtt(sdata, bcn);
+			ifmsh->sync_ops->adjust_tbtt(sdata, beacon);
 
 		skb = dev_alloc_skb(local->tx_headroom +
-				    bcn->head_len +
+				    beacon->head_len +
 				    256 + /* TIM IE */
-				    bcn->tail_len +
+				    beacon->tail_len +
 				    local->hw.extra_beacon_tailroom);
 		if (!skb)
 			goto out;
 		skb_reserve(skb, local->tx_headroom);
-		memcpy(skb_put(skb, bcn->head_len), bcn->head, bcn->head_len);
+		memcpy(skb_put(skb, beacon->head_len), beacon->head,
+		       beacon->head_len);
 		ieee80211_beacon_add_tim(sdata, &ifmsh->ps, skb, is_template);
 
 		if (offs) {
-			offs->tim_offset = bcn->head_len;
-			offs->tim_length = skb->len - bcn->head_len;
+			offs->tim_offset = beacon->head_len;
+			offs->tim_length = skb->len - beacon->head_len;
 		}
 
-		memcpy(skb_put(skb, bcn->tail_len), bcn->tail, bcn->tail_len);
+		memcpy(skb_put(skb, beacon->tail_len), beacon->tail,
+		       beacon->tail_len);
 	} else {
 		WARN_ON(1);
 		goto out;
 	}
 
 	/* CSA offsets */
-	if (offs) {
+	if (offs && beacon) {
 		int i;
 
 		for (i = 0; i < IEEE80211_MAX_CSA_COUNTERS_NUM; i++) {
-			u16 csa_off = sdata->csa_counter_offset_beacon[i];
+			u16 csa_off = beacon->csa_counter_offsets[i];
 
 			if (!csa_off)
 				continue;