diff --git a/include/ofi_util.h b/include/ofi_util.h index bc590bb4d1a..b09161747ad 100644 --- a/include/ofi_util.h +++ b/include/ofi_util.h @@ -923,6 +923,7 @@ struct util_av_attr { */ struct util_peer_addr { struct rxm_av *av; + uint64_t av_flags; fi_addr_t fi_addr; struct ofi_rbnode *node; int index; @@ -930,7 +931,8 @@ struct util_peer_addr { union ofi_sock_ip addr; }; -struct util_peer_addr *util_get_peer(struct rxm_av *av, const void *addr); +int util_get_peer(struct rxm_av *av, const void *addr, struct util_peer_addr **peer, + uint64_t flags); void util_put_peer(struct util_peer_addr *peer); /* All peer addresses, whether they've been inserted into the AV diff --git a/include/rdma/fabric.h b/include/rdma/fabric.h index 42c50532797..4d2ad18e495 100644 --- a/include/rdma/fabric.h +++ b/include/rdma/fabric.h @@ -159,6 +159,7 @@ typedef struct fid *fid_t; #define FI_PEER_TRANSFER (1ULL << 36) /* #define FI_MR_DMABUF (1ULL << 40) */ #define FI_AV_USER_ID (1ULL << 41) +#define FI_FIREWALL_ADDR (1ULL << 42) #define FI_PEER (1ULL << 43) /* #define FI_XPU_TRIGGER (1ULL << 44) */ diff --git a/man/fi_av.3.md b/man/fi_av.3.md index 7aeba1802ea..7c92e36c8fe 100644 --- a/man/fi_av.3.md +++ b/man/fi_av.3.md @@ -436,6 +436,16 @@ fi_av_set_user_id. See the user ID section below. +- *FI_FIREWALL_ADDR* +: This flag indicates that the address is behind a firewall and outgoing + connections are not allowed. If there is not an existing connection and the + provider is unable to circumvent the firewall, an FI_EHOSTUNREACH error + should be expected. If multiple addresses are being inserted simultaneously, + the flag applies to all of them. Additionally, it is possible that a + connection is available at insertion time, but is later torn down. Future + reconnects triggered by operations on the ep (fi_send, for example) may also + fail with the same error. + ## fi_av_insertsvc The fi_av_insertsvc call behaves similar to fi_av_insert, but allows the diff --git a/prov/rxm/src/rxm_conn.c b/prov/rxm/src/rxm_conn.c index 73b26f2a9f3..8f289207822 100644 --- a/prov/rxm/src/rxm_conn.c +++ b/prov/rxm/src/rxm_conn.c @@ -460,6 +460,9 @@ ssize_t rxm_get_conn(struct rxm_ep *ep, fi_addr_t addr, struct rxm_conn **conn) return 0; } + if ((*peer)->av_flags & FI_FIREWALL_ADDR) + return -FI_EHOSTUNREACH; + ret = rxm_connect(*conn); /* If the progress function encounters an error trying to establish @@ -657,9 +660,9 @@ rxm_process_connreq(struct rxm_ep *ep, struct rxm_eq_cm_entry *cm_entry) ofi_addr_set_port(&peer_addr.sa, cm_entry->data.connect.port); av = container_of(ep->util_ep.av, struct rxm_av, util_av); - peer = util_get_peer(av, &peer_addr); - if (!peer) { - RXM_WARN_ERR(FI_LOG_EP_CTRL, "util_get_peer", -FI_ENOMEM); + ret = util_get_peer(av, &peer_addr, &peer, 0); + if (ret) { + RXM_WARN_ERR(FI_LOG_EP_CTRL, "util_get_peer", ret); goto reject; } diff --git a/prov/tcp/src/xnet_rdm_cm.c b/prov/tcp/src/xnet_rdm_cm.c index 4dfc966505d..24c7ee2c0bc 100644 --- a/prov/tcp/src/xnet_rdm_cm.c +++ b/prov/tcp/src/xnet_rdm_cm.c @@ -358,6 +358,8 @@ ssize_t xnet_get_conn(struct xnet_rdm *rdm, fi_addr_t addr, return -FI_ENOMEM; if (!(*conn)->ep) { + if ((*peer)->av_flags & FI_FIREWALL_ADDR) + return -FI_EHOSTUNREACH; ret = xnet_rdm_connect(*conn); if (ret) return ret; @@ -438,9 +440,9 @@ static void xnet_process_connreq(struct fi_eq_cm_entry *cm_entry) ofi_addr_set_port(&peer_addr.sa, ntohs(msg->port)); av = container_of(rdm->util_ep.av, struct rxm_av, util_av); - peer = util_get_peer(av, &peer_addr); - if (!peer) { - XNET_WARN_ERR(FI_LOG_EP_CTRL, "util_get_peer", -FI_ENOMEM); + ret = util_get_peer(av, &peer_addr, &peer, 0); + if (ret) { + XNET_WARN_ERR(FI_LOG_EP_CTRL, "util_get_peer", ret); goto reject; } diff --git a/prov/util/src/rxm_av.c b/prov/util/src/rxm_av.c index beb11d0620c..7114efc6211 100644 --- a/prov/util/src/rxm_av.c +++ b/prov/util/src/rxm_av.c @@ -98,23 +98,32 @@ static void rxm_free_peer(struct util_peer_addr *peer) ofi_ibuf_free(peer); } -struct util_peer_addr * -util_get_peer(struct rxm_av *av, const void *addr) + +int +util_get_peer(struct rxm_av *av, const void *addr, struct util_peer_addr **peer, + uint64_t flags) { - struct util_peer_addr *peer; struct ofi_rbnode *node; + int ret = FI_SUCCESS; ofi_mutex_lock(&av->util_av.lock); node = ofi_rbmap_find(&av->addr_map, (void *) addr); if (node) { - peer = node->data; - peer->refcnt++; + *peer = node->data; + (*peer)->refcnt++; + (*peer)->av_flags |= flags; + } else if (flags & FI_FIREWALL_ADDR) { + *peer = NULL; + ret = -FI_EHOSTUNREACH; } else { - peer = rxm_alloc_peer(av, addr); + *peer = rxm_alloc_peer(av, addr); + if (!*peer) + ret = -FI_ENOMEM; + else + (*peer)->av_flags = flags; } - ofi_mutex_unlock(&av->util_av.lock); - return peer; + return ret; } static void util_deref_peer(struct util_peer_addr *peer) @@ -165,17 +174,17 @@ rxm_put_peer_addr(struct rxm_av *av, fi_addr_t fi_addr) static int rxm_av_add_peers(struct rxm_av *av, const void *addr, size_t count, - fi_addr_t *fi_addr, fi_addr_t *user_ids) + fi_addr_t *fi_addr, fi_addr_t *user_ids, uint64_t flags) { struct util_peer_addr *peer; const void *cur_addr; fi_addr_t cur_fi_addr; - size_t i; + size_t i, ret; for (i = 0; i < count; i++) { cur_addr = ((char *) addr + i * av->util_av.addrlen); - peer = util_get_peer(av, cur_addr); - if (!peer) + ret = util_get_peer(av, cur_addr, &peer, flags); + if (ret) goto err; if (user_ids) { @@ -206,7 +215,7 @@ rxm_av_add_peers(struct rxm_av *av, const void *addr, size_t count, ofi_mutex_unlock(&av->util_av.lock); } } - return -FI_ENOMEM; + return ret; } static int rxm_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, @@ -299,7 +308,7 @@ static int rxm_av_insert(struct fid_av *av_fid, const void *addr, size_t count, count = ret; - ret = rxm_av_add_peers(av, addr, count, fi_addr, user_ids); + ret = rxm_av_add_peers(av, addr, count, fi_addr, user_ids, flags); if (ret) { rxm_av_remove(av_fid, fi_addr, count, flags); goto out; @@ -345,7 +354,7 @@ static int rxm_av_insertsym(struct fid_av *av_fid, const char *node, if (ret > 0 && ret < count) count = ret; - ret = rxm_av_add_peers(av, addr, count, fi_addr, NULL); + ret = rxm_av_add_peers(av, addr, count, fi_addr, NULL, flags); if (ret) { rxm_av_remove(av_fid, fi_addr, count, flags); return ret;