Skip to content

Commit

Permalink
Merge pull request #41722 from nipunayf/fix-hash-collisions
Browse files Browse the repository at this point in the history
Fix invalid behavior in `TableIterator` with hash-collided keys
  • Loading branch information
KavinduZoysa authored Jan 2, 2024
2 parents 617e5e4 + bb13478 commit a09cc05
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;
import java.util.TreeMap;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

Expand Down Expand Up @@ -88,15 +89,14 @@ public class TableValueImpl<K, V> implements TableValue<K, V> {
private Type iteratorNextReturnType;
private ConcurrentHashMap<Long, List<Map.Entry<K, V>>> entries;
private LinkedHashMap<Long, List<V>> values;
private LinkedHashMap<Long, K> keys;
private String[] fieldNames;
private ValueHolder valueHolder;
private long maxIntKey = 0;

//These are required to achieve the iterator behavior
private LinkedHashMap<Long, Long> indexToKeyMap;
private LinkedHashMap<Long, Long> keyToIndexMap;
private LinkedHashMap<Long, KeyValuePair<K, V>> keyValues;
private Map<Long, K> indexToKeyMap;
private Map<K, Long> keyToIndexMap;
private Map<K, V> keyValues;
private long noOfAddedEntries = 0;

private boolean nextKeySupported;
Expand All @@ -108,10 +108,9 @@ public TableValueImpl(TableType tableType) {
this.type = this.tableType = tableType;

this.entries = new ConcurrentHashMap<>();
this.keys = new LinkedHashMap<>();
this.values = new LinkedHashMap<>();
this.keyToIndexMap = new LinkedHashMap<>();
this.indexToKeyMap = new LinkedHashMap<>();
this.indexToKeyMap = new TreeMap<>();
this.fieldNames = tableType.getFieldNames();
this.keyValues = new LinkedHashMap<>();
if (tableType.getFieldNames().length > 0) {
Expand Down Expand Up @@ -157,7 +156,7 @@ private void addData(ArrayValue data) {

@Override
public IteratorValue getIterator() {
return new TableIterator<K, V>();
return new TableIterator();
}

@Override
Expand Down Expand Up @@ -270,7 +269,6 @@ public Collection<V> values() {
public void clear() {
handleFrozenTableValue();
entries.clear();
keys.clear();
values.clear();
keyToIndexMap.clear();
indexToKeyMap.clear();
Expand Down Expand Up @@ -303,7 +301,7 @@ public long getNextKey() {
+ "The key sequence should only have an " +
"Integer field."));
}
return keys.size() == 0 ? 0 : (this.maxIntKey + 1);
return indexToKeyMap.size() == 0 ? 0 : (this.maxIntKey + 1);
}

public Type getKeyType() {
Expand Down Expand Up @@ -331,13 +329,7 @@ public V fillAndGet(Object key) {

@Override
public K[] getKeys() {
Object[] keyArr = new Object[keys.size()];
int i = 0;
for (K key : keys.values()) {
keyArr[i] = key;
i++;
}
return (K[]) keyArr;
return (K[]) indexToKeyMap.values().toArray();
}

@Override
Expand Down Expand Up @@ -469,7 +461,7 @@ public Type getIteratorNextReturnType() {
return iteratorNextReturnType;
}

private class TableIterator<K, V> implements IteratorValue {
private class TableIterator implements IteratorValue {
private long cursor;

TableIterator() {
Expand All @@ -478,11 +470,9 @@ private class TableIterator<K, V> implements IteratorValue {

@Override
public Object next() {
Long hash = indexToKeyMap.get(cursor);
if (hash != null) {
KeyValuePair<K, V> keyValuePair = (KeyValuePair<K, V>) keyValues.get(hash);
K key = keyValuePair.getKey();
V value = keyValuePair.getValue();
if (indexToKeyMap.containsKey(cursor)) {
K key = indexToKeyMap.get(cursor);
V value = keyValues.get(key);

List<Type> types = new ArrayList<>();
types.add(TypeChecker.getType(key));
Expand Down Expand Up @@ -533,10 +523,9 @@ public V putData(V data) {
entryList.add(entry);
UUID uuid = UUID.randomUUID();
Long hash = (long) uuid.hashCode();
updateIndexKeyMappings(hash, (K) data, data);
entries.put(hash, entryList);
updateIndexKeyMappings((K) data, hash);
values.put(hash, newData);
keyValues.put(hash, KeyValuePair.of((K) data, data));
return data;
}

Expand Down Expand Up @@ -578,20 +567,19 @@ public void addData(V data) {
ErrorHelper.getErrorDetails(ErrorCodes.TABLE_HAS_A_VALUE_FOR_KEY, key));
}

if (nextKeySupported && (keys.size() == 0 || maxIntKey < TypeChecker.anyToInt(key))) {
if (nextKeySupported && (indexToKeyMap.size() == 0 || maxIntKey < TypeChecker.anyToInt(key))) {
maxIntKey = ((Long) TypeChecker.anyToInt(key)).intValue();
}

Long hash = TableUtils.hash(key, null);

if (keys.containsKey(hash)) {
if (entries.containsKey(hash)) {
updateIndexKeyMappings(hash, key, data);
List<Map.Entry<K, V>> extEntries = entries.get(hash);
Map.Entry<K, V> entry = new AbstractMap.SimpleEntry(key, data);
extEntries.add(entry);
List<V> extValues = values.get(hash);
extValues.add(data);
keyValues.put(hash, KeyValuePair.of(key, data));
updateIndexKeyMappings(key, hash);
return;
}

Expand Down Expand Up @@ -632,13 +620,11 @@ public V putData(K key, V data) {
}

private V putData(K key, V value, List<V> data, Map.Entry<K, V> entry, Long hash) {
updateIndexKeyMappings(hash, key, value);
List<Map.Entry<K, V>> entryList = new ArrayList<>();
entryList.add(entry);
entries.put(hash, entryList);
keys.put(hash, key);
updateIndexKeyMappings(key, hash);
values.put(hash, data);
keyValues.put(hash, KeyValuePair.of(key, value));
return data.get(0);
}

Expand All @@ -656,16 +642,16 @@ public V putData(V data) {
}

public V remove(K key) {
keyValues.remove(key);
Long hash = TableUtils.hash(key, null);
keyValues.remove(hash);
List<Map.Entry<K, V>> entryList = entries.get(hash);
if (entryList != null && entryList.size() > 1) {
for (Map.Entry<K, V> entry: entryList) {
if (TypeChecker.isEqual(key, entry.getKey())) {
List<V> valueList = values.get(hash);
valueList.remove(entry.getValue());
entryList.remove(entry);
Long index = keyToIndexMap.remove(hash);
Long index = keyToIndexMap.remove(key);
indexToKeyMap.remove(index);
if (index != null && index == noOfAddedEntries - 1) {
noOfAddedEntries--;
Expand All @@ -674,13 +660,14 @@ public V remove(K key) {
}
}
}
entries.remove(hash);
keys.remove(hash);
Long index = keyToIndexMap.remove(hash);
indexToKeyMap.remove(index);
if (index != null && index == noOfAddedEntries - 1) {
noOfAddedEntries--;
if (entryList != null) {
Long index = keyToIndexMap.remove(entryList.get(0).getKey());
indexToKeyMap.remove(index);
if (index != null && index == noOfAddedEntries - 1) {
noOfAddedEntries--;
}
}
entries.remove(hash);
List<V> removedValue = values.remove(hash);
if (removedValue == null) {
return null;
Expand All @@ -689,7 +676,7 @@ public V remove(K key) {
}

public boolean containsKey(K key) {
if (keys.containsKey(TableUtils.hash(key, null))) {
if (entries.containsKey(TableUtils.hash(key, null))) {
List<Map.Entry<K, V>> entryList = entries.get(TableUtils.hash(key, null));
for (Map.Entry<K, V> entry: entryList) {
if (TypeChecker.isEqual(entry.getKey(), key)) {
Expand Down Expand Up @@ -748,35 +735,25 @@ public K wrapKey(MapValue data) {
}
}

private static final class KeyValuePair<K, V> {
private K key;
private V value;

public KeyValuePair(K key, V value) {
this.key = key;
this.value = value;
}

public static <K, V> KeyValuePair<K, V> of(K key, V value) {
return new KeyValuePair<>(key, value);
}

public K getKey() {
return key;
}

public V getValue() {
return value;
}
}

// This method updates the indexes and the order required by the iterators
private void updateIndexKeyMappings(K key, Long hash) {
if (!keyToIndexMap.containsKey(hash)) {
keyToIndexMap.put(hash, noOfAddedEntries);
indexToKeyMap.put(noOfAddedEntries, hash);
noOfAddedEntries++;
private void updateIndexKeyMappings(Long hash, K key, V value) {
if (entries.containsKey(hash)) {
List<Map.Entry<K, V>> entryList = entries.get(hash);
for (Map.Entry<K, V> entry: entryList) {
if (TypeChecker.isEqual(entry.getKey(), key)) {
long index = keyToIndexMap.remove(entry.getKey());
keyToIndexMap.put(key, index);
indexToKeyMap.put(index, key);
keyValues.remove(entry.getKey());
keyValues.put(key, value);
return;
}
}
}
keyToIndexMap.put(key, noOfAddedEntries);
indexToKeyMap.put(noOfAddedEntries, key);
keyValues.put(key, value);
noOfAddedEntries++;
}

// This method checks for inherent table type violation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ public void testHashCollisionHandlingScenarios() {
BRunUtil.invoke(compileResult, "testHashCollisionHandlingScenarios");
}

@Test
public void testHashCollisionInQuery() {
BRunUtil.invoke(compileResult, "testHashCollisionInQuery");
}

@Test
public void testGetKeysOfHashCollidedKeys() {
BRunUtil.invoke(compileResult, "testGetKeysOfHashCollidedKeys");
}

@Test
public void testGetKeyList() {
Object result = BRunUtil.invoke(compileResult, "testGetKeyList");
Expand Down
32 changes: 32 additions & 0 deletions langlib/langlib-test/src/test/resources/test-src/tablelib_test.bal
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,38 @@ function testHashCollisionHandlingScenarios() {

}

function testHashCollisionInQuery() {
table<record {readonly int|string|float? k;}> key(k) tbl1 = table [{k: "10"}];
table<record {readonly int|string|float? k;}> tbl2 = table [{k: 0}];

tbl1.add({k: 5});
tbl1.add({k: ()});
tbl1.add({k: -31});
tbl1.add({k: 0});
tbl1.add({k: 100.05});
tbl1.add({k: 30});
table<record {|readonly int|string|float? k; anydata...;|}> tbl3 =
from var tid in tbl1
where tid["k"] == 0
select tid;
assertEquals(tbl2, tbl3);

_ = tbl1.remove(());
table<record {|readonly int|string|float? k; anydata...;|}> tbl4 =
from var tid in tbl1
where tid["k"] == 0
select tid;
assertEquals(tbl2, tbl4);
}

public function testGetKeysOfHashCollidedKeys() {
table<record {readonly int? k;}> key(k) tbl1 = table [
{k: 5}, {k: 0}, {k: ()}, {k: 2}
];

assertEquals(tbl1.keys(), [5, 0, (), 2]);
}

function testGetKeyList() returns any[] {
return tab.keys();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ function testStringAsKeyValue() {
table<Row12> tbl1 = from Row12 r in tbl
where r.key1 == "k1" && r.key2 == "k2"
select r;
assertEqual(tbl1, tbl2);
assertEqual(tbl2, tbl1);

// Test the get method of the table
readonly & string keyString = "k1";
Expand Down Expand Up @@ -639,7 +639,7 @@ function testStringAsCompositeKeyValue() {
table<Row12> tbl1 = from Row12 r in tbl
where r.key1 == "k1" && r.key2 == "k2"
select r;
assertEqual(tbl1, tbl2);
assertEqual(tbl2, tbl1);

// Test the get method of the table
[string & readonly, string & readonly] keyTuple = ["k1", "k2"];
Expand Down Expand Up @@ -679,7 +679,7 @@ function testMapAsCompositeKeyValue() {
table<Row13> tbl1 = from Row13 r in tbl
where r.keys == {"k1": "v1", "k2": "v2"} && r.values == {"k1": 1, "k2": 2}
select r;
assertEqual(tbl1, tbl2);
assertEqual(tbl2, tbl1);

// Test the get method of the table
[map<string> & readonly, map<int> & readonly] keyTuple = [{"k1": "v1", "k2": "v2"}, {"k1": 1, "k2": 2}];
Expand Down Expand Up @@ -719,7 +719,7 @@ function testArrayAsCompositeKeyValue() {
table<Row14> tbl1 = from Row14 r in tbl
where r.keys == ["k1", "k2"] && r.values == [1, 2]
select r;
assertEqual(tbl1, tbl2);
assertEqual(tbl2, tbl1);

// Test the get method of the table
[string[] & readonly, int[] & readonly] keyTuple = [["k1", "k2"], [1, 2]];
Expand Down

0 comments on commit a09cc05

Please sign in to comment.