Commit 4eb9caaa authored by Jonas Termansen's avatar Jonas Termansen

Fix non-blocking accept4(2) and getting the Unix socket peer address.

Rename the internal kernel method from accept to accept4.

fixup! Fix non-blocking accept4(2) and getting the unix socket peer address.
parent 8f3e11b1
......@@ -850,11 +850,15 @@ int Descriptor::poll(ioctx_t* ctx, PollNode* node)
return vnode->poll(ctx, node);
}
Ref<Descriptor> Descriptor::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags)
Ref<Descriptor> Descriptor::accept4(ioctx_t* ctx, uint8_t* addr,
size_t* addrlen, int flags)
{
Ref<Vnode> retvnode = vnode->accept(ctx, addr, addrlen, flags);
int old_ctx_dflags = ctx->dflags;
ctx->dflags = ContextFlags(old_ctx_dflags, dflags);
Ref<Vnode> retvnode = vnode->accept4(ctx, addr, addrlen, flags);
if ( !retvnode )
return Ref<Descriptor>();
ctx->dflags = old_ctx_dflags;
return Ref<Descriptor>(new Descriptor(retvnode, O_READ | O_WRITE));
}
......
......@@ -160,7 +160,7 @@ public:
void Disconnect();
void Unmount();
Channel* Connect(ioctx_t* ctx);
Channel* Accept();
Channel* Accept(ioctx_t* ctx);
Ref<Inode> BootstrapNode(ino_t ino, mode_t type);
Ref<Inode> OpenNode(ino_t ino, mode_t type);
......@@ -181,8 +181,8 @@ class ServerNode : public AbstractInode
public:
ServerNode(Ref<Server> server);
virtual ~ServerNode();
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags);
virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags);
private:
Ref<Server> server;
......@@ -242,8 +242,8 @@ public:
virtual int poll(ioctx_t* ctx, PollNode* node);
virtual int rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname,
const char* newname);
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags);
virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags);
virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
virtual int listen(ioctx_t* ctx, int backlog);
......@@ -594,13 +594,17 @@ Channel* Server::Connect(ioctx_t* ctx)
return channel;
}
Channel* Server::Accept()
Channel* Server::Accept(ioctx_t* ctx)
{
ScopedLock lock(&connect_lock);
listener_system_tid = CurrentThread()->system_tid;
while ( !connecting && !unmounted )
{
if ( ctx->dflags & O_NONBLOCK )
return errno = EWOULDBLOCK, (Channel*) NULL;
if ( !kthread_cond_wait_signal(&connecting_cond, &connect_lock) )
return errno = EINTR, (Channel*) NULL;
}
if ( unmounted )
return errno = ECONNRESET, (Channel*) NULL;
Channel* result = connecting;
......@@ -638,18 +642,19 @@ ServerNode::~ServerNode()
server->Disconnect();
}
Ref<Inode> ServerNode::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags)
Ref<Inode> ServerNode::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags)
{
(void) addr;
(void) flags;
if ( flags & ~(0) )
return errno = EINVAL, Ref<Inode>(NULL);
size_t out_addrlen = 0;
if ( addrlen && !ctx->copy_to_dest(addrlen, &out_addrlen, sizeof(out_addrlen)) )
return Ref<Inode>(NULL);
Ref<ChannelNode> node(new ChannelNode);
if ( !node )
return Ref<Inode>(NULL);
Channel* channel = server->Accept();
Channel* channel = server->Accept(ctx);
if ( !channel )
return Ref<Inode>(NULL);
node->Construct(channel);
......@@ -1462,8 +1467,8 @@ int Unode::rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname,
return ret;
}
Ref<Inode> Unode::accept(ioctx_t* /*ctx*/, uint8_t* /*addr*/,
size_t* /*addrlen*/, int /*flags*/)
Ref<Inode> Unode::accept4(ioctx_t* /*ctx*/, uint8_t* /*addr*/,
size_t* /*addrlen*/, int /*flags*/)
{
return errno = ENOTSOCK, Ref<Inode>();
}
......
......@@ -94,8 +94,8 @@ public:
int poll(ioctx_t* ctx, PollNode* node);
int rename_here(ioctx_t* ctx, Ref<Descriptor> from, const char* oldpath,
const char* newpath);
Ref<Descriptor> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags);
Ref<Descriptor> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags);
int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
int listen(ioctx_t* ctx, int backlog);
......
......@@ -104,8 +104,8 @@ public:
virtual int poll(ioctx_t* ctx, PollNode* node) = 0;
virtual int rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname,
const char* newname) = 0;
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags) = 0;
virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags) = 0;
virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen) = 0;
virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen) = 0;
virtual int listen(ioctx_t* ctx, int backlog) = 0;
......@@ -210,8 +210,8 @@ public:
virtual int poll(ioctx_t* ctx, PollNode* node);
virtual int rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname,
const char* newname);
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags);
virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags);
virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
virtual int listen(ioctx_t* ctx, int backlog);
......
......@@ -93,7 +93,7 @@ public:
int poll(ioctx_t* ctx, PollNode* node);
int rename_here(ioctx_t* ctx, Ref<Vnode> from, const char* oldname,
const char* newname);
Ref<Vnode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags);
Ref<Vnode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags);
int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
int listen(ioctx_t* ctx, int backlog);
......
......@@ -512,8 +512,8 @@ int AbstractInode::rename_here(ioctx_t* /*ctx*/, Ref<Inode> /*from*/,
return errno = ENOTDIR, -1;
}
Ref<Inode> AbstractInode::accept(ioctx_t* /*ctx*/, uint8_t* /*addr*/,
size_t* /*addrlen*/, int /*flags*/)
Ref<Inode> AbstractInode::accept4(ioctx_t* /*ctx*/, uint8_t* /*addr*/,
size_t* /*addrlen*/, int /*flags*/)
{
return errno = ENOTSOCK, Ref<Inode>();
}
......
......@@ -731,13 +731,15 @@ int sys_accept4(int fd, void* addr, size_t* addrlen, int flags)
int fdflags = 0;
if ( flags & SOCK_CLOEXEC ) fdflags |= FD_CLOEXEC;
if ( flags & SOCK_CLOFORK ) fdflags |= FD_CLOFORK;
flags &= ~(SOCK_CLOEXEC | SOCK_CLOFORK);
int descflags = 0;
if ( flags & SOCK_NONBLOCK ) descflags |= O_NONBLOCK;
flags &= ~(SOCK_CLOEXEC | SOCK_CLOFORK | SOCK_NONBLOCK);
ioctx_t ctx; SetupUserIOCtx(&ctx);
Ref<Descriptor> conn = desc->accept(&ctx, (uint8_t*) addr, addrlen, flags);
Ref<Descriptor> conn = desc->accept4(&ctx, (uint8_t*) addr, addrlen, flags);
if ( !conn )
return -1;
if ( flags & SOCK_NONBLOCK )
conn->SetFlags(conn->GetFlags() | O_NONBLOCK);
if ( descflags )
conn->SetFlags(conn->GetFlags() | descflags);
return CurrentProcess()->GetDTable()->Allocate(conn, fdflags);
}
......
......@@ -24,6 +24,7 @@
#include <errno.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <sortix/fcntl.h>
......@@ -82,8 +83,8 @@ class StreamSocket : public AbstractInode
public:
StreamSocket(uid_t owner, gid_t group, mode_t mode, Ref<Manager> manager);
virtual ~StreamSocket();
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrsize,
int flags);
virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrsize,
int flags);
virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize);
virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrsize);
virtual int listen(ioctx_t* ctx, int backlog);
......@@ -116,6 +117,7 @@ public: /* For use by Manager. */
StreamSocket* first_pending;
StreamSocket* last_pending;
struct sockaddr_un* bound_address;
size_t bound_address_size;
bool is_listening;
bool is_connected;
bool is_refused;
......@@ -167,6 +169,7 @@ StreamSocket::StreamSocket(uid_t owner, gid_t group, mode_t mode,
this->first_pending = NULL;
this->last_pending = NULL;
this->bound_address = NULL;
this->bound_address_size = 0;
this->is_listening = false;
this->is_connected = false;
this->is_refused = false;
......@@ -181,11 +184,11 @@ StreamSocket::~StreamSocket()
{
if ( is_listening )
manager->Unlisten(this);
delete[] bound_address;
free(bound_address);
}
Ref<Inode> StreamSocket::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrsize,
int flags)
Ref<Inode> StreamSocket::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrsize,
int flags)
{
ScopedLock lock(&socket_lock);
if ( !is_listening )
......@@ -198,33 +201,25 @@ int StreamSocket::do_bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize)
if ( is_connected || is_listening || bound_address )
return errno = EINVAL, -1;
size_t path_offset = offsetof(struct sockaddr_un, sun_path);
size_t path_len = (path_offset - addrsize) / sizeof(char);
if ( addrsize < path_offset )
return errno = EINVAL, -1;
uint8_t* buffer = new uint8_t[addrsize];
if ( !buffer )
size_t path_len = path_offset - addrsize;
struct sockaddr_un* address = (struct sockaddr_un*) malloc(addrsize);
if ( !address )
return -1;
if ( ctx->copy_from_src(buffer, addr, addrsize) )
{
struct sockaddr_un* address = (struct sockaddr_un*) buffer;
if ( address->sun_family == AF_UNIX )
{
bool found_nul = false;
for ( size_t i = 0; !found_nul && i < path_len; i++ )
if ( address->sun_path[i] == '\0' )
found_nul = true;
if ( found_nul )
{
bound_address = address;
return 0;
}
errno = EINVAL;
}
else
errno = EAFNOSUPPORT;
}
delete[] buffer;
return -1;
if ( !ctx->copy_from_src(address, addr, addrsize) )
return free(address), -1;
if ( address->sun_family != AF_UNIX )
return free(address), errno = EAFNOSUPPORT, -1;
bool found_nul = false;
for ( size_t i = 0; !found_nul && i < path_len; i++ )
if ( address->sun_path[i] == '\0' )
found_nul = true;
if ( !found_nul )
return free(address), errno = EINVAL, -1;
bound_address = address;
bound_address_size = addrsize;
return 0;
}
int StreamSocket::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize)
......@@ -465,40 +460,43 @@ int Manager::AcceptPoll(StreamSocket* socket, ioctx_t* /*ctx*/, PollNode* node)
}
Ref<StreamSocket> Manager::Accept(StreamSocket* socket, ioctx_t* ctx,
uint8_t* addr, size_t* addrsize, int /*flags*/)
uint8_t* addr, size_t* addrsize, int flags)
{
if ( flags & ~(0) )
return errno = EINVAL, Ref<StreamSocket>(NULL);
ScopedLock lock(&manager_lock);
// TODO: Support non-blocking accept!
while ( !socket->first_pending )
{
if ( (ctx->dflags & O_NONBLOCK) || (flags & SOCK_NONBLOCK) )
return errno = EWOULDBLOCK, Ref<StreamSocket>(NULL);
if ( !kthread_cond_wait_signal(&socket->pending_cond, &manager_lock) )
return errno = EINTR, Ref<StreamSocket>(NULL);
}
StreamSocket* client = socket->first_pending;
struct sockaddr_un* client_addr = client->bound_address;
size_t client_addr_size = offsetof(struct sockaddr_un, sun_path) +
(strlen(client_addr->sun_path)+1) * sizeof(char);
struct sockaddr_un* bound_address = socket->bound_address;
size_t bound_address_size = socket->bound_address_size;
if ( addr )
{
size_t caller_addrsize;
if ( !ctx->copy_from_src(&caller_addrsize, addrsize, sizeof(caller_addrsize)) )
size_t used_addrsize;
if ( !ctx->copy_from_src(&used_addrsize, addrsize,
sizeof(used_addrsize)) )
return Ref<StreamSocket>(NULL);
if ( caller_addrsize < client_addr_size )
return errno = ERANGE, Ref<StreamSocket>(NULL);
if ( !ctx->copy_from_src(addrsize, &client_addr_size, sizeof(client_addr_size)) )
if ( bound_address_size < used_addrsize )
used_addrsize = bound_address_size;
if ( !ctx->copy_to_dest(addr, bound_address, bound_address_size) )
return Ref<StreamSocket>(NULL);
if ( !ctx->copy_to_dest(addr, client_addr, client_addr_size) )
if ( !ctx->copy_to_dest(addrsize, &used_addrsize,
sizeof(used_addrsize)) )
return Ref<StreamSocket>(NULL);
}
// TODO: Give the caller the address of the remote!
Ref<StreamSocket> server(new StreamSocket(0, 0, 0666, Ref<Manager>(this)));
if ( !server )
return Ref<StreamSocket>(NULL);
StreamSocket* client = socket->first_pending;
QueuePop(&socket->first_pending, &socket->last_pending);
if ( !client->outgoing.Connect(&server->incoming) )
......@@ -513,10 +511,6 @@ Ref<StreamSocket> Manager::Accept(StreamSocket* socket, ioctx_t* ctx,
client->is_connected = true;
server->is_connected = true;
// TODO: Should the server socket inherit the address of the listening
// socket or perhaps the one of the client's source/destination, or
// nothing at all?
kthread_cond_signal(&client->accepted_cond);
return server;
......
......@@ -391,12 +391,14 @@ int Vnode::poll(ioctx_t* ctx, PollNode* node)
return inode->poll(ctx, node);
}
Ref<Vnode> Vnode::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags)
Ref<Vnode> Vnode::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags)
{
Ref<Inode> retinode = inode->accept(ctx, addr, addrlen, flags);
Ref<Inode> retinode = inode->accept4(ctx, addr, addrlen, flags);
if ( !retinode )
return Ref<Vnode>();
return Ref<Vnode>(new Vnode(retinode, Ref<Vnode>(), retinode->ino, retinode->dev));
return Ref<Vnode>(new Vnode(retinode, Ref<Vnode>(), retinode->ino,
retinode->dev));
}
int Vnode::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment