summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/clients.c93
-rw-r--r--src/clients.h5
-rw-r--r--src/tcpproxy.c9
3 files changed, 89 insertions, 18 deletions
diff --git a/src/clients.c b/src/clients.c
index a4b58b3..9e5c9fc 100644
--- a/src/clients.c
+++ b/src/clients.c
@@ -35,6 +35,7 @@
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/select.h>
+#include <netinet/in.h>
#include "clients.h"
#include "tcp.h"
@@ -47,7 +48,7 @@ void clients_delete_element(void* e)
client_t* element = (client_t*)e;
close(element->fd_[0]);
-// close(element->fd_[1]);
+ close(element->fd_[1]);
free(e);
}
@@ -74,14 +75,35 @@ int clients_add(clients_t* list, int fd, const tcp_endpoint_t* remote_end)
return -2;
}
+ element->write_buf_len_[0] = 0;
+ element->write_buf_len_[1] = 0;
element->fd_[0] = fd;
- element->fd_[1] = 0;
-// TODO: open new socket
-// element->fd_[1] = socket(...);
+
+ element->fd_[1] = socket(remote_end->ss_family, SOCK_STREAM, 0);
+ if(element->fd_[1] < 0) {
+ log_printf(INFO, "Error on socket(): %s, not adding client %d", strerror(errno), element->fd_[0]);
+ close(element->fd_[0]);
+ free(element);
+ return -1;
+ }
+
+ socklen_t socklen = sizeof(*remote_end);
+ if(remote_end->ss_family == AF_INET)
+ socklen = sizeof(struct sockaddr_in);
+ else if (remote_end->ss_family == AF_INET6)
+ socklen = sizeof(struct sockaddr_in6);
+
+ if(connect(element->fd_[1], (struct sockaddr *)remote_end, socklen)==-1) {
+ log_printf(INFO, "Error on connect(): %s, not adding client %d", strerror(errno), element->fd_[0]);
+ close(element->fd_[0]);
+ close(element->fd_[1]);
+ free(element);
+ return -1;
+ }
if(slist_add(list, element) == NULL) {
close(element->fd_[0]);
-// close(element->fd_[1]);
+ close(element->fd_[1]);
free(element);
return -2;
}
@@ -136,9 +158,9 @@ void clients_read_fds(clients_t* list, fd_set* set, int* max_fd)
client_t* c = (client_t*)tmp->data_;
if(c) {
FD_SET(c->fd_[0], set);
-// FD_SET(c->fd_[1], set);
+ FD_SET(c->fd_[1], set);
*max_fd = *max_fd > c->fd_[0] ? *max_fd : c->fd_[0];
-// *max_fd = *max_fd > c->fd_[1] ? *max_fd : c->fd_[1];
+ *max_fd = *max_fd > c->fd_[1] ? *max_fd : c->fd_[1];
}
tmp = tmp->next_;
}
@@ -149,7 +171,21 @@ void clients_write_fds(clients_t* list, fd_set* set, int* max_fd)
if(!list)
return;
- // TODO: add all clients with pending data
+ slist_element_t* tmp = list->first_;
+ while(tmp) {
+ client_t* c = (client_t*)tmp->data_;
+ if(c) {
+ if(c->write_buf_len_[0]) {
+ FD_SET(c->fd_[0], set);
+ *max_fd = *max_fd > c->fd_[0] ? *max_fd : c->fd_[0];
+ }
+ if(c->write_buf_len_[1]) {
+ FD_SET(c->fd_[1], set);
+ *max_fd = *max_fd > c->fd_[1] ? *max_fd : c->fd_[1];
+ }
+ }
+ tmp = tmp->next_;
+ }
}
int clients_read(clients_t* list, fd_set* set)
@@ -173,20 +209,18 @@ int clients_read(clients_t* list, fd_set* set)
}
else continue;
- u_int8_t* buffer[1024];
- int len = recv(c->fd_[in], buffer, sizeof(buffer), 0);
+ // TODO: what when buffer is full?
+ int len = recv(c->fd_[in], &(c->write_buf_[out][c->write_buf_len_[out]]), BUFFER_LENGTH - c->write_buf_len_[out], 0);
if(len < 0) {
log_printf(INFO, "Error on recv(): %s, removing client %d", strerror(errno), c->fd_[0]);
slist_remove(list, c);
}
else if(!len) {
- log_printf(INFO, "client %d closed connection", c->fd_[0]);
+ log_printf(INFO, "client %d closed connection, removing it", c->fd_[0]);
slist_remove(list, c);
}
- else {
- log_printf(INFO, "client %d: read %d bytes", c->fd_[0], len);
- // TODO: add data to write buffer of l->fd_[out]
- }
+ else
+ c->write_buf_len_[out] += len;
}
}
@@ -195,5 +229,34 @@ int clients_read(clients_t* list, fd_set* set)
int clients_write(clients_t* list, fd_set* set)
{
+ if(!list)
+ return -1;
+
+ slist_element_t* tmp = list->first_;
+ while(tmp) {
+ client_t* c = (client_t*)tmp->data_;
+ tmp = tmp->next_;
+ if(c) {
+ int i;
+ for(i=0; i<2; ++i) {
+ if(FD_ISSET(c->fd_[i], set)) {
+ int len = send(c->fd_[i], c->write_buf_[i], c->write_buf_len_[i], 0);
+ if(len < 0) {
+ log_printf(INFO, "Error on send(): %s, removing client %d", strerror(errno), c->fd_[0]);
+ slist_remove(list, c);
+ }
+ else {
+ if(c->write_buf_len_[i] > len) {
+ memmove(c->write_buf_[i], &c->write_buf_[i][len], c->write_buf_len_[i] - len);
+ c->write_buf_len_[i] -= len;
+ }
+ else
+ c->write_buf_len_[i] = 0;
+ }
+ }
+ }
+ }
+ }
+
return 0;
}
diff --git a/src/clients.h b/src/clients.h
index 9ab8f0c..7052161 100644
--- a/src/clients.h
+++ b/src/clients.h
@@ -33,9 +33,12 @@
#include "slist.h"
#include "tcp.h"
+#define BUFFER_LENGTH 1048576
+
typedef struct {
int fd_[2];
- // TODO: add info for each client and write buffers
+ u_int8_t write_buf_[2][BUFFER_LENGTH];
+ u_int32_t write_buf_len_[2];
} client_t;
void clients_delete_element(void* e);
diff --git a/src/tcpproxy.c b/src/tcpproxy.c
index fce2e2c..6a55ea1 100644
--- a/src/tcpproxy.c
+++ b/src/tcpproxy.c
@@ -52,13 +52,15 @@ int main_loop(options_t* opt, listeners_t* listeners)
int return_value = clients_init(&clients);
while(!return_value) {
- fd_set readfds;
+ fd_set readfds, writefds;
FD_ZERO(&readfds);
+ FD_ZERO(&writefds);
FD_SET(sig_fd, &readfds);
int nfds = sig_fd;
listener_read_fds(listeners, &readfds, &nfds);
clients_read_fds(&clients, &readfds, &nfds);
- int ret = select(nfds + 1, &readfds, NULL, NULL, NULL);
+ clients_write_fds(&clients, &writefds, &nfds);
+ int ret = select(nfds + 1, &readfds, &writefds, NULL, NULL);
if(ret == -1 && errno != EINTR) {
log_printf(ERROR, "select returned with error: %s", strerror(errno));
return_value = -1;
@@ -77,6 +79,9 @@ int main_loop(options_t* opt, listeners_t* listeners)
return_value = listener_handle_accept(listeners, &clients, &readfds);
if(return_value) break;
+ return_value = clients_write(&clients, &writefds);
+ if(return_value) break;
+
return_value = clients_read(&clients, &readfds);
}