From 44c94befcb9871450f574dfe0e7c8ca098efdaa9 Mon Sep 17 00:00:00 2001 From: Othmar Gsenger Date: Wed, 4 Mar 2015 19:33:35 +0000 Subject: added auth tag support to refactored crypto implementation --- src/crypto/interface.cpp | 55 ++++++++++++++++++++++++++++++++++++++++++++++++ src/crypto/interface.h | 4 ++++ src/crypto/openssl.cpp | 27 ++++++++++++++++++++++++ src/crypto/openssl.h | 3 +++ src/unittest.cpp | 30 ++++++++++++++++++-------- 5 files changed, 110 insertions(+), 9 deletions(-) diff --git a/src/crypto/interface.cpp b/src/crypto/interface.cpp index 2ae9c16..c11e382 100644 --- a/src/crypto/interface.cpp +++ b/src/crypto/interface.cpp @@ -68,6 +68,61 @@ void Interface::decrypt(EncryptedPacket& in, PlainPacket& out, const Buffer& mas out.setLength(len); } +bool Interface::checkAndRemoveAuthTag(EncryptedPacket& packet, const Buffer& masterkey, const Buffer& mastersalt, role_t role) +{ + uint32_t digest_length = getDigestLength(); + packet.withAuthTag(true); + if(!packet.getAuthTagLength()) { + return true; + } + + Buffer digest(digest_length); + //Buffer key(masterkey.getLength(), false); + Buffer key(digest_length, false); + deriveKey(KD_INBOUND, LABEL_AUTH, role, packet.getSeqNr(), packet.getSeqNr(), packet.getMux(), masterkey, mastersalt, key); + //std::cout << "Interface::checkAndRemoveAuthTag: " << key.getHexDump() << std::endl; + calcAuthKey(key, digest, packet.getAuthenticatedPortion(), packet.getAuthenticatedPortionLength() ); + + uint8_t* tag = packet.getAuthTag(); + uint32_t length = (packet.getAuthTagLength() < digest_length) ? packet.getAuthTagLength() : digest_length; + + if(length > digest_length) + for(uint32_t i=0; i < (packet.getAuthTagLength() - digest_length); ++i) + if(tag[i]) { return false; } + + int ret = std::memcmp(&tag[packet.getAuthTagLength() - length], digest.getBuf() + digest_length - length, length); + packet.removeAuthTag(); + + if(ret) { + return false; + } + + return true; +} + +void Interface::addAuthTag(EncryptedPacket& packet, const Buffer& masterkey, const Buffer& mastersalt, role_t role) +{ + uint32_t digest_length = getDigestLength(); + packet.addAuthTag(); + if(!packet.getAuthTagLength()) { + return; + } + Buffer digest(digest_length); + //Buffer key(masterkey.getLength(), false); + Buffer key(digest_length, false); + deriveKey(KD_OUTBOUND, LABEL_AUTH, role, packet.getSeqNr(), packet.getSeqNr(), packet.getMux(), masterkey, mastersalt, key); + //std::cout << "Interface::addAuthTag: " << key.getHexDump() << std::endl; + calcAuthKey(key, digest, packet.getAuthenticatedPortion(), packet.getAuthenticatedPortionLength() ); + uint8_t* tag = packet.getAuthTag(); + uint32_t length = (packet.getAuthTagLength() < digest_length) ? packet.getAuthTagLength() : digest_length; + + if(length > digest_length) { + std::memset(tag, 0, packet.getAuthTagLength()); + } + + std::memcpy(&tag[packet.getAuthTagLength() - length], digest.getBuf() + digest_length - length, length); + +} satp_prf_label_t Interface::convertLabel(kd_dir_t dir, role_t role, satp_prf_label_t label) { diff --git a/src/crypto/interface.h b/src/crypto/interface.h index 0ca52fb..49013ba 100644 --- a/src/crypto/interface.h +++ b/src/crypto/interface.h @@ -118,6 +118,8 @@ namespace crypto { void decrypt(EncryptedPacket& in, PlainPacket& out, const Buffer& masterkey, const Buffer& mastersalt, role_t role); void calcCryptCtr(const Buffer& masterkey, const Buffer& mastersalt, kd_dir_t dir, role_t role, satp_prf_label_t label, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux, cipher_aesctr_ctr_t * ctr); void calcKeyCtr(const Buffer& mastersalt, kd_dir_t dir, role_t role, satp_prf_label_t label, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux, key_derivation_aesctr_ctr_t * ctr); + bool checkAndRemoveAuthTag(EncryptedPacket& packet, const Buffer& masterkey, const Buffer& mastersalt, role_t role); + void addAuthTag(EncryptedPacket& packet, const Buffer& masterkey, const Buffer& mastersalt, role_t role); // pure virtual @@ -125,6 +127,8 @@ namespace crypto { virtual uint32_t cipher(uint8_t* in, uint32_t ilen, uint8_t* out, uint32_t olen, const Buffer& masterkey, const Buffer& mastersalt, role_t role, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; virtual uint32_t decipher(uint8_t* in, uint32_t ilen, uint8_t* out, uint32_t olen, const Buffer& masterkey, const Buffer& mastersalt, role_t role, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; virtual void deriveKey(kd_dir_t dir, satp_prf_label_t label, role_t role, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux, const Buffer& masterkey, const Buffer& mastersalt, Buffer& key) = 0; + virtual void calcAuthKey(Buffer & key, Buffer & digest, uint8_t * payload, size_t payload_length ) = 0; + virtual uint32_t getDigestLength() = 0; // virtual virtual ~Interface(); diff --git a/src/crypto/openssl.cpp b/src/crypto/openssl.cpp index 9252b48..95ccf3e 100644 --- a/src/crypto/openssl.cpp +++ b/src/crypto/openssl.cpp @@ -47,6 +47,8 @@ #include "../log.h" #include #include +#include + #include "../anytunError.h" namespace crypto { @@ -57,6 +59,26 @@ Openssl::~Openssl() } +void Openssl::calcAuthKey(Buffer & key, Buffer & digest, uint8_t * payload, size_t payload_length ) +{ + uint32_t digest_length = getDigestLength(); + HMAC_CTX ctx; + HMAC_CTX_init(&ctx); + //HMAC_Init_ex(&ctx, NULL, 0, EVP_sha1(), NULL); + + HMAC_Init_ex(&ctx, key.getBuf(), key.getLength(), EVP_sha1(), NULL); + + uint8_t hmac[digest_length]; + HMAC_Update(&ctx, payload, payload_length ); + HMAC_Final(&ctx, hmac, NULL); + + HMAC_CTX_cleanup(&ctx); + digest.setLength(digest_length); + + std::memcpy(digest.getBuf(), hmac, digest_length); +} + + void Openssl::calcMasterKeySalt(std::string passphrase, uint16_t length, Buffer& masterkey , Buffer& mastersalt) { cLog.msg(Log::PRIO_NOTICE) << "KeyDerivation: calculating master key from passphrase"; @@ -162,4 +184,9 @@ bool Openssl::init() return true; } +uint32_t Openssl::getDigestLength() +{ + return 20; +} + } diff --git a/src/crypto/openssl.h b/src/crypto/openssl.h index 35a2d26..27ffd48 100644 --- a/src/crypto/openssl.h +++ b/src/crypto/openssl.h @@ -58,6 +58,9 @@ namespace crypto { virtual uint32_t cipher(uint8_t* in, uint32_t ilen, uint8_t* out, uint32_t olen, const Buffer& masterkey, const Buffer& mastersalt, role_t role, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); virtual uint32_t decipher(uint8_t* in, uint32_t ilen, uint8_t* out, uint32_t olen, const Buffer& masterkey, const Buffer& mastersalt, role_t role, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); virtual void deriveKey(kd_dir_t dir, satp_prf_label_t label, role_t role, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux, const Buffer& masterkey, const Buffer& mastersalt, Buffer& key); + virtual void calcAuthKey(Buffer & key, Buffer & digest, uint8_t * payload, size_t payload_length ); + virtual uint32_t getDigestLength(); + // virtual virtual ~Openssl(); virtual std::string printType(); diff --git a/src/unittest.cpp b/src/unittest.cpp index a5862d3..ef9358b 100644 --- a/src/unittest.cpp +++ b/src/unittest.cpp @@ -110,7 +110,7 @@ void testCrypt() kd->setRole(ROLE_LEFT); PlainPacket plain_packet(MAX_PACKET_LENGTH); - EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH, 10); + EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH, 20); uint16_t mux = 1; plain_packet.setLength(MAX_PACKET_LENGTH); @@ -132,6 +132,16 @@ void testCrypt() // add authentication tag a->generate(*kd, encrypted_packet); + Buffer tag0( encrypted_packet.getAuthTag(), encrypted_packet.getAuthTagLength(), false); + std::cout << "Tag 0:" << tag0.getHexDump() << std::endl; + std::auto_ptr cnew(new crypto::Openssl()); + Buffer masterkey(uint32_t(crypto::DEFAULT_KEY_LENGTH/8) , false); + Buffer mastersalt(crypto::SALT_LENGTH, false); + cnew->calcMasterKeySalt("abc", uint32_t(crypto::DEFAULT_KEY_LENGTH/8), masterkey , mastersalt); + if(!cnew->checkAndRemoveAuthTag(encrypted_packet, masterkey, mastersalt, ROLE_RIGHT )) { + std::cout << STR_ERROR << "wrong Authentication Tag!" << STR_END; + //exit(-1); + } encrypted_packet.withAuthTag(false); memset(plain_packet.getPayload(),0,MAX_PACKET_LENGTH); @@ -151,10 +161,9 @@ void testCrypt() a = std::auto_ptr(AuthAlgoFactory::create("sha1", KD_INBOUND)); // check whether auth tag is ok or not - if(!a->checkTag(*kd, encrypted_packet)) { - std::cerr << "wrong Authentication Tag!" << std::endl; - //exit(-1); - } +// Buffer tag1( encrypted_packet.getAuthTag(), encrypted_packet.getAuthTagLength(), false); +// std::cout << "Tag 1:" << tag1.getHexDump() << std::endl; + c->decrypt(*kd, encrypted_packet, plain_packet); if (memcmp(plain_packet.getPayload(), test, sizeof(test))) { @@ -168,12 +177,15 @@ void testCrypt() std::cout << STR_PASSED << "role RIGHT inbound can decrypt role LEFT's outbound packets"<< STR_END; memset(plain_packet.getPayload(), 0, sizeof(test)); - std::auto_ptr cnew(new crypto::Openssl()); - Buffer masterkey(uint32_t(crypto::DEFAULT_KEY_LENGTH/8) , false); - Buffer mastersalt(crypto::SALT_LENGTH, false); - cnew->calcMasterKeySalt("abc", uint32_t(crypto::DEFAULT_KEY_LENGTH/8), masterkey , mastersalt); std::cout << "Master Key:" << masterkey.getHexDump() << std::endl; std::cout << "Master Salt:" << mastersalt.getHexDump() << std::endl; + cnew->addAuthTag(encrypted_packet, masterkey, mastersalt, ROLE_LEFT ); + Buffer tag2( encrypted_packet.getAuthTag(), encrypted_packet.getAuthTagLength(), false); + std::cout << "Tag 2:" << tag2.getHexDump() << std::endl; + if(!a->checkTag(*kd, encrypted_packet)) { + std::cout << STR_ERROR << "wrong Authentication Tag!" << STR_END; + //exit(-1); + } cnew->decrypt(encrypted_packet, plain_packet, masterkey, mastersalt, ROLE_RIGHT ); if (memcmp(plain_packet.getPayload(), test, sizeof(test))) { std::cerr << "crypto test failed" << std::endl; -- cgit v1.2.3