summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--anytun.cpp104
-rw-r--r--buffer.cpp174
-rw-r--r--buffer.h21
-rw-r--r--cipher.cpp31
-rw-r--r--cipher.h5
-rw-r--r--encryptedPacket.cpp168
-rw-r--r--encryptedPacket.h50
-rw-r--r--keyDerivation.cpp45
-rw-r--r--mpi.cpp80
-rw-r--r--mpi.h2
-rw-r--r--plainPacket.cpp54
-rw-r--r--plainPacket.h44
12 files changed, 369 insertions, 409 deletions
diff --git a/anytun.cpp b/anytun.cpp
index 5aab904..e679281 100644
--- a/anytun.cpp
+++ b/anytun.cpp
@@ -31,7 +31,7 @@
#include <iostream>
#include <poll.h>
-#include <gcrypt.h> // for thread safe libgcrypt initialisation
+#include <gcrypt.h>
#include <cerrno> // for ENOMEM
#include "datatypes.h"
@@ -53,6 +53,9 @@
#include "seqWindow.h"
#include "connectionList.h"
+#include "mpi.h" // TODO: remove after debug
+
+
#include "syncQueue.h"
#include "syncSocketHandler.h"
#include "syncListenSocket.h"
@@ -66,9 +69,11 @@
#define PAYLOAD_TYPE_TAP 0x6558
#define PAYLOAD_TYPE_TUN 0x0800
-#define SESSION_KEYLEN_AUTH 20
-#define SESSION_KEYLEN_ENCR 16
-#define SESSION_KEYLEN_SALT 14
+#define MAX_PACKET_LENGTH 1600
+
+#define SESSION_KEYLEN_AUTH 20 // TODO: hardcoded size
+#define SESSION_KEYLEN_ENCR 16 // TODO: hardcoded size
+#define SESSION_KEYLEN_SALT 14 // TODO: hardcoded size
void createConnection(const std::string & remote_host, u_int16_t remote_port, ConnectionList & cl, u_int16_t seqSize, SyncQueue & queue)
{
@@ -128,32 +133,30 @@ void* sender(void* p)
std::auto_ptr<Cipher> c(CipherFactory::create(param->opt.getCipher()));
// std::auto_ptr<AuthAlgo> a(AuthAlgoFactory::create(param->opt.getAuthAlgo()) );
- PlainPacket plain_packet(1600); // TODO: fix me... mtu size
- EncryptedPacket packet(1600);
+ PlainPacket plain_packet(MAX_PACKET_LENGTH);
+ EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
- // TODO: hardcoded keySize!!!
- Buffer session_key(SESSION_KEYLEN_ENCR);
- Buffer session_salt(SESSION_KEYLEN_SALT);
- Buffer session_auth_key(SESSION_KEYLEN_AUTH);
+ Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
+ Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
+ Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
//TODO replace mux
u_int16_t mux = 0;
while(1)
{
- plain_packet.setLength( plain_packet.getMaxLength()); // Q@NINE wtf???
-
// read packet from device
u_int32_t len = param->dev.read(plain_packet);
plain_packet.setLength(len);
- packet.setLength( len );
- if( param->cl.empty())
+
+ if(param->cl.empty())
continue;
- ConnectionMap::iterator cit = param->cl.getConnection(mux);
+
+ ConnectionMap::iterator cit = param->cl.getConnection(mux);
if(cit==param->cl.getEnd())
continue;
ConnectionParam & conn = cit->second;
- // add payload type
+ // set payload type
if(param->dev.getType() == TunDevice::TYPE_TUN)
plain_packet.setPayloadType(PAYLOAD_TYPE_TUN);
else if(param->dev.getType() == TunDevice::TYPE_TAP)
@@ -168,16 +171,16 @@ void* sender(void* p)
c->setSalt(session_salt);
// encrypt packet
- c->encrypt(plain_packet, packet, conn.seq_nr_, param->opt.getSenderId());
+ c->encrypt(plain_packet, encrypted_packet, conn.seq_nr_, param->opt.getSenderId());
- packet.setHeader(conn.seq_nr_, param->opt.getSenderId(), mux);
+ encrypted_packet.setHeader(conn.seq_nr_, param->opt.getSenderId(), mux);
conn.seq_nr_++;
// TODO: activate authentication
-// conn.kd_.generate(LABEL_SATP_MSG_AUTH, packet.getSeqNr(), session_auth_key);
+// conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
// a->setKey(session_auth_key);
-// addPacketAuthTag(packet, a.get(), conn);
- param->src.send(packet, conn.remote_host_, conn.remote_port_);
+// addPacketAuthTag(encrypted_packet, a.get(), conn);
+ param->src.send(encrypted_packet, conn.remote_host_, conn.remote_port_);
}
pthread_exit(NULL);
}
@@ -223,30 +226,26 @@ void* receiver(void* p)
std::auto_ptr<Cipher> c( CipherFactory::create(param->opt.getCipher()) );
// std::auto_ptr<AuthAlgo> a( AuthAlgoFactory::create(param->opt.getAuthAlgo()) );
- EncryptedPacket packet(1600); // TODO: dynamic mtu size
+ EncryptedPacket encrypted_packet(1600); // TODO: dynamic mtu size
PlainPacket plain_packet(1600);
- // TODO: hardcoded keysize!!!
- Buffer session_key(SESSION_KEYLEN_SALT);
- Buffer session_salt(SESSION_KEYLEN_SALT);
- Buffer session_auth_key(SESSION_KEYLEN_AUTH);
+ Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
+ Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
+ Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
while(1)
{
string remote_host;
u_int16_t remote_port;
- packet.setLength( packet.getMaxLength() ); // Q@NINE wtf???
- plain_packet.setLength( plain_packet.getMaxLength() ); // Q@NINE wtf???
- // u_int16_t sid = 0, seq = 0;
// read packet from socket
- u_int32_t len = param->src.recv(packet, remote_host, remote_port);
- packet.setLength(len);
+ u_int32_t len = param->src.recv(encrypted_packet, remote_host, remote_port);
+ encrypted_packet.setLength(len);
// TODO: check auth tag first
-// conn.kd_.generate(LABEL_SATP_MSG_AUTH, packet.getSeqNr(), session_auth_key);
+// conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
// a->setKey( session_auth_key );
-// if(!checkPacketAuthTag(packet, a.get(), conn))
+// if(!checkPacketAuthTag(encrypted_packet, a.get(), conn))
// continue;
@@ -272,17 +271,17 @@ void* receiver(void* p)
}
// Replay Protection
- if (!checkPacketSeqNr(packet, conn))
+ if (!checkPacketSeqNr(encrypted_packet, conn))
continue;
// generate packet-key
- conn.kd_.generate(LABEL_SATP_ENCRYPTION, packet.getSeqNr(), session_key);
- conn.kd_.generate(LABEL_SATP_SALT, packet.getSeqNr(), session_salt);
+ conn.kd_.generate(LABEL_SATP_ENCRYPTION, encrypted_packet.getSeqNr(), session_key);
+ conn.kd_.generate(LABEL_SATP_SALT, encrypted_packet.getSeqNr(), session_salt);
c->setKey(session_key);
c->setSalt(session_salt);
// decrypt packet
- c->decrypt(packet, plain_packet);
+ c->decrypt(encrypted_packet, plain_packet);
// check payload_type
if((param->dev.getType() == TunDevice::TYPE_TUN && plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN) ||
@@ -341,37 +340,6 @@ bool initLibGCrypt()
int main(int argc, char* argv[])
{
-// // this must be called before any other libgcrypt call
-// if(!initLibGCrypt())
-// return -1;
-
-// u_int8_t KEY[] = {0xE1,0xF9,0x7A,0x0D,0x3E,0x01,0x8B,0xE0,0xD6,0x4F,0xA3,0x2C,0x06,0xDE,0x41,0x39};
-// u_int8_t SALT[] = {0x0E,0xC6,0x75,0xAD,0x49,0x8A,0xFE,0xEB,0xB6,0x96,0x0B,0x3A,0xAB,0xE6};
-// Buffer master_key(KEY, 16);
-// Buffer master_salt(SALT, 14);
-// std::cout << "master key: " << std::endl << master_key.getHexDump() << std::endl;
-// std::cout << "master salt: " << std::endl << master_salt.getHexDump() << std::endl;
-// std::cout << std::endl;
-// KeyDerivation kd;
-// kd.init(master_key, master_salt);
-
-// Buffer key(16);
-// kd.generate(LABEL_SATP_ENCRYPTION, 0, key);
-// std::cout << "key: " << std::endl << key.getHexDump() << std::endl;
-
-// Buffer salt(14);
-// kd.generate(LABEL_SATP_SALT, 0, salt);
-// std::cout << "salt: " << std::endl << salt.getHexDump() << std::endl;
-
-// Buffer auth(14);
-// kd.generate(LABEL_SATP_MSG_AUTH, 0, auth);
-// std::cout << "auth: " << std::endl << auth.getHexDump() << std::endl;
-
-
-// exit(0);
-
-// // *++++++++++++++++++ end of kd test
-
std::cout << "anytun - secure anycast tunneling protocol" << std::endl;
Options opt;
if(!opt.parse(argc, argv))
diff --git a/buffer.cpp b/buffer.cpp
index e8191f2..cea8d2c 100644
--- a/buffer.cpp
+++ b/buffer.cpp
@@ -36,26 +36,32 @@
#include "datatypes.h"
#include "buffer.h"
-Buffer::Buffer() : buf_(0), length_(0)
+Buffer::Buffer(bool allow_realloc) : buf_(0), length_(0), real_length_(0), allow_realloc_(allow_realloc)
{
}
-Buffer::Buffer(u_int32_t length) : length_(length)
+Buffer::Buffer(u_int32_t length, bool allow_realloc) : length_(length), real_length_(length_ + Buffer::OVER_SIZE_),
+ allow_realloc_(allow_realloc)
{
- buf_ = new u_int8_t[length_];
- if(buf_)
- std::memset(buf_, 0, length_);
- else
+ buf_ = new u_int8_t[real_length_];
+ if(!buf_) {
length_ = 0;
+ real_length_ = 0;
+ throw std::bad_alloc();
+ }
+ std::memset(buf_, 0, real_length_);
}
-Buffer::Buffer(u_int8_t* data, u_int32_t length) : length_(length)
+Buffer::Buffer(u_int8_t* data, u_int32_t length, bool allow_realloc) : length_(length), real_length_(length + Buffer::OVER_SIZE_),
+ allow_realloc_(allow_realloc)
{
- buf_ = new u_int8_t[length_];
- if(buf_)
- std::memcpy(buf_, data, length_);
- else
+ buf_ = new u_int8_t[real_length_];
+ if(!buf_) {
length_ = 0;
+ real_length_ = 0;
+ throw std::bad_alloc();
+ }
+ std::memcpy(buf_, data, length_);
}
Buffer::~Buffer()
@@ -64,13 +70,15 @@ Buffer::~Buffer()
delete[] buf_;
}
-Buffer::Buffer(const Buffer &src) : length_(src.length_)
+Buffer::Buffer(const Buffer &src) : length_(src.length_), real_length_(src.real_length_), allow_realloc_(src.allow_realloc_)
{
- buf_ = new u_int8_t[length_];
- if(buf_)
- std::memcpy(buf_, src.buf_, length_);
- else
+ buf_ = new u_int8_t[real_length_];
+ if(!buf_) {
length_ = 0;
+ real_length_ = 0;
+ throw std::bad_alloc();
+ }
+ std::memcpy(buf_, src.buf_, length_);
}
void Buffer::operator=(const Buffer &src)
@@ -79,12 +87,16 @@ void Buffer::operator=(const Buffer &src)
delete[] buf_;
length_ = src.length_;
+ real_length_ = src.real_length_;
+ allow_realloc_ = src.allow_realloc_;
- buf_ = new u_int8_t[length_];
- if(buf_)
- std::memcpy(buf_, src.buf_, length_);
- else
+ buf_ = new u_int8_t[real_length_];
+ if(!buf_) {
length_ = 0;
+ real_length_ = 0;
+ throw std::bad_alloc();
+ }
+ std::memcpy(buf_, src.buf_, length_);
}
@@ -100,58 +112,60 @@ bool Buffer::operator==(const Buffer &cmp) const
return false;
}
-
-u_int32_t Buffer::resizeFront(u_int32_t new_length)
+Buffer Buffer::operator^(const Buffer &xor_by) const
{
- if(length_ == new_length)
- return length_;
+ u_int32_t res_length = (xor_by.length_ > length_) ? xor_by.length_ : length_;
+ u_int32_t min_length = (xor_by.length_ < length_) ? xor_by.length_ : length_;
+ Buffer res(res_length);
- u_int8_t *tmp = new u_int8_t[new_length];
- if(!tmp)
- return length_;
-
- if(buf_)
- {
- u_int8_t *src=buf_, *dest=tmp;
- if(length_ < new_length)
- dest = &dest[new_length - length_];
- else
- src = &src[length_ - new_length];
- u_int32_t len = length_ < new_length ? length_ : new_length;
- std::memcpy(dest, src, len);
- delete[] buf_;
- }
+ for( u_int32_t index = 0; index < min_length; index++ )
+ res[index] = buf_[index] ^ xor_by[index];
+
+ return res;
+}
- length_ = new_length;
- buf_ = tmp;
+u_int32_t Buffer::getLength() const
+{
return length_;
}
-u_int32_t Buffer::resizeBack(u_int32_t new_length)
+void Buffer::setLength(u_int32_t new_length)
{
- if(length_ == new_length)
- return length_;
-
- u_int8_t *tmp = new u_int8_t[new_length];
- if(!tmp)
- return length_;
+ if(new_length == length_)
+ return;
- if(buf_)
+ if(new_length > real_length_)
{
- u_int32_t len = length_ < new_length ? length_ : new_length;
- std::memcpy(tmp, buf_, len);
- delete[] buf_;
- }
+ if(!allow_realloc_)
+ throw std::out_of_range("buffer::setLength() - reallocation not allowed for this Buffer");
+
+ u_int8_t* old_buf = buf_;
+ u_int32_t old_length = length_;
+
+ length_ = new_length;
+ real_length_ = length_ + Buffer::OVER_SIZE_;
+
+ buf_ = new u_int8_t[real_length_];
+ if(!buf_) {
+ length_ = 0;
+ real_length_ = 0;
+ if(old_buf)
+ delete[] old_buf;
+
+ throw std::bad_alloc();
+ }
+ std::memcpy(buf_, old_buf, old_length);
- length_ = new_length;
- buf_ = tmp;
- return length_;
-}
+ if(old_buf)
+ delete[] old_buf;
+
+ old_buf = &buf_[old_length];
+ std::memset(old_buf, 0, real_length_ - old_length);
+ }
+ else
+ length_ = new_length;
+}
-u_int32_t Buffer::getLength() const
-{
- return length_;
-}
u_int8_t* Buffer::getBuf()
{
@@ -174,7 +188,7 @@ u_int8_t Buffer::operator[](u_int32_t index) const
return buf_[index];
}
-Buffer::operator u_int8_t*() // just for write/read tun
+Buffer::operator u_int8_t*()
{
return buf_;
}
@@ -182,10 +196,10 @@ Buffer::operator u_int8_t*() // just for write/read tun
std::string Buffer::getHexDump() const
{
std::stringstream ss;
- ss << std::hex;
+ ss << "Length=" << length_ << std::endl << std::hex << std::uppercase;
for( u_int32_t index = 0; index < length_; index++ )
{
- ss << std::setw(2) << std::setfill('0') << static_cast<unsigned int>(buf_[index]) << " ";
+ ss << std::setw(2) << std::setfill('0') << u_int32_t(buf_[index]) << " ";
if(!((index+1) % 16)) {
ss << std::endl;
continue;
@@ -196,35 +210,7 @@ std::string Buffer::getHexDump() const
return ss.str();
}
-Buffer Buffer::operator^(const Buffer &xor_by) const
+bool Buffer::isReallocAllowed() const
{
- Buffer res(length_);
- if( xor_by.getLength() > length_ )
- throw std::out_of_range("buffer::operator^ const");
-
- for( u_int32_t index = 0; index < xor_by.getLength(); index++ )
- res[index] = buf_[index] ^ xor_by[index];
-
- return res;
+ return allow_realloc_;
}
-
-Buffer Buffer::leftByteShift(u_int32_t width) const
-{
- Buffer res(length_+width);
-
- for( u_int32_t index = 0; index < length_; index++ )
- res[index+width] = buf_[index];
-
- return res;
-}
-
-Buffer Buffer::rightByteShift(u_int32_t width) const
-{
- Buffer res(length_);
-
- for( u_int32_t index = 0; index < length_-width; index++ )
- res[index] = buf_[index+width];
-
- return res;
-}
-
diff --git a/buffer.h b/buffer.h
index 1a426e1..063be91 100644
--- a/buffer.h
+++ b/buffer.h
@@ -40,32 +40,33 @@ class UDPPacketSource;
class Buffer
{
public:
- Buffer();
- Buffer(u_int32_t length);
- Buffer(u_int8_t* data, u_int32_t length);
+ Buffer(bool allow_realloc = true);
+ Buffer(u_int32_t length, bool allow_realloc = true);
+ Buffer(u_int8_t* data, u_int32_t length, bool allow_realloc = true);
virtual ~Buffer();
Buffer(const Buffer &src);
void operator=(const Buffer &src);
bool operator==(const Buffer &cmp) const;
+ Buffer operator^(const Buffer &xor_by) const;
- // math operations to calculate IVs and keys
- virtual Buffer operator^(const Buffer &xor_by) const;
- virtual Buffer leftByteShift(u_int32_t width) const;
- virtual Buffer rightByteShift(u_int32_t width) const;
-
- u_int32_t resizeFront(u_int32_t new_length);
- u_int32_t resizeBack(u_int32_t new_length);
u_int32_t getLength() const;
+ virtual void setLength(u_int32_t new_length);
u_int8_t* getBuf();
u_int8_t& operator[](u_int32_t index);
u_int8_t operator[](u_int32_t index) const;
std::string getHexDump() const;
+ bool isReallocAllowed() const;
+
operator u_int8_t*();
protected:
u_int8_t *buf_;
u_int32_t length_;
+ u_int32_t real_length_;
+ bool allow_realloc_;
+
+ static const u_int32_t OVER_SIZE_ = 100;
};
#endif
diff --git a/cipher.cpp b/cipher.cpp
index 22b8019..07a9117 100644
--- a/cipher.cpp
+++ b/cipher.cpp
@@ -39,18 +39,20 @@
#include "log.h"
-void Cipher::encrypt(const PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id)
+ // TODO: in should be const but does not work with getBuf() :(
+void Cipher::encrypt(PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id)
{
- u_int32_t len = cipher(out.payload_, out.payload_length_, in.complete_payload_ , in.complete_payload_length_, seq_nr, sender_id);
+ u_int32_t len = cipher(in, in.getLength(), out.getPayload(), out.getPayloadLength(), seq_nr, sender_id);
out.setSenderId(sender_id);
out.setSeqNr(seq_nr);
out.setPayloadLength(len);
}
-void Cipher::decrypt(const EncryptedPacket & in, PlainPacket & out)
+ // TODO: in should be const but does not work with getBuf() :(
+void Cipher::decrypt(EncryptedPacket & in, PlainPacket & out)
{
- u_int32_t len = decipher(out.complete_payload_, out.complete_payload_length_, in.payload_ , in.payload_length_, in.getSeqNr(), in.getSenderId());
- out.setCompletePayloadLength(len);
+ u_int32_t len = decipher(in.getPayload() , in.getPayloadLength(), out, out.getLength(), in.getSeqNr(), in.getSenderId());
+ out.setLength(len);
}
@@ -70,9 +72,9 @@ u_int32_t NullCipher::decipher(u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_in
//****** AesIcmCipher ******
-AesIcmCipher::AesIcmCipher() : salt_(Buffer(14)) // Q@NINE 14??????
+AesIcmCipher::AesIcmCipher()
{
- // TODO: hardcoded keysize!!!!!
+ // TODO: hardcoded keysize
gcry_error_t err = gcry_cipher_open( &cipher_, GCRY_CIPHER_AES128, GCRY_CIPHER_MODE_CTR, 0 );
if( err )
cLog.msg(Log::PRIO_CRIT) << "AesIcmCipher::AesIcmCipher: Failed to open cipher";
@@ -82,7 +84,6 @@ AesIcmCipher::AesIcmCipher() : salt_(Buffer(14)) // Q@NINE 14??????
AesIcmCipher::~AesIcmCipher()
{
gcry_cipher_close( cipher_ );
- cLog.msg(Log::PRIO_DEBUG) << "AesIcmCipher::~AesIcmCipher: closed cipher";
}
@@ -129,13 +130,17 @@ void AesIcmCipher::calc(u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t o
// // sizeof(k_s) = 112 bit, random
Mpi ctr(128); // TODO: hardcoded size
- Mpi salt = Mpi(salt_.getBuf(), salt_.getLength());
- Mpi sid = sender_id; // Q@OTTI add mux to sender_id????
- Mpi seq = seq_nr;
+ Mpi salt(salt_.getBuf(), salt_.getLength());
+ Mpi sid(32); // TODO: Q@OTTI add mux to sender_id????
+ sid = sender_id;
+ Mpi seq(32);
+ seq = seq_nr;
ctr = salt.mul2exp(16) ^ sid.mul2exp(64) ^ seq.mul2exp(16); // TODO: hardcoded size
- u_int8_t *ctr_buf = ctr.getNewBuf(16); // TODO: hardcoded size
- err = gcry_cipher_setctr( cipher_, ctr_buf, 16 ); // TODO: hardcoded size
+
+ u_int32_t written;
+ u_int8_t *ctr_buf = ctr.getNewBuf(&written); // TODO: hardcoded size
+ err = gcry_cipher_setctr( cipher_, ctr_buf, written ); // TODO: hardcoded size
delete[] ctr_buf;
if( err ) {
cLog.msg(Log::PRIO_ERR) << "AesIcmCipher: Failed to set cipher CTR: " << gpg_strerror( err );
diff --git a/cipher.h b/cipher.h
index 190d9fe..0749859 100644
--- a/cipher.h
+++ b/cipher.h
@@ -44,8 +44,9 @@ class Cipher
public:
virtual ~Cipher() {};
- void encrypt(const PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id);
- void decrypt(const EncryptedPacket & in, PlainPacket & out);
+ // TODO: in should be const but does not work with getBuf() :(
+ void encrypt(PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id);
+ void decrypt(EncryptedPacket & in, PlainPacket & out);
virtual void setKey(Buffer key) = 0;
virtual void setSalt(Buffer salt) = 0;
diff --git a/encryptedPacket.cpp b/encryptedPacket.cpp
index 0731d24..a3e5886 100644
--- a/encryptedPacket.cpp
+++ b/encryptedPacket.cpp
@@ -38,139 +38,149 @@
#include "authTag.h"
#include "log.h"
-
-EncryptedPacket::EncryptedPacket(u_int32_t max_payload_length)
- : Buffer(max_payload_length + sizeof(struct HeaderStruct) + AUTHTAG_SIZE)
+// TODO: fix auth_tag stuff
+EncryptedPacket::EncryptedPacket(u_int32_t payload_length, bool allow_realloc)
+ : Buffer(payload_length + sizeof(struct HeaderStruct), allow_realloc)
{
header_ = reinterpret_cast<struct HeaderStruct*>(buf_);
- auth_tag_ = NULL;
- payload_ = buf_ + sizeof(struct HeaderStruct); // no authtag yet
- length_ = sizeof(struct HeaderStruct);
- max_length_ = max_payload_length + AUTHTAG_SIZE;
-}
-
-
-EncryptedPacket::~EncryptedPacket()
-{
- buf_ = reinterpret_cast<u_int8_t*>(header_);
- if( auth_tag_ == NULL )
- length_ = max_length_ + sizeof(struct HeaderStruct) + AUTHTAG_SIZE;
- else
- length_ = max_length_ + sizeof(struct HeaderStruct);
-}
-
-void EncryptedPacket::setPayloadLength(u_int32_t payload_length)
-{
- if( auth_tag_)
- length_= payload_length + sizeof(struct HeaderStruct)+AUTHTAG_SIZE;
- else
- length_= payload_length + sizeof(struct HeaderStruct);
+ payload_ = buf_ + sizeof(struct HeaderStruct); // TODO: fix auth_tag stuff
+ auth_tag_ = NULL; // TODO: fix auth_tag stuff
+ if(header_)
+ {
+ header_->seq_nr = 0;
+ header_->sender_id = 0;
+ header_->mux = 0;
+ }
}
seq_nr_t EncryptedPacket::getSeqNr() const
{
- return SEQ_NR_T_NTOH(header_->seq_nr);
+ if(header_)
+ return SEQ_NR_T_NTOH(header_->seq_nr);
+
+ return 0;
}
sender_id_t EncryptedPacket::getSenderId() const
{
- return SENDER_ID_T_NTOH(header_->sender_id);
-}
+ if(header_)
+ return SENDER_ID_T_NTOH(header_->sender_id);
-mux_t EncryptedPacket::getMux() const
-{
- return MUX_T_NTOH(header_->mux);
+ return 0;
}
-u_int32_t EncryptedPacket::getMaxLength() const
+mux_t EncryptedPacket::getMux() const
{
- return max_length_;
-}
+ if(header_)
+ return MUX_T_NTOH(header_->mux);
-void EncryptedPacket::setLength(u_int32_t length)
-{
- if(length > max_length_)
- throw std::out_of_range("can't set length greater then size ofsize of allocated memory");
-
- length_ = length;
- if( auth_tag_)
- payload_length_ = length_ - sizeof(struct HeaderStruct)+AUTHTAG_SIZE;
- else
- payload_length_ = length_ - sizeof(struct HeaderStruct);
+ return 0;
}
void EncryptedPacket::setSeqNr(seq_nr_t seq_nr)
{
- header_->seq_nr = SEQ_NR_T_HTON(seq_nr);
+ if(header_)
+ header_->seq_nr = SEQ_NR_T_HTON(seq_nr);
}
void EncryptedPacket::setSenderId(sender_id_t sender_id)
{
- header_->sender_id = SENDER_ID_T_HTON(sender_id);
+ if(header_)
+ header_->sender_id = SENDER_ID_T_HTON(sender_id);
}
void EncryptedPacket::setMux(mux_t mux)
{
- header_->mux = MUX_T_HTON(mux);
+ if(header_)
+ header_->mux = MUX_T_HTON(mux);
}
void EncryptedPacket::setHeader(seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
{
+ if(!header_)
+ return;
+
header_->seq_nr = SEQ_NR_T_HTON(seq_nr);
header_->sender_id = SENDER_ID_T_HTON(sender_id);
header_->mux = MUX_T_HTON(mux);
}
+u_int32_t EncryptedPacket::getPayloadLength() const
+{
+ return (length_ > sizeof(struct HeaderStruct)) ? (length_ - sizeof(struct HeaderStruct)) : 0; // TODO: fix auth_tag stuff
+}
+
+void EncryptedPacket::setPayloadLength(u_int32_t payload_length)
+{
+ Buffer::setLength(payload_length + sizeof(struct HeaderStruct));
+
+ // depending on allow_realloc buf_ may point to another address
+ header_ = reinterpret_cast<struct HeaderStruct*>(buf_);
+ payload_ = buf_ + sizeof(struct HeaderStruct); // TODO: fix auth_tag stuff
+ auth_tag_ = NULL; // TODO: fix auth_tag stuff
+}
+
+u_int8_t* EncryptedPacket::getPayload()
+{
+ return payload_;
+}
+
+
+
+
+
+
+// TODO: fix auth_tag stuff
+
bool EncryptedPacket::hasAuthTag() const
{
- if( auth_tag_ == NULL )
- return false;
- return true;
+// if( auth_tag_ == NULL )
+ return false;
+// return true;
}
void EncryptedPacket::withAuthTag(bool b)
{
- if( b && (auth_tag_ != NULL) )
- throw std::runtime_error("packet already has auth tag function enabled");
- //TODO: return instead?
- if( ! b && (auth_tag_ == NULL) )
- throw std::runtime_error("packet already has auth tag function disabled");
- //TODO: return instead?
-
- if( b ) {
- auth_tag_ = reinterpret_cast<AuthTag*>( buf_ + sizeof(struct HeaderStruct) );
- payload_ = payload_ + AUTHTAG_SIZE;
- length_ -= AUTHTAG_SIZE;
- max_length_ -= AUTHTAG_SIZE;
- } else {
- payload_ = reinterpret_cast<u_int8_t*>( auth_tag_ );
- length_ += AUTHTAG_SIZE;
- max_length_ += AUTHTAG_SIZE;
- auth_tag_ = NULL;
- }
+// if( b && (auth_tag_ != NULL) )
+// throw std::runtime_error("packet already has auth tag function enabled");
+// //TODO: return instead?
+// if( ! b && (auth_tag_ == NULL) )
+// throw std::runtime_error("packet already has auth tag function disabled");
+// //TODO: return instead?
+
+// if( b ) {
+// auth_tag_ = reinterpret_cast<AuthTag*>( buf_ + sizeof(struct HeaderStruct) );
+// payload_ = payload_ + AUTHTAG_SIZE;
+// length_ -= AUTHTAG_SIZE;
+// max_length_ -= AUTHTAG_SIZE;
+// } else {
+// payload_ = reinterpret_cast<u_int8_t*>( auth_tag_ );
+// length_ += AUTHTAG_SIZE;
+// max_length_ += AUTHTAG_SIZE;
+// auth_tag_ = NULL;
+// }
}
void EncryptedPacket::setAuthTag(AuthTag& tag)
{
- if( auth_tag_ == NULL )
- throw std::runtime_error("auth tag not enabled");
+// if( auth_tag_ == NULL )
+// throw std::runtime_error("auth tag not enabled");
- if( tag == AuthTag(0) )
- return;
+// if( tag == AuthTag(0) )
+// return;
- if( tag.getLength() != AUTHTAG_SIZE )
- throw std::length_error("authtag length mismatch with AUTHTAG_SIZE");
+// if( tag.getLength() != AUTHTAG_SIZE )
+// throw std::length_error("authtag length mismatch with AUTHTAG_SIZE");
- std::memcpy( auth_tag_, tag.getBuf(), AUTHTAG_SIZE );
+// std::memcpy( auth_tag_, tag.getBuf(), AUTHTAG_SIZE );
}
AuthTag EncryptedPacket::getAuthTag() const
{
- if( auth_tag_ == NULL )
- throw std::runtime_error("auth tag not enabled");
+// if( auth_tag_ == NULL )
+// throw std::runtime_error("auth tag not enabled");
AuthTag at(AUTHTAG_SIZE);
- std::memcpy(at, auth_tag_, AUTHTAG_SIZE );
+// std::memcpy(at, auth_tag_, AUTHTAG_SIZE );
return at;
}
-
diff --git a/encryptedPacket.h b/encryptedPacket.h
index afc7d0e..0b934f6 100644
--- a/encryptedPacket.h
+++ b/encryptedPacket.h
@@ -41,14 +41,15 @@ public:
/**
* Packet constructor
- * @param max_payload_length maximum length of encrypted payload
+ * @param the length of the payload
+ * @param allow reallocation of buffer
*/
- EncryptedPacket(u_int32_t max_payload_length);
+ EncryptedPacket(u_int32_t payload_length, bool allow_realloc = false);
/**
* Packet destructor
*/
- ~EncryptedPacket();
+ ~EncryptedPacket() {};
/**
* Get the sequence number
@@ -94,36 +95,36 @@ public:
*/
void setHeader(seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
- /**
- * Get the maximum payload size
- * @return maximum payload size
+ /**
+ * Get the length of the payload
+ * @return the length of the payload
+ */
+ u_int32_t getPayloadLength() const;
+
+ /**
+ * Set the length of the payload
+ * @param length length of the payload
*/
- u_int32_t getMaxLength() const;
+ void setPayloadLength(u_int32_t payload_length);
/**
- * Set the real length of the payload
- * @param length the real length of the payload, has to be smaller than the maximum payload size!
+ * Get the the payload
+ * @return the Pointer to the payload
*/
- void setLength(u_int32_t length);
+ u_int8_t* getPayload();
+
+
bool hasAuthTag() const;
void withAuthTag(bool b);
AuthTag getAuthTag() const;
void setAuthTag(AuthTag& tag);
- void setPayloadLength(u_int32_t payload_length);
-
-
-// bool hasHeader() const;
-// Packet& withHeader(bool b);
-// Packet& addHeader(seq_nr_t seq_nr, sender_id_t sender_id);
-// Packet& withAuthTag(bool b);
-// AuthTag getAuthTag() const;
-// Packet& addAuthTag(AuthTag auth_tag);
-
+
private:
EncryptedPacket();
EncryptedPacket(const EncryptedPacket &src);
+
struct HeaderStruct
{
seq_nr_t seq_nr;
@@ -132,14 +133,9 @@ private:
}__attribute__((__packed__));
struct HeaderStruct* header_;
- AuthTag* auth_tag_;
- u_int32_t max_length_;
-
- static const u_int32_t AUTHTAG_SIZE = 10; // 10byte
-protected:
- friend class Cipher;
u_int8_t * payload_;
- u_int32_t payload_length_;
+ AuthTag* auth_tag_;
+ static const u_int32_t AUTHTAG_SIZE = 10; // TODO: hardcoded size
};
#endif
diff --git a/keyDerivation.cpp b/keyDerivation.cpp
index f3d1fe6..dbafec6 100644
--- a/keyDerivation.cpp
+++ b/keyDerivation.cpp
@@ -97,50 +97,37 @@ void KeyDerivation::generate(satp_prf_label label, seq_nr_t seq_nr, Buffer& key)
// alignment).
//
- Mpi r(48);
+ Mpi r(48); // ld(kdr) <= 48
if( ld_kdr_ == -1 ) // means key_derivation_rate = 0
- r = 0;
+ r = 0; // TODO: no new key should be generated if r == 0, except it is the first time
else
{
Mpi seq = seq_nr;
Mpi rate = 1;
rate = rate.mul2exp(ld_kdr_);
r = seq / rate;
- }
-
- std::cout << "r: " << std::endl;
- std::cout << r.getHexDump();
+ }
+ // TODO: generate key only if index % r == 0, except it is the first time
- Mpi key_id(128), l(128); // TODO: hardcoded keySize
+ Mpi key_id(128); // TODO: hardcoded size
+ Mpi l(128); // TODO: hardcoded size
l = label;
- key_id = l.mul2exp(48) + r; // TODO: hardcoded keySize
-
- std::cout << "label: " << std::endl;
- std::cout << l.getHexDump();
+ key_id = l.mul2exp(48) + r;
- std::cout << "keyid: " << std::endl;
- std::cout << key_id.getHexDump();
-
- Mpi x(128); // TODO: hardcoded keySize
- Mpi salt = Mpi(master_salt_.getBuf(), master_salt_.getLength());
+ Mpi salt(master_salt_.getBuf(), master_salt_.getLength());
+ Mpi x(128); // TODO: hardcoded size
x = key_id ^ salt;
-
- std::cout << "x: " << std::endl;
- std::cout << x.getHexDump();
-
- std::cout << "x*2^16(ctr): " << std::endl;
- std::cout << x.mul2exp(16).getHexDump();
- u_int8_t *ctr_buf = x.mul2exp(16).getNewBuf(16); // TODO: hardcoded size
- err = gcry_cipher_setctr( cipher_ , ctr_buf, 16); // TODO: hardcoded size
-
+ u_int32_t written;
+ u_int8_t *ctr_buf = x.mul2exp(16).getNewBuf(&written); // TODO: hardcoded size
+ err = gcry_cipher_setctr( cipher_ , ctr_buf, written );
delete[] ctr_buf;
+
if( err )
- cLog.msg(Log::PRIO_ERR) << "KeyDerivation::generate: Failed to set IV: " << gpg_strerror( err );
+ cLog.msg(Log::PRIO_ERR) << "KeyDerivation::generate: Failed to set CTR: " << gpg_strerror( err );
- u_int8_t *x_buf = x.getNewBuf(16); // TODO: hardcoded size
- err = gcry_cipher_encrypt( cipher_, key, key.getLength(), x_buf, 16 ); // TODO: hardcoded size
- delete[] x_buf;
+ for(u_int32_t i=0; i < key.getLength(); ++i) key[i] = 0;
+ err = gcry_cipher_encrypt( cipher_, key, key.getLength(), NULL, 0);
if( err )
cLog.msg(Log::PRIO_ERR) << "KeyDerivation::generate: Failed to generate cipher bitstream: " << gpg_strerror( err );
}
diff --git a/mpi.cpp b/mpi.cpp
index 18a3349..a2590fa 100644
--- a/mpi.cpp
+++ b/mpi.cpp
@@ -36,41 +36,71 @@
#include <stdexcept>
#include <gcrypt.h>
+#include <iostream>
-Mpi::Mpi()
+Mpi::Mpi() : val_(NULL)
{
- val_ = gcry_mpi_new(1);
+ val_ = gcry_mpi_set_ui(NULL, 0);
+ if(!val_)
+ throw std::bad_alloc();
}
-Mpi::Mpi(u_int8_t length)
+Mpi::Mpi(u_int8_t length) : val_(NULL)
{
val_ = gcry_mpi_new(length);
+ if(!val_)
+ throw std::bad_alloc();
}
-Mpi::Mpi(const Mpi &src)
+Mpi::Mpi(const Mpi &src) : val_(NULL)
{
val_ = gcry_mpi_copy(src.val_);
+ if(!val_)
+ throw std::bad_alloc();
}
-Mpi::Mpi(const u_int8_t * src, u_int32_t len)
+Mpi::Mpi(const u_int8_t* src, u_int32_t len) : val_(NULL)
{
- gcry_mpi_scan( &val_, GCRYMPI_FMT_STD, src, len, NULL );
+ u_int8_t* src_cpy = new u_int8_t[len+1];
+ if(!src_cpy)
+ throw std::bad_alloc();
+
+ u_int8_t* buf = src_cpy;
+ u_int32_t buf_len = len;
+ if(src[0] & 0x80) // this would be a negative number, scan can't handle this :(
+ {
+ src_cpy[0] = 0;
+ buf++;
+ buf_len++;
+ }
+ std::memcpy(buf, src, len);
+
+ gcry_mpi_scan( &val_, GCRYMPI_FMT_STD, src_cpy, buf_len, NULL );
+ delete[] src_cpy;
+ if(!val_)
+ throw std::bad_alloc();
}
Mpi::~Mpi()
{
- gcry_mpi_release( val_ );
+ gcry_mpi_release( val_ );
}
void Mpi::operator=(const Mpi &src)
{
+ gcry_mpi_release( val_ );
val_ = gcry_mpi_copy(src.val_);
+ if(!val_)
+ throw std::bad_alloc();
}
void Mpi::operator=(const u_int32_t src)
{
- gcry_mpi_set_ui(val_, src);
+ gcry_mpi_release( val_ );
+ val_ = gcry_mpi_set_ui(NULL, src);
+ if(!val_)
+ throw std::bad_alloc();
}
Mpi Mpi::operator+(const Mpi &b) const
@@ -125,22 +155,29 @@ Mpi Mpi::mul2exp(u_int32_t e) const
return res;
}
-u_int8_t* Mpi::getNewBuf(u_int32_t buf_len) const
+//TODO: problem, seems as gcry_mpi_(a)print doesn't work for mpi values of '0'
+u_int8_t* Mpi::getNewBuf(u_int32_t* written) const
{
- // u_int32_t len = 0;
- u_int32_t written = 0;
+ u_int8_t* res_cpy;
+ gcry_mpi_aprint( GCRYMPI_FMT_STD, &res_cpy, written, val_ );
+ if(!res_cpy)
+ throw std::bad_alloc();
+
+ u_int8_t* buf = res_cpy;
+ if(*written > 1 && ! (res_cpy[0])) // positive number with highestBit set
+ {
+ buf++;
+ (*written)--;
+ }
+
+ u_int8_t* res = new u_int8_t[*written];
+ if(!res)
+ throw std::bad_alloc();
- u_int8_t *res = NULL;
- res = new u_int8_t[buf_len];
- std::memset(res, 0, buf_len);
+ std::memcpy(res, buf, *written);
- // len = gcry_mpi_get_nbits( val_ );
- // if( len%8 == 0 )
- // len = len/8;
- // else
- // len = (len/8)+1;
+ gcry_free(res_cpy);
- gcry_mpi_print( GCRYMPI_FMT_STD, res, buf_len, &written, val_ );
return res;
}
@@ -151,7 +188,8 @@ std::string Mpi::getHexDump() const
// u_int32_t len;
// gcry_mpi_aprint( GCRYMPI_FMT_HEX, &buf, &len, val_ );
// std::string res(buf, len);
-
+// delete[] buf;
+
gcry_mpi_dump( val_ );
std::string res("\n");
return res;
diff --git a/mpi.h b/mpi.h
index 70f6681..c746602 100644
--- a/mpi.h
+++ b/mpi.h
@@ -70,7 +70,7 @@ public:
* @param buf_len size of the new buffer that is returned
* @return a byte buffer of size buf_len
*/
- u_int8_t *getNewBuf(u_int32_t buf_len) const;
+ u_int8_t *getNewBuf(u_int32_t* written) const;
std::string getHexDump() const;
u_int32_t getLength() const;
diff --git a/plainPacket.cpp b/plainPacket.cpp
index 3ce1521..0906fa2 100644
--- a/plainPacket.cpp
+++ b/plainPacket.cpp
@@ -33,43 +33,14 @@
#include <arpa/inet.h>
#include "datatypes.h"
-
#include "plainPacket.h"
-
-PlainPacket::~PlainPacket()
-{
- buf_=complete_payload_;
- length_=max_length_;
-}
-
-PlainPacket::PlainPacket(u_int32_t max_payload_length) : Buffer(max_payload_length + sizeof(payload_type_t))
+PlainPacket::PlainPacket(u_int32_t payload_length, bool allow_realloc) : Buffer(payload_length + sizeof(payload_type_t), allow_realloc)
{
- payload_type_ = NULL;
- splitPayload();
-}
-
-void PlainPacket::splitPayload()
-{
- complete_payload_length_ = length_;
- complete_payload_ = buf_;
-
payload_type_ = reinterpret_cast<payload_type_t*>(buf_);
- buf_ += sizeof(payload_type_t);
- length_ -= sizeof(payload_type_t);
- max_length_ = length_;
-}
-
-void PlainPacket::setCompletePayloadLength(u_int32_t payload_length)
-{
- complete_payload_length_ = payload_length;
- length_=complete_payload_length_-sizeof(payload_type_t);
-}
-
-u_int32_t PlainPacket::getCompletePayloadLength()
-{
- return complete_payload_length_;
+ payload_ = buf_ + sizeof(payload_type_t);
+ *payload_type_ = 0;
}
payload_type_t PlainPacket::getPayloadType() const
@@ -83,16 +54,21 @@ void PlainPacket::setPayloadType(payload_type_t payload_type)
*payload_type_ = PAYLOAD_TYPE_T_HTON(payload_type);
}
-void PlainPacket::setLength(u_int32_t length)
+u_int32_t PlainPacket::getPayloadLength() const
+{
+ return (length_ > sizeof(payload_type_t)) ? (length_ - sizeof(payload_type_t)) : 0;
+}
+
+void PlainPacket::setPayloadLength(u_int32_t payload_length)
{
- if(length > max_length_)
- throw std::out_of_range("can't set length greater then size ofsize of allocated memory");
+ Buffer::setLength(payload_length + sizeof(payload_type_t));
- length_ = length;
- complete_payload_length_ = length_ + sizeof(payload_type_t);
+ // depending on allow_realloc buf_ may point to another address
+ payload_type_ = reinterpret_cast<payload_type_t*>(buf_);
+ payload_ = buf_ + sizeof(payload_type_t);
}
-u_int32_t PlainPacket::getMaxLength() const
+u_int8_t* PlainPacket::getPayload()
{
- return max_length_;
+ return payload_;
}
diff --git a/plainPacket.h b/plainPacket.h
index 176d841..54c387a 100644
--- a/plainPacket.h
+++ b/plainPacket.h
@@ -43,13 +43,17 @@ class Cipher;
class PlainPacket : public Buffer
{
public:
- ~PlainPacket();
-
/**
* Packet constructor
- * @param max_payload_length maximum payload length
+ * @param the length of the payload
+ * @param allow reallocation of buffer
*/
- PlainPacket(u_int32_t max_payload_length);
+ PlainPacket(u_int32_t payload_length, bool allow_realloc = false);
+
+ /**
+ * Packet destructor
+ */
+ ~PlainPacket() {};
/**
* Get the payload type
@@ -63,43 +67,31 @@ public:
*/
void setPayloadType(payload_type_t payload_type);
- void setCompletePayloadLength(u_int32_t payload_length);
- u_int32_t getCompletePayloadLength();
-
/**
- * Set the real payload length
- * @param length the real payload length
+ * Get the length of the payload
+ * @return the length of the payload
*/
- //void setRealPayloadLengt(u_int32_t length);
-
- /**
- * Get the real payload length
- * @return the real length of the payload
- */
- //u_int32_t getRealPayloadLength();
+ u_int32_t getPayloadLength() const;
/**
* Set the length of the payload
* @param length length of the payload
*/
- void setLength(u_int32_t length);
+ void setPayloadLength(u_int32_t payload_length);
/**
- * Get the size of the allocated memory for the payload
- * @return maximum size of payload
+ * Get the the payload
+ * @return the Pointer to the payload
*/
- u_int32_t getMaxLength() const;
+ u_int8_t* getPayload();
+
private:
PlainPacket();
PlainPacket(const PlainPacket &src);
- void splitPayload();
- u_int32_t max_length_;
+
payload_type_t* payload_type_;
-protected:
- friend class Cipher;
- u_int8_t * complete_payload_;
- u_int32_t complete_payload_length_;
+ u_int8_t* payload_;
};
#endif