From 5b27fb12f0dc01a4a282aa62028592795890a200 Mon Sep 17 00:00:00 2001 From: Christian Pointner Date: Thu, 15 Jan 2009 13:37:22 +0000 Subject: cipher now stores kd direction --- src/anytun.cpp | 8 ++++---- src/cipher.cpp | 34 +++++++++++++++++----------------- src/cipher.h | 28 ++++++++++++++++------------ src/cipherFactory.cpp | 4 ++-- src/cipherFactory.h | 2 +- 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/src/anytun.cpp b/src/anytun.cpp index bf20d1c..c94a260 100644 --- a/src/anytun.cpp +++ b/src/anytun.cpp @@ -152,7 +152,7 @@ void sender(void* p) { ThreadParam* param = reinterpret_cast(p); - std::auto_ptr c(CipherFactory::create(gOpt.getCipher())); + std::auto_ptr c(CipherFactory::create(gOpt.getCipher(), KD_OUTBOUND)); std::auto_ptr a(AuthAlgoFactory::create(gOpt.getAuthAlgo()) ); PlainPacket plain_packet(MAX_PACKET_LENGTH); @@ -207,7 +207,7 @@ void sender(void* p) } // encrypt packet - c->encrypt(conn.kd_, KD_OUTBOUND, plain_packet, encrypted_packet, conn.seq_nr_, gOpt.getSenderId(), mux); + c->encrypt(conn.kd_, plain_packet, encrypted_packet, conn.seq_nr_, gOpt.getSenderId(), mux); encrypted_packet.setHeader(conn.seq_nr_, gOpt.getSenderId(), mux); conn.seq_nr_++; @@ -241,7 +241,7 @@ void receiver(void* p) { ThreadParam* param = reinterpret_cast(p); - std::auto_ptr c( CipherFactory::create(gOpt.getCipher()) ); + std::auto_ptr c( CipherFactory::create(gOpt.getCipher(), KD_INBOUND) ); std::auto_ptr a( AuthAlgoFactory::create(gOpt.getAuthAlgo()) ); EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH); @@ -299,7 +299,7 @@ void receiver(void* p) } // decrypt packet - c->decrypt(conn.kd_, KD_INBOUND, encrypted_packet, plain_packet); + c->decrypt(conn.kd_, encrypted_packet, plain_packet); // check payload_type if((param->dev.getType() == TYPE_TUN && plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN4 && diff --git a/src/cipher.cpp b/src/cipher.cpp index c18dcdb..6e325d9 100644 --- a/src/cipher.cpp +++ b/src/cipher.cpp @@ -40,31 +40,31 @@ #include "cipher.h" #include "log.h" -void Cipher::encrypt(KeyDerivation& kd, kd_dir_t dir, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +void Cipher::encrypt(KeyDerivation& kd, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { - u_int32_t len = cipher(kd, dir, in, in.getLength(), out.getPayload(), out.getPayloadLength(), seq_nr, sender_id, mux); + u_int32_t len = cipher(kd, in, in.getLength(), out.getPayload(), out.getPayloadLength(), seq_nr, sender_id, mux); out.setSenderId(sender_id); out.setSeqNr(seq_nr); out.setMux(mux); out.setPayloadLength(len); } -void Cipher::decrypt(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket & in, PlainPacket & out) +void Cipher::decrypt(KeyDerivation& kd, EncryptedPacket & in, PlainPacket & out) { - u_int32_t len = decipher(kd, dir, in.getPayload() , in.getPayloadLength(), out, out.getLength(), in.getSeqNr(), in.getSenderId(), in.getMux()); + u_int32_t len = decipher(kd, in.getPayload() , in.getPayloadLength(), out, out.getLength(), in.getSeqNr(), in.getSenderId(), in.getMux()); out.setLength(len); } //******* NullCipher ******* -u_int32_t NullCipher::cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +u_int32_t NullCipher::cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { std::memcpy(out, in, (ilen < olen) ? ilen : olen); return (ilen < olen) ? ilen : olen; } -u_int32_t NullCipher::decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +u_int32_t NullCipher::decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { std::memcpy(out, in, (ilen < olen) ? ilen : olen); return (ilen < olen) ? ilen : olen; @@ -73,12 +73,12 @@ u_int32_t NullCipher::decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_ #ifndef NOCRYPT //****** AesIcmCipher ****** -AesIcmCipher::AesIcmCipher() : key_(u_int32_t(DEFAULT_KEY_LENGTH/8)), salt_(u_int32_t(SALT_LENGTH)) +AesIcmCipher::AesIcmCipher(kd_dir_t d) : Cipher(d), key_(u_int32_t(DEFAULT_KEY_LENGTH/8)), salt_(u_int32_t(SALT_LENGTH)) { init(); } -AesIcmCipher::AesIcmCipher(u_int16_t key_length) : key_(u_int32_t(key_length/8)), salt_(u_int32_t(SALT_LENGTH)) +AesIcmCipher::AesIcmCipher(kd_dir_t d, u_int16_t key_length) : Cipher(d), key_(u_int32_t(key_length/8)), salt_(u_int32_t(SALT_LENGTH)) { init(key_length); } @@ -116,21 +116,21 @@ AesIcmCipher::~AesIcmCipher() #endif } -u_int32_t AesIcmCipher::cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +u_int32_t AesIcmCipher::cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { - calc(kd, dir, in, ilen, out, olen, seq_nr, sender_id, mux); + calc(kd, in, ilen, out, olen, seq_nr, sender_id, mux); return (ilen < olen) ? ilen : olen; } -u_int32_t AesIcmCipher::decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +u_int32_t AesIcmCipher::decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { - calc(kd, dir, in, ilen, out, olen, seq_nr, sender_id, mux); + calc(kd, in, ilen, out, olen, seq_nr, sender_id, mux); return (ilen < olen) ? ilen : olen; } -void AesIcmCipher::calcCtr(KeyDerivation& kd, kd_dir_t dir, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +void AesIcmCipher::calcCtr(KeyDerivation& kd, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { - kd.generate(dir, LABEL_SATP_SALT, seq_nr, salt_); + kd.generate(dir_, LABEL_SATP_SALT, seq_nr, salt_); #ifdef ANYTUN_02_COMPAT if(!salt_[int32_t(0)]) @@ -146,14 +146,14 @@ void AesIcmCipher::calcCtr(KeyDerivation& kd, kd_dir_t dir, seq_nr_t seq_nr, sen return; } -void AesIcmCipher::calc(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) +void AesIcmCipher::calc(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) { #ifndef USE_SSL_CRYPTO if(!handle_) return; #endif - kd.generate(dir, LABEL_SATP_ENCRYPTION, seq_nr, key_); + kd.generate(dir_, LABEL_SATP_ENCRYPTION, seq_nr, key_); #ifdef USE_SSL_CRYPTO int ret = AES_set_encrypt_key(key_.getBuf(), key_.getLength()*8, &aes_key_); if(ret) { @@ -170,7 +170,7 @@ void AesIcmCipher::calc(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t } #endif - calcCtr(kd, dir, seq_nr, sender_id, mux); + calcCtr(kd, seq_nr, sender_id, mux); #ifndef USE_SSL_CRYPTO err = gcry_cipher_setctr(handle_, ctr_.buf_, CTR_LENGTH); diff --git a/src/cipher.h b/src/cipher.h index 7f5ba85..c77142e 100644 --- a/src/cipher.h +++ b/src/cipher.h @@ -49,14 +49,18 @@ class Cipher { public: + Cipher() : dir_(KD_INBOUND) {}; + Cipher(kd_dir_t d) : dir_(d) {}; virtual ~Cipher() {}; - void encrypt(KeyDerivation& kd, kd_dir_t dir, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); - void decrypt(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket & in, PlainPacket & out); + void encrypt(KeyDerivation& kd, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + void decrypt(KeyDerivation& kd, EncryptedPacket & in, PlainPacket & out); protected: - virtual u_int32_t cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; - virtual u_int32_t decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; + virtual u_int32_t cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; + virtual u_int32_t decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0; + + kd_dir_t dir_; }; //****** NullCipher ****** @@ -64,8 +68,8 @@ protected: class NullCipher : public Cipher { protected: - u_int32_t cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); - u_int32_t decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + u_int32_t cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + u_int32_t decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); }; #ifndef NOCRYPT @@ -74,8 +78,8 @@ protected: class AesIcmCipher : public Cipher { public: - AesIcmCipher(); - AesIcmCipher(u_int16_t key_length); + AesIcmCipher(kd_dir_t d); + AesIcmCipher(kd_dir_t d, u_int16_t key_length); ~AesIcmCipher(); static const u_int16_t DEFAULT_KEY_LENGTH = 128; @@ -83,14 +87,14 @@ public: static const u_int16_t SALT_LENGTH = 14; protected: - u_int32_t cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); - u_int32_t decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + u_int32_t cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + u_int32_t decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); private: void init(u_int16_t key_length = DEFAULT_KEY_LENGTH); - void calcCtr(KeyDerivation& kd, kd_dir_t dir, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); - void calc(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + void calcCtr(KeyDerivation& kd, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); + void calc(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux); #ifndef USE_SSL_CRYPTO gcry_cipher_hd_t handle_; diff --git a/src/cipherFactory.cpp b/src/cipherFactory.cpp index b02e5bc..bab0d5a 100644 --- a/src/cipherFactory.cpp +++ b/src/cipherFactory.cpp @@ -36,13 +36,13 @@ #include "cipher.h" -Cipher* CipherFactory::create(std::string const& type) +Cipher* CipherFactory::create(std::string const& type, kd_dir_t dir) { if( type == "null" ) return new NullCipher(); #ifndef NOCRYPT else if( type == "aes-ctr" ) - return new AesIcmCipher(); + return new AesIcmCipher(dir); #endif else throw std::invalid_argument("cipher not available"); diff --git a/src/cipherFactory.h b/src/cipherFactory.h index 53b8a57..23d3b92 100644 --- a/src/cipherFactory.h +++ b/src/cipherFactory.h @@ -40,7 +40,7 @@ class CipherFactory { public: - static Cipher* create(std::string const& type); + static Cipher* create(std::string const& type, kd_dir_t dir); private: CipherFactory(); -- cgit v1.2.3