Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor device data #437

Merged
merged 9 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 19 additions & 83 deletions internal/device_data.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package internal

import (
"sync"
)

const (
bitOTKCount int = iota
bitFallbackKeyTypes
Expand All @@ -18,105 +14,45 @@ func isBitSet(n int, bit int) bool {
return val > 0
}

// DeviceData contains useful data for this user's device. This list can be expanded without prompting
// schema changes. These values are upserted into the database and persisted forever.
// DeviceData contains useful data for this user's device.
type DeviceData struct {
DeviceListChanges
DeviceKeyData
UserID string
DeviceID string
}

// This is calculated from device_lists table
type DeviceListChanges struct {
DeviceListChanged []string
DeviceListLeft []string
}

// This gets serialised as CBOR in device_data table
type DeviceKeyData struct {
// Contains the latest device_one_time_keys_count values.
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
OTKCounts MapStringInt `json:"otk"`
// Contains the latest device_unused_fallback_key_types value
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
FallbackKeyTypes []string `json:"fallback"`

DeviceLists DeviceLists `json:"dl"`

// bitset for which device data changes are present. They accumulate until they get swapped over
// when they get reset
ChangedBits int `json:"c"`

UserID string
DeviceID string
}

func (dd *DeviceData) SetOTKCountChanged() {
func (dd *DeviceKeyData) SetOTKCountChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount)
}

func (dd *DeviceData) SetFallbackKeysChanged() {
func (dd *DeviceKeyData) SetFallbackKeysChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes)
}

