Skip to content

Commit

Permalink
Merge pull request #42294 from ballerina-platform/worker_change
Browse files Browse the repository at this point in the history
Send worker changes to master
  • Loading branch information
lochana-chathura authored Mar 27, 2024
2 parents 9dd67e9 + 6d20660 commit 505f257
Show file tree
Hide file tree
Showing 125 changed files with 9,380 additions and 566 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.ballerina.runtime.internal.values.MapValueImpl;
import io.ballerina.runtime.internal.values.MappingInitialValueEntry;

import static io.ballerina.runtime.api.constants.RuntimeConstants.BALLERINA_LANG_ERROR_PKG_ID;
import static io.ballerina.runtime.api.constants.RuntimeConstants.FLOAT_LANG_LIB;
import static io.ballerina.runtime.api.creators.ErrorCreator.createError;
import static io.ballerina.runtime.internal.errors.ErrorCodes.INCOMPATIBLE_CONVERT_OPERATION;
Expand Down Expand Up @@ -187,4 +188,11 @@ public static BError createInvalidFractionDigitsError() {
ErrorReasons.INVALID_FRACTION_DIGITS_ERROR),
ErrorHelper.getErrorDetails(ErrorCodes.INVALID_FRACTION_DIGITS));
}

public static BError createNoMessageError(String chnlName) {
String[] splitWorkers = chnlName.split(":")[0].split("->");
return createError(BALLERINA_LANG_ERROR_PKG_ID, "NoMessage", ErrorReasons.NO_MESSAGE_ERROR,
null, ErrorHelper.getErrorDetails(ErrorCodes.NO_MESSAGE_ERROR,
StringUtils.fromString(splitWorkers[0]), StringUtils.fromString(splitWorkers[1])));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ public enum ErrorCodes implements DiagnosticCode {
REGEXP_INVALID_HEX_DIGIT("regexp.invalid.hex.digit", "RUNTIME_0120"),
CONFIG_TOML_INVALID_MODULE_STRUCTURE_WITH_VARIABLE("config.toml.invalid.module.structure.with.variable",
"RUNTIME_0121"),
EMPTY_XML_SEQUENCE_HAS_NO_ATTRIBUTES("empty.xml.sequence.no.attributes", "RUNTIME_0122");
EMPTY_XML_SEQUENCE_HAS_NO_ATTRIBUTES("empty.xml.sequence.no.attributes", "RUNTIME_0122"),
NO_MESSAGE_ERROR("no.worker.message.received", "RUNTIME_0123");

private final String errorMsgKey;
private final String errorCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ private ErrorReasons() {}
public static final BString REGEXP_OPERATION_ERROR = getModulePrefixedReason(REGEXP_LANG_LIB,
"RegularExpressionOperationError");

public static final BString NO_MESSAGE_ERROR = StringUtils.fromString("NoMessage");

public static BString getModulePrefixedReason(String moduleName, String identifier) {
return StringUtils.fromString(BALLERINA_ORG_PREFIX.concat(moduleName)
.concat(CLOSING_CURLY_BRACE).concat(identifier));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.ballerina.runtime.api.creators.ErrorCreator;
import io.ballerina.runtime.api.utils.StringUtils;
import io.ballerina.runtime.api.values.BError;
import io.ballerina.runtime.api.values.BMap;
import io.ballerina.runtime.api.values.BString;
import io.ballerina.runtime.internal.TypeChecker;
import io.ballerina.runtime.internal.values.ChannelDetails;
Expand Down Expand Up @@ -90,6 +91,8 @@ public class Strand {
public Stack<TransactionLocalContext> trxContexts;
private State state;
private final ReentrantLock strandLock;
public BMap<BString, Object> workerReceiveMap = null;
public int channelCount = 0;

public Strand(String name, StrandMetadata metadata, Scheduler scheduler, Strand parent,
Map<String, Object> properties) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,22 @@
*/
package io.ballerina.runtime.internal.scheduling;

import io.ballerina.runtime.api.creators.ValueCreator;
import io.ballerina.runtime.api.types.Type;
import io.ballerina.runtime.api.utils.StringUtils;
import io.ballerina.runtime.api.values.BMap;
import io.ballerina.runtime.api.values.BString;
import io.ballerina.runtime.internal.ErrorUtils;
import io.ballerina.runtime.internal.values.ChannelDetails;
import io.ballerina.runtime.internal.values.ErrorValue;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static io.ballerina.runtime.internal.scheduling.State.BLOCK_AND_YIELD;

/**
* This represents a worker data channel holder that is created for each strand to hold channels required.
*
Expand All @@ -28,11 +41,13 @@
public class WDChannels {

private Map<String, WorkerDataChannel> wDChannels;
private final List<ErrorValue> errors = new ArrayList<>();

//TODO try to generalize this to a normal data channel, in that case we won't need these classes.
public WDChannels() {
// A worker receive field for multiple receive action.
public record ReceiveField(String fieldName, String channelName) {
}

//TODO try to generalize this to a normal data channel, in that case we won't need these classes.
public synchronized WorkerDataChannel getWorkerDataChannel(String name) {
if (this.wDChannels == null) {
this.wDChannels = new HashMap<>();
Expand All @@ -44,4 +59,122 @@ public synchronized WorkerDataChannel getWorkerDataChannel(String name) {
}
return channel;
}

public Object receiveDataMultipleChannels(Strand strand, ReceiveField[] receiveFields, Type targetType)
throws Throwable {
if (strand.workerReceiveMap == null) {
strand.workerReceiveMap = ValueCreator.createMapValue(targetType);
}
for (ReceiveField field : receiveFields) {
WorkerDataChannel channel = getWorkerDataChannel(field.channelName());
WorkerDataChannel.State state = channel.getState();
Object result = null;
switch (state) {
case OPEN:
result = channel.tryTakeData(strand, true);
break;
case AUTO_CLOSED:
result = ErrorUtils.createNoMessageError(field.channelName());
break;
case CLOSED:
continue;
}
checkAndPopulateResult(strand, field, result, channel);
}
return clearResultCache(strand, receiveFields);
}

private void checkAndPopulateResult(Strand strand, ReceiveField field, Object result, WorkerDataChannel channel) {
if (result == null) {
strand.setState(BLOCK_AND_YIELD);
return;
}
result = getResultValue(result);
strand.workerReceiveMap.populateInitialValue(StringUtils.fromString(field.fieldName()), result);
channel.close();
++strand.channelCount;
}

private Object clearResultCache(Strand strand, ReceiveField[] receiveFields) {
if (strand.channelCount != receiveFields.length) {
return null;
}
BMap<BString, Object> map = strand.workerReceiveMap;
strand.workerReceiveMap = null;
strand.channelCount = 0;
strand.setState(State.RUNNABLE);
return map;
}

public Object receiveDataAlternateChannels(Strand strand, String[] channels) throws Throwable {
Object result = null;
boolean allChannelsClosed = true;
for (String channelName : channels) {
WorkerDataChannel channel = getWorkerDataChannel(channelName);
WorkerDataChannel.State state = channel.getState();
if (state == WorkerDataChannel.State.OPEN) {
allChannelsClosed = false;
result = handleResultForOpenChannel(strand, channels, channel);
} else if (state == WorkerDataChannel.State.AUTO_CLOSED) {
errors.add((ErrorValue) ErrorUtils.createNoMessageError(channelName));
}
}
return processResulAndError(strand, channels, result, allChannelsClosed);
}

private Object handleResultForOpenChannel(Strand strand, String[] channels, WorkerDataChannel channel)
throws Throwable {
Object result = channel.tryTakeData(strand, true);
if (result == null) {
return null;
}
Object resultValue = getResultValue(result);
if (resultValue instanceof ErrorValue errorValue) {
errors.add(errorValue);
channel.close();
return null;
}
closeChannels(channels);
return result;
}

private static Object getResultValue(Object result) {
if (result instanceof WorkerDataChannel.WorkerResult workerResult) {
return workerResult.value;
}
return result;
}

private Object processResulAndError(Strand strand, String[] channels, Object result, boolean allChannelsClosed) {
if (result == null) {
if (errors.size() == channels.length) {
result = errors.get(errors.size() - 1);
} else if (!allChannelsClosed) {
strand.setState(BLOCK_AND_YIELD);
}
} else {
strand.setState(State.RUNNABLE);
}
return getResultValue(result);
}

private void closeChannels(String[] channels) {
for (String channelName : channels) {
WorkerDataChannel channel = getWorkerDataChannel(channelName);
channel.close();
channel.callCount = 2;
}
}

public synchronized void removeCompletedChannels(Strand strand, String channelName) {
if (this.wDChannels != null) {
WorkerDataChannel channel = this.wDChannels.get(channelName);
// callCount is incremented to 2 when the message passing is completed.
if (channel != null && channel.callCount == 2) {
this.wDChannels.remove(channelName);
strand.channelDetails.remove(new ChannelDetails(channelName, true, false));
}
}
}

}
Loading

0 comments on commit 505f257

Please sign in to comment.