diff options
Diffstat (limited to 'common-channel.c')
-rw-r--r-- | common-channel.c | 170 |
1 files changed, 77 insertions, 93 deletions
diff --git a/common-channel.c b/common-channel.c index dfff1b4..68d2b48 100644 --- a/common-channel.c +++ b/common-channel.c @@ -52,8 +52,8 @@ static void deletechannel(struct Channel *channel); static void checkinitdone(struct Channel *channel); static void checkclose(struct Channel *channel); -static void closeinfd(struct Channel * channel); -static void closeoutfd(struct Channel * channel, int fd); +static void closewritefd(struct Channel * channel); +static void closereadfd(struct Channel * channel, int fd); static void closechanfd(struct Channel *channel, int fd, int how); #define FD_UNINIT (-2) @@ -143,10 +143,11 @@ struct Channel* newchannel(unsigned int remotechan, newchan->transmaxpacket = transmaxpacket; newchan->typedata = NULL; - newchan->infd = FD_UNINIT; - newchan->outfd = FD_UNINIT; + newchan->writefd = FD_UNINIT; + newchan->readfd = FD_UNINIT; newchan->errfd = FD_CLOSED; /* this isn't always set to start with */ newchan->initconn = 0; + newchan->await_open = 0; newchan->writebuf = cbuf_new(RECV_MAXWINDOW); newchan->extrabuf = NULL; /* The user code can set it up */ @@ -176,11 +177,10 @@ struct Channel* getchannel() { } /* Iterate through the channels, performing IO if available */ -void channelio(fd_set *readfd, fd_set *writefd) { +void channelio(fd_set *readfds, fd_set *writefds) { struct Channel *channel; unsigned int i; - int ret; /* iterate through all the possible channels */ for (i = 0; i < ses.chansize; i++) { @@ -191,48 +191,30 @@ void channelio(fd_set *readfd, fd_set *writefd) { continue; } - /* read from program/pipe stdout */ - if (channel->outfd >= 0 && FD_ISSET(channel->outfd, readfd)) { + /* read data and send it over the wire */ + if (channel->readfd >= 0 && FD_ISSET(channel->readfd, readfds)) { send_msg_channel_data(channel, 0, 0); } - /* read from program/pipe stderr */ + /* read stderr data and send it over the wire */ if (channel->extrabuf == NULL && - channel->errfd >= 0 && FD_ISSET(channel->errfd, readfd)) { + channel->errfd >= 0 && FD_ISSET(channel->errfd, readfds)) { send_msg_channel_data(channel, 1, SSH_EXTENDED_DATA_STDERR); } - /* if we can read from the infd, it might be closed, so we try to - * see if it has errors */ - if (channel->infd >= 0 && channel->infd != channel->outfd - && FD_ISSET(channel->infd, readfd)) { - if (channel->initconn) { - /* Handling for "in progress" connection - this is needed - * to avoid spinning 100% CPU when we connect to a server - * which doesn't send anything (tcpfwding) */ - checkinitdone(channel); - continue; /* Important not to use the channel after - checkinitdone(), as it may be NULL */ - } - ret = write(channel->infd, NULL, 0); /* Fake write */ - if (ret < 0 && errno != EINTR && errno != EAGAIN) { - closeinfd(channel); - } - } - /* write to program/pipe stdin */ - if (channel->infd >= 0 && FD_ISSET(channel->infd, writefd)) { + if (channel->writefd >= 0 && FD_ISSET(channel->writefd, writefds)) { if (channel->initconn) { checkinitdone(channel); continue; /* Important not to use the channel after checkinitdone(), as it may be NULL */ } - writechannel(channel, channel->infd, channel->writebuf); + writechannel(channel, channel->writefd, channel->writebuf); } /* stderr for client mode */ if (channel->extrabuf != NULL - && channel->errfd >= 0 && FD_ISSET(channel->errfd, writefd)) { + && channel->errfd >= 0 && FD_ISSET(channel->errfd, writefds)) { writechannel(channel, channel->errfd, channel->extrabuf); } @@ -243,7 +225,7 @@ void channelio(fd_set *readfd, fd_set *writefd) { /* Listeners such as TCP, X11, agent-auth */ #ifdef USING_LISTENERS - handle_listeners(readfd); + handle_listeners(readfds); #endif } @@ -251,8 +233,8 @@ void channelio(fd_set *readfd, fd_set *writefd) { /* do all the EOF/close type stuff checking for a channel */ static void checkclose(struct Channel *channel) { - TRACE(("checkclose: infd %d, outfd %d, errfd %d, sentclosed %d, recvclosed %d", - channel->infd, channel->outfd, + TRACE(("checkclose: writefd %d, readfd %d, errfd %d, sentclosed %d, recvclosed %d", + channel->writefd, channel->readfd, channel->errfd, channel->sentclosed, channel->recvclosed)) TRACE(("writebuf %d extrabuf %s extrabuf %d", cbuf_getused(channel->writebuf), @@ -265,18 +247,18 @@ static void checkclose(struct Channel *channel) { * if the shell has exited etc */ if (channel->type->checkclose) { if (channel->type->checkclose(channel)) { - closeinfd(channel); + closewritefd(channel); } } if (!channel->senteof - && channel->outfd == FD_CLOSED + && channel->readfd == FD_CLOSED && (channel->extrabuf != NULL || channel->errfd == FD_CLOSED)) { send_msg_channel_eof(channel); } - if (channel->infd == FD_CLOSED - && channel->outfd == FD_CLOSED + if (channel->writefd == FD_CLOSED + && channel->readfd == FD_CLOSED && (channel->extrabuf != NULL || channel->errfd == FD_CLOSED)) { send_msg_channel_close(channel); } @@ -313,17 +295,17 @@ static void checkinitdone(struct Channel *channel) { TRACE(("enter checkinitdone")) - if (getsockopt(channel->infd, SOL_SOCKET, SO_ERROR, &val, &vallen) + if (getsockopt(channel->writefd, SOL_SOCKET, SO_ERROR, &val, &vallen) || val != 0) { send_msg_channel_open_failure(channel->remotechan, SSH_OPEN_CONNECT_FAILED, "", ""); - close(channel->infd); + close(channel->writefd); deletechannel(channel); TRACE(("leave checkinitdone: fail")) } else { send_msg_channel_open_confirmation(channel, channel->recvwindow, channel->recvmaxpacket); - channel->outfd = channel->infd; + channel->readfd = channel->writefd; channel->initconn = 0; TRACE(("leave checkinitdone: success")) } @@ -385,7 +367,7 @@ static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf) { if (len < 0 && errno != EINTR) { /* no more to write - we close it even if the fd was stderr, since * that's a nasty failure too */ - closeinfd(channel); + closewritefd(channel); } TRACE(("leave writechannel: len <= 0")) return; @@ -394,9 +376,9 @@ static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf) { cbuf_incrread(cbuf, len); channel->recvdonelen += len; - if (fd == channel->infd && len == maxlen && channel->recveof) { + if (fd == channel->writefd && cbuf_getused(cbuf) == 0 && channel->recveof) { /* Check if we're closing up */ - closeinfd(channel); + closewritefd(channel); TRACE(("leave writechannel: recveof set")) return; } @@ -409,9 +391,9 @@ static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf) { channel->recvdonelen = 0; } - assert(channel->recvwindow <= RECV_MAXWINDOW); - assert(channel->recvwindow <= cbuf_getavail(channel->writebuf)); - assert(channel->extrabuf == NULL || + dropbear_assert(channel->recvwindow <= RECV_MAXWINDOW); + dropbear_assert(channel->recvwindow <= cbuf_getavail(channel->writebuf)); + dropbear_assert(channel->extrabuf == NULL || channel->recvwindow <= cbuf_getavail(channel->extrabuf)); @@ -420,7 +402,7 @@ static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf) { /* Set the file descriptors for the main select in session.c * This avoid channels which don't have any window available, are closed, etc*/ -void setchannelfds(fd_set *readfd, fd_set *writefd) { +void setchannelfds(fd_set *readfds, fd_set *writefds) { unsigned int i; struct Channel * channel; @@ -435,41 +417,31 @@ void setchannelfds(fd_set *readfd, fd_set *writefd) { /* Stuff to put over the wire */ if (channel->transwindow > 0) { - if (channel->outfd >= 0) { - FD_SET(channel->outfd, readfd); + if (channel->readfd >= 0) { + FD_SET(channel->readfd, readfds); } if (channel->extrabuf == NULL && channel->errfd >= 0) { - FD_SET(channel->errfd, readfd); + FD_SET(channel->errfd, readfds); } } - /* For checking FD status (ie closure etc) - we don't actually - * read data from infd */ - TRACE(("infd = %d, outfd %d, errfd %d, bufused %d", - channel->infd, channel->outfd, - channel->errfd, - cbuf_getused(channel->writebuf) )) - if (channel->infd >= 0 && channel->infd != channel->outfd) { - FD_SET(channel->infd, readfd); - } - - /* Stuff from the wire, to local program/shell/user etc */ - if ((channel->infd >= 0 && cbuf_getused(channel->writebuf) > 0 ) + /* Stuff from the wire */ + if ((channel->writefd >= 0 && cbuf_getused(channel->writebuf) > 0 ) || channel->initconn) { - FD_SET(channel->infd, writefd); + FD_SET(channel->writefd, writefds); } if (channel->extrabuf != NULL && channel->errfd >= 0 && cbuf_getused(channel->extrabuf) > 0 ) { - FD_SET(channel->errfd, writefd); + FD_SET(channel->errfd, writefds); } } /* foreach channel */ #ifdef USING_LISTENERS - set_listener_fds(readfd); + set_listener_fds(readfds); #endif } @@ -492,7 +464,7 @@ void recv_msg_channel_eof() { if (cbuf_getused(channel->writebuf) == 0 && (channel->extrabuf == NULL || cbuf_getused(channel->extrabuf) == 0)) { - closeinfd(channel); + closewritefd(channel); } TRACE(("leave recv_msg_channel_eof")) @@ -540,8 +512,8 @@ static void removechannel(struct Channel * channel) { /* close the FDs in case they haven't been done * yet (ie they were shutdown etc */ - close(channel->infd); - close(channel->outfd); + close(channel->writefd); + close(channel->readfd); close(channel->errfd); channel->typedata = NULL; @@ -603,14 +575,14 @@ static void send_msg_channel_data(struct Channel *channel, int isextended, CHECKCLEARTOWRITE(); - assert(!channel->sentclosed); + dropbear_assert(!channel->sentclosed); if (isextended) { fd = channel->errfd; } else { - fd = channel->outfd; + fd = channel->readfd; } - assert(fd >= 0); + dropbear_assert(fd >= 0); maxlen = MIN(channel->transwindow, channel->transmaxpacket); /* -(1+4+4) is SSH_MSG_CHANNEL_DATA, channel number, string length, and @@ -630,7 +602,7 @@ static void send_msg_channel_data(struct Channel *channel, int isextended, if (len <= 0) { /* on error/eof, send eof */ if (len == 0 || errno != EINTR) { - closeoutfd(channel, fd); + closereadfd(channel, fd); } buf_free(buf); buf = NULL; @@ -668,7 +640,7 @@ void recv_msg_channel_data() { dropbear_exit("Unknown channel"); } - common_recv_msg_channel_data(channel, channel->infd, channel->writebuf); + common_recv_msg_channel_data(channel, channel->writefd, channel->writebuf); } /* Shared for data and stderr data - when we receive data, put it in a buffer @@ -688,7 +660,7 @@ void common_recv_msg_channel_data(struct Channel *channel, int fd, } if (fd < 0) { - dropbear_exit("received data with bad infd"); + dropbear_exit("received data with bad writefd"); } datalen = buf_getint(ses.payload); @@ -718,9 +690,9 @@ void common_recv_msg_channel_data(struct Channel *channel, int fd, len -= buflen; } - assert(channel->recvwindow >= datalen); + dropbear_assert(channel->recvwindow >= datalen); channel->recvwindow -= datalen; - assert(channel->recvwindow <= RECV_MAXWINDOW); + dropbear_assert(channel->recvwindow <= RECV_MAXWINDOW); TRACE(("leave recv_msg_channel_data")) } @@ -930,9 +902,11 @@ int send_msg_channel_open_init(int fd, const struct ChanType *type) { /* set fd non-blocking */ setnonblocking(fd); - chan->infd = chan->outfd = fd; + chan->writefd = chan->readfd = fd; ses.maxfd = MAX(ses.maxfd, fd); + chan->await_open = 1; + /* now open the channel connection */ CHECKCLEARTOWRITE(); @@ -960,6 +934,11 @@ void recv_msg_channel_open_confirmation() { dropbear_exit("Unknown channel"); } + if (!channel->await_open) { + dropbear_exit("unexpected channel reply"); + } + channel->await_open = 0; + channel->remotechan = buf_getint(ses.payload); channel->transwindow = buf_getint(ses.payload); channel->transmaxpacket = buf_getint(ses.payload); @@ -990,26 +969,31 @@ void recv_msg_channel_open_failure() { dropbear_exit("Unknown channel"); } + if (!channel->await_open) { + dropbear_exit("unexpected channel reply"); + } + channel->await_open = 0; + removechannel(channel); } #endif /* USING_LISTENERS */ /* close a stdout/stderr fd */ -static void closeoutfd(struct Channel * channel, int fd) { +static void closereadfd(struct Channel * channel, int fd) { - /* don't close it if it is the same as infd, - * unless infd is already set -1 */ - TRACE(("enter closeoutfd")) + /* don't close it if it is the same as writefd, + * unless writefd is already set -1 */ + TRACE(("enter closereadfd")) closechanfd(channel, fd, 0); - TRACE(("leave closeoutfd")) + TRACE(("leave closereadfd")) } /* close a stdin fd */ -static void closeinfd(struct Channel * channel) { +static void closewritefd(struct Channel * channel) { - TRACE(("enter closeinfd")) - closechanfd(channel, channel->infd, 1); - TRACE(("leave closeinfd")) + TRACE(("enter closewritefd")) + closechanfd(channel, channel->writefd, 1); + TRACE(("leave closewritefd")) } /* close a fd, how is 0 for stdout/stderr, 1 for stdin */ @@ -1031,15 +1015,15 @@ static void closechanfd(struct Channel *channel, int fd, int how) { closein = closeout = 1; } - if (closeout && fd == channel->outfd) { - channel->outfd = FD_CLOSED; + if (closeout && fd == channel->readfd) { + channel->readfd = FD_CLOSED; } if (closeout && (channel->extrabuf == NULL) && (fd == channel->errfd)) { channel->errfd = FD_CLOSED; } - if (closein && fd == channel->infd) { - channel->infd = FD_CLOSED; + if (closein && fd == channel->writefd) { + channel->writefd = FD_CLOSED; } if (closein && (channel->extrabuf != NULL) && (fd == channel->errfd)) { channel->errfd = FD_CLOSED; @@ -1047,8 +1031,8 @@ static void closechanfd(struct Channel *channel, int fd, int how) { /* if we called shutdown on it and all references are gone, then we * need to close() it to stop it lingering */ - if (channel->type->sepfds && channel->outfd == FD_CLOSED - && channel->infd == FD_CLOSED && channel->errfd == FD_CLOSED) { + if (channel->type->sepfds && channel->readfd == FD_CLOSED + && channel->writefd == FD_CLOSED && channel->errfd == FD_CLOSED) { close(fd); } } |