I am implementing an RSA Encrypted Socket Connection in Java, for doing that i use two classes the first is the Connection Abstract class which represents the real Socket Connection and the Second is the ConnectionCallback which is a class called when the Connection class receives data.
When data is received by the Connection class, the data gets Decrypted using a before shared public key coming from the connected endpoint (There can only be 1 connected endpoint).
ByteArray class:
package connection.data;
public class ByteArray {
private byte[] bytes;
public ByteArray(byte[] bytes){
this.bytes = bytes;
}
public ByteArray(){
}
public void add(byte[] data) {
if(this.bytes == null) this.bytes = new byte[0];
this.bytes = joinArrays(this.bytes, data);
}
private byte[] joinArrays(byte[] array1, byte[] array2) {
byte[] array = new byte[array1.length + array2.length];
System.arraycopy(array1, 0, array, 0, array1.length);
System.arraycopy(array2, 0, array, array1.length, array2.length);
return array;
}
public byte[] getBytes(){
return this.bytes;
}
}
Connection class:
package connection;
import connection.data.ByteArray;
import connection.protocols.ProtectedConnectionProtocol;
import crypto.CryptoUtils;
import crypto.algorithm.asymmetric.rsa.RSAAlgorithm;
import protocol.connection.ConnectionProtocol;
import util.function.Callback;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.PublicKey;
import java.util.Base64;
public abstract class Connection implements Runnable {
private DataInputStream in;
private DataOutputStream out;
ConnectionProtocol protocol;
private Callback callback;
private boolean isConnected = false;
public Connection() throws Exception {
this.protocol = new ProtectedConnectionProtocol(new RSAAlgorithm(1024));
this.callback = new ConnectionCallback(this);
}
public Connection(ConnectionProtocol connectionProtocol, Callback callback) throws Exception {
this.protocol = connectionProtocol;
this.callback = callback;
}
#Override
public void run() {
while(isConnected){
try {
ByteArray data = new ByteArray();
while(this.in.available() > 0){
data.add(this.read());
}
if(data.getBytes() != null){
callback.run(data);
}
} catch (Exception e){
e.printStackTrace();
break;
}
}
}
protected void openConnection(InputStream in, OutputStream out) throws Exception{
this.in = new DataInputStream(in);
this.out = new DataOutputStream(out);
this.isConnected = true;
new Thread(this).start();
this.write(CryptoUtils.encode(((PublicKey) this.protocol.getPublicKey()).getEncoded()));
}
private void write(byte[] data) throws Exception{
System.out.println(new String(data,"UTF-8"));
this.out.write(data);
this.out.flush();
}
private byte[] read() throws Exception{
byte[] bytes = new byte[8192];
int read = this.in.read(bytes);
if (read <= 0) return new byte[0]; // or return null, or something, read might be -1 when there was no data.
byte[] readBytes = new byte[read];
System.arraycopy(bytes, 0, readBytes, 0, read);
return bytes;
}
}
ConnectionCallback class:
package connection;
import connection.data.ByteArray;
import crypto.CryptoUtils;
import util.function.Callback;
import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.spec.X509EncodedKeySpec;
public class ConnectionCallback implements Callback {
private Connection connection;
public ConnectionCallback(Connection connection){
this.connection = connection;
}
#Override
public void run(Object data) throws Exception {
ByteArray bytes = (ByteArray) data;
byte[] dataToBytes = CryptoUtils.decode(bytes.getBytes());
if(this.connection.protocol.getSharedKey() == null){
X509EncodedKeySpec spec = new X509EncodedKeySpec(dataToBytes);
KeyFactory kf = KeyFactory.getInstance("RSA");
PublicKey publicKey = kf.generatePublic(spec);
this.connection.protocol.setSharedKey(publicKey);
} else {
//this.so = StrongboxObject.parse(new String(bytes.getBytes()));
}
}
}
RSAlgorithm class:
package crypto.algorithm.asymmetric.rsa;
import crypto.CryptoUtils;
import crypto.algorithm.asymmetric.AssimetricalAlgorithm;
import javax.crypto.Cipher;
import java.security.*;
import java.util.Base64;
public class RSAAlgorithm extends AssimetricalAlgorithm {
private KeyPairGenerator keyGen;
public RSAAlgorithm(int keyLength) throws Exception {
super();
this.keyGen = KeyPairGenerator.getInstance("RSA");
this.keyGen.initialize(keyLength);
this.generateKeys();
}
#Override
public void generateKeys() {
KeyPair pair = this.keyGen.generateKeyPair();
super.setPublicKey(pair.getPublic());
super.setPrivateKey(pair.getPrivate());
}
#Override
public byte[] encrypt(byte[] message) {
try {
super.cipher.init(Cipher.ENCRYPT_MODE, (PublicKey) super.getSharedKey());
return CryptoUtils.encode(super.cipher.doFinal(message));
} catch (Exception e) {
e.printStackTrace();
}
return new byte[0];
}
#Override
public byte[] decrypt(byte[] message) {
message = CryptoUtils.decode(message);
try {
super.cipher.init(Cipher.DECRYPT_MODE, (PrivateKey) super.getPrivateKey());
return super.cipher.doFinal(message);
} catch (Exception e) {
e.printStackTrace();
}
return new byte[0];
}
}
ProtectedConnectionProtocol class:
package connection.protocols;
import protocol.connection.ConnectionProtocol;
import crypto.algorithm.asymmetric.AssimetricalAlgorithm;
public class ProtectedConnectionProtocol extends ConnectionProtocol {
private AssimetricalAlgorithm algorithm;
public ProtectedConnectionProtocol(AssimetricalAlgorithm algorithm){
this.algorithm = algorithm;
}
#Override
public Object getPublicKey() {
return this.algorithm.getPublicKey();
}
#Override
public Object getPrivateKey() {
return this.algorithm.getPrivateKey();
}
#Override
public Object getSharedKey() {
return this.algorithm.getSharedKey();
}
#Override
public void setSharedKey(Object sharedKey){
this.algorithm.setSharedKey(sharedKey);
}
#Override
public byte[] decrypt(byte[] message) {
return this.algorithm.decrypt(message);
}
#Override
public byte[] encrypt(byte[] message) {
return this.algorithm.encrypt(message);
}
}
CryptoUtils class:
package crypto;
import java.util.Base64;
public class CryptoUtils {
public static byte[] encode(byte[] data){
return Base64.getEncoder().encode(data);
}
public static byte[] decode(byte[] data){
return Base64.getDecoder().decode(data);
}
}
UPDATE of 05/09/2019:
Code update same Exception:
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCcrbJGHqpJdhDbVoZCJ0bucb8YnvcVWx9HIUfJOgmAKIuTmw1VUCk85ztqDq0VP2k6IP2bSD5MegR10FtqGtGEQrv+m0eNgbvE3O7czUzvedb5wKbA8eiSPbcX8JElobOhrolOb8JQRQzWAschBNp4MDljlu+0KZQHtZa6pPYJ0wIDAQAB
java.lang.IllegalArgumentException: Illegal base64 character 0
at java.base/java.util.Base64$Decoder.decode0(Base64.java:743)
at java.base/java.util.Base64$Decoder.decode(Base64.java:535)
at crypto.CryptoUtils.decode(CryptoUtils.java:12)
at connection.ConnectionCallback.run(ConnectionCallback.java:21)
at connection.Connection.run(Connection.java:42)
at java.base/java.lang.Thread.run(Thread.java:834)
Please help me i am exasperated with this and have only 2 more days of Bounty, i prefer to give my Bounty to someone who helped me finding the solution to this problem than to lose it.
This is probably caused by your read method:
private byte[] read() throws Exception{
byte[] bytes = new byte[8192];
this.in.read(bytes);
return bytes;
}
You are always reading into array of 8192 bytes, even if there isn't enough bytes in input stream. this.in.read(bytes) returns amount of bytes read, you should use that value and only use that amount of bytes from this array, ignoring the rest - as rest of array will be just 0, so when you try to decode base64 from it you will get java.lang.IllegalArgumentException: Illegal base64 character 0
So when reading your bytes you can just copy them to new array:
private byte[] read() throws Exception{
byte[] bytes = new byte[8192];
int read = this.in.read(bytes);
if (read <= 0) return new byte[0]; // or return null, or something, read might be -1 when there was no data.
byte[] readBytes = new byte[read]
System.arraycopy(bytes, 0, readBytes, 0, read)
return readBytes;
}
Note that reading like that is actually pretty bad idea for performance, as you are allocating a lot of stuff for each read. More advanced libraries like netty have own byte buffers with separate read/write positions and just store everything in single self-resizing array of bytes, but first make it work, and if you will have any issues with performance then remember that this is one of places you might find a solution.
Also in your ByteArray you are coping both arrays into same spot:
for(int i = 0; i < this.bytes.length; i++){
bytes1[i] = this.bytes[i];
}
for(int i = 0; i < data.length; i++){
bytes1[i] = data[i]; // this loop starts from 0 too
}
you need to use i + this.bytes.length in second one. (and it's better to use System.arrayCopy)
public byte[] joinArrays(byte[] array1, byte[] array2) {
byte[] array = new byte[array1.length + array2.length];
System.arraycopy(array1, 0, array, 0, array1.length);
System.arraycopy(array2, 0, array, array1.length, array2.length);
return array;
}
And then just:
public void add(byte[] data) {
if(this.bytes == null) this.bytes = new byte[0];
this.bytes = joinArrays(this.bytes, data);
}
Also like in that other answer - it might be good idea to change flush method to just set field to null, or even better, just remove that method as I don't see it being used, and you could just create new instance of this object anyways.
I looked into your code and figured out that the problem is with the add() method in the ByteArray class. Let me show you, (See the comments)
Original : ByteArray
public void add(byte[] data){
if(this.bytes == null)
this.bytes = new byte[data.length];
byte[] bytes1 = new byte[this.bytes.length + data.length];
for(int i = 0; i < this.bytes.length; i++){
bytes1[i] = this.bytes[i]; // when this.bytes is null you are adding data.length amount of 0, which is not something you want i guess. This prevents the base64 decoder to decode
}
for(int i = 0; i < data.length; i++){
bytes1[i] = data[i];
}
this.bytes = bytes1;
}
Solution: ByteArray
public void add(byte[] data){
if(this.bytes == null) {
this.bytes = data; // just store it because the field is null
} else {
byte[] bytes1 = new byte[this.bytes.length + data.length];
for (int i = 0; i < this.bytes.length; i++) {
bytes1[i] = this.bytes[i];
}
for (int i = 0; i < data.length; i++) {
bytes1[i] = data[i];
}
this.bytes = bytes1;
}
}
public void flush(){
this.bytes = null; // Important
}
EDIT
After observing the codes that reads bytes in Connection class I found that it's reading unnecessary 0 bytes at the end. So I come up with the following workaround,
Refactor: Connection
...
public abstract class Connection implements Runnable {
...
#Override
public void run() {
while(isConnected){
try {
ByteArray data = new ByteArray();
while(this.in.available() > 0){
byte[] read = this.read();
if (read != null) {
data.add(read);
}
}
if(data.getBytes() != null){
callback.run(data);
}
} catch (Exception e){
e.printStackTrace();
break;
}
}
}
...
private byte[] read() throws Exception{
byte[] bytes = new byte[this.in.available()];
int read = this.in.read(bytes);
if (read <= 0) return null; // or return null, or something, read might be -1 when there was no data.
return bytes; // just returning the read bytes is fine. you don't need to copy.
}
}
Related
I'm working on an Android application that needs to crypt (and then to decrypt) file on the file system. I wrote an android test to test the code that I found on the web and I adapted for my needed. I try with to crypt a simple text and then try to decrypt it. The problem is when I try to decrypt it, some strange character appears at the beginning of the content that I want to crypt/decrypt. For example, I try to crypt/decrypt a string like this:
Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)
And I received
X��YK�P���$BProgramming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)Concentration - Programming Music 0100 (Part 4)
The test code is
#Test
public void test() throws IOException, GeneralSecurityException {
String input = "Concentration - Programming Music 0100 (Part 4)";
for (int i=0;i<10;i++) {
input+=input;
}
String password = EncryptSystem.encrypt(new ByteArrayInputStream(input.getBytes(Charset.forName("UTF-8"))), new File(this.context.getFilesDir(), "test.txt"));
InputStream inputStream = EncryptSystem.decrypt(password, new File(this.context.getFilesDir(), "test.txt"));
//creating an InputStreamReader object
InputStreamReader isReader = new InputStreamReader(inputStream, Charset.forName("UTF-8"));
//Creating a BufferedReader object
BufferedReader reader = new BufferedReader(isReader);
StringBuilder sb = new StringBuilder();
String str;
while ((str = reader.readLine()) != null) {
sb.append(str);
}
System.out.println(sb.toString());
Assert.assertEquals(input, sb.toString());
}
The class code is:
import android.os.Build;
import android.os.Process;
import android.util.Base64;
import android.util.Log;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.security.*;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.crypto.*;
public class EncryptSystem {
public static class SecretKeys {
private SecretKey confidentialityKey;
private byte[] iv;
/**
* An aes key derived from a base64 encoded key. This does not generate the
* key. It's not random or a PBE key.
*
* #param keysStr a base64 encoded AES key / hmac key as base64(aesKey) : base64(hmacKey).
* #return an AES and HMAC key set suitable for other functions.
*/
public static SecretKeys of(String keysStr) throws InvalidKeyException {
String[] keysArr = keysStr.split(":");
if (keysArr.length != 2) {
throw new IllegalArgumentException("Cannot parse aesKey:iv");
} else {
byte[] confidentialityKey = Base64.decode(keysArr[0], BASE64_FLAGS);
if (confidentialityKey.length != AES_KEY_LENGTH_BITS / 8) {
throw new InvalidKeyException("Base64 decoded key is not " + AES_KEY_LENGTH_BITS + " bytes");
}
byte[] iv = Base64.decode(keysArr[1], BASE64_FLAGS);
/* if (iv.length != HMAC_KEY_LENGTH_BITS / 8) {
throw new InvalidKeyException("Base64 decoded key is not " + HMAC_KEY_LENGTH_BITS + " bytes");
}*/
return new SecretKeys(
new SecretKeySpec(confidentialityKey, 0, confidentialityKey.length, CIPHER),
iv);
}
}
public SecretKeys(SecretKey confidentialityKeyIn, byte[] i) {
setConfidentialityKey(confidentialityKeyIn);
iv = new byte[i.length];
System.arraycopy(i, 0, iv, 0, i.length);
}
public SecretKey getConfidentialityKey() {
return confidentialityKey;
}
public void setConfidentialityKey(SecretKey confidentialityKey) {
this.confidentialityKey = confidentialityKey;
}
#Override
public String toString() {
return Base64.encodeToString(getConfidentialityKey().getEncoded(), BASE64_FLAGS)
+ ":" + Base64.encodeToString(this.iv, BASE64_FLAGS);
}
#Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SecretKeys that = (SecretKeys) o;
return confidentialityKey.equals(that.confidentialityKey) &&
Arrays.equals(iv, that.iv);
}
#Override
public int hashCode() {
int result = Objects.hash(confidentialityKey);
result = 31 * result + Arrays.hashCode(iv);
return result;
}
public byte[] getIv() {
return this.iv;
}
}
// If the PRNG fix would not succeed for some reason, we normally will throw an exception.
// If ALLOW_BROKEN_PRNG is true, however, we will simply log instead.
private static final boolean ALLOW_BROKEN_PRNG = false;
private static final String CIPHER_TRANSFORMATION = "AES/CBC/PKCS5Padding";
private static final String CIPHER = "AES";
private static final int AES_KEY_LENGTH_BITS = 128;
private static final int IV_LENGTH_BYTES = 16;
private static final int PBE_ITERATION_COUNT = 10000;
private static final int PBE_SALT_LENGTH_BITS = AES_KEY_LENGTH_BITS; // same size as key output
private static final String PBE_ALGORITHM = "PBKDF2WithHmacSHA1";
//Made BASE_64_FLAGS public as it's useful to know for compatibility.
public static final int BASE64_FLAGS = Base64.NO_WRAP;
//default for testing
static final AtomicBoolean prngFixed = new AtomicBoolean(false);
private static final String HMAC_ALGORITHM = "HmacSHA256";
private static final int HMAC_KEY_LENGTH_BITS = 256;
public static SecretKeys generateKey() throws GeneralSecurityException {
fixPrng();
KeyGenerator keyGen = KeyGenerator.getInstance(CIPHER);
// No need to provide a SecureRandom or set a seed since that will
// happen automatically.
keyGen.init(AES_KEY_LENGTH_BITS);
SecretKey confidentialityKey = keyGen.generateKey();
return new SecretKeys(confidentialityKey, generateIv());
}
private static void fixPrng() {
if (!prngFixed.get()) {
synchronized (PrngFixes.class) {
if (!prngFixed.get()) {
PrngFixes.apply();
prngFixed.set(true);
}
}
}
}
private static byte[] randomBytes(int length) throws GeneralSecurityException {
fixPrng();
SecureRandom random = new SecureRandom();
byte[] b = new byte[length];
random.nextBytes(b);
return b;
}
private static byte[] generateIv() throws GeneralSecurityException {
return randomBytes(IV_LENGTH_BYTES);
}
private static String keyString(SecretKeys keys) {
return keys.toString();
}
public static SecretKeys generateKeyFromPassword(String password, byte[] salt) throws GeneralSecurityException {
fixPrng();
//Get enough random bytes for both the AES key and the HMAC key:
KeySpec keySpec = new PBEKeySpec(password.toCharArray(), salt,
PBE_ITERATION_COUNT, AES_KEY_LENGTH_BITS + HMAC_KEY_LENGTH_BITS);
SecretKeyFactory keyFactory = SecretKeyFactory
.getInstance(PBE_ALGORITHM);
byte[] keyBytes = keyFactory.generateSecret(keySpec).getEncoded();
// Split the random bytes into two parts:
byte[] confidentialityKeyBytes = copyOfRange(keyBytes, 0, AES_KEY_LENGTH_BITS / 8);
byte[] integrityKeyBytes = copyOfRange(keyBytes, AES_KEY_LENGTH_BITS / 8, AES_KEY_LENGTH_BITS / 8 + HMAC_KEY_LENGTH_BITS / 8);
//Generate the AES key
SecretKey confidentialityKey = new SecretKeySpec(confidentialityKeyBytes, CIPHER);
return new SecretKeys(confidentialityKey, generateIv());
}
private static byte[] copyOfRange(byte[] from, int start, int end) {
int length = end - start;
byte[] result = new byte[length];
System.arraycopy(from, start, result, 0, length);
return result;
}
public static SecretKeys generateKeyFromPassword(String password, String salt) throws GeneralSecurityException {
return generateKeyFromPassword(password, Base64.decode(salt, BASE64_FLAGS));
}
public static String encrypt(InputStream inputStream, File fileToWrite)
throws GeneralSecurityException {
SecretKeys secretKeys = generateKey();
return encrypt(inputStream, secretKeys, fileToWrite);
}
public static InputStream decrypt(String secretKey, File fileToRead) throws InvalidKeyException, NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, FileNotFoundException {
SecretKeys secretKeys = SecretKeys.of(secretKey);
Cipher aesCipherForDecryption = Cipher.getInstance(CIPHER_TRANSFORMATION);
aesCipherForDecryption.init(Cipher.DECRYPT_MODE, secretKeys.getConfidentialityKey(),
new IvParameterSpec(secretKeys.getIv()));
return new CipherInputStream(new FileInputStream(fileToRead), aesCipherForDecryption);
}
private static String encrypt(InputStream inputStream, SecretKeys secretKeys, File fileToWrite)
throws GeneralSecurityException {
byte[] iv = generateIv();
Cipher aesCipherForEncryption = Cipher.getInstance(CIPHER_TRANSFORMATION);
aesCipherForEncryption.init(Cipher.ENCRYPT_MODE, secretKeys.getConfidentialityKey(), new IvParameterSpec(iv));
saveFile(inputStream, aesCipherForEncryption, fileToWrite);
/*
* Now we get back the IV that will actually be used. Some Android
* versions do funny stuff w/ the IV, so this is to work around bugs:
*/
/*iv = aesCipherForEncryption.getIV();
//byte[] byteCipherText = aesCipherForEncryption.doFinal(plaintext);
byte[] ivCipherConcat = CipherTextIvMac.ivCipherConcat(iv, byteCipherText);
byte[] integrityMac = generateMac(ivCipherConcat, secretKeys.getIntegrityKey());
return new CipherTextIvMac(byteCipherText, iv, integrityMac);*/
return secretKeys.toString();
}
private static boolean saveFile(InputStream inputStream, Cipher aesCipherForEncryption, File fileToWrite) {
try {
OutputStream outputStream = null;
try {
byte[] fileReader = new byte[4096];
/*long fileSize = body.contentLength();*/
long fileSizeDownloaded = 0;
outputStream = new CipherOutputStream(new FileOutputStream(fileToWrite), aesCipherForEncryption);
while (true) {
int read = inputStream.read(fileReader);
if (read == -1) {
break;
}
outputStream.write(fileReader, 0, read);
fileSizeDownloaded += read;
}
outputStream.flush();
return true;
} catch (IOException e) {
e.printStackTrace();
return false;
} finally {
if (inputStream != null) {
inputStream.close();
}
if (outputStream != null) {
outputStream.close();
}
}
} catch (IOException e) {
return false;
}
}
public static final class PrngFixes {
private static final int VERSION_CODE_JELLY_BEAN = 16;
private static final int VERSION_CODE_JELLY_BEAN_MR2 = 18;
private static final byte[] BUILD_FINGERPRINT_AND_DEVICE_SERIAL = getBuildFingerprintAndDeviceSerial();
/**
* Hidden constructor to prevent instantiation.
*/
private PrngFixes() {
}
/**
* Applies all fixes.
*
* #throws SecurityException if a fix is needed but could not be
* applied.
*/
public static void apply() {
applyOpenSSLFix();
installLinuxPRNGSecureRandom();
}
/**
* Applies the fix for OpenSSL PRNG having low entropy. Does nothing if
* the fix is not needed.
*
* #throws SecurityException if the fix is needed but could not be
* applied.
*/
private static void applyOpenSSLFix() throws SecurityException {
if ((Build.VERSION.SDK_INT < VERSION_CODE_JELLY_BEAN)
|| (Build.VERSION.SDK_INT > VERSION_CODE_JELLY_BEAN_MR2)) {
// No need to apply the fix
return;
}
try {
// Mix in the device- and invocation-specific seed.
Class.forName("org.apache.harmony.xnet.provider.jsse.NativeCrypto")
.getMethod("RAND_seed", byte[].class).invoke(null, generateSeed());
// Mix output of Linux PRNG into OpenSSL's PRNG
int bytesRead = (Integer) Class
.forName("org.apache.harmony.xnet.provider.jsse.NativeCrypto")
.getMethod("RAND_load_file", String.class, long.class)
.invoke(null, "/dev/urandom", 1024);
if (bytesRead != 1024) {
throw new IOException("Unexpected number of bytes read from Linux PRNG: "
+ bytesRead);
}
} catch (Exception e) {
if (ALLOW_BROKEN_PRNG) {
Log.w(PrngFixes.class.getSimpleName(), "Failed to seed OpenSSL PRNG", e);
} else {
throw new SecurityException("Failed to seed OpenSSL PRNG", e);
}
}
}
/**
* Installs a Linux PRNG-backed {#code SecureRandom} implementation as
* the default. Does nothing if the implementation is already the
* default or if there is not need to install the implementation.
*
* #throws SecurityException if the fix is needed but could not be
* applied.
*/
private static void installLinuxPRNGSecureRandom() throws SecurityException {
if (Build.VERSION.SDK_INT > VERSION_CODE_JELLY_BEAN_MR2) {
// No need to apply the fix
return;
}
// Install a Linux PRNG-based SecureRandom implementation as the
// default, if not yet installed.
Provider[] secureRandomProviders = Security.getProviders("SecureRandom.SHA1PRNG");
// Insert and check the provider atomically.
// The official Android Java libraries use synchronized methods for
// insertProviderAt, etc., so synchronizing on the class should
// make things more stable, and prevent race conditions with other
// versions of this code.
synchronized (java.security.Security.class) {
if ((secureRandomProviders == null)
|| (secureRandomProviders.length < 1)
|| (!secureRandomProviders[0].getClass().getSimpleName().equals("LinuxPRNGSecureRandomProvider"))) {
Security.insertProviderAt(new PrngFixes.LinuxPRNGSecureRandomProvider(), 1);
}
// Assert that new SecureRandom() and
// SecureRandom.getInstance("SHA1PRNG") return a SecureRandom backed
// by the Linux PRNG-based SecureRandom implementation.
SecureRandom rng1 = new SecureRandom();
if (!rng1.getProvider().getClass().getSimpleName().equals("LinuxPRNGSecureRandomProvider")) {
if (ALLOW_BROKEN_PRNG) {
Log.w(PrngFixes.class.getSimpleName(),
"new SecureRandom() backed by wrong Provider: " + rng1.getProvider().getClass());
return;
} else {
throw new SecurityException("new SecureRandom() backed by wrong Provider: "
+ rng1.getProvider().getClass());
}
}
SecureRandom rng2 = null;
try {
rng2 = SecureRandom.getInstance("SHA1PRNG");
} catch (NoSuchAlgorithmException e) {
if (ALLOW_BROKEN_PRNG) {
Log.w(PrngFixes.class.getSimpleName(), "SHA1PRNG not available", e);
return;
} else {
new SecurityException("SHA1PRNG not available", e);
}
}
if (!rng2.getProvider().getClass().getSimpleName().equals("LinuxPRNGSecureRandomProvider")) {
if (ALLOW_BROKEN_PRNG) {
Log.w(PrngFixes.class.getSimpleName(),
"SecureRandom.getInstance(\"SHA1PRNG\") backed by wrong" + " Provider: "
+ rng2.getProvider().getClass());
return;
} else {
throw new SecurityException(
"SecureRandom.getInstance(\"SHA1PRNG\") backed by wrong" + " Provider: "
+ rng2.getProvider().getClass());
}
}
}
}
/**
* {#code Provider} of {#code SecureRandom} engines which pass through
* all requests to the Linux PRNG.
*/
private static class LinuxPRNGSecureRandomProvider extends Provider {
public LinuxPRNGSecureRandomProvider() {
super("LinuxPRNG", 1.0, "A Linux-specific random number provider that uses"
+ " /dev/urandom");
// Although /dev/urandom is not a SHA-1 PRNG, some apps
// explicitly request a SHA1PRNG SecureRandom and we thus need
// to prevent them from getting the default implementation whose
// output may have low entropy.
put("SecureRandom.SHA1PRNG", PrngFixes.LinuxPRNGSecureRandom.class.getName());
put("SecureRandom.SHA1PRNG ImplementedIn", "Software");
}
}
/**
* {#link SecureRandomSpi} which passes all requests to the Linux PRNG (
* {#code /dev/urandom}).
*/
public static class LinuxPRNGSecureRandom extends SecureRandomSpi {
/*
* IMPLEMENTATION NOTE: Requests to generate bytes and to mix in a
* seed are passed through to the Linux PRNG (/dev/urandom).
* Instances of this class seed themselves by mixing in the current
* time, PID, UID, build fingerprint, and hardware serial number
* (where available) into Linux PRNG.
*
* Concurrency: Read requests to the underlying Linux PRNG are
* serialized (on sLock) to ensure that multiple threads do not get
* duplicated PRNG output.
*/
private static final File URANDOM_FILE = new File("/dev/urandom");
private static final Object sLock = new Object();
/**
* Input stream for reading from Linux PRNG or {#code null} if not
* yet opened.
*
* #GuardedBy("sLock")
*/
private static DataInputStream sUrandomIn;
/**
* Output stream for writing to Linux PRNG or {#code null} if not
* yet opened.
*
* #GuardedBy("sLock")
*/
private static OutputStream sUrandomOut;
/**
* Whether this engine instance has been seeded. This is needed
* because each instance needs to seed itself if the client does not
* explicitly seed it.
*/
private boolean mSeeded;
#Override
protected void engineSetSeed(byte[] bytes) {
try {
OutputStream out;
synchronized (sLock) {
out = getUrandomOutputStream();
}
out.write(bytes);
out.flush();
} catch (IOException e) {
// On a small fraction of devices /dev/urandom is not
// writable Log and ignore.
Log.w(PrngFixes.class.getSimpleName(), "Failed to mix seed into "
+ URANDOM_FILE);
} finally {
mSeeded = true;
}
}
#Override
protected void engineNextBytes(byte[] bytes) {
if (!mSeeded) {
// Mix in the device- and invocation-specific seed.
engineSetSeed(generateSeed());
}
try {
DataInputStream in;
synchronized (sLock) {
in = getUrandomInputStream();
}
synchronized (in) {
in.readFully(bytes);
}
} catch (IOException e) {
throw new SecurityException("Failed to read from " + URANDOM_FILE, e);
}
}
#Override
protected byte[] engineGenerateSeed(int size) {
byte[] seed = new byte[size];
engineNextBytes(seed);
return seed;
}
private DataInputStream getUrandomInputStream() {
synchronized (sLock) {
if (sUrandomIn == null) {
// NOTE: Consider inserting a BufferedInputStream
// between DataInputStream and FileInputStream if you need
// higher PRNG output performance and can live with future PRNG
// output being pulled into this process prematurely.
try {
sUrandomIn = new DataInputStream(new FileInputStream(URANDOM_FILE));
} catch (IOException e) {
throw new SecurityException("Failed to open " + URANDOM_FILE
+ " for reading", e);
}
}
return sUrandomIn;
}
}
private OutputStream getUrandomOutputStream() throws IOException {
synchronized (sLock) {
if (sUrandomOut == null) {
sUrandomOut = new FileOutputStream(URANDOM_FILE);
}
return sUrandomOut;
}
}
}
/**
* Generates a device- and invocation-specific seed to be mixed into the
* Linux PRNG.
*/
private static byte[] generateSeed() {
try {
ByteArrayOutputStream seedBuffer = new ByteArrayOutputStream();
DataOutputStream seedBufferOut = new DataOutputStream(seedBuffer);
seedBufferOut.writeLong(System.currentTimeMillis());
seedBufferOut.writeLong(System.nanoTime());
seedBufferOut.writeInt(Process.myPid());
seedBufferOut.writeInt(Process.myUid());
seedBufferOut.write(BUILD_FINGERPRINT_AND_DEVICE_SERIAL);
seedBufferOut.close();
return seedBuffer.toByteArray();
} catch (IOException e) {
throw new SecurityException("Failed to generate seed", e);
}
}
/**
* Gets the hardware serial number of this device.
*
* #return serial number or {#code null} if not available.
*/
private static String getDeviceSerialNumber() {
// We're using the Reflection API because of Build.SERIAL is only
// available since API Level 9 (Gingerbread, Android 2.3).
try {
return (String) Build.class.getField("SERIAL").get(null);
} catch (Exception ignored) {
return null;
}
}
private static byte[] getBuildFingerprintAndDeviceSerial() {
StringBuilder result = new StringBuilder();
String fingerprint = Build.FINGERPRINT;
if (fingerprint != null) {
result.append(fingerprint);
}
String serial = getDeviceSerialNumber();
if (serial != null) {
result.append(serial);
}
try {
return result.toString().getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException("UTF-8 encoding not supported");
}
}
}
}
Any idea about what I'm in wrong? Thank you in advance
Finally, I solve by myself. I post the solution just to help anybody that in the future will look for a similar situation. I mistake to retrieve the iv array in the encrypt method, I was generating another iv vector instead of using the one contained in secretKeys.
private static String encrypt(InputStream inputStream, SecretKeys secretKeys, File fileToWrite)
throws GeneralSecurityException {
Cipher aesCipherForEncryption = Cipher.getInstance(CIPHER_TRANSFORMATION);
aesCipherForEncryption.init(Cipher.ENCRYPT_MODE, secretKeys.getConfidentialityKey(), new IvParameterSpec(secretKeys.getIv()));
saveFile(inputStream, aesCipherForEncryption, fileToWrite);
return secretKeys.toString();
}
like the JDK Deflater/Inflater classes that allows to pass byte[] chunks and get the compressed/uncompressed value as a byte[] chunk also (No need for Input or Output Streams), does anyone know of a way to do the same but for Zip files?
The idea is to be able to read an input stream by chunks and do a kind of transformation pipeline:
- Inbound: Encrypt and compress
- Outbound: Decrypt and decompress
With the ZipInput/OutputStream classes in order to do that I need to save all the bytes before encrypting/decrypting.
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
public class Compression {
public static void main(String[] args) throws IOException, DataFormatException {
final int bufferSize = 1024;
byte[] uncompressedChunkBuffer = new byte[bufferSize];
int uncompressedChunkLength = 0;
byte[] compressedChunkBuffer = new byte[bufferSize];
int compressedChunkLength = 0;
//Compression
Deflater deflater = new Deflater();
String uncompressedText = randomText();
byte[] expectedUncompressedBytes = uncompressedText.getBytes();
System.out.println("Bytes Length: " + expectedUncompressedBytes.length);
ByteArrayInputStream uncompressedBytesInStream = new ByteArrayInputStream(expectedUncompressedBytes);
ByteArrayOutputStream compressedBytesOutStream = new ByteArrayOutputStream();
while ((uncompressedChunkLength = uncompressedBytesInStream.read(uncompressedChunkBuffer)) != -1) {
//This part allows to set and get byte[] chunks
deflater.setInput(uncompressedChunkBuffer, 0, uncompressedChunkLength);
while (!deflater.needsInput()) {
compressedChunkLength = deflater.deflate(compressedChunkBuffer);
if (compressedChunkLength > 0) {
compressedBytesOutStream.write(compressedChunkBuffer, 0, compressedChunkLength);
}
}
}
deflater.finish();
while (!deflater.finished()) {
compressedChunkLength = deflater.deflate(compressedChunkBuffer);
if (compressedChunkLength > 0) {
compressedBytesOutStream.write(compressedChunkBuffer, 0, compressedChunkLength);
}
}
deflater.end();
uncompressedBytesInStream.close();
compressedBytesOutStream.flush();
compressedBytesOutStream.close();
byte[] compressedBytes = compressedBytesOutStream.toByteArray();
System.out.println("Compressed Bytes Length: " + compressedBytes.length);
//Decompression
Inflater inflater = new Inflater();
ByteArrayInputStream compressedBytesInStream = new ByteArrayInputStream(compressedBytes);
ByteArrayOutputStream uncompressedBytesOutStream = new ByteArrayOutputStream();
while ((compressedChunkLength = compressedBytesInStream.read(compressedChunkBuffer)) != -1) {
//This part allows to set and get byte[] chunks
inflater.setInput(compressedChunkBuffer, 0, compressedChunkLength);
while ((uncompressedChunkLength = inflater.inflate(uncompressedChunkBuffer)) > 0) {
uncompressedBytesOutStream.write(uncompressedChunkBuffer, 0, uncompressedChunkLength);
}
}
while ((uncompressedChunkLength = inflater.inflate(uncompressedChunkBuffer)) > 0) {
uncompressedBytesOutStream.write(uncompressedChunkBuffer, 0, uncompressedChunkLength);
}
inflater.end();
compressedBytesInStream.close();
uncompressedBytesOutStream.flush();
uncompressedBytesOutStream.close();
byte[] actualUncompressedBytes = uncompressedBytesOutStream.toByteArray();
System.out.println("Uncompressed Bytes Length: Expected[" + expectedUncompressedBytes.length + "], Actual [" + actualUncompressedBytes.length + "]");
}
public static String randomText() {
StringBuilder sb = new StringBuilder();
int textLength = rnd(100, 999);
for (int i = 0; i < textLength; i++) {
if (rnd(0, 1) == 0) {
sb.append((char) rnd(65, 90));
} else {
sb.append((char) rnd(49, 57));
}
}
return sb.toString();
}
public static int rnd(int min, int max) {
return min + (int) (Math.random() * ((max - min) + 1));
}
}
Thanks to #rob suggestion I finally reached a solution:
private static final String SECRET_KEY_ALGO = "AES";
private static final int SECRET_KEY_SIZE_IN_BITS = 256;
private static final String AES_TRANSFORMATION = "AES/CBC/PKCS5Padding";
private static final int DEFAULT_BUFFERSIZE = 8 * 1024;
public static void main(String[] args) throws IOException {
String expected = randomText();
byte[] textBytes = expected.getBytes();
EncryptedOutputStreamWrapper enc = new EncryptedOutputStreamWrapper();
{
InputStream in = new ByteArrayInputStream(textBytes);
ZipOutputStream out = new ZipOutputStream(enc.wrap(new FileOutputStream("f.zip")));
out.putNextEntry(new ZipEntry("_"));
IOUtils.copy(in, out);
in.close();
out.closeEntry();
out.close();
}
//
DecryptedInputStreamWrapper dec = new DecryptedInputStreamWrapper(enc.getSKey(), enc.getIv());
{
ZipInputStream in = new ZipInputStream(dec.wrap(new FileInputStream("f.zip")));
OutputStream out = new FileOutputStream("f.txt");
in.getNextEntry();
IOUtils.copy(in, out);
in.closeEntry();
in.close();
out.close();
}
//
String actual = new String(IOUtils.toByteArray(new FileInputStream("f.txt")));
if (!expected.equals(actual)) {
System.out.println("Fail!");
System.out.println("Expected '" + expected + "'");
System.out.println();
System.out.println("Actual: '" + actual + "'");
} else {
System.out.println("Success!");
}
}
public static class EncryptedOutputStreamWrapper {
private Cipher cipher;
private SecretKey sKey;
private byte[] iv;
public EncryptedOutputStreamWrapper() {
try {
KeyGenerator generator = KeyGenerator.getInstance(SECRET_KEY_ALGO);
generator.init(SECRET_KEY_SIZE_IN_BITS);
this.sKey = generator.generateKey();
this.cipher = Cipher.getInstance(AES_TRANSFORMATION);
this.cipher.init(Cipher.ENCRYPT_MODE, sKey);
this.iv = cipher.getIV();
} catch (Exception e) {
throw new CipherException("Error encrypting", e);
}
}
public OutputStream wrap(final OutputStream out) {
return new BufferedOutputStream(new OutputStream() {
#Override
public void write(int b) throws IOException {
}
#Override
public void write(byte[] plainBytes, int off, int len) throws IOException {
byte[] encryptedBytes = cipher.update(plainBytes, off, len);
if (encryptedBytes != null) {
out.write(encryptedBytes, 0, encryptedBytes.length);
}
}
#Override
public void flush() throws IOException {
out.flush();
}
#Override
public void close() throws IOException {
try {
byte[] encryptedBytes = cipher.doFinal();
if (encryptedBytes != null) {
out.write(encryptedBytes, 0, encryptedBytes.length);
}
} catch (Exception e) {
throw new IOException("Error encrypting", e);
}
out.close();
}
});
}
public SecretKey getSKey() {
return sKey;
}
public byte[] getIv() {
return iv;
}
}
public static class DecryptedInputStreamWrapper {
private Cipher cipher;
public DecryptedInputStreamWrapper(SecretKey sKey, byte[] iv) {
try {
this.cipher = Cipher.getInstance(AES_TRANSFORMATION);
this.cipher.init(Cipher.DECRYPT_MODE, sKey, new IvParameterSpec(iv));
} catch (Exception e) {
throw new CipherException("Error decrypting", e);
}
}
public InputStream wrap(final InputStream in) {
return new BufferedInputStream(new InputStream() {
private byte[] buffer = new byte[DEFAULT_BUFFERSIZE];
private boolean done;
#Override
public int read() throws IOException {
return 0;
}
#Override
public int read(byte[] bytes, int off, int len) throws IOException {
if (done) {
return -1;
}
int encryptedLen = in.read(buffer);
try {
byte[] plainBytes = null;
if (encryptedLen == -1) {
done = true;
plainBytes = cipher.doFinal();
} else {
plainBytes = cipher.update(buffer, 0, encryptedLen);
}
if (plainBytes != null) {
System.arraycopy(plainBytes, 0, bytes, off, plainBytes.length);
return plainBytes.length;
}
} catch (Exception e) {
throw new IOException("Error decrypting", e);
}
return 0;
}
#Override
public void close() throws IOException {
in.close();
}
});
}
}
public static class CipherException extends RuntimeException {
private static final long serialVersionUID = 1L;
public CipherException() {
super();
}
public CipherException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}
public CipherException(String message, Throwable cause) {
super(message, cause);
}
public CipherException(String message) {
super(message);
}
public CipherException(Throwable cause) {
super(cause);
}
}
public static String randomText() {
StringBuilder sb = new StringBuilder();
int textLength = rnd(100000, 999999);
for (int i = 0; i < textLength; i++) {
if (rnd(0, 1) == 0) {
sb.append((char) rnd(65, 90));
} else {
sb.append((char) rnd(49, 57));
}
}
return sb.toString();
}
public static int rnd(int min, int max) {
return min + (int) (Math.random() * ((max - min) + 1));
}
I have two applications that interact using Thrift. They share the same secret key and I need to encrypt their messages. It makes sense to use symmetric algorithm (AES, for example), but I haven't found any library to do this. So I made a research and see following options:
Use built-in SSL support
I can use built-in SSL support, establish secure connection and use my secret key just as authentication token. It requires to install certificates in addition to the secret key they already have, but I don't need to implement anything except checking that secret key received from client is the same as secret key stored locally.
Implement symmetric encryption
So far, there are following options:
Extend TSocket and override write() and read() methods and en- / decrypt data in them. Will have increasing of traffic on small writes. For example, if TBinaryProtocol writes 4-bytes integer, it will take one block (16 bytes) in encrypted state.
Extend TSocket and wrap InputStream and OutputStream with CipherInputStream and CipherOutputStream. CipherOutputStream will not encrypt small byte arrays immediately, updating Cipher with them. After we have enough data, they will be encrypted and written to the underlying OutputStream. So it will wait until you add 4 4-byte ints and encrypt them then. It allows us not wasting traffic, but is also a cause of problem - if last value will not fill the block, it will be never encrypted and written to the underlying stream. It expects me to write number of bytes divisible by its block size (16 byte), but I can't do this using TBinaryProtocol.
Re-implement TBinaryProtocol, caching all writes instead of writing them to stream and encrypting in writeMessageEnd() method. Implement decryption in readMessageBegin(). I think encryption should be performed on the transport layer, not protocol one.
Please share your thoughts with me.
UPDATE
Java Implementation on Top of TFramedTransport
TEncryptedFramedTransport.java
package tutorial;
import org.apache.thrift.TByteArrayOutputStream;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.TTransportFactory;
import javax.crypto.Cipher;
import java.security.Key;
/**
* TEncryptedFramedTransport is a buffered TTransport. It encrypts fully read message
* with the "AES/ECB/PKCS5Padding" symmetric algorithm and send it, preceeding with a 4-byte frame size.
*/
public class TEncryptedFramedTransport extends TTransport {
public static final String ALGORITHM = "AES/ECB/PKCS5Padding";
private Cipher encryptingCipher;
private Cipher decryptingCipher;
protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF;
private int maxLength_;
private TTransport transport_ = null;
private final TByteArrayOutputStream writeBuffer_ = new TByteArrayOutputStream(1024);
private TMemoryInputTransport readBuffer_ = new TMemoryInputTransport(new byte[0]);
public static class Factory extends TTransportFactory {
private int maxLength_;
private Key secretKey_;
public Factory(Key secretKey) {
this(secretKey, DEFAULT_MAX_LENGTH);
}
public Factory(Key secretKey, int maxLength) {
maxLength_ = maxLength;
secretKey_ = secretKey;
}
#Override
public TTransport getTransport(TTransport base) {
return new TEncryptedFramedTransport(base, secretKey_, maxLength_);
}
}
/**
* Constructor wraps around another tranpsort
*/
public TEncryptedFramedTransport(TTransport transport, Key secretKey, int maxLength) {
transport_ = transport;
maxLength_ = maxLength;
try {
encryptingCipher = Cipher.getInstance(ALGORITHM);
encryptingCipher.init(Cipher.ENCRYPT_MODE, secretKey);
decryptingCipher = Cipher.getInstance(ALGORITHM);
decryptingCipher.init(Cipher.DECRYPT_MODE, secretKey);
} catch (Exception e) {
throw new RuntimeException("Unable to initialize ciphers.");
}
}
public TEncryptedFramedTransport(TTransport transport, Key secretKey) {
this(transport, secretKey, DEFAULT_MAX_LENGTH);
}
public void open() throws TTransportException {
transport_.open();
}
public boolean isOpen() {
return transport_.isOpen();
}
public void close() {
transport_.close();
}
public int read(byte[] buf, int off, int len) throws TTransportException {
if (readBuffer_ != null) {
int got = readBuffer_.read(buf, off, len);
if (got > 0) {
return got;
}
}
// Read another frame of data
readFrame();
return readBuffer_.read(buf, off, len);
}
#Override
public byte[] getBuffer() {
return readBuffer_.getBuffer();
}
#Override
public int getBufferPosition() {
return readBuffer_.getBufferPosition();
}
#Override
public int getBytesRemainingInBuffer() {
return readBuffer_.getBytesRemainingInBuffer();
}
#Override
public void consumeBuffer(int len) {
readBuffer_.consumeBuffer(len);
}
private final byte[] i32buf = new byte[4];
private void readFrame() throws TTransportException {
transport_.readAll(i32buf, 0, 4);
int size = decodeFrameSize(i32buf);
if (size < 0) {
throw new TTransportException("Read a negative frame size (" + size + ")!");
}
if (size > maxLength_) {
throw new TTransportException("Frame size (" + size + ") larger than max length (" + maxLength_ + ")!");
}
byte[] buff = new byte[size];
transport_.readAll(buff, 0, size);
try {
buff = decryptingCipher.doFinal(buff);
} catch (Exception e) {
throw new TTransportException(0, e);
}
readBuffer_.reset(buff);
}
public void write(byte[] buf, int off, int len) throws TTransportException {
writeBuffer_.write(buf, off, len);
}
#Override
public void flush() throws TTransportException {
byte[] buf = writeBuffer_.get();
int len = writeBuffer_.len();
writeBuffer_.reset();
try {
buf = encryptingCipher.doFinal(buf, 0, len);
} catch (Exception e) {
throw new TTransportException(0, e);
}
encodeFrameSize(buf.length, i32buf);
transport_.write(i32buf, 0, 4);
transport_.write(buf);
transport_.flush();
}
public static void encodeFrameSize(final int frameSize, final byte[] buf) {
buf[0] = (byte) (0xff & (frameSize >> 24));
buf[1] = (byte) (0xff & (frameSize >> 16));
buf[2] = (byte) (0xff & (frameSize >> 8));
buf[3] = (byte) (0xff & (frameSize));
}
public static int decodeFrameSize(final byte[] buf) {
return
((buf[0] & 0xff) << 24) |
((buf[1] & 0xff) << 16) |
((buf[2] & 0xff) << 8) |
((buf[3] & 0xff));
}
}
MultiplicationServer.java
package tutorial;
import co.runit.prototype.CryptoTool;
import org.apache.thrift.server.TNonblockingServer;
import org.apache.thrift.server.TServer;
import org.apache.thrift.transport.TNonblockingServerSocket;
import org.apache.thrift.transport.TNonblockingServerTransport;
import java.security.Key;
public class MultiplicationServer {
public static MultiplicationHandler handler;
public static MultiplicationService.Processor processor;
public static void main(String[] args) {
try {
handler = new MultiplicationHandler();
processor = new MultiplicationService.Processor(handler);
Runnable simple = () -> startServer(processor);
new Thread(simple).start();
} catch (Exception x) {
x.printStackTrace();
}
}
public static void startServer(MultiplicationService.Processor processor) {
try {
Key key = CryptoTool.decodeKeyBase64("1OUXS3MczVFp3SdfX41U0A==");
TNonblockingServerTransport serverTransport = new TNonblockingServerSocket(9090);
TServer server = new TNonblockingServer(new TNonblockingServer.Args(serverTransport)
.transportFactory(new TEncryptedFramedTransport.Factory(key))
.processor(processor));
System.out.println("Starting the simple server...");
server.serve();
} catch (Exception e) {
e.printStackTrace();
}
}
}
MultiplicationClient.java
package tutorial;
import co.runit.prototype.CryptoTool;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import java.security.Key;
public class MultiplicationClient {
public static void main(String[] args) {
Key key = CryptoTool.decodeKeyBase64("1OUXS3MczVFp3SdfX41U0A==");
try {
TSocket baseTransport = new TSocket("localhost", 9090);
TTransport transport = new TEncryptedFramedTransport(baseTransport, key);
transport.open();
TProtocol protocol = new TBinaryProtocol(transport);
MultiplicationService.Client client = new MultiplicationService.Client(protocol);
perform(client);
transport.close();
} catch (TException x) {
x.printStackTrace();
}
}
private static void perform(MultiplicationService.Client client) throws TException {
int product = client.multiply(3, 5);
System.out.println("3*5=" + product);
}
}
Of course, keys must be the same on the client and server. To generate and store it in Base64:
public static String generateKey() throws NoSuchAlgorithmException, InvalidAlgorithmParameterException {
KeyGenerator generator = KeyGenerator.getInstance("AES");
generator.init(128);
Key key = generator.generateKey();
return encodeKeyBase64(key);
}
public static String encodeKeyBase64(Key key) {
return Base64.getEncoder().encodeToString(key.getEncoded());
}
public static Key decodeKeyBase64(String encodedKey) {
byte[] keyBytes = Base64.getDecoder().decode(encodedKey);
return new SecretKeySpec(keyBytes, ALGORITHM);
}
UPDATE 2
Python Implementation on Top of TFramedTransport
TEncryptedTransport.py
from cStringIO import StringIO
from struct import pack, unpack
from Crypto.Cipher import AES
from thrift.transport.TTransport import TTransportBase, CReadableTransport
__author__ = 'Marboni'
BLOCK_SIZE = 16
pad = lambda s: s + (BLOCK_SIZE - len(s) % BLOCK_SIZE) * chr(BLOCK_SIZE - len(s) % BLOCK_SIZE)
unpad = lambda s: '' if not s else s[0:-ord(s[-1])]
class TEncryptedFramedTransportFactory:
def __init__(self, key):
self.__key = key
def getTransport(self, trans):
return TEncryptedFramedTransport(trans, self.__key)
class TEncryptedFramedTransport(TTransportBase, CReadableTransport):
def __init__(self, trans, key):
self.__trans = trans
self.__rbuf = StringIO()
self.__wbuf = StringIO()
self.__cipher = AES.new(key)
def isOpen(self):
return self.__trans.isOpen()
def open(self):
return self.__trans.open()
def close(self):
return self.__trans.close()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self.readFrame()
return self.__rbuf.read(sz)
def readFrame(self):
buff = self.__trans.readAll(4)
sz, = unpack('!i', buff)
encrypted = StringIO(self.__trans.readAll(sz)).getvalue()
decrypted = unpad(self.__cipher.decrypt(encrypted))
self.__rbuf = StringIO(decrypted)
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
wout = self.__wbuf.getvalue()
self.__wbuf = StringIO()
encrypted = self.__cipher.encrypt(pad(wout))
encrypted_len = len(encrypted)
buf = pack("!i", encrypted_len) + encrypted
self.__trans.write(buf)
self.__trans.flush()
# Implement the CReadableTransport interface.
#property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
while len(prefix) < reqlen:
self.readFrame()
prefix += self.__rbuf.getvalue()
self.__rbuf = StringIO(prefix)
return self.__rbuf
MultiplicationClient.py
import base64
from thrift import Thrift
from thrift.transport import TSocket
from thrift.protocol import TBinaryProtocol
from tutorial import MultiplicationService, TEncryptedTransport
key = base64.b64decode("1OUXS3MczVFp3SdfX41U0A==")
try:
transport = TSocket.TSocket('localhost', 9090)
transport = TEncryptedTransport.TEncryptedFramedTransport(transport, key)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
client = MultiplicationService.Client(protocol)
transport.open()
product = client.multiply(4, 5, 'Echo!')
print '4*5=%d' % product
transport.close()
except Thrift.TException, tx:
print tx.message
As stated by JensG, sending an externally encrypted binary or supplying a layered cipher transport are the two best options. It you need a template, take a look at the TFramedTransport. It is a simple layered transport and could easily be used as a starting block for creating a TCipherTransport.
I'm trying to send and object over udp by first serializing it and then deserializing it on the other end. I thought this would be trivial since I have sent other data over udp before and serialized stuff to the files etc.
I have debugged thing some time now and I keep getting EOFException on the receiving end. Packets arrive properly but somehow deserialization fails. I'm not sure if the mistake is in sender or receiver. I suppose the problem might be about the receiver not knowing the size of the packet.
Here is my sender class:
package com.machinedata.sensordata;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import android.content.Context;
import android.util.Log;
import com.machinedata.io.DataSerializer;
import com.machinedata.io.ManagerUdpPacket;
/**
* This class sends udp-packets. It is used to send driver's information to the manager tablet.
* #author tuomas
*
*/
public class UdpSender
{
private final int MANAGER_PORT = 1234;
private String ip = "192.168.11.50"; //tablet's IP
private DatagramSocket sock = null;
private InetAddress host;
private String mType;
private DataSerializer dataser;
public UdpSender(Context context)
{
try
{
sock = new DatagramSocket();
host = InetAddress.getByName(ip); //tabletin ip
}
catch(Exception e)
{
System.err.println("Exception alustettaessa senderia" + e);
}
dataser = new DataSerializer(context);
}
/**
* With this function we can send packets about our machine to the manager to
* see in the fleet-view.
*/
public void sendToManager(ManagerUdpPacket managerUdp)
{
//serialize
Log.v("sendudp", "Send a packet: " + managerUdp.getDriver());
//serialize
byte[] data = dataser.serializeManagerPacket(managerUdp);
//send
try
{
DatagramPacket dp = new DatagramPacket(data , data.length , host , MANAGER_PORT);
sock.send(dp);
}
catch(IOException e)
{
System.err.println("IOException senderissa " + e);
}
}
public void close()
{
sock.close();
}
}
Here is the serialization function:
/**
* Serializes packet to be sent over udp to the manager tablet.
*/
public byte[] serializeManagerPacket(ManagerUdpPacket mp)
{
try
{
ByteArrayOutputStream baos = new ByteArrayOutputStream(2048);
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(mp);
oos.close();
// get the byte array of the object
byte[] obj= baos.toByteArray();
baos.close();
return obj;
}
catch(Exception e) {
e.printStackTrace();
}
return null;
}
Packet receiver class
public class UdpReceiver {
private DatagramSocket clientSocket;
private byte[] receiveData;
private final int timeout = 1;
/**
* Create a receiver.
* #param port Port to receive from.
* #param signCount Number of signals in a packet
*/
public UdpReceiver(int port)
{
//receiveData = serializeManagerPacket(new ManagerUdpPacket("asd", new MachineData(1, 2, "asd", "modelName"), 1,2,3,4,5.0,null));
try{
clientSocket=new DatagramSocket(port);
clientSocket.setReceiveBufferSize(2048);
clientSocket.setSoTimeout(timeout);
}catch(SocketException e){
Log.e("ERR", "SocketException in UdpReceiver()");
}
}
public void close()
{
clientSocket.close();
}
/**
* Receive a data packet and split it into array.
* #param data Array to put data in, must be correct size
* #return True on successful read, false otherwise
*/
public ManagerUdpPacket receive()
{
//receive a packet
DatagramPacket recvPacket = new DatagramPacket(receiveData, receiveData.length);
try{
clientSocket.receive(recvPacket);
}catch(IOException e){
Log.e("ERR", "IOException in UdpReceiver.receive");
return null;
}
ManagerUdpPacket obj = deserializeManagerPacket(receiveData);
if (obj != null)
Log.v("udpPacket", "UDP saatu: " + obj.getDriver());
return obj;
}
/**
* Deserialize the udp-packet back to readable data.
* #param data
* #return
*/
public ManagerUdpPacket deserializeManagerPacket(byte[] data)
{
try
{
ObjectInputStream iStream = new ObjectInputStream(new ByteArrayInputStream(data));
ManagerUdpPacket obj = (ManagerUdpPacket) iStream.readObject();
iStream.close();
return obj;
}
catch(Exception e)
{
e.printStackTrace();
}
return null;
}
}
Thread which listens packets in receiving end:
dataStreamTask = new TimerTask()
{
public void run()
{
if (currentlyStreaming)
{
ManagerUdpPacket mp = udpReceiver.receive();
if(mp != null)
{
Log.v("log", "Paketti saatu! " + mp.getDriver());
}
//stop thread until next query
try {
synchronized(this){
this.wait(queryInterval);
}
} catch (InterruptedException e) {
Log.e("ERR", "InterruptedException in TimerTask.run");
}
}
}
And finally the class I'm sending over the UDP:
public class ManagerUdpPacket implements Serializable
{
private static final long serialVersionUID = 9169314425496496555L;
private Location gpsLocation;
private double totalFuelConsumption;
private long operationTime;
//workload distribution
private long idleTime = 0;
private long normalTime = 0;
private long fullTime = 0;
private int currentTaskId;
private String driverName;
String machineModelName = "";
String machineName = "";
int machineIconId = -1;
int machinePort = -1;
public ManagerUdpPacket(String driver, MachineData machine, int currentTaskId, long idleTime, long fullTime, long operationTime, double fuelConsumption, Location location)
{
driverName = driver;
this.currentTaskId = currentTaskId;
this.idleTime = idleTime;
this.fullTime = fullTime;
this.operationTime = operationTime;
this.totalFuelConsumption = fuelConsumption;
this.gpsLocation = location;
machineModelName = machine.getModelName();
machineName = machine.getName();
machineIconId = machine.getIconId();
machinePort = machine.getPort();
}
public String getDriver()
{
return driverName;
}
public int getCurrentTaskId()
{
return currentTaskId;
}
public long getIdleTime()
{
return idleTime;
}
public long getFullTime()
{
return fullTime;
}
public long getOperationTime()
{
return operationTime;
}
public double getTotalFuelConsumption()
{
return totalFuelConsumption;
}
public double getLocation()
{
return gpsLocation.getLatitude();
}
public String getMachineModelName()
{
return machineModelName;
}
public String getMachineName()
{
return machineName;
}
public int getMachineIconId()
{
return machineIconId;
}
public int getMachinePort()
{
return machinePort;
}
}
I tried to get the packet size from the size of the serialized packet or inserting arbitrary 2048 based on some examples on internet. Couldn't get it work though.
As far as i know the receive function returns the length of the bytes it received. But your buffer will be full:
Example:
int buffersize = 1024;
You send 8bytes over udp.
So your byte[] will be full with your 8 bytes but the rest of the 1024 will be 0.
save the size you get by the .receive() call and just save all values of your buffer to another byte[] and you should get your object.
For your example:
public ManagerUdpPacket receive()
{
int receivedBytes = 0;
//receive a packet
DatagramPacket recvPacket = new DatagramPacket(receiveData, receiveData.length);
try{
receivedBytes = clientSocket.receive(recvPacket);
}catch(IOException e){
Log.e("ERR", "IOException in UdpReceiver.receive");
return null;
}
byte[] myObject = new byte[receivedBytes];
for(int i = 0; i < receivedBytes; i++)
{
myObject[i] = receiveData[i];
}
ManagerUdpPacket obj = deserializeManagerPacket(myObject);
if (obj != null)
Log.v("udpPacket", "UDP saatu: " + obj.getDriver());
return obj;
}
When receiving data on UDP, always use java.net.DatagramSocket.getReceiveBufferSize();. This is the actual size of the platform or SP_RCVBUF for the socket. Since UDP is a datagram based protocol unlike TCP, which is streaming protocol, receiving buffers become critical for data sanity. Usually, receiving and sending buffers are equal in size, but you are not bothered while sending when using DatagramSocket.send(DatagramPacket), alternately, you can also use DatagramSocket.setSendBufferSize(DatagramSocket.getSendBufferSize()) for using the SO_SNDBUF option for this socket. Keep in mind, in UDP, if you use a SO_SNDBUF size greater than platform's, the packet can be discarded.
public class ServerPipelineFactory implements ChannelPipelineFactory {
#Override
public ChannelPipeline getPipeline() throws Exception {
PacketFrameDecoder decoder = new PacketFrameDecoder();
PacketFrameEncoder encoder = new PacketFrameEncoder();
return Channels.pipeline(decoder, encoder, new Handler());
}
}
and my decoder
public class PacketFrameDecoder extends
ReplayingDecoder<PacketFrameDecoder.DecoderState> {
public enum DecoderState {
READ_CONTENT;
}
private int length;
public PacketFrameDecoder() {
super(DecoderState.READ_CONTENT);
}
#Override
protected Object decode(ChannelHandlerContext chc, Channel chnl,
ChannelBuffer cb, DecoderState state) throws Exception {
switch (state) {
case READ_CONTENT:
for (int i = 0; i < cb.capacity(); i ++) {
byte b = cb.getByte(i);
System.out.println((char) b);
}
return null;
default:
throw new Error("Shouldn't reach here.");
}
}
}
and I send messages
Socket fromserver = new Socket("localhost", 7283);
PrintWriter out = new PrintWriter(fromserver.getOutputStream(), true);
int data = 12;
out.write(data);
out.flush();
out.close();
fromserver.close();
but when I get bytes- I have cb.capacity() = 256
and message "?0?0" System.out.println((char) b);
please help
Using capacity is wrong a it is the "max" amount of bytes in the buffer. Also starting at position 0 is wrong as the readerIndex could be on an other position. Please read the apidocs of ChannelBuffer which all of this explains in great detail.