Skip to content

Commit

Permalink
Merge pull request #41586 from gabilang/fix-so-with-map-equals
Browse files Browse the repository at this point in the history
Fix StackOverflow with equals check of maps with cyclic references
  • Loading branch information
warunalakshitha authored Mar 27, 2024
2 parents 23620b3 + f5ab4a1 commit 9dd67e9
Show file tree
Hide file tree
Showing 24 changed files with 409 additions and 374 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

import static io.ballerina.runtime.api.constants.RuntimeConstants.ARRAY_LANG_LIB;
import static io.ballerina.runtime.internal.TypeChecker.isEqual;
import static io.ballerina.runtime.internal.errors.ErrorCodes.INVALID_READONLY_VALUE_UPDATE;
import static io.ballerina.runtime.internal.errors.ErrorReasons.INVALID_UPDATE_ERROR_IDENTIFIER;
import static io.ballerina.runtime.internal.errors.ErrorReasons.getModulePrefixedReason;
Expand Down Expand Up @@ -73,6 +75,28 @@ public void append(Object value) {
add(size, value);
}

@Override
public boolean equals(Object o, Set<ValuePair> visitedValues) {
ValuePair compValuePair = new ValuePair(this, o);
for (ValuePair valuePair : visitedValues) {
if (valuePair.equals(compValuePair)) {
return true;
}
}
visitedValues.add(compValuePair);

ArrayValue arrayValue = (ArrayValue) o;
if (arrayValue.size() != this.size()) {
return false;
}
for (int i = 0; i < this.size(); i++) {
if (!isEqual(this.get(i), arrayValue.get(i), visitedValues)) {
return false;
}
}
return true;
}

