diff mbox series

[01/15] handshake: add ref counting to handshake_state

Message ID 20241122151551.286355-2-prestwoj@gmail.com (mailing list archive)
State New
Headers show
Series PMKSA support (SAE only) | expand

Checks

Context Check Description
tedd_an/pre-ci_am success Success
prestwoj/iwd-alpine-ci-fetch success Fetch PR
prestwoj/iwd-alpine-ci-setupell success Prep - Setup ELL
prestwoj/iwd-ci-fetch success Fetch PR
prestwoj/iwd-alpine-ci-makedistcheck success Make Distcheck
prestwoj/iwd-ci-gitlint success GitLint
prestwoj/iwd-ci-setupell success Prep - Setup ELL
prestwoj/iwd-alpine-ci-build success Build - Configure
prestwoj/iwd-ci-makedistcheck success Make Distcheck
prestwoj/iwd-ci-build success Build - Configure
prestwoj/iwd-alpine-ci-incremental_build fail Make FAIL (patch 0): unit/test-wsc.c: In function 'wsc_test_pbc_handshake': unit/test-wsc.c:2096:9: error: implicit declaration of function 'handshake_state_free'; did you mean 'handshake_state_ref'? [-Werror=implicit-function-declaration] 2096 | handshake_state_free(hs); | ^~~~~~~~~~~~~~~~~~~~ | handshake_state_ref unit/test-sae.c: In function 'test_destruct': unit/test-sae.c:186:9: error: implicit declaration of function 'handshake_state_free'; did you mean 'handshake_state_ref'? [-Werror=implicit-function-declaration] 186 | handshake_state_free(td->handshake); | ^~~~~~~~~~~~~~~~~~~~ | handshake_state_ref cc1: all warnings being treated as errors make[1]: *** [Makefile:2586: unit/test-sae.o] Error 1 make[1]: *** Waiting for unfinished jobs.... cc1: all warnings being treated as errors make[1]: *** [Makefile:2586: unit/test-wsc.o] Error 1 make: *** [Makefile:1746: all] Error 2
prestwoj/iwd-alpine-ci-makecheckvalgrind success Make Check w/Valgrind
prestwoj/iwd-alpine-ci-makecheck success Make Check
prestwoj/iwd-ci-incremental_build fail Make FAIL (patch 0): unit/test-wsc.c: In function ‘wsc_test_pbc_handshake’: unit/test-wsc.c:2096:9: error: implicit declaration of function ‘handshake_state_free’; did you mean ‘handshake_state_ref’? [-Werror=implicit-function-declaration] 2096 | handshake_state_free(hs); | ^~~~~~~~~~~~~~~~~~~~ | handshake_state_ref unit/test-sae.c: In function ‘test_destruct’: unit/test-sae.c:186:9: error: implicit declaration of function ‘handshake_state_free’; did you mean ‘handshake_state_ref’? [-Werror=implicit-function-declaration] 186 | handshake_state_free(td->handshake); | ^~~~~~~~~~~~~~~~~~~~ | handshake_state_ref cc1: all warnings being treated as errors make[1]: *** [Makefile:2585: unit/test-sae.o] Error 1 make[1]: *** Waiting for unfinished jobs.... cc1: all warnings being treated as errors make[1]: *** [Makefile:2585: unit/test-wsc.o] Error 1 make: *** [Makefile:1745: all] Error 2
prestwoj/iwd-ci-makecheckvalgrind success Make Check w/Valgrind
prestwoj/iwd-ci-clang success clang PASS
prestwoj/iwd-ci-makecheck success Make Check
prestwoj/iwd-ci-testrunner success test-runner PASS

Commit Message

James Prestwood Nov. 22, 2024, 3:15 p.m. UTC
This adds a ref count to the handshake state object (as well as
ref/unref APIs). Currently IWD is careful to ensure that netdev
holds the root reference to the handshake state. Other modules do
track it themselves, but ensure that it doesn't get referenced
after netdev frees it.

