summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/anytun.cpp8
-rw-r--r--src/cipher.cpp34
-rw-r--r--src/cipher.h28
-rw-r--r--src/cipherFactory.cpp4
-rw-r--r--src/cipherFactory.h2
5 files changed, 40 insertions, 36 deletions
diff --git a/src/anytun.cpp b/src/anytun.cpp
index bf20d1c..c94a260 100644
--- a/src/anytun.cpp
+++ b/src/anytun.cpp
@@ -152,7 +152,7 @@ void sender(void* p)
{
ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
- std::auto_ptr<Cipher> c(CipherFactory::create(gOpt.getCipher()));
+ std::auto_ptr<Cipher> c(CipherFactory::create(gOpt.getCipher(), KD_OUTBOUND));
std::auto_ptr<AuthAlgo> a(AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
PlainPacket plain_packet(MAX_PACKET_LENGTH);
@@ -207,7 +207,7 @@ void sender(void* p)
}
// encrypt packet
- c->encrypt(conn.kd_, KD_OUTBOUND, plain_packet, encrypted_packet, conn.seq_nr_, gOpt.getSenderId(), mux);
+ c->encrypt(conn.kd_, plain_packet, encrypted_packet, conn.seq_nr_, gOpt.getSenderId(), mux);
encrypted_packet.setHeader(conn.seq_nr_, gOpt.getSenderId(), mux);
conn.seq_nr_++;
@@ -241,7 +241,7 @@ void receiver(void* p)
{
ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
- std::auto_ptr<Cipher> c( CipherFactory::create(gOpt.getCipher()) );
+ std::auto_ptr<Cipher> c( CipherFactory::create(gOpt.getCipher(), KD_INBOUND) );
std::auto_ptr<AuthAlgo> a( AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
@@ -299,7 +299,7 @@ void receiver(void* p)
}
// decrypt packet
- c->decrypt(conn.kd_, KD_INBOUND, encrypted_packet, plain_packet);
+ c->decrypt(conn.kd_, encrypted_packet, plain_packet);
// check payload_type
if((param->dev.getType() == TYPE_TUN && plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN4 &&
diff --git a/src/cipher.cpp b/src/cipher.cpp
index c18dcdb..6e325d9 100644
--- a/src/cipher.cpp
+++ b/src/cipher.cpp
@@ -40,31 +40,31 @@
#include "cipher.h"
#include "log.h"
-void Cipher::encrypt(KeyDerivation& kd, kd_dir_t dir, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
+void Cipher::encrypt(KeyDerivation& kd, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
{
- u_int32_t len = cipher(kd, dir, in, in.getLength(), out.getPayload(), out.getPayloadLength(), seq_nr, sender_id, mux);
+ u_int32_t len = cipher(kd, in, in.getLength(), out.getPayload(), out.getPayloadLength(), seq_nr, sender_id, mux);
out.setSenderId(sender_id);
out.setSeqNr(seq_nr);
out.setMux(mux);
out.setPayloadLength(len);
}
-void Cipher::decrypt(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket & in, PlainPacket & out)
+void Cipher::decrypt(KeyDerivation& kd, EncryptedPacket & in, PlainPacket & out)
{
- u_int32_t len = decipher(kd, dir, in.getPayload() , in.getPayloadLength(), out, out.getLength(), in.getSeqNr(), in.getSenderId(), in.getMux());
+ u_int32_t len = decipher(kd, in.getPayload() , in.getPayloadLength(), out, out.getLength(), in.getSeqNr(), in.getSenderId(), in.getMux());
out.setLength(len);
}
//******* NullCipher *******
-u_int32_t NullCipher::cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
+u_int32_t NullCipher::cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
{
std::memcpy(out, in, (ilen < olen) ? ilen : olen);
return (ilen < olen) ? ilen : olen;
}
-u_int32_t NullCipher::decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
+u_int32_t NullCipher::decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
{
std::memcpy(out, in, (ilen < olen) ? ilen : olen);
return (ilen < olen) ? ilen : olen;
@@ -73,12 +73,12 @@ u_int32_t NullCipher::decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_
#ifndef NOCRYPT
//****** AesIcmCipher ******
-AesIcmCipher::AesIcmCipher() : key_(u_int32_t(DEFAULT_KEY_LENGTH/8)), salt_(u_int32_t(SALT_LENGTH))
+AesIcmCipher::AesIcmCipher(kd_dir_t d) : Cipher(d), key_(u_int32_t(DEFAULT_KEY_LENGTH/8)), salt_(u_int32_t(SALT_LENGTH))
{
init();
}
-AesIcmCipher::AesIcmCipher(u_int16_t key_length) : key_(u_int32_t(key_length/8)), salt_(u_int32_t(SALT_LENGTH))
+AesIcmCipher::AesIcmCipher(kd_dir_t d, u_int16_t key_length) : Cipher(d), key_(u_int32_t(key_length/8)), salt_(u_int32_t(SALT_LENGTH))
{
init(key_length);
}
@@ -116,21 +116,21 @@ AesIcmCipher::~AesIcmCipher()
#endif
}
-u_int32_t AesIcmCipher::cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
+u_int32_t AesIcmCipher::cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
{
- calc(kd, dir, in, ilen, out, olen, seq_nr, sender_id, mux);
+ calc(kd, in, ilen, out, olen, seq_nr, sender_id, mux);
return (ilen < olen) ? ilen : olen;
}
-u_int32_t AesIcmCipher::decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
+u_int32_t AesIcmCipher::decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
{
- calc(kd, dir, in, ilen, out, olen, seq_nr, sender_id, mux);
+ calc(kd, in, ilen, out, olen, seq_nr, sender_id, mux);
return (ilen < olen) ? ilen : olen;
}
-void AesIcmCipher::calcCtr(KeyDerivation& kd, kd_dir_t dir, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
+void AesIcmCipher::calcCtr(KeyDerivation& kd, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
{
- kd.generate(dir, LABEL_SATP_SALT, seq_nr, salt_);
+ kd.generate(dir_, LABEL_SATP_SALT, seq_nr, salt_);
#ifdef ANYTUN_02_COMPAT
if(!salt_[int32_t(0)])
@@ -146,14 +146,14 @@ void AesIcmCipher::calcCtr(KeyDerivation& kd, kd_dir_t dir, seq_nr_t seq_nr, sen
return;
}
-void AesIcmCipher::calc(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
+void AesIcmCipher::calc(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux)
{
#ifndef USE_SSL_CRYPTO
if(!handle_)
return;
#endif
- kd.generate(dir, LABEL_SATP_ENCRYPTION, seq_nr, key_);
+ kd.generate(dir_, LABEL_SATP_ENCRYPTION, seq_nr, key_);
#ifdef USE_SSL_CRYPTO
int ret = AES_set_encrypt_key(key_.getBuf(), key_.getLength()*8, &aes_key_);
if(ret) {
@@ -170,7 +170,7 @@ void AesIcmCipher::calc(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t
}
#endif
- calcCtr(kd, dir, seq_nr, sender_id, mux);
+ calcCtr(kd, seq_nr, sender_id, mux);
#ifndef USE_SSL_CRYPTO
err = gcry_cipher_setctr(handle_, ctr_.buf_, CTR_LENGTH);
diff --git a/src/cipher.h b/src/cipher.h
index 7f5ba85..c77142e 100644
--- a/src/cipher.h
+++ b/src/cipher.h
@@ -49,14 +49,18 @@
class Cipher
{
public:
+ Cipher() : dir_(KD_INBOUND) {};
+ Cipher(kd_dir_t d) : dir_(d) {};
virtual ~Cipher() {};
- void encrypt(KeyDerivation& kd, kd_dir_t dir, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
- void decrypt(KeyDerivation& kd, kd_dir_t dir, EncryptedPacket & in, PlainPacket & out);
+ void encrypt(KeyDerivation& kd, PlainPacket & in, EncryptedPacket & out, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
+ void decrypt(KeyDerivation& kd, EncryptedPacket & in, PlainPacket & out);
protected:
- virtual u_int32_t cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0;
- virtual u_int32_t decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0;
+ virtual u_int32_t cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0;
+ virtual u_int32_t decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux) = 0;
+
+ kd_dir_t dir_;
};
//****** NullCipher ******
@@ -64,8 +68,8 @@ protected:
class NullCipher : public Cipher
{
protected:
- u_int32_t cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
- u_int32_t decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
+ u_int32_t cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
+ u_int32_t decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
};
#ifndef NOCRYPT
@@ -74,8 +78,8 @@ protected:
class AesIcmCipher : public Cipher
{
public:
- AesIcmCipher();
- AesIcmCipher(u_int16_t key_length);
+ AesIcmCipher(kd_dir_t d);
+ AesIcmCipher(kd_dir_t d, u_int16_t key_length);
~AesIcmCipher();
static const u_int16_t DEFAULT_KEY_LENGTH = 128;
@@ -83,14 +87,14 @@ public:
static const u_int16_t SALT_LENGTH = 14;
protected:
- u_int32_t cipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
- u_int32_t decipher(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
+ u_int32_t cipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
+ u_int32_t decipher(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
private:
void init(u_int16_t key_length = DEFAULT_KEY_LENGTH);
- void calcCtr(KeyDerivation& kd, kd_dir_t dir, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
- void calc(KeyDerivation& kd, kd_dir_t dir, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
+ void calcCtr(KeyDerivation& kd, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
+ void calc(KeyDerivation& kd, u_int8_t* in, u_int32_t ilen, u_int8_t* out, u_int32_t olen, seq_nr_t seq_nr, sender_id_t sender_id, mux_t mux);
#ifndef USE_SSL_CRYPTO
gcry_cipher_hd_t handle_;
diff --git a/src/cipherFactory.cpp b/src/cipherFactory.cpp
index b02e5bc..bab0d5a 100644
--- a/src/cipherFactory.cpp
+++ b/src/cipherFactory.cpp
@@ -36,13 +36,13 @@
#include "cipher.h"
-Cipher* CipherFactory::create(std::string const& type)
+Cipher* CipherFactory::create(std::string const& type, kd_dir_t dir)
{
if( type == "null" )
return new NullCipher();
#ifndef NOCRYPT
else if( type == "aes-ctr" )
- return new AesIcmCipher();
+ return new AesIcmCipher(dir);
#endif
else
throw std::invalid_argument("cipher not available");
diff --git a/src/cipherFactory.h b/src/cipherFactory.h
index 53b8a57..23d3b92 100644
--- a/src/cipherFactory.h
+++ b/src/cipherFactory.h
@@ -40,7 +40,7 @@
class CipherFactory
{
public:
- static Cipher* create(std::string const& type);
+ static Cipher* create(std::string const& type, kd_dir_t dir);
private:
CipherFactory();