cfg80211/mac80211: allow per-station GTKs

This adds API to allow adding per-station GTKs,
updates mac80211 to support it, and also allows
drivers to remove a key from hwaccel again when
this may be necessary due to multiple GTKs.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
Signed-off-by: John W. Linville <linville@tuxdriver.com>
diff --git a/net/mac80211/key.c b/net/mac80211/key.c
index 6a63d1a..ccd676b 100644
--- a/net/mac80211/key.c
+++ b/net/mac80211/key.c
@@ -68,15 +68,21 @@
 
 	might_sleep();
 
-	if (!key->local->ops->set_key) {
-		ret = -EOPNOTSUPP;
+	if (!key->local->ops->set_key)
 		goto out_unsupported;
-	}
 
 	assert_key_lock(key->local);
 
 	sta = get_sta_for_key(key);
 
+	/*
+	 * If this is a per-STA GTK, check if it
+	 * is supported; if not, return.
+	 */
+	if (sta && !(key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE) &&
+	    !(key->local->hw.flags & IEEE80211_HW_SUPPORTS_PER_STA_GTK))
+		goto out_unsupported;
+
 	sdata = key->sdata;
 	if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN)
 		sdata = container_of(sdata->bss,
@@ -85,31 +91,28 @@
 
 	ret = drv_set_key(key->local, SET_KEY, sdata, sta, &key->conf);
 
-	if (!ret)
+	if (!ret) {
 		key->flags |= KEY_FLAG_UPLOADED_TO_HARDWARE;
+		return 0;
+	}
 
-	if (ret && ret != -ENOSPC && ret != -EOPNOTSUPP)
+	if (ret != -ENOSPC && ret != -EOPNOTSUPP)
 		wiphy_err(key->local->hw.wiphy,
 			  "failed to set key (%d, %pM) to hardware (%d)\n",
 			  key->conf.keyidx, sta ? sta->addr : bcast_addr, ret);
 
-out_unsupported:
-	if (ret) {
-		switch (key->conf.cipher) {
-		case WLAN_CIPHER_SUITE_WEP40:
-		case WLAN_CIPHER_SUITE_WEP104:
-		case WLAN_CIPHER_SUITE_TKIP:
-		case WLAN_CIPHER_SUITE_CCMP:
-		case WLAN_CIPHER_SUITE_AES_CMAC:
-			/* all of these we can do in software */
-			ret = 0;
-			break;
-		default:
-			ret = -EINVAL;
-		}
+ out_unsupported:
+	switch (key->conf.cipher) {
+	case WLAN_CIPHER_SUITE_WEP40:
+	case WLAN_CIPHER_SUITE_WEP104:
+	case WLAN_CIPHER_SUITE_TKIP:
+	case WLAN_CIPHER_SUITE_CCMP:
+	case WLAN_CIPHER_SUITE_AES_CMAC:
+		/* all of these we can do in software */
+		return 0;
+	default:
+		return -EINVAL;
 	}
-
-	return ret;
 }
 
 static void ieee80211_key_disable_hw_accel(struct ieee80211_key *key)
@@ -147,6 +150,26 @@
 	key->flags &= ~KEY_FLAG_UPLOADED_TO_HARDWARE;
 }
 
+void ieee80211_key_removed(struct ieee80211_key_conf *key_conf)
+{
+	struct ieee80211_key *key;
+
+	key = container_of(key_conf, struct ieee80211_key, conf);
+
+	might_sleep();
+	assert_key_lock(key->local);
+
+	key->flags &= ~KEY_FLAG_UPLOADED_TO_HARDWARE;
+
+	/*
+	 * Flush TX path to avoid attempts to use this key
+	 * after this function returns. Until then, drivers
+	 * must be prepared to handle the key.
+	 */
+	synchronize_rcu();
+}
+EXPORT_SYMBOL_GPL(ieee80211_key_removed);
+
 static void __ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata,
 					int idx)
 {
@@ -202,6 +225,7 @@
 
 static void __ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
 				    struct sta_info *sta,
+				    bool pairwise,
 				    struct ieee80211_key *old,
 				    struct ieee80211_key *new)
 {
@@ -210,8 +234,14 @@
 	if (new)
 		list_add(&new->list, &sdata->key_list);
 
-	if (sta) {
-		rcu_assign_pointer(sta->key, new);
+	if (sta && pairwise) {
+		rcu_assign_pointer(sta->ptk, new);
+	} else if (sta) {
+		if (old)
+			idx = old->conf.keyidx;
+		else
+			idx = new->conf.keyidx;
+		rcu_assign_pointer(sta->gtk[idx], new);
 	} else {
 		WARN_ON(new && old && new->conf.keyidx != old->conf.keyidx);
 
@@ -355,6 +385,7 @@
 {
 	struct ieee80211_key *old_key;
 	int idx, ret;
+	bool pairwise = key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE;
 
 	BUG_ON(!sdata);
 	BUG_ON(!key);
@@ -371,13 +402,6 @@
 		 */
 		if (test_sta_flags(sta, WLAN_STA_WME))
 			key->conf.flags |= IEEE80211_KEY_FLAG_WMM_STA;
-
-		/*
-		 * This key is for a specific sta interface,
-		 * inform the driver that it should try to store
-		 * this key as pairwise key.
-		 */
-		key->conf.flags |= IEEE80211_KEY_FLAG_PAIRWISE;
 	} else {
 		if (sdata->vif.type == NL80211_IFTYPE_STATION) {
 			struct sta_info *ap;
@@ -399,12 +423,14 @@
 
 	mutex_lock(&sdata->local->key_mtx);
 
-	if (sta)
-		old_key = sta->key;
+	if (sta && pairwise)
+		old_key = sta->ptk;
+	else if (sta)
+		old_key = sta->gtk[idx];
 	else
 		old_key = sdata->keys[idx];
 
-	__ieee80211_key_replace(sdata, sta, old_key, key);
+	__ieee80211_key_replace(sdata, sta, pairwise, old_key, key);
 	__ieee80211_key_destroy(old_key);
 
 	ieee80211_debugfs_key_add(key);
@@ -423,7 +449,8 @@
 	 */
 	if (key->sdata)
 		__ieee80211_key_replace(key->sdata, key->sta,
-					key, NULL);
+				key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE,
+				key, NULL);
 	__ieee80211_key_destroy(key);
 }