Future work related to PMKSA will require that station holds a
references to the handshake state, specifically for retry logic,
after netdev is done with it so we need a way to delay the free
until station is also done.
---
 src/adhoc.c     |  4 ++--
 src/ap.c        |  2 +-
 src/handshake.c | 12 +++++++++++-
 src/handshake.h |  9 ++++++---
 src/netdev.c    |  5 +++--
 src/p2p.c       |  2 +-
 src/station.c   |  8 ++++----
 src/wsc.c       |  2 +-
 8 files changed, 29 insertions(+), 15 deletions(-)
diff mbox series

Patch

diff --git a/src/adhoc.c b/src/adhoc.c
index e787dab1..930240ae 100644
--- a/src/adhoc.c
+++ b/src/adhoc.c
@@ -94,13 +94,13 @@  static void adhoc_sta_free(void *data)
 		eapol_sm_free(sta->sm);
 
 	if (sta->hs_sta)
-		handshake_state_free(sta->hs_sta);
+		handshake_state_unref(sta->hs_sta);
 
 	if (sta->sm_a)
 		eapol_sm_free(sta->sm_a);
 
 	if (sta->hs_auth)
-		handshake_state_free(sta->hs_auth);
+		handshake_state_unref(sta->hs_auth);
 
 end:
 	l_free(sta);
diff --git a/src/ap.c b/src/ap.c
index 562e00c8..d52b7e55 100644
--- a/src/ap.c
+++ b/src/ap.c
@@ -230,7 +230,7 @@  static void ap_stop_handshake(struct sta_state *sta)
 	}
 
 	if (sta->hs) {
-		handshake_state_free(sta->hs);
+		handshake_state_unref(sta->hs);
 		sta->hs = NULL;
 	}
 
diff --git a/src/handshake.c b/src/handshake.c
index fc1978df..7fb75dc4 100644
--- a/src/handshake.c
+++ b/src/handshake.c
@@ -103,7 +103,14 @@  void __handshake_set_install_ext_tk_func(handshake_install_ext_tk_func_t func)
 	install_ext_tk = func;
 }
 
-void handshake_state_free(struct handshake_state *s)
+struct handshake_state *handshake_state_ref(struct handshake_state *s)
+{
+	__sync_fetch_and_add(&s->refcount, 1);
+
+	return s;
+}
+
+void handshake_state_unref(struct handshake_state *s)
 {
 	__typeof__(s->free) destroy;
 
@@ -117,6 +124,9 @@  void handshake_state_free(struct handshake_state *s)
 		return;
 	}
 
+	if (__sync_sub_and_fetch(&s->refcount, 1))
+		return;
+
 	l_free(s->authenticator_ie);
 	l_free(s->supplicant_ie);
 	l_free(s->authenticator_rsnxe);
diff --git a/src/handshake.h b/src/handshake.h
index d1116472..6c0946d4 100644
--- a/src/handshake.h
+++ b/src/handshake.h
@@ -170,6 +170,8 @@  struct handshake_state {
 	bool in_event;
 
 	handshake_event_func_t event_func;
+
+	int refcount;
 };
 
 #define HSID(x) UNIQUE_ID(handshake_, x)
@@ -186,7 +188,7 @@  struct handshake_state {
 					##__VA_ARGS__);			\
 									\
 			if (!HSID(hs)->in_event) {			\
-				handshake_state_free(HSID(hs));		\
+				handshake_state_unref(HSID(hs));	\
 				HSID(freed) = true;			\
 			} else						\
 				HSID(hs)->in_event = false;		\
@@ -194,7 +196,8 @@  struct handshake_state {
 		HSID(freed);						\
 	})
 
-void handshake_state_free(struct handshake_state *s);
+struct handshake_state *handshake_state_ref(struct handshake_state *s);
+void handshake_state_unref(struct handshake_state *s);
 
 void handshake_state_set_supplicant_address(struct handshake_state *s,
 						const uint8_t *spa);
