Skip to content

Commit

Permalink
Improve fast-run LS command to support user env variables
Browse files Browse the repository at this point in the history
  • Loading branch information
NipunaRanasinghe committed Dec 13, 2024
1 parent cf57786 commit 4cad384
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package org.ballerinalang.langserver.command.executors;

import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import org.ballerinalang.annotation.JavaSPIService;
import org.ballerinalang.langserver.commons.ExecuteCommandContext;
import org.ballerinalang.langserver.commons.client.ExtendedLanguageClient;
import org.ballerinalang.langserver.commons.command.CommandArgument;
import org.ballerinalang.langserver.commons.command.LSCommandExecutorException;
import org.ballerinalang.langserver.commons.command.spi.LSCommandExecutor;
import org.ballerinalang.langserver.commons.workspace.RunContext;
Expand All @@ -28,11 +31,17 @@
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.InvalidPathException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

/**
* Command executor for running a Ballerina file. Each project at most has a single instance running at a time.
Expand All @@ -43,52 +52,108 @@
@JavaSPIService("org.ballerinalang.langserver.commons.command.spi.LSCommandExecutor")
public class RunExecutor implements LSCommandExecutor {

private static final String RUN_COMMAND = "RUN";

// commands arg names
private static final String ARG_PATH = "path";
private static final String ARG_PROGRAM_ARGS = "programArgs";
private static final String ARG_ENV = "env";
private static final String ARG_DEBUG_PORT = "debugPort";

// output channels
private static final String ERROR_CHANNEL = "err";
private static final String OUT_CHANNEL = "out";

@Override
public Boolean execute(ExecuteCommandContext context) throws LSCommandExecutorException {
try {
RunContext.Builder builder = new RunContext.Builder(extractPath(context));
builder.withProgramArgs(extractProgramArgs(context));
int debugPort = extractDebugArgs(context);
if (debugPort >= 0) {
builder.withDebugPort(debugPort);
}
// TODO: handle env vars

RunContext runContext = builder.build();
Optional<Process> processOpt = context.workspace().run(runContext);
RunContext workspaceRunContext = getWorkspaceRunContext(context);
Optional<Process> processOpt = context.workspace().run(workspaceRunContext);
if (processOpt.isEmpty()) {
return false;
}
Process process = processOpt.get();
listenOutputAsync(context.getLanguageClient(), process::getInputStream, "out");
listenOutputAsync(context.getLanguageClient(), process::getErrorStream, "err");
listenOutputAsync(context.getLanguageClient(), process::getInputStream, OUT_CHANNEL);
listenOutputAsync(context.getLanguageClient(), process::getErrorStream, ERROR_CHANNEL);
return true;
} catch (IOException e) {
LogTraceParams error = new LogTraceParams("Error while running the program in fast-run mode: " +
e.getMessage(), ERROR_CHANNEL);
context.getLanguageClient().logTrace(error);
throw new LSCommandExecutorException(e);
} catch (Exception e) {
LogTraceParams error = new LogTraceParams("Unexpected error while executing the fast-run: " +
e.getMessage(), ERROR_CHANNEL);
context.getLanguageClient().logTrace(error);
throw new LSCommandExecutorException(e);
}
}

private static Path extractPath(ExecuteCommandContext context) {
return Path.of(context.getArguments().getFirst().<JsonPrimitive>value().getAsString());
private RunContext getWorkspaceRunContext(ExecuteCommandContext context) {
RunContext.Builder builder = new RunContext.Builder(extractPath(context));
builder.withProgramArgs(extractProgramArgs(context));
builder.withEnv(extractEnvVariables(context));
builder.withDebugPort(extractDebugArgs(context));

return builder.build();
}

private Path extractPath(ExecuteCommandContext context) {
return getCommandArgWithName(context, ARG_PATH)
.map(CommandArgument::<JsonPrimitive>value)
.map(JsonPrimitive::getAsString)
.map(pathStr -> {
try {
Path path = Path.of(pathStr);
if (!Files.exists(path)) {
throw new IllegalArgumentException("Specified path does not exist: " + pathStr);
}
return path;
} catch (InvalidPathException e) {
throw new IllegalArgumentException("Invalid path: " + pathStr, e);
}
})
.orElseThrow(() -> new IllegalArgumentException("Path argument is required"));
}

private int extractDebugArgs(ExecuteCommandContext context) {
return context.getArguments().stream()
.filter(commandArg -> commandArg.key().equals("debugPort"))
.map(commandArg -> commandArg.<JsonPrimitive>value().getAsInt())
.findAny()
return getCommandArgWithName(context, ARG_DEBUG_PORT)
.map(CommandArgument::<JsonPrimitive>value)
.map(JsonPrimitive::getAsInt)
.orElse(-1);
}

private static List<String> extractProgramArgs(ExecuteCommandContext context) {
List<String> args = new ArrayList<>();
if (context.getArguments().size() <= 2) {
return args;
}
context.getArguments().get(2).<JsonArray>value().getAsJsonArray().iterator()
.forEachRemaining(arg -> args.add(arg.getAsString()));
private List<String> extractProgramArgs(ExecuteCommandContext context) {
return getCommandArgWithName(context, ARG_PROGRAM_ARGS)
.map(arg -> arg.<JsonArray>value().getAsJsonArray())
.map(jsonArray -> StreamSupport.stream(jsonArray.spliterator(), false)
.filter(JsonElement::isJsonPrimitive)
.map(JsonElement::getAsJsonPrimitive)
.filter(JsonPrimitive::isString)
.map(JsonPrimitive::getAsString)
.collect(Collectors.toList()))
.orElse(Collections.emptyList());
}

return args;
private Map<String, String> extractEnvVariables(ExecuteCommandContext context) {
return getCommandArgWithName(context, ARG_ENV)
.map(CommandArgument::<JsonObject>value)
.map(jsonObject -> {
Map<String, String> envMap = new HashMap<>();
for (Map.Entry<String, JsonElement> entry : jsonObject.entrySet()) {
if (entry.getValue().isJsonPrimitive() && entry.getValue().getAsJsonPrimitive().isString()) {
envMap.put(entry.getKey(), entry.getValue().getAsString());
}
}
return Collections.unmodifiableMap(envMap);
})
.orElse(Map.of());
}

private static Optional<CommandArgument> getCommandArgWithName(ExecuteCommandContext context, String name) {
return context.getArguments().stream()
.filter(commandArg -> commandArg.key().equals(name))
.findAny();
}

public void listenOutputAsync(ExtendedLanguageClient client, Supplier<InputStream> getInputStream, String channel) {
Expand All @@ -111,6 +176,6 @@ private static void listenOutput(ExtendedLanguageClient client, Supplier<InputSt

@Override
public String getCommand() {
return "RUN";
return RUN_COMMAND;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,10 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import static io.ballerina.projects.util.ProjectConstants.BALLERINA_TOML;
import static io.ballerina.projects.util.ProjectConstants.USER_DIR;
import static io.ballerina.runtime.api.constants.RuntimeConstants.MODULE_INIT_CLASS_NAME;

/**
Expand All @@ -117,6 +115,13 @@
*/
public class BallerinaWorkspaceManager implements WorkspaceManager {

// workspace run related constants
private static final String JAVA_COMMAND = System.getProperty("java.command");
private static final String USER_DIR = System.getProperty("user.dir");
private static final String HEAP_DUMP_FLAG = "-XX:+HeapDumpOnOutOfMemoryError";
private static final String HEAP_DUMP_PATH_FLAG = "-XX:HeapDumpPath=";
private static final String DEBUG_SOCKET_CONFIG = "-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:";

/**
* Cache mapping of document path to source root.
*/
Expand Down Expand Up @@ -589,67 +594,74 @@ public String uriScheme() {
}

@Override
public Optional<Process> run(RunContext context) throws IOException {
Path projectRoot = projectRoot(context.balSourcePath());
Optional<ProjectContext> projectPairOpt = projectContext(projectRoot);
if (projectPairOpt.isEmpty()) {
String msg = "Run command execution aborted because project is not loaded";
UserErrorException e = new UserErrorException(msg);
clientLogger.logError(LSContextOperation.WS_EXEC_CMD, msg, e, null, (Position) null);
public Optional<Process> run(RunContext executionContext) throws IOException {
Path projectRoot = projectRoot(executionContext.balSourcePath());
Optional<ProjectContext> projectContext = validateProjectContext(projectRoot);
if (projectContext.isEmpty()) {
return Optional.empty();
}
ProjectContext projectContext = projectPairOpt.get();
if (!stopProject(projectContext)) {
String msg = "Run command execution aborted because couldn't stop the previous run";
UserErrorException e = new UserErrorException(msg);
clientLogger.logError(LSContextOperation.WS_EXEC_CMD, msg, e, null, (Position) null);

if (!prepareProjectForExecution(projectContext.get())) {
return Optional.empty();
}

return executeProject(projectContext.get(), executionContext);
}

private Optional<ProjectContext> validateProjectContext(Path projectRoot) {
Optional<ProjectContext> projectContextOpt = projectContext(projectRoot);
if (projectContextOpt.isEmpty()) {
logError("Run command execution aborted because project is not loaded");
return Optional.empty();
}

return projectContextOpt;
}

private boolean prepareProjectForExecution(ProjectContext projectContext) {
// stop previous project run
if (!stopProject(projectContext)) {
logError("Run command execution aborted because couldn't stop the previous run");
return false;
}

Project project = projectContext.project();
Package pkg = project.currentPackage();
Module executableModule = pkg.getDefaultModule();
Optional<PackageCompilation> packageCompilation = waitAndGetPackageCompilation(project.sourceRoot(), true);
if (packageCompilation.isEmpty()) {
return Optional.empty();
logError("Run command execution aborted because package compilation failed");
return false;
}

// check for compilation errors
JBallerinaBackend jBallerinaBackend = execBackend(projectContext, packageCompilation.get());
Collection<Diagnostic> diagnostics = jBallerinaBackend.diagnosticResult().diagnostics(false);
if (diagnostics.stream().anyMatch(BallerinaWorkspaceManager::isError)) {
String msg = "Run command execution aborted due to compilation errors: " + diagnostics;
UserErrorException e = new UserErrorException(msg);
clientLogger.logError(LSContextOperation.WS_EXEC_CMD, msg, e, null, (Position) null);
return Optional.empty();
logError("Run command execution aborted due to compilation errors: " + diagnostics);
return false;
}

return true;
}

private Optional<Process> executeProject(ProjectContext projectContext, RunContext context) throws IOException {
Project project = projectContext.project();
Package pkg = project.currentPackage();
Module executableModule = pkg.getDefaultModule();
JBallerinaBackend jBallerinaBackend = execBackend(projectContext, pkg.getCompilation());
JarResolver jarResolver = jBallerinaBackend.jarResolver();
String initClassName = JarResolver.getQualifiedClassName(
executableModule.packageInstance().packageOrg().toString(),
executableModule.packageInstance().packageName().toString(),
executableModule.packageInstance().packageVersion().toString(),
MODULE_INIT_CLASS_NAME);
List<String> commands = new ArrayList<>();
commands.add(System.getProperty("java.command"));
commands.add("-XX:+HeapDumpOnOutOfMemoryError");
commands.add("-XX:HeapDumpPath=" + System.getProperty(USER_DIR));
if (context.debugPort() > 0) {
commands.add("-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:" + context.debugPort());
}
commands.add("-cp");
commands.add(getAllClassPaths(jarResolver));
commands.add(initClassName);
commands.addAll(context.programArgs());

List<String> commands = prepareExecutionCommands(context, executableModule, jarResolver);
ProcessBuilder pb = new ProcessBuilder(commands);
pb.environment().putAll(context.env());

Lock lock = projectContext.lockAndGet();
try {
Optional<Process> existing = projectContext.process();
if (existing.isPresent()) {
// We just removed this in above `stopProject`. This means there is a parallel command running.
String msg = "Run command execution aborted because another run is in progress";
UserErrorException e = new UserErrorException(msg);
clientLogger.logError(LSContextOperation.WS_EXEC_CMD, msg, e, null, (Position) null);
logError("Run command execution aborted because another run is in progress");
return Optional.empty();
}

Process ps = pb.start();
projectContext.setProcess(ps);
return Optional.of(ps);
Expand All @@ -658,6 +670,29 @@ public Optional<Process> run(RunContext context) throws IOException {
}
}

private List<String> prepareExecutionCommands(RunContext context, Module executableModule, JarResolver jarResolver) {
List<String> commands = new ArrayList<>();
commands.add(JAVA_COMMAND);
commands.add(HEAP_DUMP_FLAG);
commands.add(HEAP_DUMP_PATH_FLAG + USER_DIR);
if (context.debugPort() > 0) {
commands.add(DEBUG_SOCKET_CONFIG + context.debugPort());
}

commands.add("-cp");
commands.add(getAllClassPaths(jarResolver));

String initClassName = JarResolver.getQualifiedClassName(
executableModule.packageInstance().packageOrg().toString(),
executableModule.packageInstance().packageName().toString(),
executableModule.packageInstance().packageVersion().toString(),
MODULE_INIT_CLASS_NAME
);
commands.add(initClassName);
commands.addAll(context.programArgs());
return commands;
}

private static JBallerinaBackend execBackend(ProjectContext projectContext,
PackageCompilation packageCompilation) {
Lock lock = projectContext.lockAndGet();
Expand All @@ -675,6 +710,11 @@ private static JBallerinaBackend execBackend(ProjectContext projectContext,
}
}

private void logError(String message) {
UserErrorException e = new UserErrorException(message);
clientLogger.logError(LSContextOperation.WS_EXEC_CMD, message, e, null, (Position) null);
}

@Override
public boolean stop(Path filePath) {
Optional<ProjectContext> projectPairOpt = projectContext(projectRoot(filePath));
Expand Down Expand Up @@ -1348,7 +1388,7 @@ public void didClose(Path filePath, DidCloseTextDocumentParams params) {
}
}

// ============================================================================================================== //
// ============================================================================================================== //

private Path computeProjectRoot(Path path) {
return computeProjectKindAndProjectRoot(path).getRight();
Expand Down

0 comments on commit 4cad384

Please sign in to comment.