Skip to content

Commit

Permalink
use new socket API
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Sep 21, 2023
1 parent b246552 commit b2d941e
Showing 1 changed file with 74 additions and 41 deletions.
115 changes: 74 additions & 41 deletions onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <cstdint>
#include <memory>
#include <string>

#include "nccl_kernels.h"
#include "mpi_include.h"
Expand Down Expand Up @@ -45,36 +49,88 @@ static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) {
namespace IPC {
#define FLLOG LOGS_DEFAULT(VERBOSE)
#define FLLOGERRNO LOGS_DEFAULT(WARNING) << "error:" << strerror(errno)
#define FLLOGGAI LOGS_DEFAULT(WARNING) << "error:" << gai_strerror(ret)

int WriteOnRank0(ncclUniqueId* nccl_id, int word_size) {
int fd = socket(AF_INET, /* network versus AF_LOCAL */
SOCK_STREAM, /* reliable, bidirectional, arbitrary payload size */
0); /* system picks underlying protocol (TCP) */
if (fd < 0) {
FLLOGERRNO << (" create socket\n"); /* terminate */
return -1;
typedef std::shared_ptr<struct addrinfo> AddrInfoPtr;

int CreateSocket(bool is_server) {
int sockfd = -1;

struct addrinfo hints;
struct addrinfo *result = nullptr;
AddrInfoPtr result_ptr(result, [](struct addrinfo *p) { if(p){freeaddrinfo(p);} });

memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */
hints.ai_socktype = SOCK_STREAM; /* TCP socket. use SOCK_DGRAM for UDP */
hints.ai_flags = AI_PASSIVE; /* For wildcard IP address */
hints.ai_protocol = 0; /* Any protocol */

std::string rank0_ip = ParseEnvironmentVariableWithDefault<std::string>("RANK0_IP", "localhost");
std::string port_number = ParseEnvironmentVariableWithDefault<std::string>("RANK0_PORT", "18888");

int ret = getaddrinfo(is_server ? nullptr : rank0_ip.c_str(), port_number.c_str(), &hints, &result);
if (ret != 0) {
FLLOGGAI << " getaddrinfo failed\n";
return sockfd;
}

int32_t port_number = ParseEnvironmentVariableWithDefault<int32_t>("RANK0_PORT", 18888);
for (struct addrinfo* rp = result; rp != nullptr; rp = rp->ai_next) {
sockfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sockfd == -1) {
continue;
}

/* bind the server's local address in memory */
struct sockaddr_in saddr;
memset(&saddr, 0, sizeof(saddr)); /* clear the bytes */
saddr.sin_family = AF_INET; /* versus AF_LOCAL */
saddr.sin_addr.s_addr = htonl(INADDR_ANY); // htonl(INADDR_ANY); /* host-to-network endian */
saddr.sin_port = htons(port_number); /* for listening */
int on = 1;
int rc = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char*)&on, sizeof(on));
if (rc < 0) {
FLLOGERRNO << ("setsockopt() failed\n");
close(sockfd);
continue;
}

if (bind(fd, (struct sockaddr*)&saddr, sizeof(saddr)) < 0) {
FLLOGERRNO << ("bind\n"); /* terminate */
if (is_server){
if (bind(sockfd, rp->ai_addr, rp->ai_addrlen) == 0) {
FLLOG << "Listening on port " << port_number << " for the other GPU processores...\n";
} else {
FLLOGERRNO << ("bind failed\n");
close(sockfd);
sockfd = -1;
}
} else {
time_t start_time = time(0);
int conn_ret = connect(sockfd, rp->ai_addr, rp->ai_addrlen);
while (time(0) - start_time < 40 && conn_ret < 0) {
FLLOGERRNO << (" waiting the RANK 0 ready...\n"); /* terminate */
sleep(1);
conn_ret = connect(sockfd, rp->ai_addr, rp->ai_addrlen);
}
if (conn_ret < 0) {
close(sockfd);
sockfd = -1;
FLLOGERRNO << ("connect failed with timeout\n"); /* terminate */
} else {
FLLOG << "connect to " << rank0_ip << ":" << port_number << "success \n";
}
}
break;
}
return sockfd;
}

int WriteOnRank0(ncclUniqueId* nccl_id, int word_size) {
int fd = CreateSocket(true);
if (fd < 0) {
FLLOGERRNO << (" create socket\n"); /* terminate */
return -1;
}

/* listen to the socket */
if (listen(fd, word_size) < 0) {
FLLOGERRNO << ("listen\n"); /* terminate */
return -1;
}

FLLOG << "Listening on port " << port_number << " for the other GPU processores...\n";
word_size--; // rank 0 is not in word_size
while (word_size-- > 0) {
int client_fd = accept(fd, nullptr, nullptr); /* accept blocks */
Expand All @@ -95,35 +151,12 @@ int WriteOnRank0(ncclUniqueId* nccl_id, int word_size) {


int ReadFromRank0(ncclUniqueId* nccl_id) {
int sockfd = socket(AF_INET, /* versus AF_LOCAL */
SOCK_STREAM, /* reliable, bidirectional */
0); /* system picks protocol (TCP) */
int sockfd = CreateSocket(false);
if (sockfd < 0) {
FLLOGERRNO << ("socket");
return -1;
}

/* connect to the server: configure server's address 1st */
std::string rank0_ip = ParseEnvironmentVariableWithDefault<std::string>("RANK0_IP", "127.0.0.1");
int32_t port_number = ParseEnvironmentVariableWithDefault<int32_t>("RANK0_PORT", 18888);

struct sockaddr_in saddr;
memset(&saddr, 0, sizeof(saddr));
saddr.sin_family = AF_INET;
saddr.sin_addr.s_addr = inet_addr(rank0_ip.c_str());
saddr.sin_port = htons(port_number); /* port number in big-endian */
time_t start_time = time(0);
int conn_ret = connect(sockfd, (struct sockaddr*)&saddr, sizeof(saddr));
while (time(0) - start_time < 40 && conn_ret < 0) {
FLLOGERRNO << (" waiting the RANK 0 ready..."); /* terminate */
sleep(1);
conn_ret = connect(sockfd, (struct sockaddr*)&saddr, sizeof(saddr));
}
if (conn_ret < 0) {
FLLOGERRNO << ("connect"); /* terminate */
return -1;
}

if (read(sockfd, (nccl_id), sizeof(ncclUniqueId)) != sizeof(ncclUniqueId)) {
FLLOGERRNO << ("read"); /* terminate */
return -1;
Expand Down

0 comments on commit b2d941e

Please sign in to comment.