kmod: block socket refcount fix

This commit is contained in:
Miroslav Crnic
2026-01-09 13:13:25 +00:00
committed by GitHub
parent a25d7f1064
commit 39784ef9a8
+186 -165
View File
@@ -189,65 +189,6 @@ struct block_socket {
void (*saved_write_space)(struct sock *sk);
};
// This function needs to be called while holding a reference to the socket
// or holding and RCU lock or from under the socket callback lock.
inline void block_socket_hold(struct block_socket* socket) {
int ref_count = atomic_inc_return(&socket->refcount);
BUG_ON(ref_count <= 1); // initial reference is set on init
}
inline void block_socket_put(struct block_socket* socket) {
int ref_count = atomic_dec_return(&socket->refcount);
BUG_ON(ref_count < 0);
if (ref_count == 0) {
sock_release(socket->sock);
kfree(socket);
}
}
// if cleanup was scheduled ref is passed to workqueue otherwise refcount is reduced
// it is not safe to use socket after calling this function
inline void queue_work_or_put(struct block_socket* socket) {
if (!queue_work(ternfs_fast_wq, &socket->write_work)) {
block_socket_put(socket);
}
}
// Needs to be called with read lock on socket->sock->sock->sk_callback_lock
// or from workqueue context (ternfs_fast_wq)
// Errors socket if not already in error state and schedules cleanup
// if cleanup was scheduled ref is passed to workqueue otherwise refcount is reduced
// it is not safe to use socket after calling this function
inline void error_socket(struct block_socket* socket, int err) {
if (atomic_cmpxchg(&socket->err, 0, err) == 0) {
queue_work_or_put(socket);
} else {
block_socket_put(socket);
}
}
static void block_ops_exit(struct block_ops* ops) {
ternfs_debug("waiting for all sockets to be done");
struct block_socket* sock;
int bucket;
rcu_read_lock();
hash_for_each_rcu(ops->sockets, bucket, sock, hnode) {
ternfs_debug("scheduling winddown for %d", ntohs(sock->addr.sin_port));
block_socket_hold(sock);
error_socket(sock, -ECONNABORTED);
}
rcu_read_unlock();
// wait for all of them to be freed by work
for (bucket = 0; bucket < BLOCK_SOCKET_BUCKETS; bucket++) {
ternfs_debug("waiting for bucket %d (len %d)", bucket, atomic_read(&ops->len[bucket]));
wait_event(ops->wqs[bucket], atomic_read(&ops->len[bucket]) == 0);
}
}
// used exclusively to track time spent in state for requests
enum block_requests_state {
BR_ST_QUEUEING = 0,
@@ -280,6 +221,15 @@ struct block_request {
u32 left_to_read;
};
// This function needs to be called while holding a reference to the socket
// or holding and RCU lock or from under the socket callback lock.
inline void block_socket_hold(struct block_socket* socket) {
int ref_count = atomic_inc_return(&socket->refcount);
ternfs_debug("socket(%p) %pI4:%d ref_count=%d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), ref_count);
BUG_ON(ref_count <= 1); // initial reference is set on init
}
// We log requests which error out or take more than 10s to complete
static void block_socket_log_req_completion(struct block_socket* socket, struct block_request* req) {
u64 completed = get_jiffies_64();
@@ -307,8 +257,96 @@ static void block_socket_log_req_completion(struct block_socket* socket, struct
jiffies64_to_msecs(queued), jiffies64_to_msecs(writing), jiffies64_to_msecs(waiting), jiffies64_to_msecs(reading));
}
inline void block_socket_put(struct block_socket* socket) {
ternfs_debug("socket(%p) %pI4:%d ref_count=%d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount)-1);
int ref_count = atomic_dec_return(&socket->refcount);
BUG_ON(ref_count < 0);
if (ref_count == 0) {
// Now complete all the remaining requests with the error.
// We are the only one remaining, no need to take locks
int err = atomic_read(&socket->err);
struct block_request* req;
struct block_request* tmp;
struct list_head all_reqs;
INIT_LIST_HEAD(&all_reqs);
list_for_each_entry_safe(req, tmp, &socket->read, read_list) {
if (!list_empty(&req->write_list)) {
list_del(&req->write_list);
}
list_del(&req->read_list);
list_add(&req->write_list, &all_reqs);
}
list_for_each_entry_safe(req, tmp, &socket->write, write_list) {
list_del(&req->write_list);
list_add(&req->write_list, &all_reqs);
}
list_for_each_entry_safe(req, tmp, &all_reqs, write_list) {
atomic_cmpxchg(&req->err, 0, err);
ternfs_debug("completing request because of a socket winddown");
list_del(&req->write_list);
block_socket_log_req_completion(socket, req);
queue_work(ternfs_wq, &req->complete_work);
}
if (socket->sock != NULL) {
sock_release(socket->sock);
}
ternfs_debug("free socket(%p) %pI4:%d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port));
kfree(socket);
}
}
// if cleanup was scheduled ref is passed to workqueue otherwise refcount is reduced
// it is not safe to use socket after calling this function
inline void queue_work_or_put(struct block_socket* socket) {
ternfs_debug("socket(%p) %pI4:%d, ref_count=%d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
if (!queue_work(ternfs_fast_wq, &socket->write_work)) {
block_socket_put(socket);
}
}
// Needs to be called with read lock on socket->sock->sock->sk_callback_lock
// or from workqueue context (ternfs_fast_wq)
// Errors socket if not already in error state and schedules cleanup
// if cleanup was scheduled ref is passed to workqueue otherwise refcount is reduced
// it is not safe to use socket after calling this function
inline void error_socket(struct block_socket* socket, int err) {
ternfs_debug("socket(%p) %pI4:%d, setting error %d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), err);
if (atomic_cmpxchg(&socket->err, 0, err) == 0) {
queue_work_or_put(socket);
} else {
block_socket_put(socket);
}
}
static void block_ops_exit(struct block_ops* ops) {
ternfs_debug("waiting for all sockets to be done");
struct block_socket* sock;
int bucket;
rcu_read_lock();
hash_for_each_rcu(ops->sockets, bucket, sock, hnode) {
ternfs_debug("scheduling winddown for %d", ntohs(sock->addr.sin_port));
block_socket_hold(sock);
error_socket(sock, -ECONNABORTED);
}
rcu_read_unlock();
// wait for all of them to be freed by work
for (bucket = 0; bucket < BLOCK_SOCKET_BUCKETS; bucket++) {
ternfs_debug("waiting for bucket %d (len %d)", bucket, atomic_read(&ops->len[bucket]));
wait_event(ops->wqs[bucket], atomic_read(&ops->len[bucket]) == 0);
}
}
static void block_socket_state_check_locked(struct sock* sk) {
struct block_socket* socket = sk->sk_user_data;
ternfs_debug("socket(%p) %pI4:%d refcount %d state check triggered: %d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount), sk->sk_state);
if (sk->sk_state == TCP_ESTABLISHED) { return; } // the only good one
@@ -326,6 +364,7 @@ static void block_socket_sk_state_change(struct sock* sk) {
read_lock_bh(&sk->sk_callback_lock);
struct block_socket* socket = sk->sk_user_data;
if (socket != NULL) {
ternfs_debug("socket(%p) %pI4:%d refcount %d state change triggered: %d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount), sk->sk_state);
saved_state_change = socket->saved_state_change;
block_socket_state_check_locked(sk);
} else {
@@ -340,6 +379,7 @@ static void block_socket_connect_sk_state_change(struct sock* sk) {
read_lock_bh(&sk->sk_callback_lock);
struct block_socket* socket = sk->sk_user_data;
if (socket != NULL) {
ternfs_debug("socket(%p) %pI4:%d refcount %d connect state change triggered: %d", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount), sk->sk_state);
if (sk->sk_state == TCP_ESTABLISHED) {
complete_all(&socket->sock_wait);
}
@@ -353,6 +393,7 @@ static void block_socket_connect_sk_state_change(struct sock* sk) {
}
inline bool block_socket_check_timeout(struct block_socket* socket, u64 now, u64 timeout_jiffies, bool set_error) {
ternfs_debug("socket(%p) %pI4:%d refcount %d check timeout triggered", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
// if something is locked activity in progress, nothing to do
if (!spin_trylock_bh(&socket->list_lock)) {
return false;
@@ -391,16 +432,29 @@ static bool block_socket_cleanup(
struct block_socket* socket,
bool check_timeout
) {
ternfs_debug("socket(%p) %pI4:%d refcount %d cleanup triggered", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
if (check_timeout) {
block_socket_check_timeout(socket, get_jiffies_64(), *ops->timeout_jiffies, true);
}
int err = atomic_read(&socket->err);
if (!err) {
ternfs_debug("socket(%p) %pI4:%d refcount %d no error, no cleanup needed", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
return false;
}
// if already terminal nothing to do
// no lock needed as this is protecte by write_work_active
if (socket->terminal) {
ternfs_debug("socket(%p) %pI4:%d refcount %d already terminal", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
goto cleaned_up;
}
socket->terminal = true;
ternfs_debug("socket(%p) %pI4:%d refcount %d winding down", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
// wait for callbacks to complete
write_lock_bh(&socket->sock->sk->sk_callback_lock);
ternfs_debug("socket(%p) %pI4:%d refcount %d callbacks locked", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
// We are killing the socket either due to error or timeout
// We already have callback lock, no callback could be running except for
@@ -410,66 +464,34 @@ static bool block_socket_cleanup(
socket->sock->sk->sk_write_space = socket->saved_write_space;
socket->sock->sk->sk_user_data = NULL;
write_unlock_bh(&socket->sock->sk->sk_callback_lock);
ternfs_debug("socket(%p) %pI4:%d refcount %d callbacks unlocked", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
ternfs_debug("winding down socket to %pI4:%d due to %d", &socket->addr.sin_addr, ntohs(socket->addr.sin_port), err);
u64 key = block_socket_key(&socket->addr);
int bucket = hash_min(key, BLOCK_SOCKET_BITS);
if (!socket->terminal) {
socket->terminal = true;
// First, remove socket from hashmap. After we're done with this,
// we know nobody's going to add new requests to this.
spin_lock(&ops->locks[bucket]);
hash_del_rcu(&socket->hnode); // tied to atomic_dec below
spin_unlock(&ops->locks[bucket]);
synchronize_rcu();
// First, remove socket from hashmap. After we're done with this,
// we know nobody's going to add new requests to this.
spin_lock(&ops->locks[bucket]);
hash_del_rcu(&socket->hnode); // tied to atomic_dec below
spin_unlock(&ops->locks[bucket]);
synchronize_rcu();
ternfs_debug("socket(%p) %pI4:%d refcount %d removed from hashmap", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
// Adjust len, notify waiters
smp_mb__before_atomic();
atomic_dec(&ops->len[bucket]);
wake_up_all(&ops->wqs[bucket]);
// Release the reference held by hash table
block_socket_put(socket);
// wake up waiters on socket removal from list
complete_all(&socket->sock_wait);
// Adjust len, notify waiters
smp_mb__before_atomic();
atomic_dec(&ops->len[bucket]);
wake_up_all(&ops->wqs[bucket]);
// Now complete all the remaining requests with the error.
// We are not the only one remaining, there could be waiters for cleanup completion
// We need to take locks
struct block_request* req;
struct block_request* tmp;
struct list_head all_reqs;
INIT_LIST_HEAD(&all_reqs);
spin_lock_bh(&socket->list_lock);
list_for_each_entry_safe(req, tmp, &socket->read, read_list) {
if (!list_empty(&req->write_list)) {
list_del(&req->write_list);
}
list_del(&req->read_list);
list_add(&req->write_list, &all_reqs);
}
list_for_each_entry_safe(req, tmp, &socket->write, write_list) {
list_del(&req->write_list);
list_add(&req->write_list, &all_reqs);
}
spin_unlock_bh(&socket->list_lock);
list_for_each_entry_safe(req, tmp, &all_reqs, write_list) {
atomic_cmpxchg(&req->err, 0, err);
ternfs_debug("completing request because of a socket winddown");
list_del(&req->write_list);
block_socket_log_req_completion(socket, req);
queue_work(ternfs_wq, &req->complete_work);
}
// we entered cleanup holding reference which is release after the if
// we also need to release the refence held by hash table
block_socket_put(socket);
}
// wake up waiters on socket removal from list
complete_all(&socket->sock_wait);
ternfs_debug("socket(%p) %pI4:%d refcount %d waiters woken up", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
cleaned_up:
atomic_set(&socket->write_work_active, 0);
block_socket_put(socket);
return true;
@@ -479,7 +501,7 @@ static void block_socket_write_work(
struct block_ops* ops,
struct block_socket* socket
) {
ternfs_debug("%pI4:%d", &socket->addr.sin_addr, ntohs(socket->addr.sin_port));
ternfs_debug("socket(%p) %pI4:%d refcount %d write work triggered", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), atomic_read(&socket->refcount));
if (atomic_cmpxchg(&socket->write_work_active, 0, 1) != 0) {
// already active but we need to requeue in case it is just about to exit
@@ -599,9 +621,22 @@ static struct block_socket* get_block_socket(
memcpy(&sock->addr, addr, sizeof(struct sockaddr_in));
atomic_set(&sock->err, 0);
spin_lock_init(&sock->list_lock);
INIT_LIST_HEAD(&sock->write);
INIT_WORK(&sock->write_work, ops->write_work);
atomic_set(&sock->write_work_active, 0);
atomic_set(&sock->refcount, 1);
INIT_LIST_HEAD(&sock->read);
sock->terminal = false;
err = sock_create_kern(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock->sock);
if (err != 0) { goto out_err_sock; }
if (err != 0) {
sock->sock = NULL;
block_socket_put(sock);
goto out_err;
}
init_completion(&sock->sock_wait);
@@ -625,17 +660,26 @@ static struct block_socket* get_block_socket(
if (err < 0) {
if (unlikely(err != -EINPROGRESS)) {
ternfs_warn("could not connect to block service at %pI4:%d: %d", &sock->addr.sin_addr, ntohs(sock->addr.sin_port), err);
sock_release(sock->sock);
goto out_err_sock;
write_lock(&sock->sock->sk->sk_callback_lock);
sock->sock->sk->sk_state_change = sock->saved_state_change;
sock->sock->sk->sk_user_data = NULL;
write_unlock(&sock->sock->sk->sk_callback_lock);
block_socket_put(sock);
goto out_err;
} else {
if (likely(sock->sock->sk->sk_state != TCP_ESTABLISHED)) {
ternfs_debug("waiting for connection to %pI4:%d to be established", &sock->addr.sin_addr, ntohs(sock->addr.sin_port));
err = wait_for_completion_timeout(&sock->sock_wait, ternfs_block_service_connect_timeout_jiffies);
if (err <= 0) {
ternfs_warn("timed out waiting for connection to %pI4:%d to be established", &sock->addr.sin_addr, ntohs(sock->addr.sin_port));
sock_release(sock->sock);
// cleanup socket
write_lock(&sock->sock->sk->sk_callback_lock);
sock->sock->sk->sk_state_change = sock->saved_state_change;
sock->sock->sk->sk_user_data = NULL;
write_unlock(&sock->sock->sk->sk_callback_lock);
block_socket_put(sock);
err = -ETIMEDOUT;
goto out_err_sock;
goto out_err;
}
// we need to re-init for removal waiters
reinit_completion(&sock->sock_wait);
@@ -643,15 +687,6 @@ static struct block_socket* get_block_socket(
}
}
spin_lock_init(&sock->list_lock);
INIT_LIST_HEAD(&sock->write);
INIT_WORK(&sock->write_work, ops->write_work);
atomic_set(&sock->write_work_active, 0);
atomic_set(&sock->refcount, 1);
INIT_LIST_HEAD(&sock->read);
sock->terminal = false;
sock->last_read_activity = sock->last_write_activity = get_jiffies_64();
// now insert
@@ -680,19 +715,23 @@ static struct block_socket* get_block_socket(
hlist_for_each_entry_rcu(other_sock, &ops->sockets[bucket], hnode) { // first we check if somebody didn't get to it first
if (block_socket_key(&other_sock->addr) == key) {
// somebody got here before us
// we need to restore callbacks before releasing since we already set them up
sock->sock->sk->sk_state_change = sock->saved_state_change;
sock->sock->sk->sk_user_data = NULL;
rcu_read_unlock();
spin_unlock(&ops->locks[bucket]);
write_unlock(&sock->sock->sk->sk_callback_lock);
ternfs_debug("multiple callers tried to get socket to %pI4:%d, dropping one", &other_sock->addr.sin_addr, ntohs(other_sock->addr.sin_port));
// call again rather than trying to `sock_release` with the
// RCU read lock held, this might not be safe in atomic context.
sock_release(sock->sock);
kfree(sock);
block_socket_put(sock);
return get_block_socket(ops, addr);
}
}
rcu_read_unlock();
block_socket_hold(sock); // we are now sure we'll return it
// Put the new callbacks in
sock->saved_data_ready = sock->sock->sk->sk_data_ready;
@@ -701,8 +740,6 @@ static struct block_socket* get_block_socket(
sock->sock->sk->sk_state_change = block_socket_sk_state_change;
sock->sock->sk->sk_write_space = ops->sk_write_space;
block_socket_hold(sock); // we are now sure we'll return it
// Insert the socket into the hash map -- anyone else which
// will find it will be good to do.
hlist_add_head_rcu(&sock->hnode, &ops->sockets[bucket]);
@@ -716,8 +753,6 @@ static struct block_socket* get_block_socket(
// We are holding RCU, let's verify that we also have a socket.
BUG_ON(IS_ERR(sock));
return sock;
out_err_sock:
kfree(sock);
out_err:
return ERR_PTR(err);
}
@@ -731,12 +766,12 @@ static int block_socket_receive_req_locked(
) {
struct block_socket* socket = rd_desc->arg.data;
ternfs_debug("%pI4:%d offset=%u len=%lu", &socket->addr.sin_addr, ntohs(socket->addr.sin_port), offset, len);
ternfs_debug("socket(%p) %pI4:%d offset=%u len=%lu", socket, &socket->addr.sin_addr, ntohs(socket->addr.sin_port), offset, len);
size_t len0 = len;
if (atomic_read(&socket->err)) {
block_socket_put(socket);
int err = atomic_read(&socket->err);
if (err) {
ternfs_debug("socket(%p) error %d, dropping data", socket, err);
return len0;
}
@@ -754,6 +789,7 @@ static int block_socket_receive_req_locked(
}
int consumed = receive_single_req(req, skb, offset, len);
if (consumed < 0) {
block_socket_hold(socket);
error_socket(socket, consumed);
return len0;
}
@@ -777,6 +813,7 @@ static int block_socket_receive_req_locked(
socket->last_read_activity = get_jiffies_64();
bool scheduleCompletion = list_empty(&completed_req->write_list);
spin_unlock_bh(&socket->list_lock);
ternfs_debug("socket(%p) completed request, scheduling completion %d", socket, scheduleCompletion);
if (scheduleCompletion) {
block_socket_log_req_completion(socket, completed_req);
queue_work(ternfs_wq, &completed_req->complete_work);
@@ -786,7 +823,6 @@ static int block_socket_receive_req_locked(
// We have more data but no requests. This should not happen
BUG_ON(req == NULL && len > 0);
block_socket_put(socket);
return len0 - len;
}
@@ -799,11 +835,10 @@ static void block_socket_sk_data_ready(
read_descriptor_t rd_desc;
// Taken from iscsi -- we set count to 1 because we want the network layer to
// hand us all the skbs that are available.
struct block_socket* socket = rd_desc.arg.data = sk->sk_user_data;
rd_desc.arg.data = sk->sk_user_data;
rd_desc.count = 1;
// while this is protected by callback lock the callback might want to error a socket
// and enqueue cleanup work which needs a reference. best to take it here
block_socket_hold(socket);
tcp_read_sock(sk, &rd_desc, ops->receive_req);
block_socket_state_check_locked(sk);
}
@@ -847,7 +882,6 @@ static int block_socket_start_req(
u32 left_to_read
) {
int err, i;
bool is_first, socket_error;
struct block_socket* sock;
atomic_set(&req->err, 0);
@@ -868,35 +902,18 @@ retry_get_socket:
err = PTR_ERR(sock);
goto out_err;
}
err = 0;
err = atomic_read(&sock->err);
// We have a socket We need to place the request in the queue,
// and schedule work. We need to put it in both write and read queue
// to avoid a race where we get response before we move it from write
// to read queue
is_first = false;
socket_error = false;
spin_lock_bh(&sock->list_lock);
socket_error = atomic_read(&sock->err) != 0;
if (!socket_error) {
is_first = list_first_entry_or_null(&sock->write, struct block_request, write_list) == NULL;
list_add_tail(&req->write_list, &sock->write);
if (is_first) {
// Prevent timeout
sock->last_write_activity = get_jiffies_64();
}
list_add_tail(&req->read_list, &sock->read);
}
spin_unlock_bh(&sock->list_lock);
// current socket is in errored state. Wait for removal. Socket will not get destroyed
// as hold a reference to it.
if (socket_error) {
if (err) {
// current socket is in errored state. Wait for removal. Socket will not get destroyed
// as hold a reference to it.
err = wait_for_completion_timeout(&sock->sock_wait, MSECS_TO_JIFFIES(100));
// we timed out or got awoken in any case drop the reference
block_socket_put(sock);
if (err <= 0) {
ternfs_warn("timed out waiting for socked to bs=%016llx to be cleaned up", bs->id);
err = -ETIMEDOUT;
@@ -905,12 +922,16 @@ retry_get_socket:
goto retry_get_socket;
}
// We are first request in write queue, schedule writing
if (is_first) {
queue_work_or_put(sock);
} else {
block_socket_put(sock);
spin_lock_bh(&sock->list_lock);
if (list_first_entry_or_null(&sock->write, struct block_request, write_list) == NULL) {
// Prevent timeout if first
sock->last_write_activity = get_jiffies_64();
}
list_add_tail(&req->write_list, &sock->write);
list_add_tail(&req->read_list, &sock->read);
spin_unlock_bh(&sock->list_lock);
queue_work_or_put(sock);
return 0;