Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix StackOverflow with equals check of maps with cyclic references #41586

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

import static io.ballerina.runtime.api.constants.RuntimeConstants.ARRAY_LANG_LIB;
import static io.ballerina.runtime.internal.TypeChecker.isEqual;
import static io.ballerina.runtime.internal.errors.ErrorCodes.INVALID_READONLY_VALUE_UPDATE;
import static io.ballerina.runtime.internal.errors.ErrorReasons.INVALID_UPDATE_ERROR_IDENTIFIER;
import static io.ballerina.runtime.internal.errors.ErrorReasons.getModulePrefixedReason;
Expand Down Expand Up @@ -73,6 +75,28 @@ public void append(Object value) {
add(size, value);
}

@Override
public boolean equals(Object o, Set<ValuePair> visitedValues) {
ValuePair compValuePair = new ValuePair(this, o);
for (ValuePair valuePair : visitedValues) {
if (valuePair.equals(compValuePair)) {
return true;
}
}
visitedValues.add(compValuePair);

gabilang marked this conversation as resolved.
Show resolved Hide resolved
ArrayValue arrayValue = (ArrayValue) o;
if (arrayValue.size() != this.size()) {
return false;
}
for (int i = 0; i < this.size(); i++) {
if (!isEqual(this.get(i), arrayValue.get(i), visitedValues)) {
return false;
}
}
return true;
}

