diff options
-rw-r--r-- | src/authAlgo.cpp | 8 | ||||
-rw-r--r-- | src/authAlgo.h | 12 | ||||
-rw-r--r-- | src/buffer.cpp | 2 | ||||
-rw-r--r-- | src/cipher.cpp | 16 | ||||
-rw-r--r-- | src/cipher.h | 20 | ||||
-rw-r--r-- | src/keyDerivation.cpp | 6 | ||||
-rw-r--r-- | src/keyDerivation.h | 21 |
7 files changed, 45 insertions, 40 deletions
diff --git a/src/authAlgo.cpp b/src/authAlgo.cpp index 524d196..f18378f 100644 --- a/src/authAlgo.cpp +++ b/src/authAlgo.cpp @@ -38,11 +38,11 @@ #include <cstring> //****** NullAuthAlgo ****** -void NullAuthAlgo::generate(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet) +void NullAuthAlgo::generate(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet) { } -bool NullAuthAlgo::checkTag(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet) +bool NullAuthAlgo::checkTag(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet) { return true; } @@ -74,7 +74,7 @@ Sha1AuthAlgo::~Sha1AuthAlgo() #endif } -void Sha1AuthAlgo::generate(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet) +void Sha1AuthAlgo::generate(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet) { #ifndef USE_SSL_CRYPTO if(!handle_) @@ -116,7 +116,7 @@ void Sha1AuthAlgo::generate(KeyDerivation& kd, kd_dir dir, EncryptedPacket& pack std::memcpy(&tag[packet.getAuthTagLength() - length], &hmac[DIGEST_LENGTH - length], length); } -bool Sha1AuthAlgo::checkTag(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet) +bool Sha1AuthAlgo::checkTag(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet) { #ifndef USE_SSL_CRYPTO if(!handle_) diff --git a/src/authAlgo.h b/src/authAlgo.h index 5728426..3361ccf 100644 --- a/src/authAlgo.h +++ b/src/authAlgo.h @@ -55,13 +55,13 @@ public: * generate the mac * @param packet the packet to be authenticated */ - virtual void generate(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet) = 0; + virtual void generate(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet) = 0; /** * check the mac * @param packet the packet to be authenticated */ - virtual bool checkTag(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet) = 0; + virtual bool checkTag(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet) = 0; }; //****** NullAuthAlgo ****** @@ -69,8 +69,8 @@ public: class NullAuthAlgo : public AuthAlgo { public: - void generate(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet); - bool checkTag(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet); + void generate(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet); + bool checkTag(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet); }; #ifndef NOCRYPT @@ -83,8 +83,8 @@ public: Sha1AuthAlgo(); ~Sha1AuthAlgo(); - void generate(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet); - bool checkTag(KeyDerivation& kd, kd_dir dir, EncryptedPacket& packet); + void generate(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet); + bool checkTag(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket& packet); static const u_int32_t DIGEST_LENGTH = 20; diff --git a/src/buffer.cpp b/src/buffer.cpp index 5b1c03d..8e7bf98 100644 --- a/src/buffer.cpp +++ b/src/buffer.cpp @@ -128,8 +128,6 @@ void Buffer::operator=(const Buffer &src) std::memcpy(buf_, src.buf_, length_); } - - bool Buffer::operator==(const Buffer &cmp) const { if(length_ != cmp.length_) diff --git a/src/cipher.cpp b/src/cipher.cpp index 3d86fc5..eb958ad 100644 --- a/src/cipher.cpp +++ b/src/cipher.cpp @@ -40,7 +40,7 @@ #include "cipher.h" #include "log.h" -void Cipher::encrypt(KeyDerivation& kd, kd_dir dir, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +void Cipher::encrypt(KeyDerivation& kd, kd_dir_t dir, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { u_int32_t len = cipher(kd, dir, in, in.getLength(), out.getPayload(), out.getPayloadLength(), seq_nr, sender_id, mux); out.setSenderId(sender_id); @@ -49,7 +49,7 @@ void Cipher::encrypt(KeyDerivation& kd, kd_dir dir, PlainPacket & in, EncryptedP out.setPayloadLength(len); } -void Cipher::decrypt(KeyDerivation& kd, kd_dir dir, EncryptedPacket & in, PlainPacket & out) +void Cipher::decrypt(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket & in, PlainPacket & out) { u_int32_t len = decipher(kd, dir, in.getPayload() , in.getPayloadLength(), out, out.getLength(), in.getSeqNr(), in.getSenderId(), in.getMux()); out.setLength(len); @@ -58,13 +58,13 @@ void Cipher::decrypt(KeyDerivation& kd, kd_dir dir, EncryptedPacket & in, PlainP //******* NullCipher ******* -u_int32_t NullCipher::cipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +u_int32_t NullCipher::cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { std::memcpy(out, in, (ilen < olen) ? ilen : olen); return (ilen < olen) ? ilen : olen; } -u_int32_t NullCipher::decipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +u_int32_t NullCipher::decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { std::memcpy(out, in, (ilen < olen) ? ilen : olen); return (ilen < olen) ? ilen : olen; @@ -118,19 +118,19 @@ AesIcmCipher::~AesIcmCipher() #endif } -u_int32_t AesIcmCipher::cipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +u_int32_t AesIcmCipher::cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { calc(kd, dir, in, ilen, out, olen, seq_nr, sender_id, mux); return (ilen < olen) ? ilen : olen; } -u_int32_t AesIcmCipher::decipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +u_int32_t AesIcmCipher::decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { calc(kd, dir, in, ilen, out, olen, seq_nr, sender_id, mux); return (ilen < olen) ? ilen : olen; } -void AesIcmCipher::calcCtr(KeyDerivation& kd, kd_dir dir, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +void AesIcmCipher::calcCtr(KeyDerivation& kd, kd_dir_t dir, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { kd.generate(dir, LABEL_SATP_SALT, seq_nr, salt_); @@ -148,7 +148,7 @@ void AesIcmCipher::calcCtr(KeyDerivation& kd, kd_dir dir, seq_nr_t seq_nr, sende return; } -void AesIcmCipher::calc(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +void AesIcmCipher::calc(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { #ifndef USE_SSL_CRYPTO if(!handle_) diff --git a/src/cipher.h b/src/cipher.h index 30bbeed..7f5ba85 100644 --- a/src/cipher.h +++ b/src/cipher.h @@ -51,12 +51,12 @@ class Cipher public: virtual ~Cipher() {}; - void encrypt(KeyDerivation& kd, kd_dir dir, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); - void decrypt(KeyDerivation& kd, kd_dir dir, EncryptedPacket & in, PlainPacket & out); + void encrypt(KeyDerivation& kd, kd_dir_t dir, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + void decrypt(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket & in, PlainPacket & out); protected: - virtual u_int32_t cipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; - virtual u_int32_t decipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; + virtual u_int32_t cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; + virtual u_int32_t decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; }; //****** NullCipher ****** @@ -64,8 +64,8 @@ protected: class NullCipher : public Cipher { protected: - u_int32_t cipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); - u_int32_t decipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + u_int32_t cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + u_int32_t decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); }; #ifndef NOCRYPT @@ -83,14 +83,14 @@ public: static const u_int16_t SALT_LENGTH = 14; protected: - u_int32_t cipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); - u_int32_t decipher(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + u_int32_t cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + u_int32_t decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); private: void init(u_int16_t key_length = DEFAULT_KEY_LENGTH); - void calcCtr(KeyDerivation& kd, kd_dir dir, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); - void calc(KeyDerivation& kd, kd_dir dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + void calcCtr(KeyDerivation& kd, kd_dir_t dir, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + void calc(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); #ifndef USE_SSL_CRYPTO gcry_cipher_hd_t handle_; diff --git a/src/keyDerivation.cpp b/src/keyDerivation.cpp index fcb3001..fcb4070 100644 --- a/src/keyDerivation.cpp +++ b/src/keyDerivation.cpp @@ -51,7 +51,7 @@ void KeyDerivation::setLogKDRate(const int8_t log_rate) //****** NullKeyDerivation ****** -bool NullKeyDerivation::generate(kd_dir dir, satp_prf_label label, seq_nr_t seq_nr, Buffer& key) +bool NullKeyDerivation::generate(kd_dir_t dir, satp_prf_label_t label, seq_nr_t seq_nr, Buffer& key) { std::memset(key.getBuf(), 0, key.getLength()); return true; @@ -168,7 +168,7 @@ std::string AesIcmKeyDerivation::printType() return sstr.str(); } -bool AesIcmKeyDerivation::calcCtr(kd_dir dir, seq_nr_t* r, satp_prf_label label, seq_nr_t seq_nr) +bool AesIcmKeyDerivation::calcCtr(kd_dir_t dir, seq_nr_t* r, satp_prf_label_t label, seq_nr_t seq_nr) { *r = 0; if(ld_kdr_ >= 0) @@ -194,7 +194,7 @@ bool AesIcmKeyDerivation::calcCtr(kd_dir dir, seq_nr_t* r, satp_prf_label label, return true; } -bool AesIcmKeyDerivation::generate(kd_dir dir, satp_prf_label label, seq_nr_t seq_nr, Buffer& key) +bool AesIcmKeyDerivation::generate(kd_dir_t dir, satp_prf_label_t label, seq_nr_t seq_nr, Buffer& key) { ReadersLock lock(mutex_); diff --git a/src/keyDerivation.h b/src/keyDerivation.h index 621bb36..5a69f72 100644 --- a/src/keyDerivation.h +++ b/src/keyDerivation.h @@ -47,17 +47,22 @@ #include <boost/archive/text_oarchive.hpp> #include <boost/archive/text_iarchive.hpp> - +#define KD_LABEL_COUNT 3 typedef enum { LABEL_SATP_ENCRYPTION = 0x00, LABEL_SATP_MSG_AUTH = 0x01, LABEL_SATP_SALT = 0x02, -} satp_prf_label; +} satp_prf_label_t; typedef enum { KD_INBOUND = 0, KD_OUTBOUND = 1 -} kd_dir; +} kd_dir_t; + +typedef struct { + Buffer key_; + seq_nr_t r_; +} key_store_t; class KeyDerivation { @@ -69,7 +74,7 @@ public: void setLogKDRate(const int8_t ld_rate); virtual void init(Buffer key, Buffer salt) = 0; - virtual bool generate(kd_dir dir, satp_prf_label label, seq_nr_t seq_nr, Buffer& key) = 0; + virtual bool generate(kd_dir_t dir, satp_prf_label_t label, seq_nr_t seq_nr, Buffer& key) = 0; virtual std::string printType() { return "GenericKeyDerivation"; }; @@ -108,7 +113,7 @@ public: ~NullKeyDerivation() {}; void init(Buffer key, Buffer salt) {}; - bool generate(kd_dir dir, satp_prf_label label, seq_nr_t seq_nr, Buffer& key); + bool generate(kd_dir_t dir, satp_prf_label_t label, seq_nr_t seq_nr, Buffer& key); std::string printType() { return "NullKeyDerivation"; }; @@ -139,14 +144,14 @@ public: static const u_int16_t SALT_LENGTH = 14; void init(Buffer key, Buffer salt); - bool generate(kd_dir dir, satp_prf_label label, seq_nr_t seq_nr, Buffer& key); + bool generate(kd_dir_t dir, satp_prf_label_t label, seq_nr_t seq_nr, Buffer& key); std::string printType(); private: void updateMasterKey(); - bool calcCtr(kd_dir dir, seq_nr_t* r, satp_prf_label label, seq_nr_t seq_nr); + bool calcCtr(kd_dir_t dir, seq_nr_t* r, satp_prf_label_t label, seq_nr_t seq_nr); friend class boost::serialization::access; template<class Archive> @@ -162,6 +167,8 @@ private: u_int8_t ecount_buf_[2][AES_BLOCK_SIZE]; #endif + key_store_t key_store_[KD_LABEL_COUNT]; + union __attribute__((__packed__)) key_derivation_aesctr_ctr_union { u_int8_t buf_[CTR_LENGTH]; struct __attribute__ ((__packed__)) { |