From 39784ef9a807ff5cda1e00bf73cf0451e2d9da74 Mon Sep 17 00:00:00 2001 From: Miroslav Crnic Date: Fri, 9 Jan 2026 13:13:25 +0000 Subject: [PATCH] kmod: block socket refcount fix --- kmod/block.c | 351 +++++++++++++++++++++++++++------------------------ 1 file changed, 186 insertions(+), 165 deletions(-) diff --git a/kmod/block.c b/kmod/block.c index 40f8336a..1fc56ffc 100644 --- a/kmod/block.c +++ b/kmod/block.c @@ -189,65 +189,6 @@ struct block_socket { void (*saved_write_space)(struct sock *sk); }; - -// This function needs to be called while holding a reference to the socket -// or holding and RCU lock or from under the socket callback lock. -inline void block_socket_hold(struct block_socket* socket) { - int ref_count = atomic_inc_return(&socket->refcount); - BUG_ON(ref_count <= 1); // initial reference is set on init -} - - -inline void block_socket_put(struct block_socket* socket) { - int ref_count = atomic_dec_return(&socket->refcount); - BUG_ON(ref_count < 0); - if (ref_count == 0) { - sock_release(socket->sock); - kfree(socket); - } -} - -// if cleanup was scheduled ref is passed to workqueue otherwise refcount is reduced -// it is not safe to use socket after calling this function -inline void queue_work_or_put(struct block_socket* socket) { - if (!queue_work(ternfs_fast_wq, &socket->write_work)) { - block_socket_put(socket); - } -} - -// Needs to be called with read lock on socket->sock->sock->sk_callback_lock -// or from workqueue context (ternfs_fast_wq) -// Errors socket if not already in error state and schedules cleanup -// if cleanup was scheduled ref is passed to workqueue otherwise refcount is reduced -// it is not safe to use socket after calling this function -inline void error_socket(struct block_socket* socket, int err) { - if (atomic_cmpxchg(&socket->err, 0, err) == 0) { - queue_work_or_put(socket); - } else { - block_socket_put(socket); - } -} - -static void block_ops_exit(struct block_ops* ops) { - ternfs_debug("waiting for all sockets to be done"); - - struct block_socket* sock; - int bucket; - rcu_read_lock(); - hash_for_each_rcu(ops->sockets, bucket, sock, hnode) { - ternfs_debug("scheduling winddown for %d", ntohs(sock->addr.sin_port)); - block_socket_hold(sock); - error_socket(sock, -ECONNABORTED); - } - rcu_read_unlock(); - - // wait for all of them to be freed by work - for (bucket = 0; bucket < BLOCK_SOCKET_BUCKETS; bucket++) { - ternfs_debug("waiting for bucket %d (len %d)", bucket, atomic_read(&ops->len[bucket])); - wait_event(ops->wqs[bucket], atomic_read(&ops->len[bucket]) == 0); - } -} - // used exclusively to track time spent in state for requests enum block_requests_state { BR_ST_QUEUEING = 0, @@ -280,6 +221,15 @@ struct block_request { u32 left_to_read; }; + +// This function needs to be called while holding a reference to the socket +// or holding and RCU lock or from under the socket callback lock. +inline void block_socket_hold(struct block_socket* socket) { + int ref_count = atomic_inc_return(&socket->refcount); + ternfs_debug("socket(%p) %pI4:%d ref_count=%d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), ref_count); + BUG_ON(ref_count <= 1); // initial reference is set on init +} + // We log requests which error out or take more than 10s to complete static void block_socket_log_req_completion(struct block_socket* socket, struct block_request* req) { u64 completed = get_jiffies_64(); @@ -307,8 +257,96 @@ static void block_socket_log_req_completion(struct block_socket* socket, struct jiffies64_to_msecs(queued), jiffies64_to_msecs(writing), jiffies64_to_msecs(waiting), jiffies64_to_msecs(reading)); } + +inline void block_socket_put(struct block_socket* socket) { + ternfs_debug("socket(%p) %pI4:%d ref_count=%d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)-1); + int ref_count = atomic_dec_return(&socket->refcount); + BUG_ON(ref_count < 0); + if (ref_count == 0) { + // Now complete all the remaining requests with the error. + // We are the only one remaining, no need to take locks + int err = atomic_read(&socket->err); + + struct block_request* req; + struct block_request* tmp; + + struct list_head all_reqs; + INIT_LIST_HEAD(&all_reqs); + + list_for_each_entry_safe(req, tmp, &socket->read, read_list) { + if (!list_empty(&req->write_list)) { + list_del(&req->write_list); + } + list_del(&req->read_list); + list_add(&req->write_list, &all_reqs); + } + list_for_each_entry_safe(req, tmp, &socket->write, write_list) { + list_del(&req->write_list); + list_add(&req->write_list, &all_reqs); + } + + list_for_each_entry_safe(req, tmp, &all_reqs, write_list) { + atomic_cmpxchg(&req->err, 0, err); + ternfs_debug("completing request because of a socket winddown"); + list_del(&req->write_list); + block_socket_log_req_completion(socket, req); + queue_work(ternfs_wq, &req->complete_work); + } + if (socket->sock != NULL) { + sock_release(socket->sock); + } + ternfs_debug("free socket(%p) %pI4:%d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port)); + kfree(socket); + } +} + +// if cleanup was scheduled ref is passed to workqueue otherwise refcount is reduced +// it is not safe to use socket after calling this function +inline void queue_work_or_put(struct block_socket* socket) { + ternfs_debug("socket(%p) %pI4:%d, ref_count=%d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); + if (!queue_work(ternfs_fast_wq, &socket->write_work)) { + block_socket_put(socket); + } +} + +// Needs to be called with read lock on socket->sock->sock->sk_callback_lock +// or from workqueue context (ternfs_fast_wq) +// Errors socket if not already in error state and schedules cleanup +// if cleanup was scheduled ref is passed to workqueue otherwise refcount is reduced +// it is not safe to use socket after calling this function +inline void error_socket(struct block_socket* socket, int err) { + ternfs_debug("socket(%p) %pI4:%d, setting error %d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), err); + if (atomic_cmpxchg(&socket->err, 0, err) == 0) { + queue_work_or_put(socket); + } else { + block_socket_put(socket); + } +} + +static void block_ops_exit(struct block_ops* ops) { + ternfs_debug("waiting for all sockets to be done"); + + struct block_socket* sock; + int bucket; + rcu_read_lock(); + hash_for_each_rcu(ops->sockets, bucket, sock, hnode) { + ternfs_debug("scheduling winddown for %d", ntohs(sock->addr.sin_port)); + block_socket_hold(sock); + error_socket(sock, -ECONNABORTED); + } + rcu_read_unlock(); + + // wait for all of them to be freed by work + for (bucket = 0; bucket < BLOCK_SOCKET_BUCKETS; bucket++) { + ternfs_debug("waiting for bucket %d (len %d)", bucket, atomic_read(&ops->len[bucket])); + wait_event(ops->wqs[bucket], atomic_read(&ops->len[bucket]) == 0); + } +} + + static void block_socket_state_check_locked(struct sock* sk) { struct block_socket* socket = sk->sk_user_data; + ternfs_debug("socket(%p) %pI4:%d refcount %d state check triggered: %d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount), sk->sk_state); if (sk->sk_state == TCP_ESTABLISHED) { return; } // the only good one @@ -326,6 +364,7 @@ static void block_socket_sk_state_change(struct sock* sk) { read_lock_bh(&sk->sk_callback_lock); struct block_socket* socket = sk->sk_user_data; if (socket != NULL) { + ternfs_debug("socket(%p) %pI4:%d refcount %d state change triggered: %d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount), sk->sk_state); saved_state_change = socket->saved_state_change; block_socket_state_check_locked(sk); } else { @@ -340,6 +379,7 @@ static void block_socket_connect_sk_state_change(struct sock* sk) { read_lock_bh(&sk->sk_callback_lock); struct block_socket* socket = sk->sk_user_data; if (socket != NULL) { + ternfs_debug("socket(%p) %pI4:%d refcount %d connect state change triggered: %d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount), sk->sk_state); if (sk->sk_state == TCP_ESTABLISHED) { complete_all(&socket->sock_wait); } @@ -353,6 +393,7 @@ static void block_socket_connect_sk_state_change(struct sock* sk) { } inline bool block_socket_check_timeout(struct block_socket* socket, u64 now, u64 timeout_jiffies, bool set_error) { + ternfs_debug("socket(%p) %pI4:%d refcount %d check timeout triggered", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); // if something is locked activity in progress, nothing to do if (!spin_trylock_bh(&socket->list_lock)) { return false; @@ -391,16 +432,29 @@ static bool block_socket_cleanup( struct block_socket* socket, bool check_timeout ) { + ternfs_debug("socket(%p) %pI4:%d refcount %d cleanup triggered", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); if (check_timeout) { block_socket_check_timeout(socket, get_jiffies_64(), *ops->timeout_jiffies, true); } int err = atomic_read(&socket->err); if (!err) { + ternfs_debug("socket(%p) %pI4:%d refcount %d no error, no cleanup needed", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); return false; } + // if already terminal nothing to do + // no lock needed as this is protecte by write_work_active + if (socket->terminal) { + ternfs_debug("socket(%p) %pI4:%d refcount %d already terminal", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); + goto cleaned_up; + } + + socket->terminal = true; + ternfs_debug("socket(%p) %pI4:%d refcount %d winding down", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); + // wait for callbacks to complete write_lock_bh(&socket->sock->sk->sk_callback_lock); + ternfs_debug("socket(%p) %pI4:%d refcount %d callbacks locked", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); // We are killing the socket either due to error or timeout // We already have callback lock, no callback could be running except for @@ -410,66 +464,34 @@ static bool block_socket_cleanup( socket->sock->sk->sk_write_space = socket->saved_write_space; socket->sock->sk->sk_user_data = NULL; write_unlock_bh(&socket->sock->sk->sk_callback_lock); + ternfs_debug("socket(%p) %pI4:%d refcount %d callbacks unlocked", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); + - ternfs_debug("winding down socket to %pI4:%d due to %d", &socket->addr.sin_addr, ntohs(socket->addr.sin_port), err); u64 key = block_socket_key(&socket->addr); int bucket = hash_min(key, BLOCK_SOCKET_BITS); - if (!socket->terminal) { - socket->terminal = true; - // First, remove socket from hashmap. After we're done with this, - // we know nobody's going to add new requests to this. - spin_lock(&ops->locks[bucket]); - hash_del_rcu(&socket->hnode); // tied to atomic_dec below - spin_unlock(&ops->locks[bucket]); - synchronize_rcu(); + // First, remove socket from hashmap. After we're done with this, + // we know nobody's going to add new requests to this. + spin_lock(&ops->locks[bucket]); + hash_del_rcu(&socket->hnode); // tied to atomic_dec below + spin_unlock(&ops->locks[bucket]); + synchronize_rcu(); + ternfs_debug("socket(%p) %pI4:%d refcount %d removed from hashmap", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); - // Adjust len, notify waiters - smp_mb__before_atomic(); - atomic_dec(&ops->len[bucket]); - wake_up_all(&ops->wqs[bucket]); + // Release the reference held by hash table + block_socket_put(socket); - // wake up waiters on socket removal from list - complete_all(&socket->sock_wait); + // Adjust len, notify waiters + smp_mb__before_atomic(); + atomic_dec(&ops->len[bucket]); + wake_up_all(&ops->wqs[bucket]); - // Now complete all the remaining requests with the error. - // We are not the only one remaining, there could be waiters for cleanup completion - // We need to take locks - - struct block_request* req; - struct block_request* tmp; - - struct list_head all_reqs; - INIT_LIST_HEAD(&all_reqs); - - spin_lock_bh(&socket->list_lock); - - list_for_each_entry_safe(req, tmp, &socket->read, read_list) { - if (!list_empty(&req->write_list)) { - list_del(&req->write_list); - } - list_del(&req->read_list); - list_add(&req->write_list, &all_reqs); - } - list_for_each_entry_safe(req, tmp, &socket->write, write_list) { - list_del(&req->write_list); - list_add(&req->write_list, &all_reqs); - } - spin_unlock_bh(&socket->list_lock); - - list_for_each_entry_safe(req, tmp, &all_reqs, write_list) { - atomic_cmpxchg(&req->err, 0, err); - ternfs_debug("completing request because of a socket winddown"); - list_del(&req->write_list); - block_socket_log_req_completion(socket, req); - queue_work(ternfs_wq, &req->complete_work); - } - // we entered cleanup holding reference which is release after the if - // we also need to release the refence held by hash table - block_socket_put(socket); - } + // wake up waiters on socket removal from list + complete_all(&socket->sock_wait); + ternfs_debug("socket(%p) %pI4:%d refcount %d waiters woken up", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); +cleaned_up: atomic_set(&socket->write_work_active, 0); block_socket_put(socket); return true; @@ -479,7 +501,7 @@ static void block_socket_write_work( struct block_ops* ops, struct block_socket* socket ) { - ternfs_debug("%pI4:%d", &socket->addr.sin_addr, ntohs(socket->addr.sin_port)); + ternfs_debug("socket(%p) %pI4:%d refcount %d write work triggered", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)); if (atomic_cmpxchg(&socket->write_work_active, 0, 1) != 0) { // already active but we need to requeue in case it is just about to exit @@ -599,9 +621,22 @@ static struct block_socket* get_block_socket( memcpy(&sock->addr, addr, sizeof(struct sockaddr_in)); atomic_set(&sock->err, 0); + spin_lock_init(&sock->list_lock); + INIT_LIST_HEAD(&sock->write); + INIT_WORK(&sock->write_work, ops->write_work); + atomic_set(&sock->write_work_active, 0); + atomic_set(&sock->refcount, 1); + INIT_LIST_HEAD(&sock->read); + + + sock->terminal = false; err = sock_create_kern(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock->sock); - if (err != 0) { goto out_err_sock; } + if (err != 0) { + sock->sock = NULL; + block_socket_put(sock); + goto out_err; + } init_completion(&sock->sock_wait); @@ -625,17 +660,26 @@ static struct block_socket* get_block_socket( if (err < 0) { if (unlikely(err != -EINPROGRESS)) { ternfs_warn("could not connect to block service at %pI4:%d: %d", &sock->addr.sin_addr, ntohs(sock->addr.sin_port), err); - sock_release(sock->sock); - goto out_err_sock; + write_lock(&sock->sock->sk->sk_callback_lock); + sock->sock->sk->sk_state_change = sock->saved_state_change; + sock->sock->sk->sk_user_data = NULL; + write_unlock(&sock->sock->sk->sk_callback_lock); + block_socket_put(sock); + goto out_err; } else { if (likely(sock->sock->sk->sk_state != TCP_ESTABLISHED)) { ternfs_debug("waiting for connection to %pI4:%d to be established", &sock->addr.sin_addr, ntohs(sock->addr.sin_port)); err = wait_for_completion_timeout(&sock->sock_wait, ternfs_block_service_connect_timeout_jiffies); if (err <= 0) { ternfs_warn("timed out waiting for connection to %pI4:%d to be established", &sock->addr.sin_addr, ntohs(sock->addr.sin_port)); - sock_release(sock->sock); + // cleanup socket + write_lock(&sock->sock->sk->sk_callback_lock); + sock->sock->sk->sk_state_change = sock->saved_state_change; + sock->sock->sk->sk_user_data = NULL; + write_unlock(&sock->sock->sk->sk_callback_lock); + block_socket_put(sock); err = -ETIMEDOUT; - goto out_err_sock; + goto out_err; } // we need to re-init for removal waiters reinit_completion(&sock->sock_wait); @@ -643,15 +687,6 @@ static struct block_socket* get_block_socket( } } - spin_lock_init(&sock->list_lock); - INIT_LIST_HEAD(&sock->write); - INIT_WORK(&sock->write_work, ops->write_work); - atomic_set(&sock->write_work_active, 0); - atomic_set(&sock->refcount, 1); - INIT_LIST_HEAD(&sock->read); - - - sock->terminal = false; sock->last_read_activity = sock->last_write_activity = get_jiffies_64(); // now insert @@ -680,19 +715,23 @@ static struct block_socket* get_block_socket( hlist_for_each_entry_rcu(other_sock, &ops->sockets[bucket], hnode) { // first we check if somebody didn't get to it first if (block_socket_key(&other_sock->addr) == key) { // somebody got here before us + // we need to restore callbacks before releasing since we already set them up + sock->sock->sk->sk_state_change = sock->saved_state_change; + sock->sock->sk->sk_user_data = NULL; rcu_read_unlock(); spin_unlock(&ops->locks[bucket]); write_unlock(&sock->sock->sk->sk_callback_lock); ternfs_debug("multiple callers tried to get socket to %pI4:%d, dropping one", &other_sock->addr.sin_addr, ntohs(other_sock->addr.sin_port)); // call again rather than trying to `sock_release` with the // RCU read lock held, this might not be safe in atomic context. - sock_release(sock->sock); - kfree(sock); + block_socket_put(sock); return get_block_socket(ops, addr); } } rcu_read_unlock(); + block_socket_hold(sock); // we are now sure we'll return it + // Put the new callbacks in sock->saved_data_ready = sock->sock->sk->sk_data_ready; @@ -701,8 +740,6 @@ static struct block_socket* get_block_socket( sock->sock->sk->sk_state_change = block_socket_sk_state_change; sock->sock->sk->sk_write_space = ops->sk_write_space; - block_socket_hold(sock); // we are now sure we'll return it - // Insert the socket into the hash map -- anyone else which // will find it will be good to do. hlist_add_head_rcu(&sock->hnode, &ops->sockets[bucket]); @@ -716,8 +753,6 @@ static struct block_socket* get_block_socket( // We are holding RCU, let's verify that we also have a socket. BUG_ON(IS_ERR(sock)); return sock; -out_err_sock: - kfree(sock); out_err: return ERR_PTR(err); } @@ -731,12 +766,12 @@ static int block_socket_receive_req_locked( ) { struct block_socket* socket = rd_desc->arg.data; - ternfs_debug("%pI4:%d offset=%u len=%lu", &socket->addr.sin_addr, ntohs(socket->addr.sin_port), offset, len); + ternfs_debug("socket(%p) %pI4:%d offset=%u len=%lu", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), offset, len); size_t len0 = len; - - if (atomic_read(&socket->err)) { - block_socket_put(socket); + int err = atomic_read(&socket->err); + if (err) { + ternfs_debug("socket(%p) error %d, dropping data", socket, err); return len0; } @@ -754,6 +789,7 @@ static int block_socket_receive_req_locked( } int consumed = receive_single_req(req, skb, offset, len); if (consumed < 0) { + block_socket_hold(socket); error_socket(socket, consumed); return len0; } @@ -777,6 +813,7 @@ static int block_socket_receive_req_locked( socket->last_read_activity = get_jiffies_64(); bool scheduleCompletion = list_empty(&completed_req->write_list); spin_unlock_bh(&socket->list_lock); + ternfs_debug("socket(%p) completed request, scheduling completion %d", socket, scheduleCompletion); if (scheduleCompletion) { block_socket_log_req_completion(socket, completed_req); queue_work(ternfs_wq, &completed_req->complete_work); @@ -786,7 +823,6 @@ static int block_socket_receive_req_locked( // We have more data but no requests. This should not happen BUG_ON(req == NULL && len > 0); - block_socket_put(socket); return len0 - len; } @@ -799,11 +835,10 @@ static void block_socket_sk_data_ready( read_descriptor_t rd_desc; // Taken from iscsi -- we set count to 1 because we want the network layer to // hand us all the skbs that are available. - struct block_socket* socket = rd_desc.arg.data = sk->sk_user_data; + rd_desc.arg.data = sk->sk_user_data; rd_desc.count = 1; // while this is protected by callback lock the callback might want to error a socket // and enqueue cleanup work which needs a reference. best to take it here - block_socket_hold(socket); tcp_read_sock(sk, &rd_desc, ops->receive_req); block_socket_state_check_locked(sk); } @@ -847,7 +882,6 @@ static int block_socket_start_req( u32 left_to_read ) { int err, i; - bool is_first, socket_error; struct block_socket* sock; atomic_set(&req->err, 0); @@ -868,35 +902,18 @@ retry_get_socket: err = PTR_ERR(sock); goto out_err; } - err = 0; + err = atomic_read(&sock->err); // We have a socket We need to place the request in the queue, // and schedule work. We need to put it in both write and read queue // to avoid a race where we get response before we move it from write // to read queue - is_first = false; - socket_error = false; - spin_lock_bh(&sock->list_lock); - socket_error = atomic_read(&sock->err) != 0; - if (!socket_error) { - is_first = list_first_entry_or_null(&sock->write, struct block_request, write_list) == NULL; - list_add_tail(&req->write_list, &sock->write); - if (is_first) { - // Prevent timeout - sock->last_write_activity = get_jiffies_64(); - } - list_add_tail(&req->read_list, &sock->read); - } - spin_unlock_bh(&sock->list_lock); - - // current socket is in errored state. Wait for removal. Socket will not get destroyed - // as hold a reference to it. - if (socket_error) { + if (err) { + // current socket is in errored state. Wait for removal. Socket will not get destroyed + // as hold a reference to it. err = wait_for_completion_timeout(&sock->sock_wait, MSECS_TO_JIFFIES(100)); // we timed out or got awoken in any case drop the reference block_socket_put(sock); - - if (err <= 0) { ternfs_warn("timed out waiting for socked to bs=%016llx to be cleaned up", bs->id); err = -ETIMEDOUT; @@ -905,12 +922,16 @@ retry_get_socket: goto retry_get_socket; } - // We are first request in write queue, schedule writing - if (is_first) { - queue_work_or_put(sock); - } else { - block_socket_put(sock); + spin_lock_bh(&sock->list_lock); + if (list_first_entry_or_null(&sock->write, struct block_request, write_list) == NULL) { + // Prevent timeout if first + sock->last_write_activity = get_jiffies_64(); } + list_add_tail(&req->write_list, &sock->write); + list_add_tail(&req->read_list, &sock->read); + spin_unlock_bh(&sock->list_lock); + + queue_work_or_put(sock); return 0;