@@ -316,4 +319,4 @@  void handshake_util_build_gtk_kde(enum crypto_cipher cipher, const uint8_t *key,
 void handshake_util_build_igtk_kde(enum crypto_cipher cipher, const uint8_t *key,
 					unsigned int key_index, uint8_t *to);
 
-DEFINE_CLEANUP_FUNC(handshake_state_free);
+DEFINE_CLEANUP_FUNC(handshake_state_unref);
diff --git a/src/netdev.c b/src/netdev.c
index e86ef1bd..4dccb78a 100644
--- a/src/netdev.c
+++ b/src/netdev.c
@@ -376,6 +376,7 @@  struct handshake_state *netdev_handshake_state_new(struct netdev *netdev)
 
 	nhs->super.ifindex = netdev->index;
 	nhs->super.free = netdev_handshake_state_free;
+	nhs->super.refcount = 1;
 
 	nhs->netdev = netdev;
 	/*
@@ -828,7 +829,7 @@  static void netdev_connect_free(struct netdev *netdev)
 	eapol_preauth_cancel(netdev->index);
 
 	if (netdev->handshake) {
-		handshake_state_free(netdev->handshake);
+		handshake_state_unref(netdev->handshake);
 		netdev->handshake = NULL;
 	}
 
@@ -4239,7 +4240,7 @@  int netdev_reassociate(struct netdev *netdev, const struct scan_bss *target_bss,
 		eapol_sm_free(old_sm);
 
 	if (old_hs)
-		handshake_state_free(old_hs);
+		handshake_state_unref(old_hs);
 
 	return 0;
 }
diff --git a/src/p2p.c b/src/p2p.c
index 676ef146..7d89da21 100644
--- a/src/p2p.c
+++ b/src/p2p.c
@@ -1497,7 +1497,7 @@  static void p2p_handshake_event(struct handshake_state *hs,
 static void p2p_try_connect_group(struct p2p_device *dev)
 {
 	struct scan_bss *bss = dev->conn_wsc_bss;
-	_auto_(handshake_state_free) struct handshake_state *hs = NULL;
+	_auto_(handshake_state_unref) struct handshake_state *hs = NULL;
 	struct iovec ie_iov[16];
 	int ie_num = 0;
 	int r;
diff --git a/src/station.c b/src/station.c
index 1238734f..c1c7ba9d 100644
--- a/src/station.c
+++ b/src/station.c
@@ -1394,7 +1394,7 @@  static struct handshake_state *station_handshake_setup(struct station *station,
 	return hs;
 
 not_supported:
-	handshake_state_free(hs);
+	handshake_state_unref(hs);
 	return NULL;
 }
 
@@ -2484,7 +2484,7 @@  static void station_preauthenticate_cb(struct netdev *netdev,
 	}
 
 	if (station_transition_reassociate(station, bss, new_hs) < 0) {
-		handshake_state_free(new_hs);
+		handshake_state_unref(new_hs);
 		station_roam_failed(station);
 	}
 }
@@ -2687,7 +2687,7 @@  static bool station_try_next_transition(struct station *station,
 	}
 
 	if (station_transition_reassociate(station, bss, new_hs) < 0) {
-		handshake_state_free(new_hs);
+		handshake_state_unref(new_hs);
 		return false;
 	}
 
@@ -3734,7 +3734,7 @@  int __station_connect_network(struct station *station, struct network *network,
 				station_netdev_event,
 				station_connect_cb, station);
 	if (r < 0) {
-		handshake_state_free(hs);
+		handshake_state_unref(hs);
 		return r;
 	}
 
diff --git a/src/wsc.c b/src/wsc.c
index f88f5deb..44b8d3de 100644
--- a/src/wsc.c
+++ b/src/wsc.c
@@ -393,7 +393,7 @@  static int wsc_enrollee_connect(struct wsc_enrollee *wsce, struct scan_bss *bss,
 		return 0;
 
 error:
-	handshake_state_free(hs);
+	handshake_state_unref(hs);
 	return r;
 }