diff --git a/compiler/ballerina-lang/src/main/java/io/ballerina/compiler/api/impl/symbols/BallerinaUnionTypeSymbol.java b/compiler/ballerina-lang/src/main/java/io/ballerina/compiler/api/impl/symbols/BallerinaUnionTypeSymbol.java index e6be4d916456..388b1d0948d5 100644 --- a/compiler/ballerina-lang/src/main/java/io/ballerina/compiler/api/impl/symbols/BallerinaUnionTypeSymbol.java +++ b/compiler/ballerina-lang/src/main/java/io/ballerina/compiler/api/impl/symbols/BallerinaUnionTypeSymbol.java @@ -19,10 +19,15 @@ import io.ballerina.compiler.api.ModuleID; import io.ballerina.compiler.api.SymbolTransformer; import io.ballerina.compiler.api.SymbolVisitor; +import io.ballerina.compiler.api.impl.SymbolFactory; +import io.ballerina.compiler.api.symbols.EnumSymbol; import io.ballerina.compiler.api.symbols.TypeDescKind; import io.ballerina.compiler.api.symbols.TypeSymbol; import io.ballerina.compiler.api.symbols.UnionTypeSymbol; +import org.ballerinalang.model.symbols.SymbolKind; import org.ballerinalang.model.types.TypeKind; +import org.wso2.ballerinalang.compiler.semantics.model.symbols.BEnumSymbol; +import org.wso2.ballerinalang.compiler.semantics.model.symbols.BTypeSymbol; import org.wso2.ballerinalang.compiler.semantics.model.symbols.Symbols; import org.wso2.ballerinalang.compiler.semantics.model.types.BFiniteType; import org.wso2.ballerinalang.compiler.semantics.model.types.BType; @@ -36,6 +41,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.StringJoiner; import java.util.regex.Pattern; @@ -58,6 +64,7 @@ public class BallerinaUnionTypeSymbol extends AbstractTypeSymbol implements Unio private List memberTypes; private List originalMemberTypes; private String signature; + private EnumSymbol enumSymbol; public BallerinaUnionTypeSymbol(CompilerContext context, BUnionType unionType) { super(context, TypeDescKind.UNION, unionType); @@ -147,6 +154,32 @@ public String signature() { return this.signature; } + @Override + public boolean isEnum() { + if (this.enumSymbol != null) { + return true; + } + + return this.getBType().tsymbol.getKind() == SymbolKind.ENUM; + } + + @Override + public Optional getEnumSymbol() { + if (this.enumSymbol != null) { + return Optional.of(this.enumSymbol); + } + + BTypeSymbol tsymbol = this.getBType().tsymbol; + if (tsymbol.getKind() != SymbolKind.ENUM) { + return Optional.empty(); + } + + SymbolFactory symbolFactory = SymbolFactory.getInstance(this.context); + this.enumSymbol = symbolFactory.createEnumSymbol((BEnumSymbol) tsymbol, tsymbol.getName().value);; + + return Optional.of(this.enumSymbol); + } + @Override public void accept(SymbolVisitor visitor) { visitor.visit(this); diff --git a/compiler/ballerina-lang/src/main/java/io/ballerina/compiler/api/symbols/UnionTypeSymbol.java b/compiler/ballerina-lang/src/main/java/io/ballerina/compiler/api/symbols/UnionTypeSymbol.java index 9e8985549960..ce1b2a033a38 100644 --- a/compiler/ballerina-lang/src/main/java/io/ballerina/compiler/api/symbols/UnionTypeSymbol.java +++ b/compiler/ballerina-lang/src/main/java/io/ballerina/compiler/api/symbols/UnionTypeSymbol.java @@ -17,6 +17,7 @@ package io.ballerina.compiler.api.symbols; import java.util.List; +import java.util.Optional; /** * Represents an union type descriptor. @@ -39,4 +40,18 @@ public interface UnionTypeSymbol extends TypeSymbol { * @return {@link List} of expanded member types */ List memberTypeDescriptors(); + + /** + * Check whether the union type is an enum. + * + * @return {@code true} if the union type is an enum, {@code false} otherwise + */ + boolean isEnum(); + + /** + * Get the enum symbol if the union type is an enum. + * + * @return Optional of {@link EnumSymbol} enum symbol + */ + Optional getEnumSymbol(); } diff --git a/tests/ballerina-compiler-api-test/src/test/java/io/ballerina/semantic/api/test/symbols/UnionTypeSymbolTest.java b/tests/ballerina-compiler-api-test/src/test/java/io/ballerina/semantic/api/test/symbols/UnionTypeSymbolTest.java index 10ea35627803..3c87970bbb01 100644 --- a/tests/ballerina-compiler-api-test/src/test/java/io/ballerina/semantic/api/test/symbols/UnionTypeSymbolTest.java +++ b/tests/ballerina-compiler-api-test/src/test/java/io/ballerina/semantic/api/test/symbols/UnionTypeSymbolTest.java @@ -26,6 +26,7 @@ import io.ballerina.compiler.api.symbols.UnionTypeSymbol; import io.ballerina.projects.Document; import io.ballerina.projects.Project; +import io.ballerina.semantic.api.test.util.SemanticAPITestUtils; import org.ballerinalang.test.BCompileUtil; import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; @@ -116,6 +117,34 @@ public void testSingletonMembers() { assertList(signatures, List.of("\"int\"", "\"string\"", "100", "\"200\"", "true")); } + @Test(dataProvider = "DataForEnumInUnionTypeSymbol") + public void testEnumSymbolInUnionTypeSymbols(int line, int col, String name, boolean expIsEnum, + String typeName, List expEnumMembers) { + TypeDefinitionSymbol symbol = (TypeDefinitionSymbol) assertBasicsAndGetSymbol(model, srcFile, line, col, name, + SymbolKind.TYPE_DEFINITION); + assertEquals(symbol.typeDescriptor().typeKind(), TypeDescKind.UNION); + UnionTypeSymbol typeSymbol = (UnionTypeSymbol) symbol.typeDescriptor(); + assertEquals(typeSymbol.isEnum(), expIsEnum); + assertEquals(typeSymbol.getEnumSymbol().isPresent(), expIsEnum); + typeSymbol.getEnumSymbol().ifPresent(enumSymbol -> { + assertTrue(enumSymbol.getName().isPresent()); + assertEquals(typeSymbol.isEnum(), expIsEnum); // Check `isEnum()` when enumSymbol is not null. + assertEquals(enumSymbol.kind(), SymbolKind.ENUM); + assertEquals(enumSymbol.getName().get(), typeName); + assertEquals(enumSymbol.typeDescriptor().typeKind(), TypeDescKind.UNION); + SemanticAPITestUtils.assertList(enumSymbol.members(), expEnumMembers); + }); + } + + @DataProvider(name = "DataForEnumInUnionTypeSymbol") + public Object[][] getEnumDataInUnionType() { + return new Object[][]{ + {47, 5, "FooState", true, "State", List.of("OPEN", "CLOSED")}, + {52, 5, "BarState", false, null, List.of()}, + {59, 5, "BazConnectionState", true, "ConnectionState", List.of("OK", "ERROR")}, + }; + } + public static void assertList(List actualValues, List expectedValues) { assertEquals(actualValues.size(), expectedValues.size()); diff --git a/tests/ballerina-compiler-api-test/src/test/resources/test-src/symbols/union_type_symbol_test.bal b/tests/ballerina-compiler-api-test/src/test/resources/test-src/symbols/union_type_symbol_test.bal index d50ce0699351..b86db234628c 100644 --- a/tests/ballerina-compiler-api-test/src/test/resources/test-src/symbols/union_type_symbol_test.bal +++ b/tests/ballerina-compiler-api-test/src/test/resources/test-src/symbols/union_type_symbol_test.bal @@ -39,3 +39,22 @@ type T9 int[]|T8; type Keyword KEY | boolean | "string" | 100 | "200" | true; const KEY = "int"; + +enum State { + OPEN, + CLOSED +} + +type FooState State; + +const ENABLED = "enabled"; +const DISABLED = "disabled"; + +type BarState ENABLED | DISABLED; + +enum ConnectionState { + OK = "1", + ERROR = "0" +} + +type BazConnectionState ConnectionState;