From b8766a2041d57a3aa49c3855902953f8de0b0ec1 Mon Sep 17 00:00:00 2001 From: Christian Pointner Date: Sat, 23 Feb 2008 23:05:43 +0000 Subject: - keyderivation and cipher should work now however it needs further testing - rewrite of Buffer and Packets --- anytun.cpp | 104 +++++++++++-------------------- buffer.cpp | 174 ++++++++++++++++++++++++---------------------------- buffer.h | 21 ++++--- cipher.cpp | 31 ++++++---- cipher.h | 5 +- encryptedPacket.cpp | 168 ++++++++++++++++++++++++++------------------------ encryptedPacket.h | 50 +++++++-------- keyDerivation.cpp | 45 +++++--------- mpi.cpp | 80 +++++++++++++++++------- mpi.h | 2 +- plainPacket.cpp | 54 +++++----------- plainPacket.h | 44 ++++++------- 12 files changed, 369 insertions(+), 409 deletions(-) diff --git a/anytun.cpp b/anytun.cpp index 5aab904..e679281 100644 --- a/anytun.cpp +++ b/anytun.cpp @@ -31,7 +31,7 @@ #include #include -#include // for thread safe libgcrypt initialisation +#include #include // for ENOMEM #include "datatypes.h" @@ -53,6 +53,9 @@ #include "seqWindow.h" #include "connectionList.h" +#include "mpi.h" // TODO: remove after debug + + #include "syncQueue.h" #include "syncSocketHandler.h" #include "syncListenSocket.h" @@ -66,9 +69,11 @@ #define PAYLOAD_TYPE_TAP 0x6558 #define PAYLOAD_TYPE_TUN 0x0800 -#define SESSION_KEYLEN_AUTH 20 -#define SESSION_KEYLEN_ENCR 16 -#define SESSION_KEYLEN_SALT 14 +#define MAX_PACKET_LENGTH 1600 + +#define SESSION_KEYLEN_AUTH 20 // TODO: hardcoded size +#define SESSION_KEYLEN_ENCR 16 // TODO: hardcoded size +#define SESSION_KEYLEN_SALT 14 // TODO: hardcoded size void createConnection(const std::string & remote_host, u_int16_t remote_port, ConnectionList & cl, u_int16_t seqSize, SyncQueue & queue) { @@ -128,32 +133,30 @@ void* sender(void* p) std::auto_ptr c(CipherFactory::create(param->opt.getCipher())); // std::auto_ptr a(AuthAlgoFactory::create(param->opt.getAuthAlgo()) ); - PlainPacket plain_packet(1600); // TODO: fix me... mtu size - EncryptedPacket packet(1600); + PlainPacket plain_packet(MAX_PACKET_LENGTH); + EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH); - // TODO: hardcoded keySize!!! - Buffer session_key(SESSION_KEYLEN_ENCR); - Buffer session_salt(SESSION_KEYLEN_SALT); - Buffer session_auth_key(SESSION_KEYLEN_AUTH); + Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size + Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size + Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size //TODO replace mux u_int16_t mux = 0; while(1) { - plain_packet.setLength( plain_packet.getMaxLength()); // Q@NINE wtf??? - // read packet from device u_int32_t len = param->dev.read(plain_packet); plain_packet.setLength(len); - packet.setLength( len ); - if( param->cl.empty()) + + if(param->cl.empty()) continue; - ConnectionMap::iterator cit = param->cl.getConnection(mux); + + ConnectionMap::iterator cit = param->cl.getConnection(mux); if(cit==param->cl.getEnd()) continue; ConnectionParam & conn = cit->second; - // add payload type + // set payload type if(param->dev.getType() == TunDevice::TYPE_TUN) plain_packet.setPayloadType(PAYLOAD_TYPE_TUN); else if(param->dev.getType() == TunDevice::TYPE_TAP) @@ -168,16 +171,16 @@ void* sender(void* p) c->setSalt(session_salt); // encrypt packet - c->encrypt(plain_packet, packet, conn.seq_nr_, param->opt.getSenderId()); + c->encrypt(plain_packet, encrypted_packet, conn.seq_nr_, param->opt.getSenderId()); - packet.setHeader(conn.seq_nr_, param->opt.getSenderId(), mux); + encrypted_packet.setHeader(conn.seq_nr_, param->opt.getSenderId(), mux); conn.seq_nr_++; // TODO: activate authentication -// conn.kd_.generate(LABEL_SATP_MSG_AUTH, packet.getSeqNr(), session_auth_key); +// conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key); // a->setKey(session_auth_key); -// addPacketAuthTag(packet, a.get(), conn); - param->src.send(packet, conn.remote_host_, conn.remote_port_); +// addPacketAuthTag(encrypted_packet, a.get(), conn); + param->src.send(encrypted_packet, conn.remote_host_, conn.remote_port_); } pthread_exit(NULL); } @@ -223,30 +226,26 @@ void* receiver(void* p) std::auto_ptr c( CipherFactory::create(param->opt.getCipher()) ); // std::auto_ptr a( AuthAlgoFactory::create(param->opt.getAuthAlgo()) ); - EncryptedPacket packet(1600); // TODO: dynamic mtu size + EncryptedPacket encrypted_packet(1600); // TODO: dynamic mtu size PlainPacket plain_packet(1600); - // TODO: hardcoded keysize!!! - Buffer session_key(SESSION_KEYLEN_SALT); - Buffer session_salt(SESSION_KEYLEN_SALT); - Buffer session_auth_key(SESSION_KEYLEN_AUTH); + Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size + Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size + Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size while(1) { string remote_host; u_int16_t remote_port; - packet.setLength( packet.getMaxLength() ); // Q@NINE wtf??? - plain_packet.setLength( plain_packet.getMaxLength() ); // Q@NINE wtf??? - // u_int16_t sid = 0, seq = 0; // read packet from socket - u_int32_t len = param->src.recv(packet, remote_host, remote_port); - packet.setLength(len); + u_int32_t len = param->src.recv(encrypted_packet, remote_host, remote_port); + encrypted_packet.setLength(len); // TODO: check auth tag first -// conn.kd_.generate(LABEL_SATP_MSG_AUTH, packet.getSeqNr(), session_auth_key); +// conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key); // a->setKey( session_auth_key ); -// if(!checkPacketAuthTag(packet, a.get(), conn)) +// if(!checkPacketAuthTag(encrypted_packet, a.get(), conn)) // continue; @@ -272,17 +271,17 @@ void* receiver(void* p) } // Replay Protection - if (!checkPacketSeqNr(packet, conn)) + if (!checkPacketSeqNr(encrypted_packet, conn)) continue; // generate packet-key - conn.kd_.generate(LABEL_SATP_ENCRYPTION, packet.getSeqNr(), session_key); - conn.kd_.generate(LABEL_SATP_SALT, packet.getSeqNr(), session_salt); + conn.kd_.generate(LABEL_SATP_ENCRYPTION, encrypted_packet.getSeqNr(), session_key); + conn.kd_.generate(LABEL_SATP_SALT, encrypted_packet.getSeqNr(), session_salt); c->setKey(session_key); c->setSalt(session_salt); // decrypt packet - c->decrypt(packet, plain_packet); + c->decrypt(encrypted_packet, plain_packet); // check payload_type if((param->dev.getType() == TunDevice::TYPE_TUN && plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN) || @@ -341,37 +340,6 @@ bool initLibGCrypt() int main(int argc, char* argv[]) { -// // this must be called before any other libgcrypt call -// if(!initLibGCrypt()) -// return -1; - -// u_int8_t KEY[] = {0xE1,0xF9,0x7A,0x0D,0x3E,0x01,0x8B,0xE0,0xD6,0x4F,0xA3,0x2C,0x06,0xDE,0x41,0x39}; -// u_int8_t SALT[] = {0x0E,0xC6,0x75,0xAD,0x49,0x8A,0xFE,0xEB,0xB6,0x96,0x0B,0x3A,0xAB,0xE6}; -// Buffer master_key(KEY, 16); -// Buffer master_salt(SALT, 14); -// std::cout << "master key: " << std::endl << master_key.getHexDump() << std::endl; -// std::cout << "master salt: " << std::endl << master_salt.getHexDump() << std::endl; -// std::cout << std::endl; -// KeyDerivation kd; -// kd.init(master_key, master_salt); - -// Buffer key(16); -// kd.generate(LABEL_SATP_ENCRYPTION, 0, key); -// std::cout << "key: " << std::endl << key.getHexDump() << std::endl; - -// Buffer salt(14); -// kd.generate(LABEL_SATP_SALT, 0, salt); -// std::cout << "salt: " << std::endl << salt.getHexDump() << std::endl; - -// Buffer auth(14); -// kd.generate(LABEL_SATP_MSG_AUTH, 0, auth); -// std::cout << "auth: " << std::endl << auth.getHexDump() << std::endl; - - -// exit(0); - -// // *++++++++++++++++++ end of kd test - std::cout << "anytun - secure anycast tunneling protocol" << std::endl; Options opt; if(!opt.parse(argc, argv)) diff --git a/buffer.cpp b/buffer.cpp index e8191f2..cea8d2c 100644 --- a/buffer.cpp +++ b/buffer.cpp @@ -36,26 +36,32 @@ #include "datatypes.h" #include "buffer.h" -Buffer::Buffer() : buf_(0), length_(0) +Buffer::Buffer(bool allow_realloc) : buf_(0), length_(0), real_length_(0), allow_realloc_(allow_realloc) { } -Buffer::Buffer(u_int32_t length) : length_(length) +Buffer::Buffer(u_int32_t length, bool allow_realloc) : length_(length), real_length_(length_ + Buffer::OVER_SIZE_), + allow_realloc_(allow_realloc) { - buf_ = new u_int8_t[length_]; - if(buf_) - std::memset(buf_, 0, length_); - else + buf_ = new u_int8_t[real_length_]; + if(!buf_) { length_ = 0; + real_length_ = 0; + throw std::bad_alloc(); + } + std::memset(buf_, 0, real_length_); } -Buffer::Buffer(u_int8_t* data, u_int32_t length) : length_(length) +Buffer::Buffer(u_int8_t* data, u_int32_t length, bool allow_realloc) : length_(length), real_length_(length + Buffer::OVER_SIZE_), + allow_realloc_(allow_realloc) { - buf_ = new u_int8_t[length_]; - if(buf_) - std::memcpy(buf_, data, length_); - else + buf_ = new u_int8_t[real_length_]; + if(!buf_) { length_ = 0; + real_length_ = 0; + throw std::bad_alloc(); + } + std::memcpy(buf_, data, length_); } Buffer::~Buffer() @@ -64,13 +70,15 @@ Buffer::~Buffer() delete[] buf_; } -Buffer::Buffer(const Buffer &src) : length_(src.length_) +Buffer::Buffer(const Buffer &src) : length_(src.length_), real_length_(src.real_length_), allow_realloc_(src.allow_realloc_) { - buf_ = new u_int8_t[length_]; - if(buf_) - std::memcpy(buf_, src.buf_, length_); - else + buf_ = new u_int8_t[real_length_]; + if(!buf_) { length_ = 0; + real_length_ = 0; + throw std::bad_alloc(); + } + std::memcpy(buf_, src.buf_, length_); } void Buffer::operator=(const Buffer &src) @@ -79,12 +87,16 @@ void Buffer::operator=(const Buffer &src) delete[] buf_; length_ = src.length_; + real_length_ = src.real_length_; + allow_realloc_ = src.allow_realloc_; - buf_ = new u_int8_t[length_]; - if(buf_) - std::memcpy(buf_, src.buf_, length_); - else + buf_ = new u_int8_t[real_length_]; + if(!buf_) { length_ = 0; + real_length_ = 0; + throw std::bad_alloc(); + } + std::memcpy(buf_, src.buf_, length_); } @@ -100,58 +112,60 @@ bool Buffer::operator==(const Buffer &cmp) const return false; } - -u_int32_t Buffer::resizeFront(u_int32_t new_length) +Buffer Buffer::operator^(const Buffer &xor_by) const { - if(length_ == new_length) - return length_; + u_int32_t res_length = (xor_by.length_ > length_) ? xor_by.length_ : length_; + u_int32_t min_length = (xor_by.length_ < length_) ? xor_by.length_ : length_; + Buffer res(res_length); - u_int8_t *tmp = new u_int8_t[new_length]; - if(!tmp) - return length_; - - if(buf_) - { - u_int8_t *src=buf_, *dest=tmp; - if(length_ < new_length) - dest = &dest[new_length - length_]; - else - src = &src[length_ - new_length]; - u_int32_t len = length_ < new_length ? length_ : new_length; - std::memcpy(dest, src, len); - delete[] buf_; - } + for( u_int32_t index = 0; index < min_length; index++ ) + res[index] = buf_[index] ^ xor_by[index]; + + return res; +} - length_ = new_length; - buf_ = tmp; +u_int32_t Buffer::getLength() const +{ return length_; } -u_int32_t Buffer::resizeBack(u_int32_t new_length) +void Buffer::setLength(u_int32_t new_length) { - if(length_ == new_length) - return length_; - - u_int8_t *tmp = new u_int8_t[new_length]; - if(!tmp) - return length_; + if(new_length == length_) + return; - if(buf_) + if(new_length > real_length_) { - u_int32_t len = length_ < new_length ? length_ : new_length; - std::memcpy(tmp, buf_, len); - delete[] buf_; - } + if(!allow_realloc_) + throw std::out_of_range("buffer::setLength() - reallocation not allowed for this Buffer"); + + u_int8_t* old_buf = buf_; + u_int32_t old_length = length_; + + length_ = new_length; + real_length_ = length_ + Buffer::OVER_SIZE_; + + buf_ = new u_int8_t[real_length_]; + if(!buf_) { + length_ = 0; + real_length_ = 0; + if(old_buf) + delete[] old_buf; + + throw std::bad_alloc(); + } + std::memcpy(buf_, old_buf, old_length); - length_ = new_length; - buf_ = tmp; - return length_; -} + if(old_buf) + delete[] old_buf; + + old_buf = &buf_[old_length]; + std::memset(old_buf, 0, real_length_ - old_length); + } + else + length_ = new_length; +} -u_int32_t Buffer::getLength() const -{ - return length_; -} u_int8_t* Buffer::getBuf() { @@ -174,7 +188,7 @@ u_int8_t Buffer::operator[](u_int32_t index) const return buf_[index]; } -Buffer::operator u_int8_t*() // just for write/read tun +Buffer::operator u_int8_t*() { return buf_; } @@ -182,10 +196,10 @@ Buffer::operator u_int8_t*() // just for write/read tun std::string Buffer::getHexDump() const { std::stringstream ss; - ss << std::hex; + ss << "Length=" << length_ << std::endl << std::hex << std::uppercase; for( u_int32_t index = 0; index < length_; index++ ) { - ss << std::setw(2) << std::setfill('0') << static_cast(buf_[index]) << " "; + ss << std::setw(2) << std::setfill('0') << u_int32_t(buf_[index]) << " "; if(!((index+1) % 16)) { ss << std::endl; continue; @@ -196,35 +210,7 @@ std::string Buffer::getHexDump() const return ss.str(); } -Buffer Buffer::operator^(const Buffer &xor_by) const +bool Buffer::isReallocAllowed() const { - Buffer res(length_); - if( xor_by.getLength() > length_ ) - throw std::out_of_range("buffer::operator^ const"); - - for( u_int32_t index = 0; index < xor_by.getLength(); index++ ) - res[index] = buf_[index] ^ xor_by[index]; - - return res; + return allow_realloc_; } - -Buffer Buffer::leftByteShift(u_int32_t width) const -{ - Buffer res(length_+width); - - for( u_int32_t index = 0; index < length_; index++ ) - res[index+width] = buf_[index]; - - return res; -} - -Buffer Buffer::rightByteShift(u_int32_t width) const -{ - Buffer res(length_); - - for( u_int32_t index = 0; index < length_-width; index++ ) - res[index] = buf_[index+width]; - - return res; -} - diff --git a/buffer.h b/buffer.h index 1a426e1..063be91 100644 --- a/buffer.h +++ b/buffer.h @@ -40,32 +40,33 @@ class UDPPacketSource; class Buffer { public: - Buffer(); - Buffer(u_int32_t length); - Buffer(u_int8_t* data, u_int32_t length); + Buffer(bool allow_realloc = true); + Buffer(u_int32_t length, bool allow_realloc = true); + Buffer(u_int8_t* data, u_int32_t length, bool allow_realloc = true); virtual ~Buffer(); Buffer(const Buffer &src); void operator=(const Buffer &src); bool operator==(const Buffer &cmp) const; + Buffer operator^(const Buffer &xor_by) const; - // math operations to calculate IVs and keys - virtual Buffer operator^(const Buffer &xor_by) const; - virtual Buffer leftByteShift(u_int32_t width) const; - virtual Buffer rightByteShift(u_int32_t width) const; - - u_int32_t resizeFront(u_int32_t new_length); - u_int32_t resizeBack(u_int32_t new_length); u_int32_t getLength() const; + virtual void setLength(u_int32_t new_length); u_int8_t* getBuf(); u_int8_t& operator[](u_int32_t index); u_int8_t operator[](u_int32_t index) const; std::string getHexDump() const; + bool isReallocAllowed() const; + operator u_int8_t*(); protected: u_int8_t *buf_; u_int32_t length_; + u_int32_t real_length_; + bool allow_realloc_; + + static const u_int32_t OVER_SIZE_ = 100; }; #endif diff --git a/cipher.cpp b/cipher.cpp index 22b8019..07a9117 100644 --- a/cipher.cpp +++ b/cipher.cpp @@ -39,18 +39,20 @@ #include "log.h" -void Cipher::encrypt(const PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id) + // TODO: in should be const but does not work with getBuf() :( +void Cipher::encrypt(PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id) { - u_int32_t len = cipher(out.payload_, out.payload_length_, in.complete_payload_ , in.complete_payload_length_, seq_nr, sender_id); + u_int32_t len = cipher(in, in.getLength(), out.getPayload(), out.getPayloadLength(), seq_nr, sender_id); out.setSenderId(sender_id); out.setSeqNr(seq_nr); out.setPayloadLength(len); } -void Cipher::decrypt(const EncryptedPacket & in, PlainPacket & out) + // TODO: in should be const but does not work with getBuf() :( +void Cipher::decrypt(EncryptedPacket & in, PlainPacket & out) { - u_int32_t len = decipher(out.complete_payload_, out.complete_payload_length_, in.payload_ , in.payload_length_, in.getSeqNr(), in.getSenderId()); - out.setCompletePayloadLength(len); + u_int32_t len = decipher(in.getPayload() , in.getPayloadLength(), out, out.getLength(), in.getSeqNr(), in.getSenderId()); + out.setLength(len); } @@ -70,9 +72,9 @@ u_int32_t NullCipher::decipher(u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_in //****** AesIcmCipher ****** -AesIcmCipher::AesIcmCipher() : salt_(Buffer(14)) // Q@NINE 14?????? +AesIcmCipher::AesIcmCipher() { - // TODO: hardcoded keysize!!!!! + // TODO: hardcoded keysize gcry_error_t err = gcry_cipher_open( &cipher_, GCRY_CIPHER_AES128, GCRY_CIPHER_MODE_CTR, 0 ); if( err ) cLog.msg(Log::PRIO_CRIT) << "AesIcmCipher::AesIcmCipher: Failed to open cipher"; @@ -82,7 +84,6 @@ AesIcmCipher::AesIcmCipher() : salt_(Buffer(14)) // Q@NINE 14?????? AesIcmCipher::~AesIcmCipher() { gcry_cipher_close( cipher_ ); - cLog.msg(Log::PRIO_DEBUG) << "AesIcmCipher::~AesIcmCipher: closed cipher"; } @@ -129,13 +130,17 @@ void AesIcmCipher::calc(u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t o // // sizeof(k_s) = 112 bit, random Mpi ctr(128); // TODO: hardcoded size - Mpi salt = Mpi(salt_.getBuf(), salt_.getLength()); - Mpi sid = sender_id; // Q@OTTI add mux to sender_id???? - Mpi seq = seq_nr; + Mpi salt(salt_.getBuf(), salt_.getLength()); + Mpi sid(32); // TODO: Q@OTTI add mux to sender_id???? + sid = sender_id; + Mpi seq(32); + seq = seq_nr; ctr = salt.mul2exp(16) ^ sid.mul2exp(64) ^ seq.mul2exp(16); // TODO: hardcoded size - u_int8_t *ctr_buf = ctr.getNewBuf(16); // TODO: hardcoded size - err = gcry_cipher_setctr( cipher_, ctr_buf, 16 ); // TODO: hardcoded size + + u_int32_t written; + u_int8_t *ctr_buf = ctr.getNewBuf(&written); // TODO: hardcoded size + err = gcry_cipher_setctr( cipher_, ctr_buf, written ); // TODO: hardcoded size delete[] ctr_buf; if( err ) { cLog.msg(Log::PRIO_ERR) << "AesIcmCipher: Failed to set cipher CTR: " << gpg_strerror( err ); diff --git a/cipher.h b/cipher.h index 190d9fe..0749859 100644 --- a/cipher.h +++ b/cipher.h @@ -44,8 +44,9 @@ class Cipher public: virtual ~Cipher() {}; - void encrypt(const PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id); - void decrypt(const EncryptedPacket & in, PlainPacket & out); + // TODO: in should be const but does not work with getBuf() :( + void encrypt(PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id); + void decrypt(EncryptedPacket & in, PlainPacket & out); virtual void setKey(Buffer key) = 0; virtual void setSalt(Buffer salt) = 0; diff --git a/encryptedPacket.cpp b/encryptedPacket.cpp index 0731d24..a3e5886 100644 --- a/encryptedPacket.cpp +++ b/encryptedPacket.cpp @@ -38,139 +38,149 @@ #include "authTag.h" #include "log.h" - -EncryptedPacket::EncryptedPacket(u_int32_t max_payload_length) - : Buffer(max_payload_length + sizeof(struct HeaderStruct) + AUTHTAG_SIZE) +// TODO: fix auth_tag stuff +EncryptedPacket::EncryptedPacket(u_int32_t payload_length, bool allow_realloc) + : Buffer(payload_length + sizeof(struct HeaderStruct), allow_realloc) { header_ = reinterpret_cast(buf_); - auth_tag_ = NULL; - payload_ = buf_ + sizeof(struct HeaderStruct); // no authtag yet - length_ = sizeof(struct HeaderStruct); - max_length_ = max_payload_length + AUTHTAG_SIZE; -} - - -EncryptedPacket::~EncryptedPacket() -{ - buf_ = reinterpret_cast(header_); - if( auth_tag_ == NULL ) - length_ = max_length_ + sizeof(struct HeaderStruct) + AUTHTAG_SIZE; - else - length_ = max_length_ + sizeof(struct HeaderStruct); -} - -void EncryptedPacket::setPayloadLength(u_int32_t payload_length) -{ - if( auth_tag_) - length_= payload_length + sizeof(struct HeaderStruct)+AUTHTAG_SIZE; - else - length_= payload_length + sizeof(struct HeaderStruct); + payload_ = buf_ + sizeof(struct HeaderStruct); // TODO: fix auth_tag stuff + auth_tag_ = NULL; // TODO: fix auth_tag stuff + if(header_) + { + header_->seq_nr = 0; + header_->sender_id = 0; + header_->mux = 0; + } } seq_nr_t EncryptedPacket::getSeqNr() const { - return SEQ_NR_T_NTOH(header_->seq_nr); + if(header_) + return SEQ_NR_T_NTOH(header_->seq_nr); + + return 0; } sender_id_t EncryptedPacket::getSenderId() const { - return SENDER_ID_T_NTOH(header_->sender_id); -} + if(header_) + return SENDER_ID_T_NTOH(header_->sender_id); -mux_t EncryptedPacket::getMux() const -{ - return MUX_T_NTOH(header_->mux); + return 0; } -u_int32_t EncryptedPacket::getMaxLength() const +mux_t EncryptedPacket::getMux() const { - return max_length_; -} + if(header_) + return MUX_T_NTOH(header_->mux); -void EncryptedPacket::setLength(u_int32_t length) -{ - if(length > max_length_) - throw std::out_of_range("can't set length greater then size ofsize of allocated memory"); - - length_ = length; - if( auth_tag_) - payload_length_ = length_ - sizeof(struct HeaderStruct)+AUTHTAG_SIZE; - else - payload_length_ = length_ - sizeof(struct HeaderStruct); + return 0; } void EncryptedPacket::setSeqNr(seq_nr_t seq_nr) { - header_->seq_nr = SEQ_NR_T_HTON(seq_nr); + if(header_) + header_->seq_nr = SEQ_NR_T_HTON(seq_nr); } void EncryptedPacket::setSenderId(sender_id_t sender_id) { - header_->sender_id = SENDER_ID_T_HTON(sender_id); + if(header_) + header_->sender_id = SENDER_ID_T_HTON(sender_id); } void EncryptedPacket::setMux(mux_t mux) { - header_->mux = MUX_T_HTON(mux); + if(header_) + header_->mux = MUX_T_HTON(mux); } void EncryptedPacket::setHeader(seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { + if(!header_) + return; + header_->seq_nr = SEQ_NR_T_HTON(seq_nr); header_->sender_id = SENDER_ID_T_HTON(sender_id); header_->mux = MUX_T_HTON(mux); } +u_int32_t EncryptedPacket::getPayloadLength() const +{ + return (length_ > sizeof(struct HeaderStruct)) ? (length_ - sizeof(struct HeaderStruct)) : 0; // TODO: fix auth_tag stuff +} + +void EncryptedPacket::setPayloadLength(u_int32_t payload_length) +{ + Buffer::setLength(payload_length + sizeof(struct HeaderStruct)); + + // depending on allow_realloc buf_ may point to another address + header_ = reinterpret_cast(buf_); + payload_ = buf_ + sizeof(struct HeaderStruct); // TODO: fix auth_tag stuff + auth_tag_ = NULL; // TODO: fix auth_tag stuff +} + +u_int8_t* EncryptedPacket::getPayload() +{ + return payload_; +} + + + + + + +// TODO: fix auth_tag stuff + bool EncryptedPacket::hasAuthTag() const { - if( auth_tag_ == NULL ) - return false; - return true; +// if( auth_tag_ == NULL ) + return false; +// return true; } void EncryptedPacket::withAuthTag(bool b) { - if( b && (auth_tag_ != NULL) ) - throw std::runtime_error("packet already has auth tag function enabled"); - //TODO: return instead? - if( ! b && (auth_tag_ == NULL) ) - throw std::runtime_error("packet already has auth tag function disabled"); - //TODO: return instead? - - if( b ) { - auth_tag_ = reinterpret_cast( buf_ + sizeof(struct HeaderStruct) ); - payload_ = payload_ + AUTHTAG_SIZE; - length_ -= AUTHTAG_SIZE; - max_length_ -= AUTHTAG_SIZE; - } else { - payload_ = reinterpret_cast( auth_tag_ ); - length_ += AUTHTAG_SIZE; - max_length_ += AUTHTAG_SIZE; - auth_tag_ = NULL; - } +// if( b && (auth_tag_ != NULL) ) +// throw std::runtime_error("packet already has auth tag function enabled"); +// //TODO: return instead? +// if( ! b && (auth_tag_ == NULL) ) +// throw std::runtime_error("packet already has auth tag function disabled"); +// //TODO: return instead? + +// if( b ) { +// auth_tag_ = reinterpret_cast( buf_ + sizeof(struct HeaderStruct) ); +// payload_ = payload_ + AUTHTAG_SIZE; +// length_ -= AUTHTAG_SIZE; +// max_length_ -= AUTHTAG_SIZE; +// } else { +// payload_ = reinterpret_cast( auth_tag_ ); +// length_ += AUTHTAG_SIZE; +// max_length_ += AUTHTAG_SIZE; +// auth_tag_ = NULL; +// } } void EncryptedPacket::setAuthTag(AuthTag& tag) { - if( auth_tag_ == NULL ) - throw std::runtime_error("auth tag not enabled"); +// if( auth_tag_ == NULL ) +// throw std::runtime_error("auth tag not enabled"); - if( tag == AuthTag(0) ) - return; +// if( tag == AuthTag(0) ) +// return; - if( tag.getLength() != AUTHTAG_SIZE ) - throw std::length_error("authtag length mismatch with AUTHTAG_SIZE"); +// if( tag.getLength() != AUTHTAG_SIZE ) +// throw std::length_error("authtag length mismatch with AUTHTAG_SIZE"); - std::memcpy( auth_tag_, tag.getBuf(), AUTHTAG_SIZE ); +// std::memcpy( auth_tag_, tag.getBuf(), AUTHTAG_SIZE ); } AuthTag EncryptedPacket::getAuthTag() const { - if( auth_tag_ == NULL ) - throw std::runtime_error("auth tag not enabled"); +// if( auth_tag_ == NULL ) +// throw std::runtime_error("auth tag not enabled"); AuthTag at(AUTHTAG_SIZE); - std::memcpy(at, auth_tag_, AUTHTAG_SIZE ); +// std::memcpy(at, auth_tag_, AUTHTAG_SIZE ); return at; } - diff --git a/encryptedPacket.h b/encryptedPacket.h index afc7d0e..0b934f6 100644 --- a/encryptedPacket.h +++ b/encryptedPacket.h @@ -41,14 +41,15 @@ public: /** * Packet constructor - * @param max_payload_length maximum length of encrypted payload + * @param the length of the payload + * @param allow reallocation of buffer */ - EncryptedPacket(u_int32_t max_payload_length); + EncryptedPacket(u_int32_t payload_length, bool allow_realloc = false); /** * Packet destructor */ - ~EncryptedPacket(); + ~EncryptedPacket() {}; /** * Get the sequence number @@ -94,36 +95,36 @@ public: */ void setHeader(seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); - /** - * Get the maximum payload size - * @return maximum payload size + /** + * Get the length of the payload + * @return the length of the payload + */ + u_int32_t getPayloadLength() const; + + /** + * Set the length of the payload + * @param length length of the payload */ - u_int32_t getMaxLength() const; + void setPayloadLength(u_int32_t payload_length); /** - * Set the real length of the payload - * @param length the real length of the payload, has to be smaller than the maximum payload size! + * Get the the payload + * @return the Pointer to the payload */ - void setLength(u_int32_t length); + u_int8_t* getPayload(); + + bool hasAuthTag() const; void withAuthTag(bool b); AuthTag getAuthTag() const; void setAuthTag(AuthTag& tag); - void setPayloadLength(u_int32_t payload_length); - - -// bool hasHeader() const; -// Packet& withHeader(bool b); -// Packet& addHeader(seq_nr_t seq_nr, sender_id_t sender_id); -// Packet& withAuthTag(bool b); -// AuthTag getAuthTag() const; -// Packet& addAuthTag(AuthTag auth_tag); - + private: EncryptedPacket(); EncryptedPacket(const EncryptedPacket &src); + struct HeaderStruct { seq_nr_t seq_nr; @@ -132,14 +133,9 @@ private: }__attribute__((__packed__)); struct HeaderStruct* header_; - AuthTag* auth_tag_; - u_int32_t max_length_; - - static const u_int32_t AUTHTAG_SIZE = 10; // 10byte -protected: - friend class Cipher; u_int8_t * payload_; - u_int32_t payload_length_; + AuthTag* auth_tag_; + static const u_int32_t AUTHTAG_SIZE = 10; // TODO: hardcoded size }; #endif diff --git a/keyDerivation.cpp b/keyDerivation.cpp index f3d1fe6..dbafec6 100644 --- a/keyDerivation.cpp +++ b/keyDerivation.cpp @@ -97,50 +97,37 @@ void KeyDerivation::generate(satp_prf_label label, seq_nr_t seq_nr, Buffer& key) // alignment). // - Mpi r(48); + Mpi r(48); // ld(kdr) <= 48 if( ld_kdr_ == -1 ) // means key_derivation_rate = 0 - r = 0; + r = 0; // TODO: no new key should be generated if r == 0, except it is the first time else { Mpi seq = seq_nr; Mpi rate = 1; rate = rate.mul2exp(ld_kdr_); r = seq / rate; - } - - std::cout << "r: " << std::endl; - std::cout << r.getHexDump(); + } + // TODO: generate key only if index % r == 0, except it is the first time - Mpi key_id(128), l(128); // TODO: hardcoded keySize + Mpi key_id(128); // TODO: hardcoded size + Mpi l(128); // TODO: hardcoded size l = label; - key_id = l.mul2exp(48) + r; // TODO: hardcoded keySize - - std::cout << "label: " << std::endl; - std::cout << l.getHexDump(); + key_id = l.mul2exp(48) + r; - std::cout << "keyid: " << std::endl; - std::cout << key_id.getHexDump(); - - Mpi x(128); // TODO: hardcoded keySize - Mpi salt = Mpi(master_salt_.getBuf(), master_salt_.getLength()); + Mpi salt(master_salt_.getBuf(), master_salt_.getLength()); + Mpi x(128); // TODO: hardcoded size x = key_id ^ salt; - - std::cout << "x: " << std::endl; - std::cout << x.getHexDump(); - - std::cout << "x*2^16(ctr): " << std::endl; - std::cout << x.mul2exp(16).getHexDump(); - u_int8_t *ctr_buf = x.mul2exp(16).getNewBuf(16); // TODO: hardcoded size - err = gcry_cipher_setctr( cipher_ , ctr_buf, 16); // TODO: hardcoded size - + u_int32_t written; + u_int8_t *ctr_buf = x.mul2exp(16).getNewBuf(&written); // TODO: hardcoded size + err = gcry_cipher_setctr( cipher_ , ctr_buf, written ); delete[] ctr_buf; + if( err ) - cLog.msg(Log::PRIO_ERR) << "KeyDerivation::generate: Failed to set IV: " << gpg_strerror( err ); + cLog.msg(Log::PRIO_ERR) << "KeyDerivation::generate: Failed to set CTR: " << gpg_strerror( err ); - u_int8_t *x_buf = x.getNewBuf(16); // TODO: hardcoded size - err = gcry_cipher_encrypt( cipher_, key, key.getLength(), x_buf, 16 ); // TODO: hardcoded size - delete[] x_buf; + for(u_int32_t i=0; i < key.getLength(); ++i) key[i] = 0; + err = gcry_cipher_encrypt( cipher_, key, key.getLength(), NULL, 0); if( err ) cLog.msg(Log::PRIO_ERR) << "KeyDerivation::generate: Failed to generate cipher bitstream: " << gpg_strerror( err ); } diff --git a/mpi.cpp b/mpi.cpp index 18a3349..a2590fa 100644 --- a/mpi.cpp +++ b/mpi.cpp @@ -36,41 +36,71 @@ #include #include +#include -Mpi::Mpi() +Mpi::Mpi() : val_(NULL) { - val_ = gcry_mpi_new(1); + val_ = gcry_mpi_set_ui(NULL, 0); + if(!val_) + throw std::bad_alloc(); } -Mpi::Mpi(u_int8_t length) +Mpi::Mpi(u_int8_t length) : val_(NULL) { val_ = gcry_mpi_new(length); + if(!val_) + throw std::bad_alloc(); } -Mpi::Mpi(const Mpi &src) +Mpi::Mpi(const Mpi &src) : val_(NULL) { val_ = gcry_mpi_copy(src.val_); + if(!val_) + throw std::bad_alloc(); } -Mpi::Mpi(const u_int8_t * src, u_int32_t len) +Mpi::Mpi(const u_int8_t* src, u_int32_t len) : val_(NULL) { - gcry_mpi_scan( &val_, GCRYMPI_FMT_STD, src, len, NULL ); + u_int8_t* src_cpy = new u_int8_t[len+1]; + if(!src_cpy) + throw std::bad_alloc(); + + u_int8_t* buf = src_cpy; + u_int32_t buf_len = len; + if(src[0] & 0x80) // this would be a negative number, scan can't handle this :( + { + src_cpy[0] = 0; + buf++; + buf_len++; + } + std::memcpy(buf, src, len); + + gcry_mpi_scan( &val_, GCRYMPI_FMT_STD, src_cpy, buf_len, NULL ); + delete[] src_cpy; + if(!val_) + throw std::bad_alloc(); } Mpi::~Mpi() { - gcry_mpi_release( val_ ); + gcry_mpi_release( val_ ); } void Mpi::operator=(const Mpi &src) { + gcry_mpi_release( val_ ); val_ = gcry_mpi_copy(src.val_); + if(!val_) + throw std::bad_alloc(); } void Mpi::operator=(const u_int32_t src) { - gcry_mpi_set_ui(val_, src); + gcry_mpi_release( val_ ); + val_ = gcry_mpi_set_ui(NULL, src); + if(!val_) + throw std::bad_alloc(); } Mpi Mpi::operator+(const Mpi &b) const @@ -125,22 +155,29 @@ Mpi Mpi::mul2exp(u_int32_t e) const return res; } -u_int8_t* Mpi::getNewBuf(u_int32_t buf_len) const +//TODO: problem, seems as gcry_mpi_(a)print doesn't work for mpi values of '0' +u_int8_t* Mpi::getNewBuf(u_int32_t* written) const { - // u_int32_t len = 0; - u_int32_t written = 0; + u_int8_t* res_cpy; + gcry_mpi_aprint( GCRYMPI_FMT_STD, &res_cpy, written, val_ ); + if(!res_cpy) + throw std::bad_alloc(); + + u_int8_t* buf = res_cpy; + if(*written > 1 && ! (res_cpy[0])) // positive number with highestBit set + { + buf++; + (*written)--; + } + + u_int8_t* res = new u_int8_t[*written]; + if(!res) + throw std::bad_alloc(); - u_int8_t *res = NULL; - res = new u_int8_t[buf_len]; - std::memset(res, 0, buf_len); + std::memcpy(res, buf, *written); - // len = gcry_mpi_get_nbits( val_ ); - // if( len%8 == 0 ) - // len = len/8; - // else - // len = (len/8)+1; + gcry_free(res_cpy); - gcry_mpi_print( GCRYMPI_FMT_STD, res, buf_len, &written, val_ ); return res; } @@ -151,7 +188,8 @@ std::string Mpi::getHexDump() const // u_int32_t len; // gcry_mpi_aprint( GCRYMPI_FMT_HEX, &buf, &len, val_ ); // std::string res(buf, len); - +// delete[] buf; + gcry_mpi_dump( val_ ); std::string res("\n"); return res; diff --git a/mpi.h b/mpi.h index 70f6681..c746602 100644 --- a/mpi.h +++ b/mpi.h @@ -70,7 +70,7 @@ public: * @param buf_len size of the new buffer that is returned * @return a byte buffer of size buf_len */ - u_int8_t *getNewBuf(u_int32_t buf_len) const; + u_int8_t *getNewBuf(u_int32_t* written) const; std::string getHexDump() const; u_int32_t getLength() const; diff --git a/plainPacket.cpp b/plainPacket.cpp index 3ce1521..0906fa2 100644 --- a/plainPacket.cpp +++ b/plainPacket.cpp @@ -33,43 +33,14 @@ #include #include "datatypes.h" - #include "plainPacket.h" - -PlainPacket::~PlainPacket() -{ - buf_=complete_payload_; - length_=max_length_; -} - -PlainPacket::PlainPacket(u_int32_t max_payload_length) : Buffer(max_payload_length + sizeof(payload_type_t)) +PlainPacket::PlainPacket(u_int32_t payload_length, bool allow_realloc) : Buffer(payload_length + sizeof(payload_type_t), allow_realloc) { - payload_type_ = NULL; - splitPayload(); -} - -void PlainPacket::splitPayload() -{ - complete_payload_length_ = length_; - complete_payload_ = buf_; - payload_type_ = reinterpret_cast(buf_); - buf_ += sizeof(payload_type_t); - length_ -= sizeof(payload_type_t); - max_length_ = length_; -} - -void PlainPacket::setCompletePayloadLength(u_int32_t payload_length) -{ - complete_payload_length_ = payload_length; - length_=complete_payload_length_-sizeof(payload_type_t); -} - -u_int32_t PlainPacket::getCompletePayloadLength() -{ - return complete_payload_length_; + payload_ = buf_ + sizeof(payload_type_t); + *payload_type_ = 0; } payload_type_t PlainPacket::getPayloadType() const @@ -83,16 +54,21 @@ void PlainPacket::setPayloadType(payload_type_t payload_type) *payload_type_ = PAYLOAD_TYPE_T_HTON(payload_type); } -void PlainPacket::setLength(u_int32_t length) +u_int32_t PlainPacket::getPayloadLength() const +{ + return (length_ > sizeof(payload_type_t)) ? (length_ - sizeof(payload_type_t)) : 0; +} + +void PlainPacket::setPayloadLength(u_int32_t payload_length) { - if(length > max_length_) - throw std::out_of_range("can't set length greater then size ofsize of allocated memory"); + Buffer::setLength(payload_length + sizeof(payload_type_t)); - length_ = length; - complete_payload_length_ = length_ + sizeof(payload_type_t); + // depending on allow_realloc buf_ may point to another address + payload_type_ = reinterpret_cast(buf_); + payload_ = buf_ + sizeof(payload_type_t); } -u_int32_t PlainPacket::getMaxLength() const +u_int8_t* PlainPacket::getPayload() { - return max_length_; + return payload_; } diff --git a/plainPacket.h b/plainPacket.h index 176d841..54c387a 100644 --- a/plainPacket.h +++ b/plainPacket.h @@ -43,13 +43,17 @@ class Cipher; class PlainPacket : public Buffer { public: - ~PlainPacket(); - /** * Packet constructor - * @param max_payload_length maximum payload length + * @param the length of the payload + * @param allow reallocation of buffer */ - PlainPacket(u_int32_t max_payload_length); + PlainPacket(u_int32_t payload_length, bool allow_realloc = false); + + /** + * Packet destructor + */ + ~PlainPacket() {}; /** * Get the payload type @@ -63,43 +67,31 @@ public: */ void setPayloadType(payload_type_t payload_type); - void setCompletePayloadLength(u_int32_t payload_length); - u_int32_t getCompletePayloadLength(); - /** - * Set the real payload length - * @param length the real payload length + * Get the length of the payload + * @return the length of the payload */ - //void setRealPayloadLengt(u_int32_t length); - - /** - * Get the real payload length - * @return the real length of the payload - */ - //u_int32_t getRealPayloadLength(); + u_int32_t getPayloadLength() const; /** * Set the length of the payload * @param length length of the payload */ - void setLength(u_int32_t length); + void setPayloadLength(u_int32_t payload_length); /** - * Get the size of the allocated memory for the payload - * @return maximum size of payload + * Get the the payload + * @return the Pointer to the payload */ - u_int32_t getMaxLength() const; + u_int8_t* getPayload(); + private: PlainPacket(); PlainPacket(const PlainPacket &src); - void splitPayload(); - u_int32_t max_length_; + payload_type_t* payload_type_; -protected: - friend class Cipher; - u_int8_t * complete_payload_; - u_int32_t complete_payload_length_; + u_int8_t* payload_; }; #endif -- cgit v1.2.3