@@ -435,6 +435,11 @@ static unsigned long max_index(void *entry)
return (XA_CHUNK_SIZE << xa_to_node(entry)->shift) - 1;
}
+static inline void *xa_zero_to_null(void *entry)
+{
+ return xa_is_zero(entry) ? NULL : entry;
+}
+
static void xas_shrink(struct xa_state *xas)
{
struct xarray *xa = xas->xa;
@@ -451,8 +456,8 @@ static void xas_shrink(struct xa_state *xas)
break;
if (!xa_is_node(entry) && node->shift)
break;
- if (xa_is_zero(entry) && xa_zero_busy(xa))
- entry = NULL;
+ if (xa_zero_busy(xa))
+ entry = xa_zero_to_null(entry);
xas->xa_node = XAS_BOUNDS;
RCU_INIT_POINTER(xa->xa_head, entry);
@@ -1474,9 +1479,7 @@ void *xa_load(struct xarray *xa, unsigned long index)
rcu_read_lock();
do {
- entry = xas_load(&xas);
- if (xa_is_zero(entry))
- entry = NULL;
+ entry = xa_zero_to_null(xas_load(&xas));
} while (xas_retry(&xas, entry));
rcu_read_unlock();
@@ -1486,11 +1489,9 @@ EXPORT_SYMBOL(xa_load);
static void *xas_result(struct xa_state *xas, void *curr)
{
- if (xa_is_zero(curr))
- return NULL;
if (xas_error(xas))
curr = xas->xa_node;
- return curr;
+ return xa_zero_to_null(curr);
}
/**
Reduce code duplication by extracting a static inline function that returns its argument if it is non-zero and NULL otherwise. This changes xas_result to check for errors before checking for zero but this cannot change the behavior of existing callers: - __xa_erase: passes the result of xas_store(_, NULL) which cannot fail. - __xa_store: passes the result of xas_store(_, entry) which may fail. xas_store calls xas_create when entry is not NULL which returns NULL on error, which is immediately checked. This should not change observable behavior. - __xa_cmpxchg: passes the result of xas_load(_) which might be zero. This would previously return NULL regardless of the outcome of xas_store but xas_store cannot fail if xas_load returns zero because there is no need to allocate memory. - xa_store_range: same as __xa_erase. Signed-off-by: Tamir Duberstein <tamird@gmail.com> --- lib/xarray.c | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-)