@Override
public Object reverse() {
throw new UnsupportedOperationException("reverse for tuple types is not supported directly.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
Expand Down Expand Up @@ -1285,12 +1287,33 @@ private int getCurrentArrayLength() {
@Override
public int hashCode() {
int result = Objects.hash(type, elementType);
result = 31 * result + Arrays.hashCode(refValues);
result = 31 * result + calculateHashCode(new ArrayList<>());
result = 31 * result + Arrays.hashCode(intValues);
result = 31 * result + Arrays.hashCode(booleanValues);
result = 31 * result + Arrays.hashCode(byteValues);
result = 31 * result + Arrays.hashCode(floatValues);
result = 31 * result + Arrays.hashCode(bStringValues);
return result;
}

private int calculateHashCode(List<Object> visited) {
if (refValues == null) {
return 0;
}

int result = 1;
if (visited.contains(refValues)) {
return 31 * result + System.identityHashCode(refValues);
}
visited.add(refValues);

for (Object ref : refValues) {
if (ref instanceof ArrayValueImpl) {
result = 31 * result + calculateHashCode(visited);
} else {
result = 31 * result + (ref == null ? 0 : ref.hashCode());
}
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,4 @@ private String getNonBmpCharWithSurrogates(long currentIndex) {
public boolean hasNext() {
return cursor < length;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.StringJoiner;

import static io.ballerina.runtime.api.PredefinedTypes.TYPE_MAP;
import static io.ballerina.runtime.api.constants.RuntimeConstants.BLANG_SRC_FILE_SUFFIX;
import static io.ballerina.runtime.api.constants.RuntimeConstants.DOT;
import static io.ballerina.runtime.api.constants.RuntimeConstants.MODULE_INIT_CLASS_NAME;
import static io.ballerina.runtime.internal.TypeChecker.isEqual;
import static io.ballerina.runtime.internal.util.StringUtils.getExpressionStringVal;
import static io.ballerina.runtime.internal.util.StringUtils.getStringVal;

Expand Down Expand Up @@ -448,4 +450,19 @@ private String cleanupClassName(String className) {
private boolean isCompilerAddedName(String name) {
return name != null && name.startsWith("$") && name.endsWith("$");
}

/**
* Deep equality check for error values.
*
* @param o The error value to be compared
* @param visitedValues Visited values due to circular references
* @return True if the error values are equal, false otherwise
*/
@Override
public boolean equals(Object o, Set<ValuePair> visitedValues) {
ErrorValue errorValue = (ErrorValue) o;
return isEqual(this.getMessage(), errorValue.getMessage(), visitedValues) &&
((MapValueImpl<?, ?>) this.getDetails()).equals(errorValue.getDetails(), visitedValues) &&
isEqual(this.getCause(), errorValue.getCause(), visitedValues);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import static io.ballerina.runtime.api.constants.RuntimeConstants.MAP_LANG_LIB;
import static io.ballerina.runtime.api.utils.TypeUtils.getImpliedType;
import static io.ballerina.runtime.internal.JsonInternalUtils.mergeJson;
import static io.ballerina.runtime.internal.TypeChecker.isEqual;
import static io.ballerina.runtime.internal.ValueUtils.getTypedescValue;
import static io.ballerina.runtime.internal.errors.ErrorCodes.INVALID_READONLY_VALUE_UPDATE;
import static io.ballerina.runtime.internal.errors.ErrorReasons.INVALID_UPDATE_ERROR_IDENTIFIER;
Expand Down Expand Up @@ -354,30 +355,35 @@ public boolean containsKey(Object key) {
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}

if (o == null || getClass() != o.getClass()) {
return false;
public boolean equals(Object o, Set<ValuePair> visitedValues) {
ValuePair compValuePair = new ValuePair(this, o);
for (ValuePair valuePair : visitedValues) {
if (valuePair.equals(compValuePair)) {
return true;
}
}
visitedValues.add(compValuePair);

MapValueImpl<?, ?> mapValue = (MapValueImpl<?, ?>) o;

if (mapValue.type.getTag() != this.type.getTag()) {
if (!(o instanceof MapValueImpl<?, ?> mapValue)) {
return false;
}

if (mapValue.referredType.getTag() != this.referredType.getTag()) {
if (this.entrySet().size() != mapValue.entrySet().size()) {
return false;
}

if (this.entrySet().size() != mapValue.entrySet().size()) {
if (!this.keySet().containsAll(mapValue.keySet())) {
return false;
}

return entrySet().equals(mapValue.entrySet());
Iterator<Map.Entry<K, V>> mapIterator = this.entrySet().iterator();
while (mapIterator.hasNext()) {
Map.Entry<K, V> lhsMapEntry = mapIterator.next();
if (!isEqual(lhsMapEntry.getValue(), mapValue.get(lhsMapEntry.getKey()), visitedValues)) {
return false;
}
}
return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import io.ballerina.runtime.api.values.BRefValue;

import java.util.Set;

/**
* <p>
* Interface to be implemented by all the reference types.
Expand All @@ -31,4 +33,7 @@
*/
public interface RefValue extends SimpleValue, BRefValue {

default boolean equals(Object o, Set<ValuePair> visitedValues) {
return o.equals(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public void setReQuantifier(RegExpQuantifier reQuantifier) {
}

private Object getValidReAtom(Object reAtom) {
// If reAtom is an instance of BString it's an insertion. Hence we need to parse it and check whether it's a
// If reAtom is an instance of BString it's an insertion. Hence, we need to parse it and check whether it's a
// valid insertion.
if (reAtom instanceof BString) {
validateInsertion((BString) reAtom);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
* @since 2201.3.0
*/
public class RegExpCapturingGroup extends RegExpCommonValue implements RegExpAtom {
private String openParen;
private RegExpFlagExpression flagExpr;
private RegExpDisjunction reDisjunction;
private String closeParen;
private final String openParen;
private final RegExpFlagExpression flagExpr;
private final RegExpDisjunction reDisjunction;
private final String closeParen;

public RegExpCapturingGroup(String openParen, RegExpFlagExpression flagExpr,
RegExpDisjunction reDisjunction, String closeParen) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
* @since 2201.3.0
*/
public class RegExpFlagExpression extends RegExpCommonValue {
private String questionMark;
private RegExpFlagOnOff flagsOnOff;
private String colon;
private final String questionMark;
private final RegExpFlagOnOff flagsOnOff;
private final String colon;

public RegExpFlagExpression(String questionMark, RegExpFlagOnOff flagsOnOff, String colon) {
this.questionMark = questionMark;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
* @since 2201.3.0
*/
public class RegExpFlagOnOff extends RegExpCommonValue {
private String flags;
private final String flags;

public RegExpFlagOnOff(String flags) {
this.flags = flags;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
* @since 2201.3.0
*/
public class RegExpQuantifier extends RegExpCommonValue {
private String quantifier;
private String nonGreedyChar;
private final String quantifier;
private final String nonGreedyChar;

public RegExpQuantifier(String quantifier, String nonGreedyChar) {
this.quantifier = quantifier;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
* @since 2201.3.0
*/
public class RegExpSequence extends RegExpCommonValue {
private RegExpTerm[] termsList;
private final RegExpTerm[] termsList;

public RegExpSequence(ArrayValue termsList) {
this.termsList = getRegExpSeqList(termsList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.util.Map;
import java.util.Objects;
import java.util.Set;

import static io.ballerina.runtime.internal.ValueUtils.getTypedescValue;

Expand Down Expand Up @@ -106,4 +107,19 @@ public void freezeDirect() {
public String toString() {
return this.stringValue(null);
}

/**
* Deep equality check for regular expression.
*
* @param o The regular expression on the right hand side
* @param visitedValues Visited values in order to break cyclic references.
* @return True if the regular expressions are equal, else false.
*/
@Override
public boolean equals(Object o, Set<ValuePair> visitedValues) {
if (!(o instanceof RegExpValue rhsRegExpValue)) {
return false;
}
return this.stringValue(null).equals(rhsRegExpValue.stringValue(null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
import java.util.concurrent.ConcurrentHashMap;

import static io.ballerina.runtime.api.constants.RuntimeConstants.TABLE_LANG_LIB;
import static io.ballerina.runtime.api.utils.TypeUtils.getImpliedType;
import static io.ballerina.runtime.internal.TypeChecker.isEqual;
import static io.ballerina.runtime.internal.ValueUtils.getTypedescValue;
import static io.ballerina.runtime.internal.errors.ErrorReasons.INHERENT_TYPE_VIOLATION_ERROR_IDENTIFIER;
import static io.ballerina.runtime.internal.errors.ErrorReasons.OPERATION_NOT_SUPPORTED_ERROR;
Expand Down Expand Up @@ -461,6 +463,47 @@ public Type getIteratorNextReturnType() {
return iteratorNextReturnType;
}

/**
* Check whether the given table value is equal to the current value.
*
* @param o the value to check equality with
* @param visitedValues the values that have already been visited
* @return true if the current value is equal to the given value
*/
@Override
public boolean equals(Object o, Set<ValuePair> visitedValues) {
ValuePair compValuePair = new ValuePair(this, o);
for (ValuePair valuePair : visitedValues) {
if (valuePair.equals(compValuePair)) {
return true;
}
}
visitedValues.add(compValuePair);

if (!(o instanceof TableValueImpl<?, ?> table)) {
return false;
}
if (this.size() != table.size()) {
return false;
}

boolean isLhsKeyedTable =
((BTableType) getImpliedType(this.getType())).getFieldNames().length > 0;
boolean isRhsKeyedTable =
((BTableType) getImpliedType(table.getType())).getFieldNames().length > 0;
Object[] lhsTableValues = this.values().toArray();
Object[] rhsTableValues = table.values().toArray();
if (isLhsKeyedTable != isRhsKeyedTable) {
return false;
}
for (int i = 0; i < lhsTableValues.length; i++) {
if (!isEqual(lhsTableValues[i], rhsTableValues[i], visitedValues)) {
return false;
}
}
return true;
}

private class TableIterator implements IteratorValue {
private long cursor;

Expand Down Expand Up @@ -592,7 +635,7 @@ public V getData(K key) {
return null;
}
for (Map.Entry<K, V> entry: entryList) {
if (TypeChecker.isEqual(key, entry.getKey())) {
if (isEqual(key, entry.getKey())) {
return entry.getValue();
}
}
Expand Down Expand Up @@ -663,7 +706,7 @@ public V remove(K key) {
List<Map.Entry<K, V>> entryList = entries.get(hash);
if (entryList != null && entryList.size() > 1) {
for (Map.Entry<K, V> entry: entryList) {
if (TypeChecker.isEqual(key, entry.getKey())) {
if (isEqual(key, entry.getKey())) {
List<V> valueList = values.get(hash);
valueList.remove(entry.getValue());
entryList.remove(entry);
Expand Down Expand Up @@ -695,7 +738,7 @@ public boolean containsKey(K key) {
if (entries.containsKey(TableUtils.hash(key, null))) {
List<Map.Entry<K, V>> entryList = entries.get(TableUtils.hash(key, null));
for (Map.Entry<K, V> entry: entryList) {
if (TypeChecker.isEqual(entry.getKey(), key)) {
if (isEqual(entry.getKey(), key)) {
return true;
}
}
Expand Down
Loading

0 comments on commit 9dd67e9

Please sign in to comment.