func (dd *DeviceData) OTKCountChanged() bool {
func (dd *DeviceKeyData) OTKCountChanged() bool {
return isBitSet(dd.ChangedBits, bitOTKCount)
}
func (dd *DeviceData) FallbackKeysChanged() bool {
func (dd *DeviceKeyData) FallbackKeysChanged() bool {
return isBitSet(dd.ChangedBits, bitFallbackKeyTypes)
}

type UserDeviceKey struct {
UserID string
DeviceID string
}

type DeviceDataMap struct {
deviceDataMu *sync.Mutex
deviceDataMap map[UserDeviceKey]*DeviceData
Pos int64
}

func NewDeviceDataMap(startPos int64, devices []DeviceData) *DeviceDataMap {
ddm := &DeviceDataMap{
deviceDataMu: &sync.Mutex{},
deviceDataMap: make(map[UserDeviceKey]*DeviceData),
Pos: startPos,
}
for i, dd := range devices {
ddm.deviceDataMap[UserDeviceKey{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}] = &devices[i]
}
return ddm
}

func (d *DeviceDataMap) Get(userID, deviceID string) *DeviceData {
key := UserDeviceKey{
UserID: userID,
DeviceID: deviceID,
}
d.deviceDataMu.Lock()
defer d.deviceDataMu.Unlock()
dd, ok := d.deviceDataMap[key]
if !ok {
return nil
}
return dd
}

func (d *DeviceDataMap) Update(dd DeviceData) DeviceData {
key := UserDeviceKey{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}
d.deviceDataMu.Lock()
defer d.deviceDataMu.Unlock()
existing, ok := d.deviceDataMap[key]
if !ok {
existing = &DeviceData{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}
}
if dd.OTKCounts != nil {
existing.OTKCounts = dd.OTKCounts
}
if dd.FallbackKeyTypes != nil {
existing.FallbackKeyTypes = dd.FallbackKeyTypes
}
existing.DeviceLists = existing.DeviceLists.Combine(dd.DeviceLists)

d.deviceDataMap[key] = existing

return *existing
}
88 changes: 44 additions & 44 deletions state/device_data_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ type DeviceDataRow struct {
ID int64 `db:"id"`
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
// This will contain internal.DeviceData serialised as JSON. It's stored in a single column as we don't
// This will contain internal.DeviceKeyData serialised as JSON. It's stored in a single column as we don't
// need to perform searches on this data.
Data []byte `db:"data"`
KeyData []byte `db:"data"`
}

type DeviceDataTable struct {
db *sqlx.DB
db *sqlx.DB
deviceListTable *DeviceListTable
}

func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
Expand All @@ -37,14 +38,16 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
ALTER TABLE syncv3_device_data SET (fillfactor = 90);
`)
return &DeviceDataTable{
db: db,
db: db,
deviceListTable: NewDeviceListTable(db),
}
}

// Atomically select the device data for this user|device and then swap DeviceLists around if set.
// This should only be called by the v3 HTTP APIs when servicing an E2EE extension request.
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// grab otk counts and fallback key types
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil {
Expand All @@ -54,32 +57,38 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
}
return err
}
result = &internal.DeviceData{}
var keyData *internal.DeviceKeyData
// unmarshal to swap
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
return err
}
decMode, err := opts.DecMode()
result.UserID = userID
result.DeviceID = deviceID
if keyData != nil {
result.DeviceKeyData = *keyData
}

deviceListChanges, err := t.deviceListTable.SelectTx(txn, userID, deviceID, swap)
if err != nil {
return err
}
if err = decMode.Unmarshal(row.Data, &result); err != nil {
return err
for targetUserID, targetState := range deviceListChanges {
switch targetState {
case internal.DeviceListChanged:
result.DeviceListChanged = append(result.DeviceListChanged, targetUserID)
case internal.DeviceListLeft:
result.DeviceListLeft = append(result.DeviceListLeft, targetUserID)
}
}
result.UserID = userID
result.DeviceID = deviceID
if !swap {
return nil // don't swap
}
// the caller will only look at sent, so make sure what is new is now in sent
result.DeviceLists.Sent = result.DeviceLists.New

// swap over the fields
writeBack := *result
writeBack.DeviceLists.Sent = result.DeviceLists.New
writeBack.DeviceLists.New = make(map[string]int)
writeBack := *keyData
writeBack.ChangedBits = 0

if reflect.DeepEqual(result, &writeBack) {
if reflect.DeepEqual(keyData, &writeBack) {
// The update to the DB would be a no-op; don't bother with it.
// This helps reduce write usage and the contention on the unique index for
// the device_data table.
Expand All @@ -97,52 +106,43 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
return
}

func (t *DeviceDataTable) DeleteDevice(userID, deviceID string) error {
_, err := t.db.Exec(`DELETE FROM syncv3_device_data WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
return err
}

// Upsert combines what is in the database for this user|device with the partial entry `dd`
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
func (t *DeviceDataTable) Upsert(userID, deviceID string, keys internal.DeviceKeyData, deviceListChanges map[string]int) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// Update device lists
if err = t.deviceListTable.UpsertTx(txn, userID, deviceID, deviceListChanges); err != nil {
return err
}
// select what already exists
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil && err != sql.ErrNoRows {
return err
}
// unmarshal and combine
var tempDD internal.DeviceData
if len(row.Data) > 0 {
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
}
decMode, err := opts.DecMode()
if err != nil {
return err
}
if err = decMode.Unmarshal(row.Data, &tempDD); err != nil {
var keyData internal.DeviceKeyData
if len(row.KeyData) > 0 {
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
return err
}
}
if dd.FallbackKeyTypes != nil {
tempDD.FallbackKeyTypes = dd.FallbackKeyTypes
tempDD.SetFallbackKeysChanged()
if keys.FallbackKeyTypes != nil {
keyData.FallbackKeyTypes = keys.FallbackKeyTypes
keyData.SetFallbackKeysChanged()
}
if dd.OTKCounts != nil {
tempDD.OTKCounts = dd.OTKCounts
tempDD.SetOTKCountChanged()
if keys.OTKCounts != nil {
keyData.OTKCounts = keys.OTKCounts
keyData.SetOTKCountChanged()
}
tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists)

data, err := cbor.Marshal(tempDD)
data, err := cbor.Marshal(keyData)
if err != nil {
return err
}
_, err = txn.Exec(
`INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3)
ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3`,
dd.UserID, dd.DeviceID, data,
userID, deviceID, data,
)
return err
})
Expand Down
Loading
Loading