@Override
public Object reverse() {
throw new UnsupportedOperationException("reverse for tuple types is not supported directly.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
Expand Down Expand Up @@ -1285,12 +1287,33 @@ private int getCurrentArrayLength() {
@Override
public int hashCode() {
int result = Objects.hash(type, elementType);
result = 31 * result + Arrays.hashCode(refValues);
result = 31 * result + calculateHashCode(new ArrayList<>());
result = 31 * result + Arrays.hashCode(intValues);
result = 31 * result + Arrays.hashCode(booleanValues);
result = 31 * result + Arrays.hashCode(byteValues);
result = 31 * result + Arrays.hashCode(floatValues);
result = 31 * result + Arrays.hashCode(bStringValues);
return result;
}

private int calculateHashCode(List<Object> visited) {
if (refValues == null) {
return 0;
}

int result = 1;
if (visited.contains(refValues)) {
return 31 * result + System.identityHashCode(refValues);
}
visited.add(refValues);

for (Object ref : refValues) {
if (ref instanceof ArrayValueImpl) {
result = 31 * result + calculateHashCode(visited);
} else {
result = 31 * result + (ref == null ? 0 : ref.hashCode());
}
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,4 @@ private String getNonBmpCharWithSurrogates(long currentIndex) {
public boolean hasNext() {
return cursor < length;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.StringJoiner;

import static io.ballerina.runtime.api.PredefinedTypes.TYPE_MAP;
import static io.ballerina.runtime.api.constants.RuntimeConstants.BLANG_SRC_FILE_SUFFIX;
import static io.ballerina.runtime.api.constants.RuntimeConstants.DOT;
import static io.ballerina.runtime.api.constants.RuntimeConstants.MODULE_INIT_CLASS_NAME;
import static io.ballerina.runtime.internal.TypeChecker.isEqual;
import static io.ballerina.runtime.internal.util.StringUtils.getExpressionStringVal;
import static io.ballerina.runtime.internal.util.StringUtils.getStringVal;

Expand Down Expand Up @@ -448,4 +450,19 @@
private boolean isCompilerAddedName(String name) {
return name != null && name.startsWith("$") && name.endsWith("$");
}

/**
* Deep equality check for error values.
*
* @param o The error value to be compared
* @param visitedValues Visited values due to circular references
* @return True if the error values are equal, false otherwise
*/
@Override
public boolean equals(Object o, Set<ValuePair> visitedValues) {
ErrorValue errorValue = (ErrorValue) o;

Check warning on line 463 in bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/internal/values/ErrorValue.java

View check run for this annotation

Codecov / codecov/patch

bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/internal/values/ErrorValue.java#L463

Added line #L463 was not covered by tests
return isEqual(this.getMessage(), errorValue.getMessage(), visitedValues) &&
((MapValueImpl<?, ?>) this.getDetails()).equals(errorValue.getDetails(), visitedValues) &&
isEqual(this.getCause(), errorValue.getCause(), visitedValues);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import static io.ballerina.runtime.api.constants.RuntimeConstants.MAP_LANG_LIB;
import static io.ballerina.runtime.api.utils.TypeUtils.getImpliedType;
import static io.ballerina.runtime.internal.JsonInternalUtils.mergeJson;
import static io.ballerina.runtime.internal.TypeChecker.isEqual;
import static io.ballerina.runtime.internal.ValueUtils.getTypedescValue;
import static io.ballerina.runtime.internal.errors.ErrorCodes.INVALID_READONLY_VALUE_UPDATE;
import static io.ballerina.runtime.internal.errors.ErrorReasons.INVALID_UPDATE_ERROR_IDENTIFIER;
Expand Down Expand Up @@ -354,30 +355,35 @@ public boolean containsKey(Object key) {
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}

if (o == null || getClass() != o.getClass()) {
return false;
public boolean equals(Object o, Set<ValuePair> visitedValues) {
ValuePair compValuePair = new ValuePair(this, o);
for (ValuePair valuePair : visitedValues) {
if (valuePair.equals(compValuePair)) {
return true;
}
}
visitedValues.add(compValuePair);

MapValueImpl<?, ?> mapValue = (MapValueImpl<?, ?>) o;

if (mapValue.type.getTag() != this.type.getTag()) {
if (!(o instanceof MapValueImpl<?, ?> mapValue)) {
return false;
}

if (mapValue.referredType.getTag() != this.referredType.getTag()) {
if (this.entrySet().size() != mapValue.entrySet().size()) {
return false;
}

if (this.entrySet().size() != mapValue.entrySet().size()) {
if (!this.keySet().containsAll(mapValue.keySet())) {
return false;
}

return entrySet().equals(mapValue.entrySet());
Iterator<Map.Entry<K, V>> mapIterator = this.entrySet().iterator();
while (mapIterator.hasNext()) {
Map.Entry<K, V> lhsMapEntry = mapIterator.next();
if (!isEqual(lhsMapEntry.getValue(), mapValue.get(lhsMapEntry.getKey()), visitedValues)) {
return false;
}
}
return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import io.ballerina.runtime.api.values.BRefValue;

import java.util.Set;

/**
* <p>
* Interface to be implemented by all the reference types.
Expand All @@ -31,4 +33,7 @@
*/
public interface RefValue extends SimpleValue, BRefValue {

default boolean equals(Object o, Set<ValuePair> visitedValues) {
return o.equals(this);

Check warning on line 37 in bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/internal/values/RefValue.java

View check run for this annotation

Codecov / codecov/patch

bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/internal/values/RefValue.java#L37

Added line #L37 was not covered by tests
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public void setReQuantifier(RegExpQuantifier reQuantifier) {
}

private Object getValidReAtom(Object reAtom) {
// If reAtom is an instance of BString it's an insertion. Hence we need to parse it and check whether it's a
// If reAtom is an instance of BString it's an insertion. Hence, we need to parse it and check whether it's a
// valid insertion.
if (reAtom instanceof BString) {
validateInsertion((BString) reAtom);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
* @since 2201.3.0
*/
public class RegExpCapturingGroup extends RegExpCommonValue implements RegExpAtom {
private String openParen;
private RegExpFlagExpression flagExpr;
private RegExpDisjunction reDisjunction;
private String closeParen;
private final String openParen;
private final RegExpFlagExpression flagExpr;
private final RegExpDisjunction reDisjunction;
private final String closeParen;

public RegExpCapturingGroup(String openParen, RegExpFlagExpression flagExpr,
RegExpDisjunction reDisjunction, String closeParen) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
* @since 2201.3.0
*/
public class RegExpFlagExpression extends RegExpCommonValue {
private String questionMark;
private RegExpFlagOnOff flagsOnOff;
private String colon;
private final String questionMark;
private final RegExpFlagOnOff flagsOnOff;
private final String colon;

public RegExpFlagExpression(String questionMark, RegExpFlagOnOff flagsOnOff, String colon) {
this.questionMark = questionMark;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
* @since 2201.3.0
*/
public class RegExpFlagOnOff extends RegExpCommonValue {
private String flags;
private final String flags;

public RegExpFlagOnOff(String flags) {
this.flags = flags;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
* @since 2201.3.0
*/
public class RegExpQuantifier extends RegExpCommonValue {
private String quantifier;
private String nonGreedyChar;
private final String quantifier;
private final String nonGreedyChar;

public RegExpQuantifier(String quantifier, String nonGreedyChar) {
this.quantifier = quantifier;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
* @since 2201.3.0
*/
public class RegExpSequence extends RegExpCommonValue {
private RegExpTerm[] termsList;
private final RegExpTerm[] termsList;

public RegExpSequence(ArrayValue termsList) {
this.termsList = getRegExpSeqList(termsList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.util.Map;
import java.util.Objects;
import java.util.Set;

import static io.ballerina.runtime.internal.ValueUtils.getTypedescValue;

Expand Down Expand Up @@ -106,4 +107,19 @@ public void freezeDirect() {
public String toString() {
return this.stringValue(null);
}

/**
* Deep equality check for regular expression.
*
* @param o The regular expression on the right hand side
* @param visitedValues Visited values in order to break cyclic references.
* @return True if the regular expressions are equal, else false.
*/
@Override
public boolean equals(Object o, Set<ValuePair> visitedValues) {
if (!(o instanceof RegExpValue rhsRegExpValue)) {
return false;
}
return this.stringValue(null).equals(rhsRegExpValue.stringValue(null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
import java.util.concurrent.ConcurrentHashMap;

import static io.ballerina.runtime.api.constants.RuntimeConstants.TABLE_LANG_LIB;
import static io.ballerina.runtime.api.utils.TypeUtils.getImpliedType;
import static io.ballerina.runtime.internal.TypeChecker.isEqual;
import static io.ballerina.runtime.internal.ValueUtils.getTypedescValue;
import static io.ballerina.runtime.internal.errors.ErrorReasons.INHERENT_TYPE_VIOLATION_ERROR_IDENTIFIER;
import static io.ballerina.runtime.internal.errors.ErrorReasons.OPERATION_NOT_SUPPORTED_ERROR;
Expand Down Expand Up @@ -461,6 +463,47 @@
return iteratorNextReturnType;
}

/**
* Check whether the given table value is equal to the current value.
*
* @param o the value to check equality with
* @param visitedValues the values that have already been visited
* @return true if the current value is equal to the given value
*/
@Override
public boolean equals(Object o, Set<ValuePair> visitedValues) {
ValuePair compValuePair = new ValuePair(this, o);
for (ValuePair valuePair : visitedValues) {
if (valuePair.equals(compValuePair)) {
return true;
}
}
visitedValues.add(compValuePair);

if (!(o instanceof TableValueImpl<?, ?> table)) {
return false;

Check warning on line 484 in bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/internal/values/TableValueImpl.java

View check run for this annotation

Codecov / codecov/patch

bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/internal/values/TableValueImpl.java#L484

Added line #L484 was not covered by tests
}
if (this.size() != table.size()) {
return false;
}

boolean isLhsKeyedTable =
((BTableType) getImpliedType(this.getType())).getFieldNames().length > 0;
boolean isRhsKeyedTable =
((BTableType) getImpliedType(table.getType())).getFieldNames().length > 0;
Object[] lhsTableValues = this.values().toArray();
Object[] rhsTableValues = table.values().toArray();
if (isLhsKeyedTable != isRhsKeyedTable) {
return false;
}
for (int i = 0; i < lhsTableValues.length; i++) {
if (!isEqual(lhsTableValues[i], rhsTableValues[i], visitedValues)) {
return false;
}
}
return true;
}

private class TableIterator implements IteratorValue {
private long cursor;

Expand Down Expand Up @@ -595,7 +638,7 @@
return null;
}
for (Map.Entry<K, V> entry: entryList) {
if (TypeChecker.isEqual(key, entry.getKey())) {
if (isEqual(key, entry.getKey())) {
return entry.getValue();
}
}
Expand Down Expand Up @@ -647,7 +690,7 @@
List<Map.Entry<K, V>> entryList = entries.get(hash);
if (entryList != null && entryList.size() > 1) {
for (Map.Entry<K, V> entry: entryList) {
if (TypeChecker.isEqual(key, entry.getKey())) {
if (isEqual(key, entry.getKey())) {
List<V> valueList = values.get(hash);
valueList.remove(entry.getValue());
entryList.remove(entry);
Expand Down Expand Up @@ -679,7 +722,7 @@
if (entries.containsKey(TableUtils.hash(key, null))) {
List<Map.Entry<K, V>> entryList = entries.get(TableUtils.hash(key, null));
for (Map.Entry<K, V> entry: entryList) {
if (TypeChecker.isEqual(entry.getKey(), key)) {
if (isEqual(entry.getKey(), key)) {
return true;
}
}
Expand Down
Loading
Loading