Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#132 Change C++ port public API #146

Merged
merged 6 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ports/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.7)
project(antlr4-c3 VERSION 1.1.0)
project(antlr4-c3 VERSION 2.0.0)

option(ANTLR4C3_DEVELOPER "Enable ${PROJECT_NAME} developer mode" OFF)

Expand Down
130 changes: 61 additions & 69 deletions ports/cpp/source/antlr4-c3/CodeCompletionCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <iostream>
#include <iterator>
#include <ranges>
#include <ratio>
#include <set>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -82,28 +81,28 @@ std::vector<std::string> CodeCompletionCore::atnStateTypeMap // NOLINT

CodeCompletionCore::CodeCompletionCore(antlr4::Parser* parser)
: parser(parser)
, atn(parser->getATN())
, vocabulary(parser->getVocabulary())
, ruleNames(parser->getRuleNames())
, timeoutMS(0)
, atn(&parser->getATN())
, vocabulary(&parser->getVocabulary())
, ruleNames(&parser->getRuleNames())
, timeout(0)
, cancel(nullptr) {
}

CandidatesCollection CodeCompletionCore::collectCandidates( // NOLINT
size_t caretTokenIndex,
antlr4::ParserRuleContext* context,
size_t timeoutMS,
std::atomic<bool>* cancel
CandidatesCollection CodeCompletionCore::collectCandidates(
size_t caretTokenIndex, Parameters parameters
) {
const auto* context = parameters.context;

timeout = parameters.timeout;
cancel = parameters.isCancelled;
timeoutStart = std::chrono::steady_clock::now();

shortcutMap.clear();
candidates.rules.clear();
candidates.tokens.clear();
candidates.cancelled = false;
candidates.isCancelled = false;
statesProcessed = 0;
precedenceStack = {};
timeoutStart = std::chrono::steady_clock::now();
this->cancel = cancel;
this->timeoutMS = timeoutMS;

tokenStartIndex = (context != nullptr) ? context->start->getTokenIndex() : 0;
auto* const tokenStream = parser->getTokenStream();
Expand All @@ -130,31 +129,31 @@ CandidatesCollection CodeCompletionCore::collectCandidates( // NOLINT
RuleWithStartTokenList callStack = {};
const size_t startRule = (context != nullptr) ? context->getRuleIndex() : 0;

processRule(atn.ruleToStartState[startRule], 0, callStack, 0, 0, candidates.cancelled);
processRule(atn->ruleToStartState[startRule], 0, callStack, 0, 0, candidates.isCancelled);

if (showResult) {
if (candidates.cancelled) {
if (debugOptions.showResult) {
if (candidates.isCancelled) {
std::cout << "*** TIMED OUT ***\n";
}

std::cout << "States processed: " << statesProcessed << "\n\n";

std::cout << "Collected rules:\n";
for (const auto& [tokenIndex, rule] : candidates.rules) {
std::cout << ruleNames[tokenIndex];
std::cout << ruleNames->at(tokenIndex);
std::cout << ", path: ";

for (const size_t token : rule.ruleList) {
std::cout << ruleNames[token] << " ";
std::cout << ruleNames->at(token) << " ";
}
}
std::cout << "\n\n";

std::set<std::string> sortedTokens;
for (const auto& [token, tokenList] : candidates.tokens) {
std::string value = vocabulary.getDisplayName(token);
std::string value = vocabulary->getDisplayName(token);
for (const size_t following : tokenList) {
value += " " + vocabulary.getDisplayName(following);
value += " " + vocabulary->getDisplayName(following);
}
sortedTokens.emplace(value);
}
Expand Down Expand Up @@ -267,8 +266,8 @@ bool CodeCompletionCore::translateToRuleIndex(
.startTokenIndex = rwst.startTokenIndex,
.ruleList = path,
};
if (showDebugOutput) {
std::cout << "=====> collected: " << ruleNames[rwst.ruleIndex] << "\n";
if (debugOptions.showDebugOutput) {
std::cout << "=====> collected: " << ruleNames->at(rwst.ruleIndex) << "\n";
}
}

Expand Down Expand Up @@ -414,25 +413,22 @@ bool CodeCompletionCore::collectFollowSets( // NOLINT
collectFollowSets(transition->target, stopState, followSets, stateStack, ruleStack);
isExhaustive = isExhaustive && nextStateFollowSetsIsExhaustive;
} else if (transition->getTransitionType() == antlr4::atn::TransitionType::WILDCARD) {
FollowSetWithPath set;
set.intervals = antlr4::misc::IntervalSet::of(
antlr4::Token::MIN_USER_TOKEN_TYPE, static_cast<ptrdiff_t>(atn.maxTokenType)
);
set.path = ruleStack;
followSets.emplace_back(std::move(set));
followSets.push_back({
.intervals = allUserTokens(),
.path = ruleStack,
.following = {},
});
} else {
antlr4::misc::IntervalSet label = transition->label();
if (!label.isEmpty()) {
if (transition->getTransitionType() == antlr4::atn::TransitionType::NOT_SET) {
label = label.complement(antlr4::misc::IntervalSet::of(
antlr4::Token::MIN_USER_TOKEN_TYPE, static_cast<ptrdiff_t>(atn.maxTokenType)
));
label = label.complement(allUserTokens());
}
FollowSetWithPath set;
set.intervals = label;
set.path = ruleStack;
set.following = getFollowingTokens(transition);
followSets.emplace_back(std::move(set));
followSets.push_back({
.intervals = label,
.path = ruleStack,
.following = getFollowingTokens(transition),
});
}
}
}
Expand Down Expand Up @@ -472,20 +468,17 @@ CodeCompletionCore::RuleEndStatus CodeCompletionCore::processRule( // NOLINT

// Check for timeout
timedOut = false;
if (timeoutMS > 0) {
const std::chrono::duration<size_t, std::milli> timeout(timeoutMS);
if (std::chrono::steady_clock::now() - timeoutStart > timeout) {
timedOut = true;
return {};
}
if (timeout.has_value() && std::chrono::steady_clock::now() - timeoutStart > timeout) {
timedOut = true;
return {};
}

// Start with rule specific handling before going into the ATN walk.

// Check first if we've taken this path with the same input before.
std::unordered_map<size_t, RuleEndStatus>& positionMap = shortcutMap[startState->ruleIndex];
if (positionMap.contains(tokenListIndex)) {
if (showDebugOutput) {
if (debugOptions.showDebugOutput) {
std::cout << "=====> shortcut" << "\n";
}
return positionMap[tokenListIndex];
Expand All @@ -505,7 +498,7 @@ CodeCompletionCore::RuleEndStatus CodeCompletionCore::processRule( // NOLINT
FollowSetsPerState& setsPerState = followSetsByATN[typeid(parser)];

if (!setsPerState.contains(startState->stateNumber)) {
antlr4::atn::RuleStopState* stop = atn.ruleToStopState[startState->ruleIndex];
antlr4::atn::RuleStopState* stop = atn->ruleToStopState[startState->ruleIndex];
setsPerState[startState->stateNumber] = determineFollowSets(startState, stop);
}
const FollowSetsHolder& followSets = setsPerState[startState->stateNumber];
Expand Down Expand Up @@ -542,8 +535,8 @@ CodeCompletionCore::RuleEndStatus CodeCompletionCore::processRule( // NOLINT
if (!translateStackToRuleIndex(fullPath)) {
for (const size_t symbol : set.intervals.toList()) {
if (!ignoredTokens.contains(symbol)) {
if (showDebugOutput) {
std::cout << "=====> collected: " << vocabulary.getDisplayName(symbol) << "\n";
if (debugOptions.showDebugOutput) {
std::cout << "=====> collected: " << vocabulary->getDisplayName(symbol) << "\n";
}
if (!candidates.tokens.contains(symbol)) {
// Following is empty if there is more than one entry in the
Expand Down Expand Up @@ -606,14 +599,14 @@ CodeCompletionCore::RuleEndStatus CodeCompletionCore::processRule( // NOLINT
const size_t currentSymbol = tokens[currentEntry.tokenListIndex]->getType();

const bool atCaret = currentEntry.tokenListIndex >= tokens.size() - 1;
if (showDebugOutput) {
if (debugOptions.showDebugOutput) {
printDescription(
indentation,
currentEntry.state,
generateBaseDescription(currentEntry.state),
currentEntry.tokenListIndex
);
if (showRuleStack) {
if (debugOptions.showRuleStack) {
printRuleState(callStack);
}
}
Expand Down Expand Up @@ -685,7 +678,7 @@ CodeCompletionCore::RuleEndStatus CodeCompletionCore::processRule( // NOLINT
if (atCaret) {
if (!translateStackToRuleIndex(callStack)) {
for (const auto token :
std::views::iota(antlr4::Token::MIN_USER_TOKEN_TYPE, atn.maxTokenType + 1)) {
std::views::iota(antlr4::Token::MIN_USER_TOKEN_TYPE, atn->maxTokenType + 1)) {
if (!ignoredTokens.contains(token)) {
candidates.tokens[token] = {};
}
Expand Down Expand Up @@ -714,18 +707,16 @@ CodeCompletionCore::RuleEndStatus CodeCompletionCore::processRule( // NOLINT
antlr4::misc::IntervalSet set = transition->label();
if (!set.isEmpty()) {
if (transition->getTransitionType() == antlr4::atn::TransitionType::NOT_SET) {
set = set.complement(antlr4::misc::IntervalSet::of(
antlr4::Token::MIN_USER_TOKEN_TYPE, static_cast<ptrdiff_t>(atn.maxTokenType)
));
set = set.complement(allUserTokens());
}
if (atCaret) {
if (!translateStackToRuleIndex(callStack)) {
const std::vector<ptrdiff_t> list = set.toList();
const bool hasTokenSequence = list.size() == 1;
for (const size_t symbol : list) {
if (!ignoredTokens.contains(symbol)) {
if (showDebugOutput) {
std::cout << "=====> collected: " << vocabulary.getDisplayName(symbol)
if (debugOptions.showDebugOutput) {
std::cout << "=====> collected: " << vocabulary->getDisplayName(symbol)
<< "\n";
}

Expand All @@ -745,8 +736,8 @@ CodeCompletionCore::RuleEndStatus CodeCompletionCore::processRule( // NOLINT
}
} else {
if (set.contains(currentSymbol)) {
if (showDebugOutput) {
std::cout << "=====> consumed: " << vocabulary.getDisplayName(currentSymbol)
if (debugOptions.showDebugOutput) {
std::cout << "=====> consumed: " << vocabulary->getDisplayName(currentSymbol)
<< "\n";
}
statePipeline.push_back({
Expand All @@ -772,9 +763,11 @@ CodeCompletionCore::RuleEndStatus CodeCompletionCore::processRule( // NOLINT
return result;
}

// ----------------------------------------------------------------------------
// MARK: - Debug
// ----------------------------------------------------------------------------
antlr4::misc::IntervalSet CodeCompletionCore::allUserTokens() const {
const auto min = antlr4::Token::MIN_USER_TOKEN_TYPE;
const auto max = static_cast<ptrdiff_t>(atn->maxTokenType);
return antlr4::misc::IntervalSet::of(min, max);
}

std::string CodeCompletionCore::generateBaseDescription(antlr4::atn::ATNState* state) {
const std::string stateValue = (state->stateNumber == antlr4::atn::ATNState::INVALID_STATE_NUMBER)
Expand All @@ -785,7 +778,7 @@ std::string CodeCompletionCore::generateBaseDescription(antlr4::atn::ATNState* s
output << "[" << stateValue << " " << atnStateTypeMap[static_cast<size_t>(state->getStateType())]
<< "]";
output << " in ";
output << ruleNames[state->ruleIndex];
output << ruleNames->at(state->ruleIndex);
return output.str();
}

Expand All @@ -798,22 +791,21 @@ void CodeCompletionCore::printDescription(
const auto indent = std::string(indentation * 2, ' ');

std::string transitionDescription;
if (debugOutputWithTransitions) {
if (debugOptions.showTransitions) {
for (const antlr4::atn::ConstTransitionPtr& transition : state->transitions) {
std::string labels;
std::vector<ptrdiff_t> symbols = transition->label().toList();

if (symbols.size() > 2) {
// Only print start and end symbols to avoid large lists in debug
// output.
labels = vocabulary.getDisplayName(static_cast<size_t>(symbols[0])) + " .. " +
vocabulary.getDisplayName(static_cast<size_t>(symbols[symbols.size() - 1]));
// Only print start and end symbols to avoid large lists in debug output.
labels = vocabulary->getDisplayName(static_cast<size_t>(symbols[0])) + " .. " +
vocabulary->getDisplayName(static_cast<size_t>(symbols[symbols.size() - 1]));
} else {
for (const size_t symbol : symbols) {
if (!labels.empty()) {
labels += ", ";
}
labels += vocabulary.getDisplayName(symbol);
labels += vocabulary->getDisplayName(symbol);
}
}
if (labels.empty()) {
Expand All @@ -830,7 +822,7 @@ void CodeCompletionCore::printDescription(
transitionDescription +=
atnStateTypeMap[static_cast<size_t>(transition->target->getStateType())];
transitionDescription += "] in ";
transitionDescription += ruleNames[transition->target->ruleIndex];
transitionDescription += ruleNames->at(transition->target->ruleIndex);
}
}

Expand All @@ -852,7 +844,7 @@ void CodeCompletionCore::printRuleState(RuleWithStartTokenList const& stack) {
}

for (const RuleWithStartToken& rule : stack) {
std::cout << ruleNames[rule.ruleIndex];
std::cout << ruleNames->at(rule.ruleIndex);
}
std::cout << "\n";
}
Expand Down
Loading
Loading