diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/RecipientDatabase.java b/app/src/main/java/org/thoughtcrime/securesms/database/RecipientDatabase.java index 8c82df623..2c6e46e31 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/RecipientDatabase.java +++ b/app/src/main/java/org/thoughtcrime/securesms/database/RecipientDatabase.java @@ -120,6 +120,7 @@ public class RecipientDatabase extends Database { private static final String PROFILE_FAMILY_NAME = "profile_family_name"; private static final String PROFILE_JOINED_NAME = "profile_joined_name"; private static final String MENTION_SETTING = "mention_setting"; + private static final String STORAGE_PROTO = "storage_proto"; public static final String SEARCH_PROFILE_NAME = "search_signal_profile"; private static final String SORT_NAME = "sort_name"; @@ -150,7 +151,8 @@ public class RecipientDatabase extends Database { private static final String[] MENTION_SEARCH_PROJECTION = new String[]{ID, removeWhitespace("COALESCE(" + nullIfEmpty(SYSTEM_DISPLAY_NAME) + ", " + nullIfEmpty(PROFILE_JOINED_NAME) + ", " + nullIfEmpty(PROFILE_GIVEN_NAME) + ", " + nullIfEmpty(USERNAME) + ", " + nullIfEmpty(PHONE) + ")") + " AS " + SORT_NAME}; private static final String[] RECIPIENT_FULL_PROJECTION = ArrayUtils.concat( - new String[] { TABLE_NAME + "." + ID }, + new String[] { TABLE_NAME + "." + ID, + TABLE_NAME + "." + STORAGE_PROTO }, TYPED_RECIPIENT_PROJECTION, new String[] { IdentityDatabase.TABLE_NAME + "." + IdentityDatabase.VERIFIED + " AS " + IDENTITY_STATUS, @@ -336,7 +338,8 @@ public class RecipientDatabase extends Database { GROUPS_V2_CAPABILITY + " INTEGER DEFAULT " + Recipient.Capability.UNKNOWN.serialize() + ", " + STORAGE_SERVICE_ID + " TEXT UNIQUE DEFAULT NULL, " + DIRTY + " INTEGER DEFAULT " + DirtyState.CLEAN.getId() + ", " + - MENTION_SETTING + " INTEGER DEFAULT " + MentionSetting.ALWAYS_NOTIFY.getId() + ");"; + MENTION_SETTING + " INTEGER DEFAULT " + MentionSetting.ALWAYS_NOTIFY.getId() + + STORAGE_PROTO + " TEXT DEFAULT NULL);"; private static final String INSIGHTS_INVITEE_LIST = "SELECT " + TABLE_NAME + "." + ID + " FROM " + TABLE_NAME + @@ -907,6 +910,12 @@ public class RecipientDatabase extends Database { values.put(STORAGE_SERVICE_ID, Base64.encodeBytes(update.getId().getRaw())); values.put(DIRTY, DirtyState.CLEAN.getId()); + if (update.hasUnknownFields()) { + values.put(STORAGE_PROTO, Base64.encodeBytes(update.serializeUnknownFields())); + } else { + values.putNull(STORAGE_PROTO); + } + int updateCount = db.update(TABLE_NAME, values, STORAGE_SERVICE_ID + " = ?", new String[]{Base64.encodeBytes(storageId.getRaw())}); if (updateCount < 1) { throw new AssertionError("Account update didn't match any rows!"); @@ -981,6 +990,12 @@ public class RecipientDatabase extends Database { values.put(COLOR, ContactColors.generateFor(profileName.toString()).serialize()); } + if (contact.hasUnknownFields()) { + values.put(STORAGE_PROTO, Base64.encodeBytes(contact.serializeUnknownFields())); + } else { + values.putNull(STORAGE_PROTO); + } + return values; } @@ -992,6 +1007,13 @@ public class RecipientDatabase extends Database { values.put(BLOCKED, groupV1.isBlocked() ? "1" : "0"); values.put(STORAGE_SERVICE_ID, Base64.encodeBytes(groupV1.getId().getRaw())); values.put(DIRTY, DirtyState.CLEAN.getId()); + + if (groupV1.hasUnknownFields()) { + values.put(STORAGE_PROTO, Base64.encodeBytes(groupV1.serializeUnknownFields())); + } else { + values.putNull(STORAGE_PROTO); + } + return values; } @@ -1003,6 +1025,13 @@ public class RecipientDatabase extends Database { values.put(BLOCKED, groupV2.isBlocked() ? "1" : "0"); values.put(STORAGE_SERVICE_ID, Base64.encodeBytes(groupV2.getId().getRaw())); values.put(DIRTY, DirtyState.CLEAN.getId()); + + if (groupV2.hasUnknownFields()) { + values.put(STORAGE_PROTO, Base64.encodeBytes(groupV2.serializeUnknownFields())); + } else { + values.putNull(STORAGE_PROTO); + } + return values; } @@ -1113,6 +1142,7 @@ public class RecipientDatabase extends Database { int groupsV2CapabilityValue = CursorUtil.requireInt(cursor, GROUPS_V2_CAPABILITY); String storageKeyRaw = CursorUtil.requireString(cursor, STORAGE_SERVICE_ID); int mentionSettingId = CursorUtil.requireInt(cursor, MENTION_SETTING); + String storageProtoRaw = CursorUtil.getString(cursor, STORAGE_PROTO).orNull(); Optional identityKeyRaw = CursorUtil.getString(cursor, IDENTITY_KEY); Optional identityStatusRaw = CursorUtil.getInt(cursor, IDENTITY_STATUS); @@ -1159,8 +1189,9 @@ public class RecipientDatabase extends Database { } } - byte[] storageKey = storageKeyRaw != null ? Base64.decodeOrThrow(storageKeyRaw) : null; - byte[] identityKey = identityKeyRaw.transform(Base64::decodeOrThrow).orNull(); + byte[] storageKey = storageKeyRaw != null ? Base64.decodeOrThrow(storageKeyRaw) : null; + byte[] identityKey = identityKeyRaw.transform(Base64::decodeOrThrow).orNull(); + byte[] storageProto = storageProtoRaw != null ? Base64.decodeOrThrow(storageProtoRaw) : null; IdentityDatabase.VerifiedStatus identityStatus = identityStatusRaw.transform(IdentityDatabase.VerifiedStatus::forState).or(IdentityDatabase.VerifiedStatus.DEFAULT); @@ -1180,7 +1211,8 @@ public class RecipientDatabase extends Database { Recipient.Capability.deserialize(uuidCapabilityValue), Recipient.Capability.deserialize(groupsV2CapabilityValue), InsightsBannerTier.fromId(insightsBannerTier), - storageKey, identityKey, identityStatus, MentionSetting.fromId(mentionSettingId)); + storageKey, identityKey, identityStatus, MentionSetting.fromId(mentionSettingId), + storageProto); } public BulkOperationsHandle beginBulkSystemContactUpdate() { @@ -2512,6 +2544,7 @@ public class RecipientDatabase extends Database { private final byte[] identityKey; private final IdentityDatabase.VerifiedStatus identityStatus; private final MentionSetting mentionSetting; + private final byte[] storageProto; RecipientSettings(@NonNull RecipientId id, @Nullable UUID uuid, @@ -2551,7 +2584,8 @@ public class RecipientDatabase extends Database { @Nullable byte[] storageId, @Nullable byte[] identityKey, @NonNull IdentityDatabase.VerifiedStatus identityStatus, - @NonNull MentionSetting mentionSetting) + @NonNull MentionSetting mentionSetting, + @Nullable byte[] storageProto) { this.id = id; this.uuid = uuid; @@ -2592,6 +2626,7 @@ public class RecipientDatabase extends Database { this.identityKey = identityKey; this.identityStatus = identityStatus; this.mentionSetting = mentionSetting; + this.storageProto = storageProto; } public RecipientId getId() { @@ -2752,6 +2787,10 @@ public class RecipientDatabase extends Database { public @NonNull MentionSetting getMentionSetting() { return mentionSetting; } + + public @Nullable byte[] getStorageProto() { + return storageProto; + } } public static class RecipientReader implements Closeable { diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SQLCipherOpenHelper.java b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SQLCipherOpenHelper.java index 5cd53e894..0407b4fdc 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SQLCipherOpenHelper.java +++ b/app/src/main/java/org/thoughtcrime/securesms/database/helpers/SQLCipherOpenHelper.java @@ -143,8 +143,9 @@ public class SQLCipherOpenHelper extends SQLiteOpenHelper { private static final int MENTIONS = 68; private static final int PINNED_CONVERSATIONS = 69; private static final int MENTION_GLOBAL_SETTING_MIGRATION = 70; + private static final int UNKNOWN_STORAGE_FIELDS = 71; - private static final int DATABASE_VERSION = 70; + private static final int DATABASE_VERSION = 71; private static final String DATABASE_NAME = "signal.db"; private final Context context; @@ -1008,6 +1009,10 @@ public class SQLCipherOpenHelper extends SQLiteOpenHelper { db.update("recipient", updateNever, "mention_setting = 2", null); } + if (oldVersion < UNKNOWN_STORAGE_FIELDS) { + db.execSQL("ALTER TABLE recipient ADD COLUMN storage_proto TEXT DEFAULT NULL"); + } + db.setTransactionSuccessful(); } finally { db.endTransaction(); diff --git a/app/src/main/java/org/thoughtcrime/securesms/storage/AccountConflictMerger.java b/app/src/main/java/org/thoughtcrime/securesms/storage/AccountConflictMerger.java index 62dbf3613..eba2ca391 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/storage/AccountConflictMerger.java +++ b/app/src/main/java/org/thoughtcrime/securesms/storage/AccountConflictMerger.java @@ -3,23 +3,15 @@ package org.thoughtcrime.securesms.storage; import androidx.annotation.NonNull; import androidx.annotation.Nullable; -import com.annimon.stream.Stream; - import org.thoughtcrime.securesms.logging.Log; import org.whispersystems.libsignal.util.guava.Optional; -import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.storage.SignalAccountRecord; -import org.whispersystems.signalservice.api.storage.SignalContactRecord; -import org.whispersystems.signalservice.internal.storage.protos.ContactRecord.IdentityState; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.HashSet; -import java.util.List; import java.util.Objects; import java.util.Set; -import java.util.UUID; class AccountConflictMerger implements StorageSyncHelper.ConflictMerger { @@ -63,6 +55,7 @@ class AccountConflictMerger implements StorageSyncHelper.ConflictMerger givenName; private final Optional familyName; @@ -19,8 +21,9 @@ public final class SignalAccountRecord implements SignalRecord { private final Optional profileKey; public SignalAccountRecord(StorageId id, AccountRecord proto) { - this.id = id; - this.proto = proto; + this.id = id; + this.proto = proto; + this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto); this.givenName = OptionalUtil.absentIfEmpty(proto.getGivenName()); this.familyName = OptionalUtil.absentIfEmpty(proto.getFamilyName()); @@ -33,6 +36,14 @@ public final class SignalAccountRecord implements SignalRecord { return id; } + public boolean hasUnknownFields() { + return hasUnknownFields; + } + + public byte[] serializeUnknownFields() { + return hasUnknownFields ? proto.toByteArray() : null; + } + public Optional getGivenName() { return givenName; } @@ -91,11 +102,18 @@ public final class SignalAccountRecord implements SignalRecord { private final StorageId id; private final AccountRecord.Builder builder; + private byte[] unknownFields; + public Builder(byte[] rawId) { this.id = StorageId.forAccount(rawId); this.builder = AccountRecord.newBuilder(); } + public Builder setUnknownFields(byte[] serializedUnknowns) { + this.unknownFields = serializedUnknowns; + return this; + } + public Builder setGivenName(String givenName) { builder.setGivenName(givenName == null ? "" : givenName); return this; @@ -142,7 +160,13 @@ public final class SignalAccountRecord implements SignalRecord { } public SignalAccountRecord build() { - return new SignalAccountRecord(id, builder.build()); + AccountRecord proto = builder.build(); + + if (unknownFields != null) { + proto = ProtoUtil.combineWithUnknownFields(proto, unknownFields); + } + + return new SignalAccountRecord(id, proto); } } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalContactRecord.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalContactRecord.java index f67025717..7d12a4e08 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalContactRecord.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalContactRecord.java @@ -5,6 +5,7 @@ import com.google.protobuf.ByteString; import org.whispersystems.libsignal.util.guava.Optional; import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.util.OptionalUtil; +import org.whispersystems.signalservice.api.util.ProtoUtil; import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.internal.storage.protos.ContactRecord; import org.whispersystems.signalservice.internal.storage.protos.ContactRecord.IdentityState; @@ -15,6 +16,7 @@ public final class SignalContactRecord implements SignalRecord { private final StorageId id; private final ContactRecord proto; + private final boolean hasUnknownFields; private final SignalServiceAddress address; private final Optional givenName; @@ -24,8 +26,9 @@ public final class SignalContactRecord implements SignalRecord { private final Optional identityKey; public SignalContactRecord(StorageId id, ContactRecord proto) { - this.id = id; - this.proto = proto; + this.id = id; + this.proto = proto; + this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto); this.address = new SignalServiceAddress(UuidUtil.parseOrNull(proto.getServiceUuid()), proto.getServiceE164()); this.givenName = OptionalUtil.absentIfEmpty(proto.getGivenName()); @@ -40,6 +43,14 @@ public final class SignalContactRecord implements SignalRecord { return id; } + public boolean hasUnknownFields() { + return hasUnknownFields; + } + + public byte[] serializeUnknownFields() { + return hasUnknownFields ? proto.toByteArray() : null; + } + public SignalServiceAddress getAddress() { return address; } @@ -102,6 +113,8 @@ public final class SignalContactRecord implements SignalRecord { private final StorageId id; private final ContactRecord.Builder builder; + private byte[] unknownFields; + public Builder(byte[] rawId, SignalServiceAddress address) { this.id = StorageId.forContact(rawId); this.builder = ContactRecord.newBuilder(); @@ -110,6 +123,11 @@ public final class SignalContactRecord implements SignalRecord { builder.setServiceE164(address.getNumber().or("")); } + public Builder setUnknownFields(byte[] serializedUnknowns) { + this.unknownFields = serializedUnknowns; + return this; + } + public Builder setGivenName(String givenName) { builder.setGivenName(givenName == null ? "" : givenName); return this; @@ -156,7 +174,13 @@ public final class SignalContactRecord implements SignalRecord { } public SignalContactRecord build() { - return new SignalContactRecord(id, builder.build()); + ContactRecord proto = builder.build(); + + if (unknownFields != null) { + proto = ProtoUtil.combineWithUnknownFields(proto, unknownFields); + } + + return new SignalContactRecord(id, proto); } } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV1Record.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV1Record.java index 028d68ce7..c0db0b45c 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV1Record.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV1Record.java @@ -2,6 +2,7 @@ package org.whispersystems.signalservice.api.storage; import com.google.protobuf.ByteString; +import org.whispersystems.signalservice.api.util.ProtoUtil; import org.whispersystems.signalservice.internal.storage.protos.GroupV1Record; import java.util.Objects; @@ -11,11 +12,13 @@ public final class SignalGroupV1Record implements SignalRecord { private final StorageId id; private final GroupV1Record proto; private final byte[] groupId; + private final boolean hasUnknownFields; public SignalGroupV1Record(StorageId id, GroupV1Record proto) { - this.id = id; - this.proto = proto; - this.groupId = proto.getId().toByteArray(); + this.id = id; + this.proto = proto; + this.groupId = proto.getId().toByteArray(); + this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto); } @Override @@ -23,6 +26,14 @@ public final class SignalGroupV1Record implements SignalRecord { return id; } + public boolean hasUnknownFields() { + return hasUnknownFields; + } + + public byte[] serializeUnknownFields() { + return hasUnknownFields ? proto.toByteArray() : null; + } + public byte[] getGroupId() { return groupId; } @@ -61,6 +72,8 @@ public final class SignalGroupV1Record implements SignalRecord { private final StorageId id; private final GroupV1Record.Builder builder; + private byte[] unknownFields; + public Builder(byte[] rawId, byte[] groupId) { this.id = StorageId.forGroupV1(rawId); this.builder = GroupV1Record.newBuilder(); @@ -68,6 +81,11 @@ public final class SignalGroupV1Record implements SignalRecord { builder.setId(ByteString.copyFrom(groupId)); } + public Builder setUnknownFields(byte[] serializedUnknowns) { + this.unknownFields = serializedUnknowns; + return this; + } + public Builder setBlocked(boolean blocked) { builder.setBlocked(blocked); return this; @@ -84,7 +102,13 @@ public final class SignalGroupV1Record implements SignalRecord { } public SignalGroupV1Record build() { - return new SignalGroupV1Record(id, builder.build()); + GroupV1Record proto = builder.build(); + + if (unknownFields != null) { + proto = ProtoUtil.combineWithUnknownFields(proto, unknownFields); + } + + return new SignalGroupV1Record(id, proto); } } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV2Record.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV2Record.java index cd5a5d4ac..cb23c0afe 100644 --- a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV2Record.java +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/storage/SignalGroupV2Record.java @@ -4,7 +4,7 @@ import com.google.protobuf.ByteString; import org.signal.zkgroup.InvalidInputException; import org.signal.zkgroup.groups.GroupMasterKey; -import org.whispersystems.signalservice.internal.storage.protos.GroupV1Record; +import org.whispersystems.signalservice.api.util.ProtoUtil; import org.whispersystems.signalservice.internal.storage.protos.GroupV2Record; import java.util.Objects; @@ -14,10 +14,12 @@ public final class SignalGroupV2Record implements SignalRecord { private final StorageId id; private final GroupV2Record proto; private final GroupMasterKey masterKey; + private final boolean hasUnknownFields; public SignalGroupV2Record(StorageId id, GroupV2Record proto) { - this.id = id; - this.proto = proto; + this.id = id; + this.proto = proto; + this.hasUnknownFields = ProtoUtil.hasUnknownFields(proto); try { this.masterKey = new GroupMasterKey(proto.getMasterKey().toByteArray()); } catch (InvalidInputException e) { @@ -30,6 +32,14 @@ public final class SignalGroupV2Record implements SignalRecord { return id; } + public boolean hasUnknownFields() { + return hasUnknownFields; + } + + public byte[] serializeUnknownFields() { + return hasUnknownFields ? proto.toByteArray() : null; + } + public GroupMasterKey getMasterKey() { return masterKey; } @@ -68,6 +78,8 @@ public final class SignalGroupV2Record implements SignalRecord { private final StorageId id; private final GroupV2Record.Builder builder; + private byte[] unknownFields; + public Builder(byte[] rawId, GroupMasterKey masterKey) { this.id = StorageId.forGroupV2(rawId); this.builder = GroupV2Record.newBuilder(); @@ -75,6 +87,11 @@ public final class SignalGroupV2Record implements SignalRecord { builder.setMasterKey(ByteString.copyFrom(masterKey.serialize())); } + public Builder setUnknownFields(byte[] serializedUnknowns) { + this.unknownFields = serializedUnknowns; + return this; + } + public Builder setBlocked(boolean blocked) { builder.setBlocked(blocked); return this; @@ -91,7 +108,13 @@ public final class SignalGroupV2Record implements SignalRecord { } public SignalGroupV2Record build() { - return new SignalGroupV2Record(id, builder.build()); + GroupV2Record proto = builder.build(); + + if (unknownFields != null) { + proto = ProtoUtil.combineWithUnknownFields(proto, unknownFields); + } + + return new SignalGroupV2Record(id, proto); } } } diff --git a/libsignal/service/src/main/java/org/whispersystems/signalservice/api/util/ProtoUtil.java b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/util/ProtoUtil.java new file mode 100644 index 000000000..9a5ce9325 --- /dev/null +++ b/libsignal/service/src/main/java/org/whispersystems/signalservice/api/util/ProtoUtil.java @@ -0,0 +1,134 @@ +package org.whispersystems.signalservice.api.util; + +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.GeneratedMessageLite; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.UnknownFieldSetLite; + +import org.whispersystems.libsignal.logging.Log; +import org.whispersystems.libsignal.util.ByteUtil; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.LinkedList; +import java.util.List; + +public final class ProtoUtil { + + private static final String TAG = ProtoUtil.class.getSimpleName(); + + private static final String DEFAULT_INSTANCE = "DEFAULT_INSTANCE"; + + private ProtoUtil() { } + + /** + * True if there are unknown fields anywhere inside the proto or its nested protos. + */ + @SuppressWarnings("rawtypes") + public static boolean hasUnknownFields(GeneratedMessageLite rootProto) { + try { + List allProtos = getInnerProtos(rootProto); + allProtos.add(rootProto); + + for (GeneratedMessageLite proto : allProtos) { + Field field = GeneratedMessageLite.class.getDeclaredField("unknownFields"); + field.setAccessible(true); + + UnknownFieldSetLite unknownFields = (UnknownFieldSetLite) field.get(proto); + + if (unknownFields != null && unknownFields.getSerializedSize() > 0) { + return true; + } + } + } catch (NoSuchFieldException | IllegalAccessException e) { + Log.w(TAG, "Failed to read proto private fields! Assuming no unknown fields."); + } + + return false; + } + + /** + * This takes two arguments: A proto model, and the bytes of another proto model of the same type. + * This will take the proto model and append onto it any unknown fields from the serialized proto + * model. Why is this useful? Well, if you do {@code myProto.parseFrom(data).toBuilder().build()}, + * you will lose any unknown fields that were in {@code data}. This lets you create a new model + * and plop the unknown fields back on from some other instance. + * + * A notable limitation of the current implementation is, however, that it does not support adding + * back unknown fields to *inner* messages. Unknown fields on inner messages will simply not be + * acknowledged. + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + public static Proto combineWithUnknownFields(Proto proto, byte[] serializedWithUnknownFields) { + if (serializedWithUnknownFields == null) { + return proto; + } + + try { + Proto protoWithUnknownFields = (Proto) proto.getParserForType().parseFrom(serializedWithUnknownFields); + byte[] unknownFields = getUnknownFields(protoWithUnknownFields); + + if (unknownFields == null) { + return proto; + } + + byte[] combined = ByteUtil.combine(proto.toByteArray(), unknownFields); + + return (Proto) proto.getParserForType().parseFrom(combined); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException(); + } + } + + @SuppressWarnings("rawtypes") + private static byte[] getUnknownFields(GeneratedMessageLite proto) { + try { + Field field = GeneratedMessageLite.class.getDeclaredField("unknownFields"); + field.setAccessible(true); + UnknownFieldSetLite unknownFields = (UnknownFieldSetLite) field.get(proto); + + if (unknownFields == null || unknownFields.getSerializedSize() == 0) { + return null; + } + + ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); + CodedOutputStream outputStream = CodedOutputStream.newInstance(byteStream); + + unknownFields.writeTo(outputStream); + outputStream.flush(); + + return byteStream.toByteArray(); + } catch (NoSuchFieldException | IllegalAccessException | IOException e) { + Log.w(TAG, "Failed to retrieve unknown fields.", e); + return null; + } + } + + /** + * Recursively retrieves all inner complex proto types inside a given proto. + */ + @SuppressWarnings("rawtypes") + private static List getInnerProtos(GeneratedMessageLite proto) { + List innerProtos = new LinkedList<>(); + + try { + Field[] fields = proto.getClass().getDeclaredFields(); + + for (Field field : fields) { + if (!field.getName().equals(DEFAULT_INSTANCE) && GeneratedMessageLite.class.isAssignableFrom(field.getType())) { + field.setAccessible(true); + + GeneratedMessageLite inner = (GeneratedMessageLite) field.get(proto); + innerProtos.add(inner); + innerProtos.addAll(getInnerProtos(inner)); + } + } + + } catch (IllegalAccessException e) { + Log.w(TAG, "Failed to get inner protos!", e); + } + + return innerProtos; + } +} diff --git a/libsignal/service/src/test/java/org/whispersystems/signalservice/api/util/ProtoUtilTest.java b/libsignal/service/src/test/java/org/whispersystems/signalservice/api/util/ProtoUtilTest.java new file mode 100644 index 000000000..fd3f2718d --- /dev/null +++ b/libsignal/service/src/test/java/org/whispersystems/signalservice/api/util/ProtoUtilTest.java @@ -0,0 +1,229 @@ +package org.whispersystems.signalservice.api.util; + +import com.google.protobuf.InvalidProtocolBufferException; + +import org.junit.Assert; +import org.junit.Test; +import org.thoughtcrime.securesms.util.testprotos.TestInnerMessage; +import org.thoughtcrime.securesms.util.testprotos.TestInnerMessageWithNewString; +import org.thoughtcrime.securesms.util.testprotos.TestPerson; +import org.thoughtcrime.securesms.util.testprotos.TestPersonWithNewFieldOnMessage; +import org.thoughtcrime.securesms.util.testprotos.TestPersonWithNewMessage; +import org.thoughtcrime.securesms.util.testprotos.TestPersonWithNewRepeatedString; +import org.thoughtcrime.securesms.util.testprotos.TestPersonWithNewString; +import org.thoughtcrime.securesms.util.testprotos.TestPersonWithNewStringAndInt; +import org.whispersystems.signalservice.api.util.ProtoUtil; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class ProtoUtilTest { + + @Test + public void hasUnknownFields_noUnknowns() { + TestPerson person = TestPerson.newBuilder() + .setName("Peter Parker") + .setAge(23) + .build(); + + assertFalse(ProtoUtil.hasUnknownFields(person)); + } + + @Test + public void hasUnknownFields_unknownString() throws InvalidProtocolBufferException { + TestPersonWithNewString person = TestPersonWithNewString.newBuilder() + .setName("Peter Parker") + .setAge(23) + .setJob("Reporter") + .build(); + + TestPerson personWithUnknowns = TestPerson.parseFrom(person.toByteArray()); + + assertTrue(ProtoUtil.hasUnknownFields(personWithUnknowns)); + } + + @Test + public void hasUnknownFields_multipleUnknowns() throws InvalidProtocolBufferException { + TestPersonWithNewStringAndInt person = TestPersonWithNewStringAndInt.newBuilder() + .setName("Peter Parker") + .setAge(23) + .setJob("Reporter") + .setSalary(75_000) + .build(); + + TestPerson personWithUnknowns = TestPerson.parseFrom(person.toByteArray()); + + assertTrue(ProtoUtil.hasUnknownFields(personWithUnknowns)); + } + + @Test + public void hasUnknownFields_unknownMessage() throws InvalidProtocolBufferException { + TestPersonWithNewMessage person = TestPersonWithNewMessage.newBuilder() + .setName("Peter Parker") + .setAge(23) + .setJob(TestPersonWithNewMessage.Job.newBuilder() + .setTitle("Reporter") + .setSalary(75_000)) + .build(); + + TestPerson personWithUnknowns = TestPerson.parseFrom(person.toByteArray()); + + assertTrue(ProtoUtil.hasUnknownFields(personWithUnknowns)); + } + + @Test + public void hasUnknownFields_unknownInsideMessage() throws InvalidProtocolBufferException { + TestPersonWithNewFieldOnMessage person = TestPersonWithNewFieldOnMessage.newBuilder() + .setName("Peter Parker") + .setAge(23) + .setJob(TestPersonWithNewFieldOnMessage.Job.newBuilder() + .setTitle("Reporter") + .setSalary(75_000) + .setStartDate(100)) + .build(); + + TestPersonWithNewMessage personWithUnknowns = TestPersonWithNewMessage.parseFrom(person.toByteArray()); + + assertTrue(ProtoUtil.hasUnknownFields(personWithUnknowns)); + } + + @Test + public void combineWithUnknownFields_noUnknowns() throws InvalidProtocolBufferException { + TestPerson personWithUnknowns = TestPerson.newBuilder() + .setName("Peter Parker") + .setAge(23) + .build(); + + TestPerson localRepresentation = TestPerson.newBuilder() + .setName("Spider-Man") + .setAge(23) + .build(); + + TestPerson combinedWithUnknowns = ProtoUtil.combineWithUnknownFields(localRepresentation, personWithUnknowns.toByteArray()); + TestPersonWithNewString reparsedPerson = TestPersonWithNewString.parseFrom(combinedWithUnknowns.toByteArray()); + + Assert.assertEquals("Spider-Man", reparsedPerson.getName()); + Assert.assertEquals(23, reparsedPerson.getAge()); + } + + @Test + public void combineWithUnknownFields_appendedString() throws InvalidProtocolBufferException { + TestPersonWithNewString personWithUnknowns = TestPersonWithNewString.newBuilder() + .setName("Peter Parker") + .setAge(23) + .setJob("Reporter") + .build(); + + TestPerson localRepresentation = TestPerson.newBuilder() + .setName("Spider-Man") + .setAge(23) + .build(); + + TestPerson combinedWithUnknowns = ProtoUtil.combineWithUnknownFields(localRepresentation, personWithUnknowns.toByteArray()); + TestPersonWithNewString reparsedPerson = TestPersonWithNewString.parseFrom(combinedWithUnknowns.toByteArray()); + + Assert.assertEquals("Spider-Man", reparsedPerson.getName()); + Assert.assertEquals(23, reparsedPerson.getAge()); + Assert.assertEquals("Reporter", reparsedPerson.getJob()); + } + + @Test + public void combineWithUnknownFields_appendedRepeatedString() throws InvalidProtocolBufferException { + TestPersonWithNewRepeatedString personWithUnknowns = TestPersonWithNewRepeatedString.newBuilder() + .setName("Peter Parker") + .setAge(23) + .addJobs("Reporter") + .addJobs("Super Hero") + .build(); + + TestPerson localRepresentation = TestPerson.newBuilder() + .setName("Spider-Man") + .setAge(23) + .build(); + + TestPerson combinedWithUnknowns = ProtoUtil.combineWithUnknownFields(localRepresentation, personWithUnknowns.toByteArray()); + TestPersonWithNewRepeatedString reparsedPerson = TestPersonWithNewRepeatedString.parseFrom(combinedWithUnknowns.toByteArray()); + + Assert.assertEquals("Spider-Man", reparsedPerson.getName()); + Assert.assertEquals(23, reparsedPerson.getAge()); + Assert.assertEquals(2, reparsedPerson.getJobsCount()); + Assert.assertEquals("Reporter", reparsedPerson.getJobs(0)); + Assert.assertEquals("Super Hero", reparsedPerson.getJobs(1)); + } + + @Test + public void combineWithUnknownFields_appendedStringAndInt() throws InvalidProtocolBufferException { + TestPersonWithNewStringAndInt personWithUnknowns = TestPersonWithNewStringAndInt.newBuilder() + .setName("Peter Parker") + .setAge(23) + .setJob("Reporter") + .setSalary(75_000) + .build(); + + TestPerson localRepresentation = TestPerson.newBuilder() + .setName("Spider-Man") + .setAge(23) + .build(); + + TestPerson combinedWithUnknowns = ProtoUtil.combineWithUnknownFields(localRepresentation, personWithUnknowns.toByteArray()); + TestPersonWithNewStringAndInt reparsedPerson = TestPersonWithNewStringAndInt.parseFrom(combinedWithUnknowns.toByteArray()); + + Assert.assertEquals("Spider-Man", reparsedPerson.getName()); + Assert.assertEquals(23, reparsedPerson.getAge()); + Assert.assertEquals("Reporter", reparsedPerson.getJob()); + Assert.assertEquals(75_000, reparsedPerson.getSalary()); + } + + @Test + public void combineWithUnknownFields_appendedMessage() throws InvalidProtocolBufferException { + TestPersonWithNewMessage personWithUnknowns = TestPersonWithNewMessage.newBuilder() + .setName("Peter Parker") + .setAge(23) + .setJob(TestPersonWithNewMessage.Job.newBuilder() + .setTitle("Reporter") + .setSalary(75_000)) + .build(); + + TestPerson localRepresentation = TestPerson.newBuilder() + .setName("Spider-Man") + .setAge(23) + .build(); + + TestPerson combinedWithUnknowns = ProtoUtil.combineWithUnknownFields(localRepresentation, personWithUnknowns.toByteArray()); + TestPersonWithNewMessage reparsedPerson = TestPersonWithNewMessage.parseFrom(combinedWithUnknowns.toByteArray()); + + Assert.assertEquals("Spider-Man", reparsedPerson.getName()); + Assert.assertEquals(23, reparsedPerson.getAge()); + Assert.assertEquals("Reporter", reparsedPerson.getJob().getTitle()); + Assert.assertEquals(75_000, reparsedPerson.getJob().getSalary()); + } + + /** + * This isn't ideal behavior. This is more to show how something works. In the future, it'd be + * nice to support inner unknown fields. + */ + @Test + public void combineWithUnknownFields_innerMessagesUnknownsIgnored() throws InvalidProtocolBufferException { + TestInnerMessageWithNewString test = TestInnerMessageWithNewString.newBuilder() + .setInner(TestInnerMessageWithNewString.Inner.newBuilder() + .setA("a1") + .setB("b1") + .build()) + .build(); + + TestInnerMessage localRepresentation = TestInnerMessage.newBuilder() + .setInner(TestInnerMessage.Inner.newBuilder() + .setA("a2") + .build()) + .build(); + + TestInnerMessage combined = ProtoUtil.combineWithUnknownFields(localRepresentation, test.toByteArray()); + TestInnerMessageWithNewString reparsedTest = TestInnerMessageWithNewString.parseFrom(combined.toByteArray()); + + Assert.assertEquals("a2", reparsedTest.getInner().getA()); + Assert.assertEquals("", reparsedTest.getInner().getB()); + } +} diff --git a/libsignal/service/src/test/proto/Test.proto b/libsignal/service/src/test/proto/Test.proto new file mode 100644 index 000000000..a4543b47a --- /dev/null +++ b/libsignal/service/src/test/proto/Test.proto @@ -0,0 +1,70 @@ +syntax = "proto3"; + +package signal; + +option java_package = "org.thoughtcrime.securesms.util.testprotos"; +option java_multiple_files = true; + +message TestPerson { + string name = 1; + int32 age = 2; +} + +message TestPersonWithNewString { + string name = 1; + int32 age = 2; + string job = 3; +} + +message TestPersonWithNewRepeatedString { + string name = 1; + int32 age = 2; + repeated string jobs = 3; +} + +message TestPersonWithNewStringAndInt { + string name = 1; + int32 age = 2; + string job = 3; + int32 salary = 4; +} + +message TestPersonWithNewMessage { + message Job { + string title = 1; + uint32 salary = 2; + } + + string name = 1; + int32 age = 2; + Job job = 3; +} + +message TestPersonWithNewFieldOnMessage { + message Job { + string title = 1; + uint32 salary = 2; + uint64 startDate = 3; + } + + string name = 1; + int32 age = 2; + Job job = 3; +} + +message TestInnerMessage { + message Inner { + string a = 1; + } + + Inner inner = 1; +} + +message TestInnerMessageWithNewString { + message Inner { + string a = 1; + string b = 2; + } + + Inner inner = 1; +} \ No newline at end of file